Source code for deepcausalmmm.core.unified_model

"""
DeepCausalMMM model implementation combining GRU, DAG, and interaction components.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, Dict, Any
import numpy as np

from deepcausalmmm.core.dag_model import NodeToEdge, EdgeToNode, DAGConstraint
from deepcausalmmm.core.seasonality import DetectSeasonality

import logging

logger = logging.getLogger('deepcausalmmm')

[docs] class DeepCausalMMM(nn.Module): """ Deep Causal Marketing Mix Model with DAG structure and channel interactions. This model combines deep learning with causal inference to understand the impact of marketing channels on business KPIs while learning causal relationships between channels through a Directed Acyclic Graph (DAG). The model features: - GRU-based temporal modeling for time-varying coefficients - Learnable coefficient bounds for realistic attribution - DAG learning for causal channel interactions (triangular mask or opt-in NOTEARS) - Adstock and saturation transformations - Multi-region support with shared and region-specific parameters - Zero hardcoding philosophy - all parameters are learnable or configurable Parameters ---------- n_media : int, default=10 Number of media channels in the dataset ctrl_dim : int, default=15 Number of control variables (weather, events, etc.) hidden : int, default=32 Hidden dimension size for GRU and MLP layers n_regions : int, default=2 Number of geographic regions or DMAs dropout : float, default=0.1 Dropout rate for regularization during training sparsity_weight : float, default=0.01 Weight for sparsity regularization on coefficients enable_dag : bool, default=True Whether to enable DAG learning for channel interactions enable_interactions : bool, default=True Whether to enable channel interaction modeling l1_weight : float, default=0.001 L1 regularization weight for coefficient sparsity l2_weight : float, default=0.001 L2 regularization weight for coefficient smoothness burn_in_weeks : int, default=4 Number of initial weeks for GRU stabilization use_coefficient_momentum : bool, default=True Whether to use momentum for coefficient stabilization momentum_decay : float, default=0.9 Decay rate for coefficient momentum use_warm_start : bool, default=True Whether to use warm start training initialization warm_start_epochs : int, default=50 Number of epochs for warm start phase stabilization_method : str, default="exponential" Method for coefficient stabilization ("linear", "exponential", "sigmoid") coeff_l2_weight : float, default=0.1 L2 regularization specifically for media coefficients coeff_gen_l2_weight : float, default=0.05 L2 regularization for coefficient generation layers gru_layers : int, default=1 Number of GRU layers (configured, not hardcoded) ctrl_hidden_ratio : float, default=0.5 Control hidden size as ratio of main hidden dimension dag_mode : str, default="triangular" DAG acyclicity mode: ``"triangular"`` (upper-triangular mask, default) or ``"notears"`` (continuous NOTEARS penalty; Zheng et al., 2018) notears_lambda1 : float, default=0.01 L1 sparsity on the full adjacency in NOTEARS mode notears_rho_init : float, default=1.0 Initial augmented-Lagrangian penalty ``rho`` notears_alpha_init : float, default=0.0 Initial dual variable ``alpha`` for NOTEARS notears_rho_max : float, default=1e16 Upper cap on ``rho`` for numerical safety dag_temperature : float, default=1.0 Sigmoid temperature for DAG edge weights (``< 1`` sharpens toward {0, 1}) notears_group_l1 : float, default=0.0 Column-group L1 over adjacency columns (NOTEARS only) Attributes ---------- media_coeffs : torch.nn.Parameter Time-varying coefficients for media channels ctrl_coeffs : torch.nn.Parameter Coefficients for control variables dag_matrix : torch.nn.Parameter Learnable DAG adjacency matrix for channel interactions region_baseline : torch.nn.Parameter Region-specific baseline contributions seasonal_coeff : torch.nn.Parameter Learnable coefficient for seasonal component Examples -------- >>> import torch >>> from deepcausalmmm import DeepCausalMMM >>> >>> # Initialize model >>> model = DeepCausalMMM( ... n_media=5, ... ctrl_dim=3, ... n_regions=2, ... hidden=64 ... ) >>> >>> # Prepare data tensors >>> media_data = torch.randn(2, 104, 5) # [regions, weeks, channels] >>> control_data = torch.randn(2, 104, 3) # [regions, weeks, controls] >>> regions = torch.arange(2).unsqueeze(1).repeat(1, 104) >>> >>> # Forward pass >>> predictions, media_coeffs, media_contributions, outputs = model( ... media_data, control_data, regions ... ) >>> >>> print(f"Predictions shape: {predictions.shape}") >>> print(f"Media contributions: {outputs['contributions'].shape}") """
[docs] def __init__( self, n_media: int = 10, ctrl_dim: int = 15, hidden: int = 32, # Smaller default n_regions: int = 2, dropout: float = 0.1, sparsity_weight: float = 0.01, # Smaller default enable_dag: bool = True, enable_interactions: bool = True, l1_weight: float = 0.001, # Smaller default l2_weight: float = 0.001, # Smaller default burn_in_weeks: int = 4, # NEW: Number of weeks to stabilize GRU # NEW: Advanced stabilization parameters use_coefficient_momentum: bool = True, momentum_decay: float = 0.9, use_warm_start: bool = True, warm_start_epochs: int = 50, stabilization_method: str = "exponential", # "linear", "exponential", "sigmoid" # COEFFICIENT REGULARIZATION: Prevent coefficient explosion coeff_l2_weight: float = 0.1, coeff_gen_l2_weight: float = 0.05, # NEW: Config-driven parameters (no hardcoding!) gru_layers: int = 1, ctrl_hidden_ratio: float = 0.5, # Control hidden size as ratio of main hidden size # NOTEARS DAG learning (Zheng et al., 2018). Default keeps prior behaviour. dag_mode: str = "triangular", # "triangular" | "notears" notears_lambda1: float = 0.01, # L1 sparsity weight on full W in NOTEARS mode notears_rho_init: float = 1.0, # Initial penalty parameter rho notears_alpha_init: float = 0.0, # Initial dual variable alpha notears_rho_max: float = 1e16, # Cap on rho for numerical safety dag_temperature: float = 1.0, # <1.0 sharpens sigmoid edges toward {0,1} notears_group_l1: float = 0.0, # L2,1 group L1 over adj columns (NOTEARS only) ): super().__init__() # DAG-edge sigmoid temperature (used in dag_interaction()). self.dag_temperature = float(dag_temperature) self.notears_group_l1 = float(notears_group_l1) # Store dimensions and flags self.n_media = n_media self.ctrl_dim = ctrl_dim self.hidden_size = hidden self.n_regions = n_regions self.enable_dag = enable_dag self.enable_interactions = enable_interactions self.l1_weight = l1_weight self.l2_weight = l2_weight self.burn_in_weeks = burn_in_weeks # NEW: Store burn-in period # COEFFICIENT REGULARIZATION: Store parameters to prevent explosion self.coeff_l2_weight = coeff_l2_weight self.coeff_gen_l2_weight = coeff_gen_l2_weight # NEW: Advanced stabilization parameters self.use_coefficient_momentum = use_coefficient_momentum self.momentum_decay = momentum_decay self.use_warm_start = use_warm_start self.warm_start_epochs = warm_start_epochs self.stabilization_method = stabilization_method # Coefficient momentum tracking if self.use_coefficient_momentum: self.register_buffer('media_coeff_momentum', torch.zeros(n_media)) self.register_buffer('ctrl_coeff_momentum', torch.zeros(ctrl_dim)) self.register_buffer('momentum_step', torch.tensor(0)) # Adstock parameters - STABILIZED initialization self.alpha = nn.Parameter(torch.ones(n_media) * 0.8) # Start with reasonable adstock # STABILIZED HILL - Initialize for proper saturation curves # a (slope) should be >= 2.0 for clear diminishing returns # Initialize to inverse_softplus(2.5) so after softplus + clamp we get ~2.5 # Initialize hill_a so that softplus(hill_a) >= 2.0 naturally (without clamp floor) # softplus(2.5) ≈ 2.58, giving room to learn both up and down within [2.0, 5.0] self.hill_a = nn.Parameter(torch.ones(n_media) * 2.5) # softplus(2.5) ≈ 2.58 # hill_g will be initialized per-channel based on SOV data in initialize_hill_from_data() self.hill_g = nn.Parameter(torch.rand(n_media) * 0.2 + 0.1) # Default: 0.1-0.3 (will be overwritten) # CAUSAL DAG components - discover meaningful relationships if enable_dag and enable_interactions: if dag_mode not in ("triangular", "notears"): raise ValueError( f"dag_mode must be 'triangular' or 'notears', got {dag_mode!r}" ) self.dag_mode = dag_mode if dag_mode == "triangular": # Existing behaviour: acyclicity guaranteed by construction via # an upper-triangular mask. Requires a fixed channel ordering. self.adj_logits = nn.Parameter(torch.randn(n_media, n_media) * 0.3 - 0.1) mask = torch.triu(torch.ones(n_media, n_media), diagonal=1) self.register_buffer('tri_mask', mask) else: # NOTEARS mode: acyclicity enforced via a smooth scalar penalty # h(W) = tr(exp(W ⊙ W)) − d. The mask only zeros the diagonal # so the model can learn arbitrary DAG topology from data. # Match the triangular-mode init scale so edges are active from # epoch 0 and contribute to predictions immediately. The initial # h(W) will be large but that is fine: warmup keeps the penalty # off until Huber has established a fit, then the augmented # Lagrangian outer loop drives h(W) → 0. self.adj_logits = nn.Parameter(torch.randn(n_media, n_media) * 0.3 - 0.1) mask = torch.ones(n_media, n_media) - torch.eye(n_media) # zero diagonal only self.register_buffer('tri_mask', mask) # Augmented Lagrangian state. Buffers (not parameters) so they # are not touched by the optimiser; updated in notears_update_duals(). self.register_buffer('notears_rho', torch.tensor(notears_rho_init)) self.register_buffer('notears_alpha', torch.tensor(notears_alpha_init)) self.notears_lambda1 = notears_lambda1 self.notears_rho_max = notears_rho_max self._notears_h_prev = float('inf') # Warmup gate: when False, get_dag_loss() and notears_update_duals() # behave as if NOTEARS were disabled. The trainer flips this to True # after `notears_warmup_epochs` to let Huber loss establish a good # data fit before the acyclicity penalty starts pulling weights around. self.register_buffer('notears_active', torch.tensor(True)) # Per-channel DAG-strength: each target channel learns how much # of its effective input comes from its causal parents vs itself. # Initialised at sigmoid(-1.4) ≈ 0.2 so the DAG starts with a # modest 20% blend, then the model can grow per-channel mixing # where the parents help and shrink it where they don't. self.interaction_weight = nn.Parameter(torch.ones(n_media) * -1.4) else: # Sensible default so attribute always exists self.dag_mode = dag_mode # FIXED control processing - use config-driven dimensions (NO HARDCODING!) self.ctrl_hidden = int(hidden * ctrl_hidden_ratio) # Config-driven control hidden size self.ctrl_mlp = nn.Sequential( nn.Linear(ctrl_dim, self.ctrl_hidden), nn.Tanh(), # Bounded activation nn.Dropout(dropout) ) # CONFIG-DRIVEN GRU - NO HARDCODING! gru_input_size = n_media + self.ctrl_hidden self.gru_layers = gru_layers # Use config value, no hardcoding! self.gru = nn.GRU( input_size=gru_input_size, hidden_size=hidden, num_layers=self.gru_layers, # Config-driven layers batch_first=True, dropout=dropout if self.gru_layers > 1 else 0 # Conditional dropout based on layers ) # DISABLE residual connections for stability self.use_residual = False # REVERT - was causing instability # ENHANCED coefficient generator for ultra-low RMSE self.coeff_gen = nn.Sequential( nn.Linear(hidden, hidden), # FULL capacity first layer nn.ReLU(), # Better gradient flow than Tanh nn.Linear(hidden, hidden // 2), nn.ReLU(), nn.Linear(hidden // 2, n_media) ) # Initialize coefficient generator carefully for layer in self.coeff_gen: if isinstance(layer, nn.Linear): nn.init.xavier_uniform_(layer.weight, gain=0.1) nn.init.zeros_(layer.bias) # ENHANCED control coefficient generator for ultra-low RMSE self.ctrl_coeff_gen = nn.Sequential( nn.Linear(hidden, hidden), # FULL capacity first layer nn.ReLU(), # Better gradient flow than Tanh nn.Linear(hidden, hidden // 2), nn.ReLU(), nn.Linear(hidden // 2, ctrl_dim) ) # Initialize control coefficients for layer in self.ctrl_coeff_gen: if isinstance(layer, nn.Linear): nn.init.xavier_uniform_(layer.weight, gain=0.1) nn.init.zeros_(layer.bias) # ATTRIBUTION PRIOR: Target media contribution (business knowledge) self.media_contribution_prior = 0.40 # 40% is typical for MMM self.attribution_reg_weight = 0.5 # Balance between prediction and attribution (0.5 = equal priority) # SEASONAL REGULARIZATION: Prevent seasonal_coeff from being suppressed to zero self.seasonal_prior = 1.0 # Target seasonal coefficient value self.seasonal_reg_weight = 0.1 # Weight for seasonal regularization (lower than attribution) # REGION-SPECIFIC baselines - IMPROVED initialization self.region_baseline = nn.Parameter(torch.randn(n_regions) * 0.1) # Global bias and prediction scaling self.global_bias = nn.Parameter(torch.zeros(1)) self.prediction_scale = nn.Parameter(torch.ones(1)) # NEW: Time trend component to handle growth patterns # FROZEN: Set to zero and not learnable (disabled per user request) self.register_buffer('time_trend_weight', torch.zeros(1)) # FROZEN at 0 self.register_buffer('time_trend_bias', torch.zeros(1)) # FROZEN at 0 # NEW: Seasonal component with learnable coefficient self.seasonal_coeff = nn.Parameter(torch.ones(1)) # Learnable seasonal coefficient self.seasonal_components = None # Will be initialized with actual data from decomposition self.seasonality_detector = DetectSeasonality() # For seasonal decomposition # FULLY LEARNABLE: No hardcoded bounds - let model discover everything self.coeff_range_raw = nn.Parameter(torch.tensor(0.0)) # exp(0) = 1.0, no upper bound self.ctrl_coeff_range_raw = nn.Parameter(torch.tensor(0.0)) # exp(0) = 1.0, no upper bound # FROZEN: Trend damping (disabled - trend is frozen at zero) self.register_buffer('trend_damping_raw', torch.tensor(0.0)) # FROZEN - not learnable # HYBRID APPROACH: Fixed regularization weights (handled in trainer) # Only core model parameters are learnable - not loss balancing weights # FULLY LEARNABLE: All initialization scaling factors - no hardcoded multipliers self.stable_coeff_scale_raw = nn.Parameter(torch.tensor(-2.3)) # exp(-2.3) ≈ 0.1, but can learn optimal self.region_baseline_scale_raw = nn.Parameter(torch.tensor(-2.3)) # exp(-2.3) ≈ 0.1, but can learn optimal # LEARNABLE COEFFICIENT BOUNDS: Each channel learns its optimal maximum coefficient self.media_coeff_max_raw = nn.Parameter(torch.ones(n_media) * 1.0) # Start at ~2.7 (softplus(1.0)) self.ctrl_coeff_max_raw = nn.Parameter(torch.ones(ctrl_dim) * 1.5) # Start at ~4.5 (softplus(1.5)) # SIMPLE TRAINABLE CONSTANTS: Global bounds only (prevent overfitting) # Channel-specific bounds caused severe overfitting (R² dropped from 0.924 to 0.340) # Using global bounds with natural activation functions for optimal generalization # GRU hidden state initialization - STABILIZED self.h0 = nn.Parameter(torch.randn(1, 1, hidden) * 0.01) # NEW: Stable coefficient reference for burn-in stabilization # Initialize to small values - will be set properly from data self.stable_media_coeff = nn.Parameter(torch.zeros(n_media)) self.stable_ctrl_coeff = nn.Parameter(torch.zeros(ctrl_dim))
[docs] def initialize_baseline(self, y_data: torch.Tensor): """Initialize baseline to match target data statistics. CRITICAL: y_data is ALREADY in scaled space (y/y_mean per region)! Extract ALL parameters directly from the actual data distribution. IMPORTANT: Skip padding weeks to avoid baseline bias! """ with torch.no_grad(): # y_data shape: [n_regions, n_timesteps] - already in scaled space # CRITICAL: Remove padding weeks (first 4 weeks) to get true data statistics y_no_padding = y_data[:, self.burn_in_weeks:] if y_data.shape[1] > self.burn_in_weeks else y_data y_numpy = y_no_padding.cpu().numpy() n_regions, n_timesteps = y_numpy.shape # 1. GLOBAL BASELINE: Keep learnable but start small to let model learn the right scale # The model will learn the appropriate global offset during training self.global_bias.data = torch.FloatTensor([0.0]) # 2. REGION BASELINES: Use actual per-region means as absolute baselines (not deviations) # This ensures each region starts with a positive baseline equal to its historical mean region_means_scaled = y_numpy.mean(axis=1) # [n_regions] - actual region means self.region_baseline.data = torch.FloatTensor(region_means_scaled) # 3. PREDICTION SCALE: Initialize to 1.0 (neutral scaling) # The model will learn the appropriate scaling during training # softplus(0) = 1.0, so no initial scaling applied self.prediction_scale.data = torch.FloatTensor([0.0]) # softplus(0) = 1.0 # 4. TIME TREND: DISABLED - Frozen at zero (user request) # Trend is a buffer (not a parameter), so it won't be learned during training # No need to set - already initialized to zero in __init__ logger.info(f" Trend initialization: DISABLED (FROZEN at zero - not learnable)") # Get original scale means for display if hasattr(self, 'scaling_constants'): y_mean_per_region = self.scaling_constants.get('y_mean_per_region') if y_mean_per_region is not None: avg_original_visits = (region_means_scaled.mean() * y_mean_per_region.mean().item()) else: avg_original_visits = None else: avg_original_visits = None logger.info(f"Initialized baselines (Linear scaled - y/mean):") logger.info(f" Region baselines range: [{region_means_scaled.min():.3f}, {region_means_scaled.max():.3f}]") logger.info(f" Global bias (Linear scale): {self.global_bias.item():.3f} [LEARNABLE - constrained ≥ 0 via softplus]") if avg_original_visits is not None: logger.info(f" Expected prediction baseline: {region_means_scaled.mean():.3f} (normalized) -> ~{avg_original_visits:.0f} visits") else: logger.info(f" Expected prediction baseline: {region_means_scaled.mean():.3f} (normalized)") # Initialize seasonal components using actual data decomposition self._initialize_seasonal_components(y_data)
def _initialize_seasonal_components(self, y_data: torch.Tensor): """ Initialize seasonal components using multiplicative decomposition per region. Args: y_data: Target data [n_regions, n_timesteps] in linear scaled space (y/mean) """ with torch.no_grad(): # Skip padding weeks for seasonal decomposition y_no_padding = y_data[:, self.burn_in_weeks:] if y_data.shape[1] > self.burn_in_weeks else y_data # Convert to original scale by multiplying by region means if hasattr(self, 'scaling_constants'): y_mean_per_region = self.scaling_constants.get('y_mean_per_region') if y_mean_per_region is not None: # y_original = y_scaled * y_mean # y_mean_per_region shape: [n_regions, 1], broadcasts to [n_regions, n_timesteps] y_original_scale = (y_no_padding * y_mean_per_region).cpu().numpy() else: # Fallback: assume data is already reasonably scaled y_original_scale = y_no_padding.cpu().numpy() else: # Fallback: assume data is already reasonably scaled y_original_scale = y_no_padding.cpu().numpy() # Extract seasonal components per region seasonal_components = self.seasonality_detector.extract_seasonal_components_per_region( y_original_scale, start_week=0 ) # Apply Min-Max scaling per region to bring seasonality to [0, 1] range # This ensures seasonality can only ADD to baseline, never subtract n_regions, n_weeks = seasonal_components.shape seasonal_normalized = torch.zeros_like(seasonal_components) for region_idx in range(n_regions): region_seasonal = seasonal_components[region_idx, :] region_min = region_seasonal.min() region_max = region_seasonal.max() if region_max > region_min: # Min-Max scaling: (x - min) / (max - min) -> [0, 1] range seasonal_normalized[region_idx, :] = (region_seasonal - region_min) / (region_max - region_min) else: # If no variation, set to middle of range seasonal_normalized[region_idx, :] = 0.5 # Store seasonal components (will be used in forward pass) self.seasonal_components = seasonal_normalized logger.info(f" Initialized seasonal components:") logger.info(f" Seasonal coefficient: {self.seasonal_coeff.item():.3f} [LEARNABLE - constrained ≥ 0 via softplus]") logger.info(f" Components range: [{seasonal_normalized.min():.3f}, {seasonal_normalized.max():.3f}] (Min-Max scaled [0, 1])") logger.info(f" Components mean: {seasonal_normalized.mean():.3f}") logger.info(f" Scaling: Min-Max per region - seasonality can only ADD to baseline")
[docs] def initialize_hill_from_data(self, Xm: torch.Tensor): """ Initialize hill_g based on per-channel SOV (Share of Voice) distribution. For each channel, hill_g is set to the 60th percentile of its SOV values, ensuring the inflection point matches where the channel typically operates. Args: Xm: Media data [n_regions, n_timesteps, n_channels] in SOV-scaled space [0, 1] """ with torch.no_grad(): B, T, n_media = Xm.shape # Skip padding weeks Xm_no_padding = Xm[:, self.burn_in_weeks:] if T > self.burn_in_weeks else Xm logger.info(f"\nInitializing Hill g parameters from SOV data:") for i in range(n_media): # Get all SOV values for this channel channel_sov = Xm_no_padding[:, :, i].flatten() # Filter out zeros (weeks with no spend) channel_sov_nonzero = channel_sov[channel_sov > 1e-4] if len(channel_sov_nonzero) > 10: # Use 60th percentile as inflection point # This ensures the channel shows good S-curve behavior around typical spend levels p60 = torch.quantile(channel_sov_nonzero, 0.6).item() # Clamp to reasonable range [0.05, 0.9] # - Below 0.05: Too early saturation # - Above 0.9: Inflection too late (most data will be linear) g_target = max(0.05, min(0.9, p60)) # Convert to raw value (before softplus) # softplus(x) = log(1 + exp(x)), so inverse: x = log(exp(g) - 1) # For numerical stability, use: x = g + log(exp(-g) + 1) - log(2) for g > 1 if g_target > 0.5: g_raw = np.log(np.exp(g_target) - 1.0) else: g_raw = np.log(g_target / (1.0 - g_target + 1e-8)) self.hill_g.data[i] = g_raw logger.info(f" Channel {i+1:2d}: p60={p60:.4f} → g_target={g_target:.4f} (raw={g_raw:.4f})") else: # Not enough data, keep default logger.info(f" Channel {i+1:2d}: Insufficient data, keeping default g") logger.info(f"\nHill g initialization complete") g_after_softplus = torch.nn.functional.softplus(self.hill_g) logger.info(f" Range after softplus: [{g_after_softplus.min():.4f}, {g_after_softplus.max():.4f}]")
[docs] def initialize_stable_coefficients_from_data(self, Xm: torch.Tensor, Xc: torch.Tensor, y: torch.Tensor): """ Initialize stable coefficients based on simple linear regression on the data. This provides domain-informed starting points for coefficient stabilization. """ with torch.no_grad(): B, T, n_media = Xm.shape _, _, n_control = Xc.shape # CRITICAL: Remove padding weeks to avoid biasing the regression Xm_no_padding = Xm[:, self.burn_in_weeks:] if T > self.burn_in_weeks else Xm Xc_no_padding = Xc[:, self.burn_in_weeks:] if T > self.burn_in_weeks else Xc y_no_padding = y[:, self.burn_in_weeks:] if T > self.burn_in_weeks else y # MEDIA COEFFICIENTS - Flatten data for regression (no padding) X_media_flat = Xm_no_padding.reshape(-1, n_media) # [B*(T-burn_in), n_media] y_flat = y_no_padding.reshape(-1) # [B*(T-burn_in)] # Simple ridge regression to get initial media coefficients XtX = torch.mm(X_media_flat.t(), X_media_flat) lambda_reg = 0.01 * torch.eye(n_media, device=Xm.device) XtX_reg = XtX + lambda_reg Xty = torch.mv(X_media_flat.t(), y_flat) try: beta_media = torch.linalg.solve(XtX_reg, Xty) # POSITIVE-ONLY INITIALIZATION: Use abs and sigmoid to ensure non-negative media coefficients beta_media_raw = torch.abs(beta_media) # Ensure positive beta_media_positive = torch.sigmoid(beta_media_raw) # Range: [0, 1] - POSITIVE ONLY! # Update stable media coefficients data_scale = y_flat.std() / X_media_flat.std() self.stable_media_coeff.data = beta_media_positive * data_scale logger.info(f"Initialized stable coefficients from data:") logger.info(f" Media coeff range (POSITIVE-ONLY): [{beta_media_positive.min().item():.4f}, {beta_media_positive.max().item():.4f}]") except torch.linalg.LinAlgError: logger.warning("Warning: Could not solve for media coefficients, using correlation fallback") correlations = torch.zeros(n_media, device=Xm.device) for i in range(n_media): if X_media_flat[:, i].std() > 1e-8: correlations[i] = torch.corrcoef(torch.stack([X_media_flat[:, i], y_flat]))[0, 1] correlations = torch.nan_to_num(correlations, 0.0) stable_coeff_scale = torch.exp(self.stable_coeff_scale_raw) # FULLY LEARNABLE scaling self.stable_media_coeff.data = correlations * stable_coeff_scale # CONTROL COEFFICIENTS - Allow negative effects (no padding) X_control_flat = Xc_no_padding.reshape(-1, n_control) # [B*(T-burn_in), n_control] # Ridge regression for control coefficients XtX_ctrl = torch.mm(X_control_flat.t(), X_control_flat) lambda_reg_ctrl = 0.01 * torch.eye(n_control, device=Xc.device) XtX_reg_ctrl = XtX_ctrl + lambda_reg_ctrl Xty_ctrl = torch.mv(X_control_flat.t(), y_flat) try: beta_control = torch.linalg.solve(XtX_reg_ctrl, Xty_ctrl) # Allow both positive and negative control effects - NO clipping # Control variables should be able to have strong negative effects # Scale control coefficients appropriately ctrl_data_scale = y_flat.std() / X_control_flat.std() stable_coeff_scale = torch.exp(self.stable_coeff_scale_raw) # FULLY LEARNABLE scaling self.stable_ctrl_coeff.data = beta_control * ctrl_data_scale * stable_coeff_scale logger.info(f" Control coeff range: [{beta_control.min().item():.4f}, {beta_control.max().item():.4f}]") except torch.linalg.LinAlgError: logger.warning("Warning: Could not solve for control coefficients, using correlation fallback") correlations_ctrl = torch.zeros(n_control, device=Xc.device) for i in range(n_control): if X_control_flat[:, i].std() > 1e-8: correlations_ctrl[i] = torch.corrcoef(torch.stack([X_control_flat[:, i], y_flat]))[0, 1] correlations_ctrl = torch.nan_to_num(correlations_ctrl, 0.0) # Allow negative correlations for controls stable_coeff_scale = torch.exp(self.stable_coeff_scale_raw) # FULLY LEARNABLE scaling self.stable_ctrl_coeff.data = correlations_ctrl * stable_coeff_scale
[docs] def warm_start_training(self, Xm: torch.Tensor, Xc: torch.Tensor, R: torch.Tensor, y: torch.Tensor, optimizer: torch.optim.Optimizer, epochs: int = None): """ Warm-start training phase to stabilize GRU coefficients before main training. Uses only stable coefficients and focuses on learning good hidden state initialization. """ if not self.use_warm_start: return epochs = epochs or self.warm_start_epochs logger.info(f"Starting warm-start training for {epochs} epochs...") # Initialize stable coefficients from data self.initialize_stable_coefficients_from_data(Xm, Xc, y) # Save original parameters original_coeff_range_raw = self.coeff_range_raw.data.clone() original_ctrl_coeff_range_raw = self.ctrl_coeff_range_raw.data.clone() original_burn_in = self.burn_in_weeks # Temporarily use only stable coefficients (no dynamic variation) self.coeff_range_raw.data = torch.tensor(-10.0) # sigmoid(-10) ≈ 0, so range ≈ 1.0 (minimal) self.ctrl_coeff_range_raw.data = torch.tensor(-10.0) # sigmoid(-10) ≈ 0, so range ≈ 1.0 (minimal) self.burn_in_weeks = 999 # Force all weeks to use stable coefficients # Freeze coefficient generators during warm-start for param in self.coeff_gen.parameters(): param.requires_grad = False for param in self.ctrl_coeff_gen.parameters(): param.requires_grad = False # Train only GRU, baselines, and stable coefficients warm_start_params = [ self.h0, self.stable_media_coeff, self.stable_ctrl_coeff, self.region_baseline, self.global_bias, self.prediction_scale ] + list(self.gru.parameters()) # Use config learning rate but scaled UP for faster warm-start convergence config_lr = getattr(self, 'config_lr', 0.005) # Default fallback warm_start_lr = config_lr * 2.0 # 2x main LR for faster coefficient stabilization warm_optimizer = torch.optim.Adam(warm_start_params, lr=warm_start_lr) # Scaled from config LR self.train() for epoch in range(epochs): warm_optimizer.zero_grad() y_pred, _, _, outputs = self.forward(Xm, Xc, R) loss = F.mse_loss(y_pred, y) # Add stronger regularization to prevent extreme values reg_loss = 0.01 * (self.h0.pow(2).mean() + self.stable_media_coeff.pow(2).mean()) total_loss = loss + reg_loss total_loss.backward() # Aggressive gradient clipping for regional scaling stability torch.nn.utils.clip_grad_norm_(warm_start_params, max_norm=0.1) # Much more aggressive warm_optimizer.step() if epoch % 10 == 0: logger.info(f" Warm-start epoch {epoch}/{epochs}, Loss: {loss.item():.6f}") # Restore original parameters self.coeff_range_raw.data = original_coeff_range_raw self.ctrl_coeff_range_raw.data = original_ctrl_coeff_range_raw self.burn_in_weeks = original_burn_in # Unfreeze coefficient generators for param in self.coeff_gen.parameters(): param.requires_grad = True for param in self.ctrl_coeff_gen.parameters(): param.requires_grad = True logger.info(f" Warm-start training completed. GRU initialized for stable coefficients.")
[docs] def adstock(self, x: torch.Tensor) -> torch.Tensor: """STABILIZED adstock transformation.""" B, T, C = x.shape alpha = torch.sigmoid(self.alpha).view(1, 1, -1) alpha = torch.clamp(alpha, 0, 0.8) # Cap at 0.8 for stability out_list = [x[:, 0:1]] for t in range(1, T): prev_adstock = out_list[-1] current = x[:, t:t+1] + alpha * prev_adstock # Clip to prevent explosion current = torch.clamp(current, 0, 10) out_list.append(current) return torch.cat(out_list, dim=1)
[docs] def hill(self, x: torch.Tensor) -> torch.Tensor: """STABILIZED Hill saturation transformation.""" a = F.softplus(self.hill_a).view(1, 1, -1) g = F.softplus(self.hill_g).view(1, 1, -1) # Ensure positive input x_safe = F.relu(x) + 1e-8 # Stabilized Hill with clipping # CRITICAL: a (slope) must be >= 2.0 for proper saturation curves a = torch.clamp(a, 2.0, 5.0) # Changed from [0.1, 2.0] to [2.0, 5.0] g = torch.clamp(g, 0.01, 1.0) num = torch.pow(x_safe, a) denom = num + torch.pow(g, a) result = num / (denom + 1e-8) return torch.clamp(result, 0, 1) # Ensure bounded output
[docs] def dag_interaction(self, x: torch.Tensor) -> torch.Tensor: """Load-bearing DAG-driven channel interactions. Each target channel j's effective input becomes a learned blend of itself and a DAG-driven aggregation of its causal parents: parents_j = sum_i adj[i, j] * x_i (column-j of adj * x) x_j_new = (1 - mix_j) * x_j + mix_j * parents_j Two design choices make NOTEARS actually structure-learning here rather than decorative: * `mix` is a *per-channel* learnable scalar in (0, 1). The model can turn the DAG up where it helps prediction (strong causal parents) and down where it doesn't, so each adj column receives genuine per-channel gradient signal rather than a globally-averaged one. * adj uses a temperature-controlled sigmoid (``dag_temperature``), so the model is encouraged toward {near-0, near-1} edges instead of a soft cluster around 0.5 / sigmoid floor — this is the standard trick for making sigmoid gates bi-modal. With the previous additive form ``x + scalar * matmul(x, adj)`` the loss was satisfied even with a uniform adjacency at the L1 floor, which is why the learned graph collapsed to ~equal edges. The blended form forces the model to either pick informative parents or keep ``mix_j`` close to 0 (effectively ignoring the DAG for that channel). """ if not (self.enable_dag and self.enable_interactions): return x T_dag = float(getattr(self, 'dag_temperature', 1.0)) # Lower temperature -> sharper {0,1} edges; T_dag=1.0 reproduces # the standard sigmoid behaviour. adj_probs = torch.sigmoid(self.adj_logits / max(T_dag, 1e-3)) adj = adj_probs * self.tri_mask # [B,T,C] @ [C,C] -> [B,T,C]; column j of adj weights parents of j. parents = torch.matmul(x, adj) # Per-target mix: shape [n_channels], broadcast over batch and time. mix = torch.sigmoid(self.interaction_weight) return (1.0 - mix) * x + mix * parents
[docs] def process_media(self, X: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """Process media variables through transformations.""" outputs = {} B, T, J = X.shape # Apply Adstock X_adstock = self.adstock(X) outputs['media_adstock'] = X_adstock # Apply Hill transformation X_hill = self.hill(X_adstock) outputs['media_hill'] = X_hill # Apply DAG and interactions if enabled if self.enable_dag and self.enable_interactions: # Use the new dag_interaction function X_processed = self.dag_interaction(X_hill) outputs['media_dag'] = X_processed else: X_processed = X_hill return X_processed, outputs
[docs] def apply_burn_in_stabilization(self, coeffs: torch.Tensor, stable_coeff: torch.Tensor) -> torch.Tensor: """ Advanced burn-in stabilization with multiple transition methods. Args: coeffs: Time-varying coefficients [B, T, dim] stable_coeff: Stable reference coefficients [dim] Returns: Stabilized coefficients with smooth burn-in transition """ B, T, dim = coeffs.shape if T <= self.burn_in_weeks: # If sequence is shorter than burn-in, use stable coefficients return stable_coeff.unsqueeze(0).unsqueeze(0).expand(B, T, -1) # Create stabilization weights based on method stabilized_coeffs = coeffs.clone() for week in range(self.burn_in_weeks): # Calculate transition weight based on method if self.stabilization_method == "linear": # Linear transition from 1.0 (fully stable) to 0.0 (fully dynamic) stable_weight = 1.0 - (week / self.burn_in_weeks) elif self.stabilization_method == "exponential": # Exponential decay - slower initial transition, faster later stable_weight = float(torch.exp(torch.tensor(-3.0 * week / self.burn_in_weeks))) elif self.stabilization_method == "sigmoid": # Sigmoid transition - smooth S-curve x = (week - self.burn_in_weeks/2) / (self.burn_in_weeks/4) stable_weight = float(1.0 / (1.0 + torch.exp(torch.tensor(x)))) else: # Default to linear stable_weight = 1.0 - (week / self.burn_in_weeks) dynamic_weight = 1.0 - stable_weight # Blend stable and dynamic coefficients stabilized_coeffs[:, week, :] = ( stable_weight * stable_coeff.unsqueeze(0).expand(B, -1) + dynamic_weight * coeffs[:, week, :] ) # DISABLED: Coefficient momentum system removed to prevent gradient blocking # The momentum system was using .detach() which blocked gradients return stabilized_coeffs
[docs] def forward( self, Xm: torch.Tensor, Xc: torch.Tensor, R: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]: """ Forward pass through the DeepCausalMMM model. Processes media and control variables through the neural network to generate predictions, time-varying coefficients, per-channel media contributions, and a detailed ``outputs`` dict (baseline, seasonality, control contributions, DAG, etc.). Parameters ---------- Xm : torch.Tensor Media data tensor of shape [batch_size, time_steps, n_media] Should be SOV-scaled (Share of Voice) normalized to [0, 1] range Xc : torch.Tensor Control variables tensor of shape [batch_size, time_steps, ctrl_dim] Should be standardized (z-score normalized) R : torch.Tensor Region indicators tensor of shape [batch_size, time_steps] Integer values representing region/DMA IDs Returns ------- predictions : torch.Tensor Model predictions (scaled KPI), shape typically ``[batch_size, time_steps]`` (broadcast with components before ``prediction_scale``). media_coefficients : torch.Tensor Time-varying media coefficients, shape ``[batch_size, time_steps, n_media]``. media_contributions : torch.Tensor Per-channel media contributions ``X_processed * media_coefficients``, shape ``[batch_size, time_steps, n_media]``. outputs : Dict[str, Any] Detailed tensors, including: - ``contributions``: same as ``media_contributions`` (media channel breakdown) - ``coefficients``: same as returned ``media_coefficients`` - ``control_contributions``: control variable contributions ``[batch, time, ctrl_dim]`` - ``control_coefficients``: control coefficients ``[batch, time, ctrl_dim]`` - ``baseline``: baseline without seasonality (for waterfall-style splits) - ``seasonal_contribution``: seasonal term ``[batch, time]`` - ``dag_matrix`` (when DAG enabled): adjacency ``[n_media, n_media]`` - ``adstocked_media``, ``media_hill``, ``media_dag`` (when applicable): media pipeline stages Examples -------- >>> import torch >>> model = DeepCausalMMM(n_media=3, ctrl_dim=2, n_regions=2) >>> >>> # Prepare input tensors >>> media = torch.rand(2, 52, 3) # 2 regions, 52 weeks, 3 channels >>> control = torch.randn(2, 52, 2) # 2 control variables >>> regions = torch.tensor([[0]*52, [1]*52]) # Region indicators >>> >>> # Forward pass >>> pred, media_coeffs, media_contrib, outputs = model(media, control, regions) >>> >>> # Access detailed outputs >>> media_contrib_out = outputs['contributions'] >>> ctrl_contrib = outputs['control_contributions'] >>> dag_matrix = outputs.get('dag_matrix') Notes ----- The forward pass applies the following transformations in order: 1. Media processing: Adstock -> Hill saturation -> DAG interactions 2. Feature processing: Media features + Control features -> GRU 3. Coefficient generation: Time-varying coefficients from GRU states 4. Contribution calculation: Features * Coefficients 5. Final prediction: Baseline + Seasonality + Media + Control contributions The model enforces several constraints: - DAG acyclicity: upper-triangular mask (default) or NOTEARS penalty when ``dag_mode='notears'`` - Non-negative baseline and seasonal contributions - Learnable coefficient bounds to prevent explosion - Burn-in period stabilization for initial weeks """ B, T, _ = Xm.shape # Process media variables X_processed, outputs = self.process_media(Xm) # Process control variables through MLP - FIXED ctrl_features = self.ctrl_mlp(Xc) # [B, T, ctrl_hidden] # SOV-AWARE FEATURE PROCESSING: Media data is already SOV-scaled to [0,1], controls are standardized # No additional normalization needed - SOV scaling already provides balanced importance learning X_processed_norm = X_processed # SOV-scaled media features are already normalized [0,1] ctrl_features_norm = ctrl_features # Control features are already properly scaled by pipeline # GRU processing with NORMALIZED media and control features - REVERTED TO STABLE gru_in = torch.cat([X_processed_norm, ctrl_features_norm], dim=-1) h0 = self.h0.repeat(self.gru_layers, B, 1) # Single layer GRU h_seq, _ = self.gru(gru_in, h0) # Generate time-varying coefficients - REGULARIZED for stable attribution media_coeffs_raw = self.coeff_gen(h_seq) # LEARNABLE BOUNDS: Each channel learns its optimal maximum coefficient with non-zero guarantee learned_max = F.softplus(self.media_coeff_max_raw) + 0.1 # Range: [0.1, ∞) - NON-ZERO GUARANTEE! media_coeffs_unstable = torch.sigmoid(media_coeffs_raw) * learned_max.unsqueeze(0).unsqueeze(0) # [0, learned_max] per channel ctrl_coeffs_raw = self.ctrl_coeff_gen(h_seq) # LEARNABLE BOUNDS: Each control variable learns its optimal maximum coefficient learned_ctrl_max = F.softplus(self.ctrl_coeff_max_raw) + 0.1 # Range: [0.1, ∞) - NON-ZERO GUARANTEE! ctrl_coeffs_unstable = torch.tanh(ctrl_coeffs_raw) * learned_ctrl_max.unsqueeze(0).unsqueeze(0) # [-learned_max, learned_max] per control # NEW: Apply burn-in stabilization media_coeffs = self.apply_burn_in_stabilization(media_coeffs_unstable, self.stable_media_coeff) ctrl_coeffs = self.apply_burn_in_stabilization(ctrl_coeffs_unstable, self.stable_ctrl_coeff) # Calculate contributions media_contrib = X_processed * media_coeffs media_term = media_contrib.sum(-1) # [B, T] # Control contributions using original control values ctrl_contrib = Xc * ctrl_coeffs ctrl_term = ctrl_contrib.sum(-1) # [B, T] # Region baseline - CONSTRAINED to be non-negative region_ids = R[:, 0] if R.dim() > 1 else R # Handle both 1D and 2D region tensors region_baselines = F.softplus(self.region_baseline[region_ids]) # CONSTRAINED: Always ≥ 0 reg_term = region_baselines.unsqueeze(1).expand(-1, T) # [B, T] # NEW: Time trend component - Add linear growth capability time_steps = torch.arange(T, dtype=torch.float32, device=Xm.device).unsqueeze(0).expand(B, -1) # [B, T] trend_term = self.time_trend_weight * time_steps + self.time_trend_bias # [B, T] # NEW: Seasonal component - Add data-driven seasonality to baseline seasonal_term = torch.zeros(B, T, device=Xm.device) if self.seasonal_components is not None: # Get seasonal components for current time window seasonal_data = self.seasonal_components.to(Xm.device) # [n_regions, n_weeks] # Handle potential size mismatch (seasonal components might be shorter due to padding removal) if seasonal_data.shape[1] >= T: # Take the last T weeks (most recent) seasonal_slice = seasonal_data[:, -T:] else: # Pad if seasonal data is shorter pad_size = T - seasonal_data.shape[1] seasonal_slice = F.pad(seasonal_data, (pad_size, 0), mode='replicate') # CONSTRAINT: Ensure seasonal coefficient is non-negative (seasonality can only add to baseline) seasonal_coeff_positive = F.softplus(self.seasonal_coeff) # Always ≥ 0 # Apply constrained seasonal coefficient seasonal_term = seasonal_coeff_positive * seasonal_slice[region_ids] # PREDICTION - CORRECT: Total baseline = global_bias + region_deviation + seasonality # CONSTRAINT: Ensure global_bias is non-negative (baseline visits can't be negative) global_bias_positive = F.softplus(self.global_bias) # Always ≥ 0 # COMPUTE BASELINE (calculate once, use for both prediction and attribution) baseline_without_seasonal = F.relu(reg_term + global_bias_positive) # Baseline WITHOUT seasonal # Step 1: Compute total raw prediction raw_prediction = media_term + ctrl_term + baseline_without_seasonal + seasonal_term + trend_term # Step 3: Scale final prediction y = raw_prediction * F.softplus(self.prediction_scale) # Store outputs (direct computation) outputs['coefficients'] = media_coeffs outputs['control_coefficients'] = ctrl_coeffs outputs['contributions'] = media_contrib # [B, T, n_media] outputs['trend_contribution'] = trend_term # Trend is frozen at 0 outputs['control_contributions'] = ctrl_contrib # [B, T, n_control] outputs['seasonal_contribution'] = seasonal_term # Seasonal component (separate for waterfall) outputs['baseline'] = baseline_without_seasonal # Baseline WITHOUT seasonal (avoids double-counting) outputs['raw_prediction'] = raw_prediction outputs['prediction_scale'] = F.softplus(self.prediction_scale) outputs['burn_in_weeks'] = self.burn_in_weeks # DEBUG: Verify components sum to raw_prediction components_sum = media_term + ctrl_term + baseline_without_seasonal + seasonal_term + trend_term diff = torch.abs(components_sum - raw_prediction).mean().item() # Also verify per-channel media_contrib sums to media_term media_contrib_sum = media_contrib.sum(dim=-1) # Sum across channels media_diff = torch.abs(media_contrib_sum - media_term).mean().item() if diff > 1e-5: print(f"WARNING: Components don't sum to raw_prediction! Diff: {diff:.6f}") if media_diff > 1e-5: print(f"WARNING: media_contrib.sum() != media_term! Diff: {media_diff:.6f}") # ====================================================================== # ATTRIBUTION REGULARIZATION: Media contribution should match prior # ====================================================================== # Use ABSOLUTE value constraint (not ratio) to prevent "make everything tiny" loophole # Target: media_term should be ~40% of total_prediction in absolute terms # Calculate target media contribution (40% of total prediction) target_media = self.media_contribution_prior * raw_prediction # [B, T] # Regularization loss: penalize deviation from target absolute contribution attribution_reg_loss_raw = F.mse_loss(media_term, target_media) # For monitoring: also calculate actual proportion media_proportion = media_term / (raw_prediction + 1e-8) # DYNAMIC SCALING: Scale attribution loss to match MSE magnitude # This ensures both losses are comparable before weighting # We'll compute MSE in trainer and pass it here, but for now store raw loss outputs['attribution_reg_loss_raw'] = attribution_reg_loss_raw outputs['media_proportion'] = media_proportion.mean() # For monitoring outputs['media_term_mean'] = media_term.mean() # For debugging outputs['target_media_mean'] = target_media.mean() # For debugging # ====================================================================== # SEASONAL REGULARIZATION: Prevent seasonal_coeff from going to zero # ====================================================================== # Target: seasonal_coeff should stay near its prior (e.g., 1.0) # This prevents the model from suppressing seasonality seasonal_prior = getattr(self, 'seasonal_prior', 1.0) seasonal_reg_loss_raw = F.mse_loss( F.softplus(self.seasonal_coeff), # Current seasonal coefficient torch.tensor([seasonal_prior], device=self.seasonal_coeff.device) # Target prior ) outputs['seasonal_reg_loss_raw'] = seasonal_reg_loss_raw return y, media_coeffs, media_contrib, outputs # Return direct contributions
[docs] def h_acyclicity(self, W: torch.Tensor) -> torch.Tensor: """NOTEARS acyclicity scalar: h(W) = tr(exp(W ⊙ W)) − d. Equals zero iff W is the adjacency of a DAG; smooth and differentiable elsewhere. See Zheng et al., 2018 (https://arxiv.org/abs/1803.01422). Args: W: Square adjacency matrix (n_media × n_media). Returns: Scalar tensor; minimised toward 0 during training. """ M = W * W return torch.trace(torch.matrix_exp(M)) - W.shape[0]
[docs] def get_dag_loss(self) -> torch.Tensor: """DAG regularisation. Mode-aware: triangular uses sparsity/confidence penalties only (acyclicity is structural); NOTEARS additionally adds the augmented-Lagrangian acyclicity term 0.5·rho·h(W)² + alpha·h(W) and an L1 penalty on the full adjacency.""" if not (self.enable_dag and self.enable_interactions): return torch.tensor(0.0, device=self.global_bias.device) T_dag = float(getattr(self, 'dag_temperature', 1.0)) adj_probs = torch.sigmoid(self.adj_logits / max(T_dag, 1e-3)) adj = adj_probs * self.tri_mask sparsity_loss = torch.sum(adj) confidence_loss = torch.sum(adj_probs * (1 - adj_probs)) l1_penalty = torch.sum(torch.abs(self.adj_logits)) if getattr(self, 'dag_mode', 'triangular') == 'triangular': total_dag_loss = (0.01 * sparsity_loss + 0.001 * confidence_loss + 0.0002 * l1_penalty) else: total_dag_loss = 1e-4 * l1_penalty if bool(getattr(self, 'notears_active', torch.tensor(True)).item()): h = self.h_acyclicity(adj) # Column-group L1: per-target column L2 norm — encourages each # channel j to pick a focused parent set or none at all. group_l1 = torch.sum(torch.sqrt(torch.sum(adj.pow(2), dim=0) + 1e-8)) notears_term = (0.5 * self.notears_rho * h.pow(2) + self.notears_alpha * h + self.notears_lambda1 * sparsity_loss + self.notears_group_l1 * group_l1) total_dag_loss = total_dag_loss + notears_term return total_dag_loss
[docs] @torch.no_grad() def notears_update_duals(self, factor: float = 10.0, progress: float = 0.25) -> Dict[str, float]: """Augmented-Lagrangian dual update (NOTEARS outer loop). Call once every K epochs from the trainer. Returns diagnostic dict ({"h", "rho", "alpha"}) for logging; returns empty dict in triangular mode. Args: factor: Multiplicative growth applied to rho when h(W) stalls. progress: Required relative shrinkage of h between outer iterations. If h_new > progress * h_prev, rho is grown by `factor`. """ if (not (self.enable_dag and self.enable_interactions) or getattr(self, 'dag_mode', 'triangular') != 'notears' or not bool(getattr(self, 'notears_active', torch.tensor(True)).item())): return {} T_dag = max(float(getattr(self, 'dag_temperature', 1.0)), 1e-3) adj = torch.sigmoid(self.adj_logits / T_dag) * self.tri_mask h = self.h_acyclicity(adj).item() if h > progress * self._notears_h_prev: self.notears_rho.mul_(factor).clamp_(max=self.notears_rho_max) # Standard augmented-Lagrangian dual ascent on the equality constraint h = 0 self.notears_alpha.add_(self.notears_rho * h) self._notears_h_prev = h return { "h": float(h), "rho": float(self.notears_rho.item()), "alpha": float(self.notears_alpha.item()), }
[docs] @torch.no_grad() def get_dag_adjacency_matrix(self, eps: Optional[float] = None) -> torch.Tensor: """Learned adjacency with ``dag_temperature`` and ``tri_mask`` applied. Args: eps: If ``None``, return continuous edge weights. If a float, zero entries with ``|w| < eps`` (same rule as ``threshold_dag``). Returns: Square adjacency tensor ``[n_media, n_media]``. """ if not (self.enable_dag and self.enable_interactions): return torch.zeros(self.n_media, self.n_media, device=self.global_bias.device) T_dag = max(float(getattr(self, 'dag_temperature', 1.0)), 1e-3) W = torch.sigmoid(self.adj_logits / T_dag) * self.tri_mask if eps is not None: W = W.detach().clone() W[W.abs() < eps] = 0.0 return W return W.detach()
[docs] @torch.no_grad() def threshold_dag(self, eps: float = 0.3) -> torch.Tensor: """Post-training pruning: zero out adjacency entries with |w| < eps. Returns the thresholded adjacency tensor. For NOTEARS mode this is the recommended way to obtain a clean discrete DAG from the continuous W. """ return self.get_dag_adjacency_matrix(eps=eps)
[docs] def get_sparsity_loss(self) -> torch.Tensor: """Sparsity loss to encourage sparse coefficients.""" if not (self.enable_dag and self.enable_interactions): return torch.tensor(0.0, device=self.global_bias.device) # L1 penalty on media coefficients for sparsity media_sparsity = torch.sum(torch.abs(self.stable_media_coeff)) # L1 penalty on control coefficients ctrl_sparsity = torch.sum(torch.abs(self.stable_ctrl_coeff)) # L1 penalty on GRU weights for temporal sparsity gru_sparsity = sum(torch.sum(torch.abs(param)) for param in self.gru.parameters()) total_sparsity = media_sparsity + ctrl_sparsity + 0.1 * gru_sparsity return total_sparsity
[docs] def get_regularization_loss(self) -> torch.Tensor: """Calculate combined regularization loss including DAG penalty and coefficient regularization.""" l1_loss = torch.tensor(0.0, device=self.global_bias.device) l2_loss = torch.tensor(0.0, device=self.global_bias.device) for param in self.parameters(): if param.requires_grad: l1_loss += torch.abs(param).mean() l2_loss += (param ** 2).mean() reg_loss = self.l1_weight * l1_loss + self.l2_weight * l2_loss # COEFFICIENT-SPECIFIC REGULARIZATION: Prevent coefficient explosion coeff_reg_loss = torch.tensor(0.0, device=self.global_bias.device) # Strong L2 penalty on coefficient range parameters to prevent explosion coeff_range_penalty = (self.coeff_range_raw ** 2).mean() * self.coeff_l2_weight ctrl_coeff_range_penalty = (self.ctrl_coeff_range_raw ** 2).mean() * self.coeff_l2_weight coeff_reg_loss += coeff_range_penalty + ctrl_coeff_range_penalty # L2 penalty on coefficient generator weights (prevents generating extreme coefficients) for name, param in self.named_parameters(): if 'coeff_gen' in name and param.requires_grad: coeff_reg_loss += (param ** 2).mean() * self.coeff_gen_l2_weight reg_loss += coeff_reg_loss # Add DAG loss if enabled if self.enable_dag and self.enable_interactions: reg_loss = reg_loss + self.get_dag_loss() return reg_loss
[docs] def get_parameters(self) -> Dict[str, torch.Tensor]: """Get model parameters for analysis.""" params = { 'adstock_alpha': torch.sigmoid(self.alpha).detach(), 'hill_a': F.softplus(self.hill_a).detach(), 'hill_g': F.softplus(self.hill_g).detach(), 'global_bias': self.global_bias.detach(), 'prediction_scale': F.softplus(self.prediction_scale).detach(), } if self.enable_dag and self.enable_interactions: adj_probs = torch.sigmoid(self.adj_logits) params['adjacency'] = (adj_probs * self.tri_mask).detach() params['interaction_weight'] = torch.sigmoid(self.interaction_weight).detach() return params
[docs] def create_unified_mmm( n_media: int, n_control: int, hidden_size: int = 64, n_regions: int = 2, dropout: float = 0.1, sparsity_weight: float = 0.1, enable_dag: bool = True, enable_interactions: bool = True, l1_weight: float = 0.01, l2_weight: float = 0.01, coeff_range: float = 2.0 ) -> DeepCausalMMM: """Factory function to create a DeepCausalMMM model.""" model = DeepCausalMMM( n_media=n_media, ctrl_dim=n_control, hidden=hidden_size, n_regions=n_regions, dropout=dropout, sparsity_weight=sparsity_weight, enable_dag=enable_dag, enable_interactions=enable_interactions, l1_weight=l1_weight, l2_weight=l2_weight, coeff_range=coeff_range ) return model