"""
DeepCausalMMM model implementation combining GRU, DAG, and interaction components.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, Dict, Any
import numpy as np
from deepcausalmmm.core.dag_model import NodeToEdge, EdgeToNode, DAGConstraint
from deepcausalmmm.core.seasonality import DetectSeasonality
import logging
logger = logging.getLogger('deepcausalmmm')
[docs]
class DeepCausalMMM(nn.Module):
"""
Deep Causal Marketing Mix Model with DAG structure and channel interactions.
This model combines deep learning with causal inference to understand the impact
of marketing channels on business KPIs while learning causal relationships between
channels through a Directed Acyclic Graph (DAG).
The model features:
- GRU-based temporal modeling for time-varying coefficients
- Learnable coefficient bounds for realistic attribution
- DAG learning for causal channel interactions (triangular mask or opt-in NOTEARS)
- Adstock and saturation transformations
- Multi-region support with shared and region-specific parameters
- Zero hardcoding philosophy - all parameters are learnable or configurable
Parameters
----------
n_media : int, default=10
Number of media channels in the dataset
ctrl_dim : int, default=15
Number of control variables (weather, events, etc.)
hidden : int, default=32
Hidden dimension size for GRU and MLP layers
n_regions : int, default=2
Number of geographic regions or DMAs
dropout : float, default=0.1
Dropout rate for regularization during training
sparsity_weight : float, default=0.01
Weight for sparsity regularization on coefficients
enable_dag : bool, default=True
Whether to enable DAG learning for channel interactions
enable_interactions : bool, default=True
Whether to enable channel interaction modeling
l1_weight : float, default=0.001
L1 regularization weight for coefficient sparsity
l2_weight : float, default=0.001
L2 regularization weight for coefficient smoothness
burn_in_weeks : int, default=4
Number of initial weeks for GRU stabilization
use_coefficient_momentum : bool, default=True
Whether to use momentum for coefficient stabilization
momentum_decay : float, default=0.9
Decay rate for coefficient momentum
use_warm_start : bool, default=True
Whether to use warm start training initialization
warm_start_epochs : int, default=50
Number of epochs for warm start phase
stabilization_method : str, default="exponential"
Method for coefficient stabilization ("linear", "exponential", "sigmoid")
coeff_l2_weight : float, default=0.1
L2 regularization specifically for media coefficients
coeff_gen_l2_weight : float, default=0.05
L2 regularization for coefficient generation layers
gru_layers : int, default=1
Number of GRU layers (configured, not hardcoded)
ctrl_hidden_ratio : float, default=0.5
Control hidden size as ratio of main hidden dimension
dag_mode : str, default="triangular"
DAG acyclicity mode: ``"triangular"`` (upper-triangular mask, default)
or ``"notears"`` (continuous NOTEARS penalty; Zheng et al., 2018)
notears_lambda1 : float, default=0.01
L1 sparsity on the full adjacency in NOTEARS mode
notears_rho_init : float, default=1.0
Initial augmented-Lagrangian penalty ``rho``
notears_alpha_init : float, default=0.0
Initial dual variable ``alpha`` for NOTEARS
notears_rho_max : float, default=1e16
Upper cap on ``rho`` for numerical safety
dag_temperature : float, default=1.0
Sigmoid temperature for DAG edge weights (``< 1`` sharpens toward {0, 1})
notears_group_l1 : float, default=0.0
Column-group L1 over adjacency columns (NOTEARS only)
Attributes
----------
media_coeffs : torch.nn.Parameter
Time-varying coefficients for media channels
ctrl_coeffs : torch.nn.Parameter
Coefficients for control variables
dag_matrix : torch.nn.Parameter
Learnable DAG adjacency matrix for channel interactions
region_baseline : torch.nn.Parameter
Region-specific baseline contributions
seasonal_coeff : torch.nn.Parameter
Learnable coefficient for seasonal component
Examples
--------
>>> import torch
>>> from deepcausalmmm import DeepCausalMMM
>>>
>>> # Initialize model
>>> model = DeepCausalMMM(
... n_media=5,
... ctrl_dim=3,
... n_regions=2,
... hidden=64
... )
>>>
>>> # Prepare data tensors
>>> media_data = torch.randn(2, 104, 5) # [regions, weeks, channels]
>>> control_data = torch.randn(2, 104, 3) # [regions, weeks, controls]
>>> regions = torch.arange(2).unsqueeze(1).repeat(1, 104)
>>>
>>> # Forward pass
>>> predictions, media_coeffs, media_contributions, outputs = model(
... media_data, control_data, regions
... )
>>>
>>> print(f"Predictions shape: {predictions.shape}")
>>> print(f"Media contributions: {outputs['contributions'].shape}")
"""
[docs]
def __init__(
self,
n_media: int = 10,
ctrl_dim: int = 15,
hidden: int = 32, # Smaller default
n_regions: int = 2,
dropout: float = 0.1,
sparsity_weight: float = 0.01, # Smaller default
enable_dag: bool = True,
enable_interactions: bool = True,
l1_weight: float = 0.001, # Smaller default
l2_weight: float = 0.001, # Smaller default
burn_in_weeks: int = 4, # NEW: Number of weeks to stabilize GRU
# NEW: Advanced stabilization parameters
use_coefficient_momentum: bool = True,
momentum_decay: float = 0.9,
use_warm_start: bool = True,
warm_start_epochs: int = 50,
stabilization_method: str = "exponential", # "linear", "exponential", "sigmoid"
# COEFFICIENT REGULARIZATION: Prevent coefficient explosion
coeff_l2_weight: float = 0.1,
coeff_gen_l2_weight: float = 0.05,
# NEW: Config-driven parameters (no hardcoding!)
gru_layers: int = 1,
ctrl_hidden_ratio: float = 0.5, # Control hidden size as ratio of main hidden size
# NOTEARS DAG learning (Zheng et al., 2018). Default keeps prior behaviour.
dag_mode: str = "triangular", # "triangular" | "notears"
notears_lambda1: float = 0.01, # L1 sparsity weight on full W in NOTEARS mode
notears_rho_init: float = 1.0, # Initial penalty parameter rho
notears_alpha_init: float = 0.0, # Initial dual variable alpha
notears_rho_max: float = 1e16, # Cap on rho for numerical safety
dag_temperature: float = 1.0, # <1.0 sharpens sigmoid edges toward {0,1}
notears_group_l1: float = 0.0, # L2,1 group L1 over adj columns (NOTEARS only)
):
super().__init__()
# DAG-edge sigmoid temperature (used in dag_interaction()).
self.dag_temperature = float(dag_temperature)
self.notears_group_l1 = float(notears_group_l1)
# Store dimensions and flags
self.n_media = n_media
self.ctrl_dim = ctrl_dim
self.hidden_size = hidden
self.n_regions = n_regions
self.enable_dag = enable_dag
self.enable_interactions = enable_interactions
self.l1_weight = l1_weight
self.l2_weight = l2_weight
self.burn_in_weeks = burn_in_weeks # NEW: Store burn-in period
# COEFFICIENT REGULARIZATION: Store parameters to prevent explosion
self.coeff_l2_weight = coeff_l2_weight
self.coeff_gen_l2_weight = coeff_gen_l2_weight
# NEW: Advanced stabilization parameters
self.use_coefficient_momentum = use_coefficient_momentum
self.momentum_decay = momentum_decay
self.use_warm_start = use_warm_start
self.warm_start_epochs = warm_start_epochs
self.stabilization_method = stabilization_method
# Coefficient momentum tracking
if self.use_coefficient_momentum:
self.register_buffer('media_coeff_momentum', torch.zeros(n_media))
self.register_buffer('ctrl_coeff_momentum', torch.zeros(ctrl_dim))
self.register_buffer('momentum_step', torch.tensor(0))
# Adstock parameters - STABILIZED initialization
self.alpha = nn.Parameter(torch.ones(n_media) * 0.8) # Start with reasonable adstock
# STABILIZED HILL - Initialize for proper saturation curves
# a (slope) should be >= 2.0 for clear diminishing returns
# Initialize to inverse_softplus(2.5) so after softplus + clamp we get ~2.5
# Initialize hill_a so that softplus(hill_a) >= 2.0 naturally (without clamp floor)
# softplus(2.5) ≈ 2.58, giving room to learn both up and down within [2.0, 5.0]
self.hill_a = nn.Parameter(torch.ones(n_media) * 2.5) # softplus(2.5) ≈ 2.58
# hill_g will be initialized per-channel based on SOV data in initialize_hill_from_data()
self.hill_g = nn.Parameter(torch.rand(n_media) * 0.2 + 0.1) # Default: 0.1-0.3 (will be overwritten)
# CAUSAL DAG components - discover meaningful relationships
if enable_dag and enable_interactions:
if dag_mode not in ("triangular", "notears"):
raise ValueError(
f"dag_mode must be 'triangular' or 'notears', got {dag_mode!r}"
)
self.dag_mode = dag_mode
if dag_mode == "triangular":
# Existing behaviour: acyclicity guaranteed by construction via
# an upper-triangular mask. Requires a fixed channel ordering.
self.adj_logits = nn.Parameter(torch.randn(n_media, n_media) * 0.3 - 0.1)
mask = torch.triu(torch.ones(n_media, n_media), diagonal=1)
self.register_buffer('tri_mask', mask)
else:
# NOTEARS mode: acyclicity enforced via a smooth scalar penalty
# h(W) = tr(exp(W ⊙ W)) − d. The mask only zeros the diagonal
# so the model can learn arbitrary DAG topology from data.
# Match the triangular-mode init scale so edges are active from
# epoch 0 and contribute to predictions immediately. The initial
# h(W) will be large but that is fine: warmup keeps the penalty
# off until Huber has established a fit, then the augmented
# Lagrangian outer loop drives h(W) → 0.
self.adj_logits = nn.Parameter(torch.randn(n_media, n_media) * 0.3 - 0.1)
mask = torch.ones(n_media, n_media) - torch.eye(n_media) # zero diagonal only
self.register_buffer('tri_mask', mask)
# Augmented Lagrangian state. Buffers (not parameters) so they
# are not touched by the optimiser; updated in notears_update_duals().
self.register_buffer('notears_rho', torch.tensor(notears_rho_init))
self.register_buffer('notears_alpha', torch.tensor(notears_alpha_init))
self.notears_lambda1 = notears_lambda1
self.notears_rho_max = notears_rho_max
self._notears_h_prev = float('inf')
# Warmup gate: when False, get_dag_loss() and notears_update_duals()
# behave as if NOTEARS were disabled. The trainer flips this to True
# after `notears_warmup_epochs` to let Huber loss establish a good
# data fit before the acyclicity penalty starts pulling weights around.
self.register_buffer('notears_active', torch.tensor(True))
# Per-channel DAG-strength: each target channel learns how much
# of its effective input comes from its causal parents vs itself.
# Initialised at sigmoid(-1.4) ≈ 0.2 so the DAG starts with a
# modest 20% blend, then the model can grow per-channel mixing
# where the parents help and shrink it where they don't.
self.interaction_weight = nn.Parameter(torch.ones(n_media) * -1.4)
else:
# Sensible default so attribute always exists
self.dag_mode = dag_mode
# FIXED control processing - use config-driven dimensions (NO HARDCODING!)
self.ctrl_hidden = int(hidden * ctrl_hidden_ratio) # Config-driven control hidden size
self.ctrl_mlp = nn.Sequential(
nn.Linear(ctrl_dim, self.ctrl_hidden),
nn.Tanh(), # Bounded activation
nn.Dropout(dropout)
)
# CONFIG-DRIVEN GRU - NO HARDCODING!
gru_input_size = n_media + self.ctrl_hidden
self.gru_layers = gru_layers # Use config value, no hardcoding!
self.gru = nn.GRU(
input_size=gru_input_size,
hidden_size=hidden,
num_layers=self.gru_layers, # Config-driven layers
batch_first=True,
dropout=dropout if self.gru_layers > 1 else 0 # Conditional dropout based on layers
)
# DISABLE residual connections for stability
self.use_residual = False # REVERT - was causing instability
# ENHANCED coefficient generator for ultra-low RMSE
self.coeff_gen = nn.Sequential(
nn.Linear(hidden, hidden), # FULL capacity first layer
nn.ReLU(), # Better gradient flow than Tanh
nn.Linear(hidden, hidden // 2),
nn.ReLU(),
nn.Linear(hidden // 2, n_media)
)
# Initialize coefficient generator carefully
for layer in self.coeff_gen:
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight, gain=0.1)
nn.init.zeros_(layer.bias)
# ENHANCED control coefficient generator for ultra-low RMSE
self.ctrl_coeff_gen = nn.Sequential(
nn.Linear(hidden, hidden), # FULL capacity first layer
nn.ReLU(), # Better gradient flow than Tanh
nn.Linear(hidden, hidden // 2),
nn.ReLU(),
nn.Linear(hidden // 2, ctrl_dim)
)
# Initialize control coefficients
for layer in self.ctrl_coeff_gen:
if isinstance(layer, nn.Linear):
nn.init.xavier_uniform_(layer.weight, gain=0.1)
nn.init.zeros_(layer.bias)
# ATTRIBUTION PRIOR: Target media contribution (business knowledge)
self.media_contribution_prior = 0.40 # 40% is typical for MMM
self.attribution_reg_weight = 0.5 # Balance between prediction and attribution (0.5 = equal priority)
# SEASONAL REGULARIZATION: Prevent seasonal_coeff from being suppressed to zero
self.seasonal_prior = 1.0 # Target seasonal coefficient value
self.seasonal_reg_weight = 0.1 # Weight for seasonal regularization (lower than attribution)
# REGION-SPECIFIC baselines - IMPROVED initialization
self.region_baseline = nn.Parameter(torch.randn(n_regions) * 0.1)
# Global bias and prediction scaling
self.global_bias = nn.Parameter(torch.zeros(1))
self.prediction_scale = nn.Parameter(torch.ones(1))
# NEW: Time trend component to handle growth patterns
# FROZEN: Set to zero and not learnable (disabled per user request)
self.register_buffer('time_trend_weight', torch.zeros(1)) # FROZEN at 0
self.register_buffer('time_trend_bias', torch.zeros(1)) # FROZEN at 0
# NEW: Seasonal component with learnable coefficient
self.seasonal_coeff = nn.Parameter(torch.ones(1)) # Learnable seasonal coefficient
self.seasonal_components = None # Will be initialized with actual data from decomposition
self.seasonality_detector = DetectSeasonality() # For seasonal decomposition
# FULLY LEARNABLE: No hardcoded bounds - let model discover everything
self.coeff_range_raw = nn.Parameter(torch.tensor(0.0)) # exp(0) = 1.0, no upper bound
self.ctrl_coeff_range_raw = nn.Parameter(torch.tensor(0.0)) # exp(0) = 1.0, no upper bound
# FROZEN: Trend damping (disabled - trend is frozen at zero)
self.register_buffer('trend_damping_raw', torch.tensor(0.0)) # FROZEN - not learnable
# HYBRID APPROACH: Fixed regularization weights (handled in trainer)
# Only core model parameters are learnable - not loss balancing weights
# FULLY LEARNABLE: All initialization scaling factors - no hardcoded multipliers
self.stable_coeff_scale_raw = nn.Parameter(torch.tensor(-2.3)) # exp(-2.3) ≈ 0.1, but can learn optimal
self.region_baseline_scale_raw = nn.Parameter(torch.tensor(-2.3)) # exp(-2.3) ≈ 0.1, but can learn optimal
# LEARNABLE COEFFICIENT BOUNDS: Each channel learns its optimal maximum coefficient
self.media_coeff_max_raw = nn.Parameter(torch.ones(n_media) * 1.0) # Start at ~2.7 (softplus(1.0))
self.ctrl_coeff_max_raw = nn.Parameter(torch.ones(ctrl_dim) * 1.5) # Start at ~4.5 (softplus(1.5))
# SIMPLE TRAINABLE CONSTANTS: Global bounds only (prevent overfitting)
# Channel-specific bounds caused severe overfitting (R² dropped from 0.924 to 0.340)
# Using global bounds with natural activation functions for optimal generalization
# GRU hidden state initialization - STABILIZED
self.h0 = nn.Parameter(torch.randn(1, 1, hidden) * 0.01)
# NEW: Stable coefficient reference for burn-in stabilization
# Initialize to small values - will be set properly from data
self.stable_media_coeff = nn.Parameter(torch.zeros(n_media))
self.stable_ctrl_coeff = nn.Parameter(torch.zeros(ctrl_dim))
[docs]
def initialize_baseline(self, y_data: torch.Tensor):
"""Initialize baseline to match target data statistics.
CRITICAL: y_data is ALREADY in scaled space (y/y_mean per region)!
Extract ALL parameters directly from the actual data distribution.
IMPORTANT: Skip padding weeks to avoid baseline bias!
"""
with torch.no_grad():
# y_data shape: [n_regions, n_timesteps] - already in scaled space
# CRITICAL: Remove padding weeks (first 4 weeks) to get true data statistics
y_no_padding = y_data[:, self.burn_in_weeks:] if y_data.shape[1] > self.burn_in_weeks else y_data
y_numpy = y_no_padding.cpu().numpy()
n_regions, n_timesteps = y_numpy.shape
# 1. GLOBAL BASELINE: Keep learnable but start small to let model learn the right scale
# The model will learn the appropriate global offset during training
self.global_bias.data = torch.FloatTensor([0.0])
# 2. REGION BASELINES: Use actual per-region means as absolute baselines (not deviations)
# This ensures each region starts with a positive baseline equal to its historical mean
region_means_scaled = y_numpy.mean(axis=1) # [n_regions] - actual region means
self.region_baseline.data = torch.FloatTensor(region_means_scaled)
# 3. PREDICTION SCALE: Initialize to 1.0 (neutral scaling)
# The model will learn the appropriate scaling during training
# softplus(0) = 1.0, so no initial scaling applied
self.prediction_scale.data = torch.FloatTensor([0.0]) # softplus(0) = 1.0
# 4. TIME TREND: DISABLED - Frozen at zero (user request)
# Trend is a buffer (not a parameter), so it won't be learned during training
# No need to set - already initialized to zero in __init__
logger.info(f" Trend initialization: DISABLED (FROZEN at zero - not learnable)")
# Get original scale means for display
if hasattr(self, 'scaling_constants'):
y_mean_per_region = self.scaling_constants.get('y_mean_per_region')
if y_mean_per_region is not None:
avg_original_visits = (region_means_scaled.mean() * y_mean_per_region.mean().item())
else:
avg_original_visits = None
else:
avg_original_visits = None
logger.info(f"Initialized baselines (Linear scaled - y/mean):")
logger.info(f" Region baselines range: [{region_means_scaled.min():.3f}, {region_means_scaled.max():.3f}]")
logger.info(f" Global bias (Linear scale): {self.global_bias.item():.3f} [LEARNABLE - constrained ≥ 0 via softplus]")
if avg_original_visits is not None:
logger.info(f" Expected prediction baseline: {region_means_scaled.mean():.3f} (normalized) -> ~{avg_original_visits:.0f} visits")
else:
logger.info(f" Expected prediction baseline: {region_means_scaled.mean():.3f} (normalized)")
# Initialize seasonal components using actual data decomposition
self._initialize_seasonal_components(y_data)
def _initialize_seasonal_components(self, y_data: torch.Tensor):
"""
Initialize seasonal components using multiplicative decomposition per region.
Args:
y_data: Target data [n_regions, n_timesteps] in linear scaled space (y/mean)
"""
with torch.no_grad():
# Skip padding weeks for seasonal decomposition
y_no_padding = y_data[:, self.burn_in_weeks:] if y_data.shape[1] > self.burn_in_weeks else y_data
# Convert to original scale by multiplying by region means
if hasattr(self, 'scaling_constants'):
y_mean_per_region = self.scaling_constants.get('y_mean_per_region')
if y_mean_per_region is not None:
# y_original = y_scaled * y_mean
# y_mean_per_region shape: [n_regions, 1], broadcasts to [n_regions, n_timesteps]
y_original_scale = (y_no_padding * y_mean_per_region).cpu().numpy()
else:
# Fallback: assume data is already reasonably scaled
y_original_scale = y_no_padding.cpu().numpy()
else:
# Fallback: assume data is already reasonably scaled
y_original_scale = y_no_padding.cpu().numpy()
# Extract seasonal components per region
seasonal_components = self.seasonality_detector.extract_seasonal_components_per_region(
y_original_scale, start_week=0
)
# Apply Min-Max scaling per region to bring seasonality to [0, 1] range
# This ensures seasonality can only ADD to baseline, never subtract
n_regions, n_weeks = seasonal_components.shape
seasonal_normalized = torch.zeros_like(seasonal_components)
for region_idx in range(n_regions):
region_seasonal = seasonal_components[region_idx, :]
region_min = region_seasonal.min()
region_max = region_seasonal.max()
if region_max > region_min:
# Min-Max scaling: (x - min) / (max - min) -> [0, 1] range
seasonal_normalized[region_idx, :] = (region_seasonal - region_min) / (region_max - region_min)
else:
# If no variation, set to middle of range
seasonal_normalized[region_idx, :] = 0.5
# Store seasonal components (will be used in forward pass)
self.seasonal_components = seasonal_normalized
logger.info(f" Initialized seasonal components:")
logger.info(f" Seasonal coefficient: {self.seasonal_coeff.item():.3f} [LEARNABLE - constrained ≥ 0 via softplus]")
logger.info(f" Components range: [{seasonal_normalized.min():.3f}, {seasonal_normalized.max():.3f}] (Min-Max scaled [0, 1])")
logger.info(f" Components mean: {seasonal_normalized.mean():.3f}")
logger.info(f" Scaling: Min-Max per region - seasonality can only ADD to baseline")
[docs]
def initialize_hill_from_data(self, Xm: torch.Tensor):
"""
Initialize hill_g based on per-channel SOV (Share of Voice) distribution.
For each channel, hill_g is set to the 60th percentile of its SOV values,
ensuring the inflection point matches where the channel typically operates.
Args:
Xm: Media data [n_regions, n_timesteps, n_channels] in SOV-scaled space [0, 1]
"""
with torch.no_grad():
B, T, n_media = Xm.shape
# Skip padding weeks
Xm_no_padding = Xm[:, self.burn_in_weeks:] if T > self.burn_in_weeks else Xm
logger.info(f"\nInitializing Hill g parameters from SOV data:")
for i in range(n_media):
# Get all SOV values for this channel
channel_sov = Xm_no_padding[:, :, i].flatten()
# Filter out zeros (weeks with no spend)
channel_sov_nonzero = channel_sov[channel_sov > 1e-4]
if len(channel_sov_nonzero) > 10:
# Use 60th percentile as inflection point
# This ensures the channel shows good S-curve behavior around typical spend levels
p60 = torch.quantile(channel_sov_nonzero, 0.6).item()
# Clamp to reasonable range [0.05, 0.9]
# - Below 0.05: Too early saturation
# - Above 0.9: Inflection too late (most data will be linear)
g_target = max(0.05, min(0.9, p60))
# Convert to raw value (before softplus)
# softplus(x) = log(1 + exp(x)), so inverse: x = log(exp(g) - 1)
# For numerical stability, use: x = g + log(exp(-g) + 1) - log(2) for g > 1
if g_target > 0.5:
g_raw = np.log(np.exp(g_target) - 1.0)
else:
g_raw = np.log(g_target / (1.0 - g_target + 1e-8))
self.hill_g.data[i] = g_raw
logger.info(f" Channel {i+1:2d}: p60={p60:.4f} → g_target={g_target:.4f} (raw={g_raw:.4f})")
else:
# Not enough data, keep default
logger.info(f" Channel {i+1:2d}: Insufficient data, keeping default g")
logger.info(f"\nHill g initialization complete")
g_after_softplus = torch.nn.functional.softplus(self.hill_g)
logger.info(f" Range after softplus: [{g_after_softplus.min():.4f}, {g_after_softplus.max():.4f}]")
[docs]
def initialize_stable_coefficients_from_data(self, Xm: torch.Tensor, Xc: torch.Tensor, y: torch.Tensor):
"""
Initialize stable coefficients based on simple linear regression on the data.
This provides domain-informed starting points for coefficient stabilization.
"""
with torch.no_grad():
B, T, n_media = Xm.shape
_, _, n_control = Xc.shape
# CRITICAL: Remove padding weeks to avoid biasing the regression
Xm_no_padding = Xm[:, self.burn_in_weeks:] if T > self.burn_in_weeks else Xm
Xc_no_padding = Xc[:, self.burn_in_weeks:] if T > self.burn_in_weeks else Xc
y_no_padding = y[:, self.burn_in_weeks:] if T > self.burn_in_weeks else y
# MEDIA COEFFICIENTS - Flatten data for regression (no padding)
X_media_flat = Xm_no_padding.reshape(-1, n_media) # [B*(T-burn_in), n_media]
y_flat = y_no_padding.reshape(-1) # [B*(T-burn_in)]
# Simple ridge regression to get initial media coefficients
XtX = torch.mm(X_media_flat.t(), X_media_flat)
lambda_reg = 0.01 * torch.eye(n_media, device=Xm.device)
XtX_reg = XtX + lambda_reg
Xty = torch.mv(X_media_flat.t(), y_flat)
try:
beta_media = torch.linalg.solve(XtX_reg, Xty)
# POSITIVE-ONLY INITIALIZATION: Use abs and sigmoid to ensure non-negative media coefficients
beta_media_raw = torch.abs(beta_media) # Ensure positive
beta_media_positive = torch.sigmoid(beta_media_raw) # Range: [0, 1] - POSITIVE ONLY!
# Update stable media coefficients
data_scale = y_flat.std() / X_media_flat.std()
self.stable_media_coeff.data = beta_media_positive * data_scale
logger.info(f"Initialized stable coefficients from data:")
logger.info(f" Media coeff range (POSITIVE-ONLY): [{beta_media_positive.min().item():.4f}, {beta_media_positive.max().item():.4f}]")
except torch.linalg.LinAlgError:
logger.warning("Warning: Could not solve for media coefficients, using correlation fallback")
correlations = torch.zeros(n_media, device=Xm.device)
for i in range(n_media):
if X_media_flat[:, i].std() > 1e-8:
correlations[i] = torch.corrcoef(torch.stack([X_media_flat[:, i], y_flat]))[0, 1]
correlations = torch.nan_to_num(correlations, 0.0)
stable_coeff_scale = torch.exp(self.stable_coeff_scale_raw) # FULLY LEARNABLE scaling
self.stable_media_coeff.data = correlations * stable_coeff_scale
# CONTROL COEFFICIENTS - Allow negative effects (no padding)
X_control_flat = Xc_no_padding.reshape(-1, n_control) # [B*(T-burn_in), n_control]
# Ridge regression for control coefficients
XtX_ctrl = torch.mm(X_control_flat.t(), X_control_flat)
lambda_reg_ctrl = 0.01 * torch.eye(n_control, device=Xc.device)
XtX_reg_ctrl = XtX_ctrl + lambda_reg_ctrl
Xty_ctrl = torch.mv(X_control_flat.t(), y_flat)
try:
beta_control = torch.linalg.solve(XtX_reg_ctrl, Xty_ctrl)
# Allow both positive and negative control effects - NO clipping
# Control variables should be able to have strong negative effects
# Scale control coefficients appropriately
ctrl_data_scale = y_flat.std() / X_control_flat.std()
stable_coeff_scale = torch.exp(self.stable_coeff_scale_raw) # FULLY LEARNABLE scaling
self.stable_ctrl_coeff.data = beta_control * ctrl_data_scale * stable_coeff_scale
logger.info(f" Control coeff range: [{beta_control.min().item():.4f}, {beta_control.max().item():.4f}]")
except torch.linalg.LinAlgError:
logger.warning("Warning: Could not solve for control coefficients, using correlation fallback")
correlations_ctrl = torch.zeros(n_control, device=Xc.device)
for i in range(n_control):
if X_control_flat[:, i].std() > 1e-8:
correlations_ctrl[i] = torch.corrcoef(torch.stack([X_control_flat[:, i], y_flat]))[0, 1]
correlations_ctrl = torch.nan_to_num(correlations_ctrl, 0.0)
# Allow negative correlations for controls
stable_coeff_scale = torch.exp(self.stable_coeff_scale_raw) # FULLY LEARNABLE scaling
self.stable_ctrl_coeff.data = correlations_ctrl * stable_coeff_scale
[docs]
def warm_start_training(self, Xm: torch.Tensor, Xc: torch.Tensor, R: torch.Tensor, y: torch.Tensor,
optimizer: torch.optim.Optimizer, epochs: int = None):
"""
Warm-start training phase to stabilize GRU coefficients before main training.
Uses only stable coefficients and focuses on learning good hidden state initialization.
"""
if not self.use_warm_start:
return
epochs = epochs or self.warm_start_epochs
logger.info(f"Starting warm-start training for {epochs} epochs...")
# Initialize stable coefficients from data
self.initialize_stable_coefficients_from_data(Xm, Xc, y)
# Save original parameters
original_coeff_range_raw = self.coeff_range_raw.data.clone()
original_ctrl_coeff_range_raw = self.ctrl_coeff_range_raw.data.clone()
original_burn_in = self.burn_in_weeks
# Temporarily use only stable coefficients (no dynamic variation)
self.coeff_range_raw.data = torch.tensor(-10.0) # sigmoid(-10) ≈ 0, so range ≈ 1.0 (minimal)
self.ctrl_coeff_range_raw.data = torch.tensor(-10.0) # sigmoid(-10) ≈ 0, so range ≈ 1.0 (minimal)
self.burn_in_weeks = 999 # Force all weeks to use stable coefficients
# Freeze coefficient generators during warm-start
for param in self.coeff_gen.parameters():
param.requires_grad = False
for param in self.ctrl_coeff_gen.parameters():
param.requires_grad = False
# Train only GRU, baselines, and stable coefficients
warm_start_params = [
self.h0, self.stable_media_coeff, self.stable_ctrl_coeff,
self.region_baseline, self.global_bias, self.prediction_scale
] + list(self.gru.parameters())
# Use config learning rate but scaled UP for faster warm-start convergence
config_lr = getattr(self, 'config_lr', 0.005) # Default fallback
warm_start_lr = config_lr * 2.0 # 2x main LR for faster coefficient stabilization
warm_optimizer = torch.optim.Adam(warm_start_params, lr=warm_start_lr) # Scaled from config LR
self.train()
for epoch in range(epochs):
warm_optimizer.zero_grad()
y_pred, _, _, outputs = self.forward(Xm, Xc, R)
loss = F.mse_loss(y_pred, y)
# Add stronger regularization to prevent extreme values
reg_loss = 0.01 * (self.h0.pow(2).mean() + self.stable_media_coeff.pow(2).mean())
total_loss = loss + reg_loss
total_loss.backward()
# Aggressive gradient clipping for regional scaling stability
torch.nn.utils.clip_grad_norm_(warm_start_params, max_norm=0.1) # Much more aggressive
warm_optimizer.step()
if epoch % 10 == 0:
logger.info(f" Warm-start epoch {epoch}/{epochs}, Loss: {loss.item():.6f}")
# Restore original parameters
self.coeff_range_raw.data = original_coeff_range_raw
self.ctrl_coeff_range_raw.data = original_ctrl_coeff_range_raw
self.burn_in_weeks = original_burn_in
# Unfreeze coefficient generators
for param in self.coeff_gen.parameters():
param.requires_grad = True
for param in self.ctrl_coeff_gen.parameters():
param.requires_grad = True
logger.info(f" Warm-start training completed. GRU initialized for stable coefficients.")
[docs]
def adstock(self, x: torch.Tensor) -> torch.Tensor:
"""STABILIZED adstock transformation."""
B, T, C = x.shape
alpha = torch.sigmoid(self.alpha).view(1, 1, -1)
alpha = torch.clamp(alpha, 0, 0.8) # Cap at 0.8 for stability
out_list = [x[:, 0:1]]
for t in range(1, T):
prev_adstock = out_list[-1]
current = x[:, t:t+1] + alpha * prev_adstock
# Clip to prevent explosion
current = torch.clamp(current, 0, 10)
out_list.append(current)
return torch.cat(out_list, dim=1)
[docs]
def hill(self, x: torch.Tensor) -> torch.Tensor:
"""STABILIZED Hill saturation transformation."""
a = F.softplus(self.hill_a).view(1, 1, -1)
g = F.softplus(self.hill_g).view(1, 1, -1)
# Ensure positive input
x_safe = F.relu(x) + 1e-8
# Stabilized Hill with clipping
# CRITICAL: a (slope) must be >= 2.0 for proper saturation curves
a = torch.clamp(a, 2.0, 5.0) # Changed from [0.1, 2.0] to [2.0, 5.0]
g = torch.clamp(g, 0.01, 1.0)
num = torch.pow(x_safe, a)
denom = num + torch.pow(g, a)
result = num / (denom + 1e-8)
return torch.clamp(result, 0, 1) # Ensure bounded output
[docs]
def dag_interaction(self, x: torch.Tensor) -> torch.Tensor:
"""Load-bearing DAG-driven channel interactions.
Each target channel j's effective input becomes a learned blend of
itself and a DAG-driven aggregation of its causal parents:
parents_j = sum_i adj[i, j] * x_i (column-j of adj * x)
x_j_new = (1 - mix_j) * x_j + mix_j * parents_j
Two design choices make NOTEARS actually structure-learning here
rather than decorative:
* `mix` is a *per-channel* learnable scalar in (0, 1). The model can
turn the DAG up where it helps prediction (strong causal parents)
and down where it doesn't, so each adj column receives genuine
per-channel gradient signal rather than a globally-averaged one.
* adj uses a temperature-controlled sigmoid (``dag_temperature``),
so the model is encouraged toward {near-0, near-1} edges instead
of a soft cluster around 0.5 / sigmoid floor — this is the
standard trick for making sigmoid gates bi-modal.
With the previous additive form ``x + scalar * matmul(x, adj)`` the
loss was satisfied even with a uniform adjacency at the L1 floor,
which is why the learned graph collapsed to ~equal edges. The
blended form forces the model to either pick informative parents or
keep ``mix_j`` close to 0 (effectively ignoring the DAG for that
channel).
"""
if not (self.enable_dag and self.enable_interactions):
return x
T_dag = float(getattr(self, 'dag_temperature', 1.0))
# Lower temperature -> sharper {0,1} edges; T_dag=1.0 reproduces
# the standard sigmoid behaviour.
adj_probs = torch.sigmoid(self.adj_logits / max(T_dag, 1e-3))
adj = adj_probs * self.tri_mask
# [B,T,C] @ [C,C] -> [B,T,C]; column j of adj weights parents of j.
parents = torch.matmul(x, adj)
# Per-target mix: shape [n_channels], broadcast over batch and time.
mix = torch.sigmoid(self.interaction_weight)
return (1.0 - mix) * x + mix * parents
[docs]
def apply_burn_in_stabilization(self, coeffs: torch.Tensor, stable_coeff: torch.Tensor) -> torch.Tensor:
"""
Advanced burn-in stabilization with multiple transition methods.
Args:
coeffs: Time-varying coefficients [B, T, dim]
stable_coeff: Stable reference coefficients [dim]
Returns:
Stabilized coefficients with smooth burn-in transition
"""
B, T, dim = coeffs.shape
if T <= self.burn_in_weeks:
# If sequence is shorter than burn-in, use stable coefficients
return stable_coeff.unsqueeze(0).unsqueeze(0).expand(B, T, -1)
# Create stabilization weights based on method
stabilized_coeffs = coeffs.clone()
for week in range(self.burn_in_weeks):
# Calculate transition weight based on method
if self.stabilization_method == "linear":
# Linear transition from 1.0 (fully stable) to 0.0 (fully dynamic)
stable_weight = 1.0 - (week / self.burn_in_weeks)
elif self.stabilization_method == "exponential":
# Exponential decay - slower initial transition, faster later
stable_weight = float(torch.exp(torch.tensor(-3.0 * week / self.burn_in_weeks)))
elif self.stabilization_method == "sigmoid":
# Sigmoid transition - smooth S-curve
x = (week - self.burn_in_weeks/2) / (self.burn_in_weeks/4)
stable_weight = float(1.0 / (1.0 + torch.exp(torch.tensor(x))))
else:
# Default to linear
stable_weight = 1.0 - (week / self.burn_in_weeks)
dynamic_weight = 1.0 - stable_weight
# Blend stable and dynamic coefficients
stabilized_coeffs[:, week, :] = (
stable_weight * stable_coeff.unsqueeze(0).expand(B, -1) +
dynamic_weight * coeffs[:, week, :]
)
# DISABLED: Coefficient momentum system removed to prevent gradient blocking
# The momentum system was using .detach() which blocked gradients
return stabilized_coeffs
[docs]
def forward(
self,
Xm: torch.Tensor,
Xc: torch.Tensor,
R: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""
Forward pass through the DeepCausalMMM model.
Processes media and control variables through the neural network to generate
predictions, time-varying coefficients, per-channel media contributions, and
a detailed ``outputs`` dict (baseline, seasonality, control contributions, DAG, etc.).
Parameters
----------
Xm : torch.Tensor
Media data tensor of shape [batch_size, time_steps, n_media]
Should be SOV-scaled (Share of Voice) normalized to [0, 1] range
Xc : torch.Tensor
Control variables tensor of shape [batch_size, time_steps, ctrl_dim]
Should be standardized (z-score normalized)
R : torch.Tensor
Region indicators tensor of shape [batch_size, time_steps]
Integer values representing region/DMA IDs
Returns
-------
predictions : torch.Tensor
Model predictions (scaled KPI), shape typically ``[batch_size, time_steps]``
(broadcast with components before ``prediction_scale``).
media_coefficients : torch.Tensor
Time-varying media coefficients, shape ``[batch_size, time_steps, n_media]``.
media_contributions : torch.Tensor
Per-channel media contributions ``X_processed * media_coefficients``,
shape ``[batch_size, time_steps, n_media]``.
outputs : Dict[str, Any]
Detailed tensors, including:
- ``contributions``: same as ``media_contributions`` (media channel breakdown)
- ``coefficients``: same as returned ``media_coefficients``
- ``control_contributions``: control variable contributions ``[batch, time, ctrl_dim]``
- ``control_coefficients``: control coefficients ``[batch, time, ctrl_dim]``
- ``baseline``: baseline without seasonality (for waterfall-style splits)
- ``seasonal_contribution``: seasonal term ``[batch, time]``
- ``dag_matrix`` (when DAG enabled): adjacency ``[n_media, n_media]``
- ``adstocked_media``, ``media_hill``, ``media_dag`` (when applicable): media pipeline stages
Examples
--------
>>> import torch
>>> model = DeepCausalMMM(n_media=3, ctrl_dim=2, n_regions=2)
>>>
>>> # Prepare input tensors
>>> media = torch.rand(2, 52, 3) # 2 regions, 52 weeks, 3 channels
>>> control = torch.randn(2, 52, 2) # 2 control variables
>>> regions = torch.tensor([[0]*52, [1]*52]) # Region indicators
>>>
>>> # Forward pass
>>> pred, media_coeffs, media_contrib, outputs = model(media, control, regions)
>>>
>>> # Access detailed outputs
>>> media_contrib_out = outputs['contributions']
>>> ctrl_contrib = outputs['control_contributions']
>>> dag_matrix = outputs.get('dag_matrix')
Notes
-----
The forward pass applies the following transformations in order:
1. Media processing: Adstock -> Hill saturation -> DAG interactions
2. Feature processing: Media features + Control features -> GRU
3. Coefficient generation: Time-varying coefficients from GRU states
4. Contribution calculation: Features * Coefficients
5. Final prediction: Baseline + Seasonality + Media + Control contributions
The model enforces several constraints:
- DAG acyclicity: upper-triangular mask (default) or NOTEARS penalty
when ``dag_mode='notears'``
- Non-negative baseline and seasonal contributions
- Learnable coefficient bounds to prevent explosion
- Burn-in period stabilization for initial weeks
"""
B, T, _ = Xm.shape
# Process media variables
X_processed, outputs = self.process_media(Xm)
# Process control variables through MLP - FIXED
ctrl_features = self.ctrl_mlp(Xc) # [B, T, ctrl_hidden]
# SOV-AWARE FEATURE PROCESSING: Media data is already SOV-scaled to [0,1], controls are standardized
# No additional normalization needed - SOV scaling already provides balanced importance learning
X_processed_norm = X_processed # SOV-scaled media features are already normalized [0,1]
ctrl_features_norm = ctrl_features # Control features are already properly scaled by pipeline
# GRU processing with NORMALIZED media and control features - REVERTED TO STABLE
gru_in = torch.cat([X_processed_norm, ctrl_features_norm], dim=-1)
h0 = self.h0.repeat(self.gru_layers, B, 1) # Single layer GRU
h_seq, _ = self.gru(gru_in, h0)
# Generate time-varying coefficients - REGULARIZED for stable attribution
media_coeffs_raw = self.coeff_gen(h_seq)
# LEARNABLE BOUNDS: Each channel learns its optimal maximum coefficient with non-zero guarantee
learned_max = F.softplus(self.media_coeff_max_raw) + 0.1 # Range: [0.1, ∞) - NON-ZERO GUARANTEE!
media_coeffs_unstable = torch.sigmoid(media_coeffs_raw) * learned_max.unsqueeze(0).unsqueeze(0) # [0, learned_max] per channel
ctrl_coeffs_raw = self.ctrl_coeff_gen(h_seq)
# LEARNABLE BOUNDS: Each control variable learns its optimal maximum coefficient
learned_ctrl_max = F.softplus(self.ctrl_coeff_max_raw) + 0.1 # Range: [0.1, ∞) - NON-ZERO GUARANTEE!
ctrl_coeffs_unstable = torch.tanh(ctrl_coeffs_raw) * learned_ctrl_max.unsqueeze(0).unsqueeze(0) # [-learned_max, learned_max] per control
# NEW: Apply burn-in stabilization
media_coeffs = self.apply_burn_in_stabilization(media_coeffs_unstable, self.stable_media_coeff)
ctrl_coeffs = self.apply_burn_in_stabilization(ctrl_coeffs_unstable, self.stable_ctrl_coeff)
# Calculate contributions
media_contrib = X_processed * media_coeffs
media_term = media_contrib.sum(-1) # [B, T]
# Control contributions using original control values
ctrl_contrib = Xc * ctrl_coeffs
ctrl_term = ctrl_contrib.sum(-1) # [B, T]
# Region baseline - CONSTRAINED to be non-negative
region_ids = R[:, 0] if R.dim() > 1 else R # Handle both 1D and 2D region tensors
region_baselines = F.softplus(self.region_baseline[region_ids]) # CONSTRAINED: Always ≥ 0
reg_term = region_baselines.unsqueeze(1).expand(-1, T) # [B, T]
# NEW: Time trend component - Add linear growth capability
time_steps = torch.arange(T, dtype=torch.float32, device=Xm.device).unsqueeze(0).expand(B, -1) # [B, T]
trend_term = self.time_trend_weight * time_steps + self.time_trend_bias # [B, T]
# NEW: Seasonal component - Add data-driven seasonality to baseline
seasonal_term = torch.zeros(B, T, device=Xm.device)
if self.seasonal_components is not None:
# Get seasonal components for current time window
seasonal_data = self.seasonal_components.to(Xm.device) # [n_regions, n_weeks]
# Handle potential size mismatch (seasonal components might be shorter due to padding removal)
if seasonal_data.shape[1] >= T:
# Take the last T weeks (most recent)
seasonal_slice = seasonal_data[:, -T:]
else:
# Pad if seasonal data is shorter
pad_size = T - seasonal_data.shape[1]
seasonal_slice = F.pad(seasonal_data, (pad_size, 0), mode='replicate')
# CONSTRAINT: Ensure seasonal coefficient is non-negative (seasonality can only add to baseline)
seasonal_coeff_positive = F.softplus(self.seasonal_coeff) # Always ≥ 0
# Apply constrained seasonal coefficient
seasonal_term = seasonal_coeff_positive * seasonal_slice[region_ids]
# PREDICTION - CORRECT: Total baseline = global_bias + region_deviation + seasonality
# CONSTRAINT: Ensure global_bias is non-negative (baseline visits can't be negative)
global_bias_positive = F.softplus(self.global_bias) # Always ≥ 0
# COMPUTE BASELINE (calculate once, use for both prediction and attribution)
baseline_without_seasonal = F.relu(reg_term + global_bias_positive) # Baseline WITHOUT seasonal
# Step 1: Compute total raw prediction
raw_prediction = media_term + ctrl_term + baseline_without_seasonal + seasonal_term + trend_term
# Step 3: Scale final prediction
y = raw_prediction * F.softplus(self.prediction_scale)
# Store outputs (direct computation)
outputs['coefficients'] = media_coeffs
outputs['control_coefficients'] = ctrl_coeffs
outputs['contributions'] = media_contrib # [B, T, n_media]
outputs['trend_contribution'] = trend_term # Trend is frozen at 0
outputs['control_contributions'] = ctrl_contrib # [B, T, n_control]
outputs['seasonal_contribution'] = seasonal_term # Seasonal component (separate for waterfall)
outputs['baseline'] = baseline_without_seasonal # Baseline WITHOUT seasonal (avoids double-counting)
outputs['raw_prediction'] = raw_prediction
outputs['prediction_scale'] = F.softplus(self.prediction_scale)
outputs['burn_in_weeks'] = self.burn_in_weeks
# DEBUG: Verify components sum to raw_prediction
components_sum = media_term + ctrl_term + baseline_without_seasonal + seasonal_term + trend_term
diff = torch.abs(components_sum - raw_prediction).mean().item()
# Also verify per-channel media_contrib sums to media_term
media_contrib_sum = media_contrib.sum(dim=-1) # Sum across channels
media_diff = torch.abs(media_contrib_sum - media_term).mean().item()
if diff > 1e-5:
print(f"WARNING: Components don't sum to raw_prediction! Diff: {diff:.6f}")
if media_diff > 1e-5:
print(f"WARNING: media_contrib.sum() != media_term! Diff: {media_diff:.6f}")
# ======================================================================
# ATTRIBUTION REGULARIZATION: Media contribution should match prior
# ======================================================================
# Use ABSOLUTE value constraint (not ratio) to prevent "make everything tiny" loophole
# Target: media_term should be ~40% of total_prediction in absolute terms
# Calculate target media contribution (40% of total prediction)
target_media = self.media_contribution_prior * raw_prediction # [B, T]
# Regularization loss: penalize deviation from target absolute contribution
attribution_reg_loss_raw = F.mse_loss(media_term, target_media)
# For monitoring: also calculate actual proportion
media_proportion = media_term / (raw_prediction + 1e-8)
# DYNAMIC SCALING: Scale attribution loss to match MSE magnitude
# This ensures both losses are comparable before weighting
# We'll compute MSE in trainer and pass it here, but for now store raw loss
outputs['attribution_reg_loss_raw'] = attribution_reg_loss_raw
outputs['media_proportion'] = media_proportion.mean() # For monitoring
outputs['media_term_mean'] = media_term.mean() # For debugging
outputs['target_media_mean'] = target_media.mean() # For debugging
# ======================================================================
# SEASONAL REGULARIZATION: Prevent seasonal_coeff from going to zero
# ======================================================================
# Target: seasonal_coeff should stay near its prior (e.g., 1.0)
# This prevents the model from suppressing seasonality
seasonal_prior = getattr(self, 'seasonal_prior', 1.0)
seasonal_reg_loss_raw = F.mse_loss(
F.softplus(self.seasonal_coeff), # Current seasonal coefficient
torch.tensor([seasonal_prior], device=self.seasonal_coeff.device) # Target prior
)
outputs['seasonal_reg_loss_raw'] = seasonal_reg_loss_raw
return y, media_coeffs, media_contrib, outputs # Return direct contributions
[docs]
def h_acyclicity(self, W: torch.Tensor) -> torch.Tensor:
"""NOTEARS acyclicity scalar: h(W) = tr(exp(W ⊙ W)) − d.
Equals zero iff W is the adjacency of a DAG; smooth and differentiable
elsewhere. See Zheng et al., 2018 (https://arxiv.org/abs/1803.01422).
Args:
W: Square adjacency matrix (n_media × n_media).
Returns:
Scalar tensor; minimised toward 0 during training.
"""
M = W * W
return torch.trace(torch.matrix_exp(M)) - W.shape[0]
[docs]
def get_dag_loss(self) -> torch.Tensor:
"""DAG regularisation. Mode-aware: triangular uses sparsity/confidence
penalties only (acyclicity is structural); NOTEARS additionally adds the
augmented-Lagrangian acyclicity term 0.5·rho·h(W)² + alpha·h(W) and an
L1 penalty on the full adjacency."""
if not (self.enable_dag and self.enable_interactions):
return torch.tensor(0.0, device=self.global_bias.device)
T_dag = float(getattr(self, 'dag_temperature', 1.0))
adj_probs = torch.sigmoid(self.adj_logits / max(T_dag, 1e-3))
adj = adj_probs * self.tri_mask
sparsity_loss = torch.sum(adj)
confidence_loss = torch.sum(adj_probs * (1 - adj_probs))
l1_penalty = torch.sum(torch.abs(self.adj_logits))
if getattr(self, 'dag_mode', 'triangular') == 'triangular':
total_dag_loss = (0.01 * sparsity_loss
+ 0.001 * confidence_loss
+ 0.0002 * l1_penalty)
else:
total_dag_loss = 1e-4 * l1_penalty
if bool(getattr(self, 'notears_active', torch.tensor(True)).item()):
h = self.h_acyclicity(adj)
# Column-group L1: per-target column L2 norm — encourages each
# channel j to pick a focused parent set or none at all.
group_l1 = torch.sum(torch.sqrt(torch.sum(adj.pow(2), dim=0) + 1e-8))
notears_term = (0.5 * self.notears_rho * h.pow(2)
+ self.notears_alpha * h
+ self.notears_lambda1 * sparsity_loss
+ self.notears_group_l1 * group_l1)
total_dag_loss = total_dag_loss + notears_term
return total_dag_loss
[docs]
@torch.no_grad()
def notears_update_duals(self, factor: float = 10.0,
progress: float = 0.25) -> Dict[str, float]:
"""Augmented-Lagrangian dual update (NOTEARS outer loop).
Call once every K epochs from the trainer. Returns diagnostic dict
({"h", "rho", "alpha"}) for logging; returns empty dict in triangular mode.
Args:
factor: Multiplicative growth applied to rho when h(W) stalls.
progress: Required relative shrinkage of h between outer iterations.
If h_new > progress * h_prev, rho is grown by `factor`.
"""
if (not (self.enable_dag and self.enable_interactions)
or getattr(self, 'dag_mode', 'triangular') != 'notears'
or not bool(getattr(self, 'notears_active', torch.tensor(True)).item())):
return {}
T_dag = max(float(getattr(self, 'dag_temperature', 1.0)), 1e-3)
adj = torch.sigmoid(self.adj_logits / T_dag) * self.tri_mask
h = self.h_acyclicity(adj).item()
if h > progress * self._notears_h_prev:
self.notears_rho.mul_(factor).clamp_(max=self.notears_rho_max)
# Standard augmented-Lagrangian dual ascent on the equality constraint h = 0
self.notears_alpha.add_(self.notears_rho * h)
self._notears_h_prev = h
return {
"h": float(h),
"rho": float(self.notears_rho.item()),
"alpha": float(self.notears_alpha.item()),
}
[docs]
@torch.no_grad()
def get_dag_adjacency_matrix(self, eps: Optional[float] = None) -> torch.Tensor:
"""Learned adjacency with ``dag_temperature`` and ``tri_mask`` applied.
Args:
eps: If ``None``, return continuous edge weights. If a float, zero
entries with ``|w| < eps`` (same rule as ``threshold_dag``).
Returns:
Square adjacency tensor ``[n_media, n_media]``.
"""
if not (self.enable_dag and self.enable_interactions):
return torch.zeros(self.n_media, self.n_media,
device=self.global_bias.device)
T_dag = max(float(getattr(self, 'dag_temperature', 1.0)), 1e-3)
W = torch.sigmoid(self.adj_logits / T_dag) * self.tri_mask
if eps is not None:
W = W.detach().clone()
W[W.abs() < eps] = 0.0
return W
return W.detach()
[docs]
@torch.no_grad()
def threshold_dag(self, eps: float = 0.3) -> torch.Tensor:
"""Post-training pruning: zero out adjacency entries with |w| < eps.
Returns the thresholded adjacency tensor. For NOTEARS mode this is the
recommended way to obtain a clean discrete DAG from the continuous W.
"""
return self.get_dag_adjacency_matrix(eps=eps)
[docs]
def get_sparsity_loss(self) -> torch.Tensor:
"""Sparsity loss to encourage sparse coefficients."""
if not (self.enable_dag and self.enable_interactions):
return torch.tensor(0.0, device=self.global_bias.device)
# L1 penalty on media coefficients for sparsity
media_sparsity = torch.sum(torch.abs(self.stable_media_coeff))
# L1 penalty on control coefficients
ctrl_sparsity = torch.sum(torch.abs(self.stable_ctrl_coeff))
# L1 penalty on GRU weights for temporal sparsity
gru_sparsity = sum(torch.sum(torch.abs(param)) for param in self.gru.parameters())
total_sparsity = media_sparsity + ctrl_sparsity + 0.1 * gru_sparsity
return total_sparsity
[docs]
def get_regularization_loss(self) -> torch.Tensor:
"""Calculate combined regularization loss including DAG penalty and coefficient regularization."""
l1_loss = torch.tensor(0.0, device=self.global_bias.device)
l2_loss = torch.tensor(0.0, device=self.global_bias.device)
for param in self.parameters():
if param.requires_grad:
l1_loss += torch.abs(param).mean()
l2_loss += (param ** 2).mean()
reg_loss = self.l1_weight * l1_loss + self.l2_weight * l2_loss
# COEFFICIENT-SPECIFIC REGULARIZATION: Prevent coefficient explosion
coeff_reg_loss = torch.tensor(0.0, device=self.global_bias.device)
# Strong L2 penalty on coefficient range parameters to prevent explosion
coeff_range_penalty = (self.coeff_range_raw ** 2).mean() * self.coeff_l2_weight
ctrl_coeff_range_penalty = (self.ctrl_coeff_range_raw ** 2).mean() * self.coeff_l2_weight
coeff_reg_loss += coeff_range_penalty + ctrl_coeff_range_penalty
# L2 penalty on coefficient generator weights (prevents generating extreme coefficients)
for name, param in self.named_parameters():
if 'coeff_gen' in name and param.requires_grad:
coeff_reg_loss += (param ** 2).mean() * self.coeff_gen_l2_weight
reg_loss += coeff_reg_loss
# Add DAG loss if enabled
if self.enable_dag and self.enable_interactions:
reg_loss = reg_loss + self.get_dag_loss()
return reg_loss
[docs]
def get_parameters(self) -> Dict[str, torch.Tensor]:
"""Get model parameters for analysis."""
params = {
'adstock_alpha': torch.sigmoid(self.alpha).detach(),
'hill_a': F.softplus(self.hill_a).detach(),
'hill_g': F.softplus(self.hill_g).detach(),
'global_bias': self.global_bias.detach(),
'prediction_scale': F.softplus(self.prediction_scale).detach(),
}
if self.enable_dag and self.enable_interactions:
adj_probs = torch.sigmoid(self.adj_logits)
params['adjacency'] = (adj_probs * self.tri_mask).detach()
params['interaction_weight'] = torch.sigmoid(self.interaction_weight).detach()
return params
[docs]
def create_unified_mmm(
n_media: int,
n_control: int,
hidden_size: int = 64,
n_regions: int = 2,
dropout: float = 0.1,
sparsity_weight: float = 0.1,
enable_dag: bool = True,
enable_interactions: bool = True,
l1_weight: float = 0.01,
l2_weight: float = 0.01,
coeff_range: float = 2.0
) -> DeepCausalMMM:
"""Factory function to create a DeepCausalMMM model."""
model = DeepCausalMMM(
n_media=n_media,
ctrl_dim=n_control,
hidden=hidden_size,
n_regions=n_regions,
dropout=dropout,
sparsity_weight=sparsity_weight,
enable_dag=enable_dag,
enable_interactions=enable_interactions,
l1_weight=l1_weight,
l2_weight=l2_weight,
coeff_range=coeff_range
)
return model