deepcausalmmm.core.unified_model
DeepCausalMMM model implementation combining GRU, DAG, and interaction components.
Functions
|
Factory function to create a DeepCausalMMM model. |
Classes
|
Deep Causal Marketing Mix Model with DAG structure and channel interactions. |
- class deepcausalmmm.core.unified_model.DeepCausalMMM(n_media: int = 10, ctrl_dim: int = 15, hidden: int = 32, n_regions: int = 2, dropout: float = 0.1, sparsity_weight: float = 0.01, enable_dag: bool = True, enable_interactions: bool = True, l1_weight: float = 0.001, l2_weight: float = 0.001, burn_in_weeks: int = 4, use_coefficient_momentum: bool = True, momentum_decay: float = 0.9, use_warm_start: bool = True, warm_start_epochs: int = 50, stabilization_method: str = 'exponential', coeff_l2_weight: float = 0.1, coeff_gen_l2_weight: float = 0.05, gru_layers: int = 1, ctrl_hidden_ratio: float = 0.5, dag_mode: str = 'triangular', notears_lambda1: float = 0.01, notears_rho_init: float = 1.0, notears_alpha_init: float = 0.0, notears_rho_max: float = 1e+16, dag_temperature: float = 1.0, notears_group_l1: float = 0.0)[source]
Deep Causal Marketing Mix Model with DAG structure and channel interactions.
This model combines deep learning with causal inference to understand the impact of marketing channels on business KPIs while learning causal relationships between channels through a Directed Acyclic Graph (DAG).
The model features: - GRU-based temporal modeling for time-varying coefficients - Learnable coefficient bounds for realistic attribution - DAG learning for causal channel interactions (triangular mask or opt-in NOTEARS) - Adstock and saturation transformations - Multi-region support with shared and region-specific parameters - Zero hardcoding philosophy - all parameters are learnable or configurable
- Parameters:
n_media (int, default=10) – Number of media channels in the dataset
ctrl_dim (int, default=15) – Number of control variables (weather, events, etc.)
hidden (int, default=32) – Hidden dimension size for GRU and MLP layers
n_regions (int, default=2) – Number of geographic regions or DMAs
dropout (float, default=0.1) – Dropout rate for regularization during training
sparsity_weight (float, default=0.01) – Weight for sparsity regularization on coefficients
enable_dag (bool, default=True) – Whether to enable DAG learning for channel interactions
enable_interactions (bool, default=True) – Whether to enable channel interaction modeling
l1_weight (float, default=0.001) – L1 regularization weight for coefficient sparsity
l2_weight (float, default=0.001) – L2 regularization weight for coefficient smoothness
burn_in_weeks (int, default=4) – Number of initial weeks for GRU stabilization
use_coefficient_momentum (bool, default=True) – Whether to use momentum for coefficient stabilization
momentum_decay (float, default=0.9) – Decay rate for coefficient momentum
use_warm_start (bool, default=True) – Whether to use warm start training initialization
warm_start_epochs (int, default=50) – Number of epochs for warm start phase
stabilization_method (str, default="exponential") – Method for coefficient stabilization (“linear”, “exponential”, “sigmoid”)
coeff_l2_weight (float, default=0.1) – L2 regularization specifically for media coefficients
coeff_gen_l2_weight (float, default=0.05) – L2 regularization for coefficient generation layers
gru_layers (int, default=1) – Number of GRU layers (configured, not hardcoded)
ctrl_hidden_ratio (float, default=0.5) – Control hidden size as ratio of main hidden dimension
dag_mode (str, default="triangular") – DAG acyclicity mode:
"triangular"(upper-triangular mask, default) or"notears"(continuous NOTEARS penalty; Zheng et al., 2018)notears_lambda1 (float, default=0.01) – L1 sparsity on the full adjacency in NOTEARS mode
notears_rho_init (float, default=1.0) – Initial augmented-Lagrangian penalty
rhonotears_alpha_init (float, default=0.0) – Initial dual variable
alphafor NOTEARSnotears_rho_max (float, default=1e16) – Upper cap on
rhofor numerical safetydag_temperature (float, default=1.0) – Sigmoid temperature for DAG edge weights (
< 1sharpens toward {0, 1})notears_group_l1 (float, default=0.0) – Column-group L1 over adjacency columns (NOTEARS only)
- media_coeffs
Time-varying coefficients for media channels
- Type:
torch.nn.Parameter
- ctrl_coeffs
Coefficients for control variables
- Type:
torch.nn.Parameter
- dag_matrix
Learnable DAG adjacency matrix for channel interactions
- Type:
torch.nn.Parameter
- region_baseline
Region-specific baseline contributions
- Type:
torch.nn.Parameter
- seasonal_coeff
Learnable coefficient for seasonal component
- Type:
torch.nn.Parameter
Examples
>>> import torch >>> from deepcausalmmm import DeepCausalMMM >>> >>> # Initialize model >>> model = DeepCausalMMM( ... n_media=5, ... ctrl_dim=3, ... n_regions=2, ... hidden=64 ... ) >>> >>> # Prepare data tensors >>> media_data = torch.randn(2, 104, 5) # [regions, weeks, channels] >>> control_data = torch.randn(2, 104, 3) # [regions, weeks, controls] >>> regions = torch.arange(2).unsqueeze(1).repeat(1, 104) >>> >>> # Forward pass >>> predictions, media_coeffs, media_contributions, outputs = model( ... media_data, control_data, regions ... ) >>> >>> print(f"Predictions shape: {predictions.shape}") >>> print(f"Media contributions: {outputs['contributions'].shape}")
- __init__(n_media: int = 10, ctrl_dim: int = 15, hidden: int = 32, n_regions: int = 2, dropout: float = 0.1, sparsity_weight: float = 0.01, enable_dag: bool = True, enable_interactions: bool = True, l1_weight: float = 0.001, l2_weight: float = 0.001, burn_in_weeks: int = 4, use_coefficient_momentum: bool = True, momentum_decay: float = 0.9, use_warm_start: bool = True, warm_start_epochs: int = 50, stabilization_method: str = 'exponential', coeff_l2_weight: float = 0.1, coeff_gen_l2_weight: float = 0.05, gru_layers: int = 1, ctrl_hidden_ratio: float = 0.5, dag_mode: str = 'triangular', notears_lambda1: float = 0.01, notears_rho_init: float = 1.0, notears_alpha_init: float = 0.0, notears_rho_max: float = 1e+16, dag_temperature: float = 1.0, notears_group_l1: float = 0.0)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- initialize_baseline(y_data: Tensor)[source]
Initialize baseline to match target data statistics.
CRITICAL: y_data is ALREADY in scaled space (y/y_mean per region)! Extract ALL parameters directly from the actual data distribution. IMPORTANT: Skip padding weeks to avoid baseline bias!
- initialize_hill_from_data(Xm: Tensor)[source]
Initialize hill_g based on per-channel SOV (Share of Voice) distribution.
For each channel, hill_g is set to the 60th percentile of its SOV values, ensuring the inflection point matches where the channel typically operates.
- Parameters:
Xm – Media data [n_regions, n_timesteps, n_channels] in SOV-scaled space [0, 1]
- initialize_stable_coefficients_from_data(Xm: Tensor, Xc: Tensor, y: Tensor)[source]
Initialize stable coefficients based on simple linear regression on the data. This provides domain-informed starting points for coefficient stabilization.
- warm_start_training(Xm: Tensor, Xc: Tensor, R: Tensor, y: Tensor, optimizer: Optimizer, epochs: int = None)[source]
Warm-start training phase to stabilize GRU coefficients before main training. Uses only stable coefficients and focuses on learning good hidden state initialization.
- dag_interaction(x: Tensor) Tensor[source]
Load-bearing DAG-driven channel interactions.
Each target channel j’s effective input becomes a learned blend of itself and a DAG-driven aggregation of its causal parents:
parents_j = sum_i adj[i, j] * x_i (column-j of adj * x) x_j_new = (1 - mix_j) * x_j + mix_j * parents_j
Two design choices make NOTEARS actually structure-learning here rather than decorative:
mix is a per-channel learnable scalar in (0, 1). The model can turn the DAG up where it helps prediction (strong causal parents) and down where it doesn’t, so each adj column receives genuine per-channel gradient signal rather than a globally-averaged one.
adj uses a temperature-controlled sigmoid (
dag_temperature), so the model is encouraged toward {near-0, near-1} edges instead of a soft cluster around 0.5 / sigmoid floor — this is the standard trick for making sigmoid gates bi-modal.
With the previous additive form
x + scalar * matmul(x, adj)the loss was satisfied even with a uniform adjacency at the L1 floor, which is why the learned graph collapsed to ~equal edges. The blended form forces the model to either pick informative parents or keepmix_jclose to 0 (effectively ignoring the DAG for that channel).
- process_media(X: Tensor) Tuple[Tensor, Dict[str, Tensor]][source]
Process media variables through transformations.
- apply_burn_in_stabilization(coeffs: Tensor, stable_coeff: Tensor) Tensor[source]
Advanced burn-in stabilization with multiple transition methods.
- Parameters:
coeffs – Time-varying coefficients [B, T, dim]
stable_coeff – Stable reference coefficients [dim]
- Returns:
Stabilized coefficients with smooth burn-in transition
- forward(Xm: Tensor, Xc: Tensor, R: Tensor) Tuple[Tensor, Tensor, Tensor, Dict[str, Any]][source]
Forward pass through the DeepCausalMMM model.
Processes media and control variables through the neural network to generate predictions, time-varying coefficients, per-channel media contributions, and a detailed
outputsdict (baseline, seasonality, control contributions, DAG, etc.).- Parameters:
Xm (torch.Tensor) – Media data tensor of shape [batch_size, time_steps, n_media] Should be SOV-scaled (Share of Voice) normalized to [0, 1] range
Xc (torch.Tensor) – Control variables tensor of shape [batch_size, time_steps, ctrl_dim] Should be standardized (z-score normalized)
R (torch.Tensor) – Region indicators tensor of shape [batch_size, time_steps] Integer values representing region/DMA IDs
- Returns:
predictions (torch.Tensor) – Model predictions (scaled KPI), shape typically
[batch_size, time_steps](broadcast with components beforeprediction_scale).media_coefficients (torch.Tensor) – Time-varying media coefficients, shape
[batch_size, time_steps, n_media].media_contributions (torch.Tensor) – Per-channel media contributions
X_processed * media_coefficients, shape[batch_size, time_steps, n_media].outputs (Dict[str, Any]) – Detailed tensors, including: -
contributions: same asmedia_contributions(media channel breakdown) -coefficients: same as returnedmedia_coefficients-control_contributions: control variable contributions[batch, time, ctrl_dim]-control_coefficients: control coefficients[batch, time, ctrl_dim]-baseline: baseline without seasonality (for waterfall-style splits) -seasonal_contribution: seasonal term[batch, time]-dag_matrix(when DAG enabled): adjacency[n_media, n_media]-adstocked_media,media_hill,media_dag(when applicable): media pipeline stages
Examples
>>> import torch >>> model = DeepCausalMMM(n_media=3, ctrl_dim=2, n_regions=2) >>> >>> # Prepare input tensors >>> media = torch.rand(2, 52, 3) # 2 regions, 52 weeks, 3 channels >>> control = torch.randn(2, 52, 2) # 2 control variables >>> regions = torch.tensor([[0]*52, [1]*52]) # Region indicators >>> >>> # Forward pass >>> pred, media_coeffs, media_contrib, outputs = model(media, control, regions) >>> >>> # Access detailed outputs >>> media_contrib_out = outputs['contributions'] >>> ctrl_contrib = outputs['control_contributions'] >>> dag_matrix = outputs.get('dag_matrix')
Notes
The forward pass applies the following transformations in order: 1. Media processing: Adstock -> Hill saturation -> DAG interactions 2. Feature processing: Media features + Control features -> GRU 3. Coefficient generation: Time-varying coefficients from GRU states 4. Contribution calculation: Features * Coefficients 5. Final prediction: Baseline + Seasonality + Media + Control contributions
The model enforces several constraints: - DAG acyclicity: upper-triangular mask (default) or NOTEARS penalty
when
dag_mode='notears'Non-negative baseline and seasonal contributions
Learnable coefficient bounds to prevent explosion
Burn-in period stabilization for initial weeks
- h_acyclicity(W: Tensor) Tensor[source]
NOTEARS acyclicity scalar: h(W) = tr(exp(W ⊙ W)) − d.
Equals zero iff W is the adjacency of a DAG; smooth and differentiable elsewhere. See Zheng et al., 2018 (https://arxiv.org/abs/1803.01422).
- Parameters:
W – Square adjacency matrix (n_media × n_media).
- Returns:
Scalar tensor; minimised toward 0 during training.
- get_dag_loss() Tensor[source]
DAG regularisation. Mode-aware: triangular uses sparsity/confidence penalties only (acyclicity is structural); NOTEARS additionally adds the augmented-Lagrangian acyclicity term 0.5·rho·h(W)² + alpha·h(W) and an L1 penalty on the full adjacency.
- notears_update_duals(factor: float = 10.0, progress: float = 0.25) Dict[str, float][source]
Augmented-Lagrangian dual update (NOTEARS outer loop).
Call once every K epochs from the trainer. Returns diagnostic dict ({“h”, “rho”, “alpha”}) for logging; returns empty dict in triangular mode.
- Parameters:
factor – Multiplicative growth applied to rho when h(W) stalls.
progress – Required relative shrinkage of h between outer iterations. If h_new > progress * h_prev, rho is grown by factor.
- get_dag_adjacency_matrix(eps: float | None = None) Tensor[source]
Learned adjacency with
dag_temperatureandtri_maskapplied.- Parameters:
eps – If
None, return continuous edge weights. If a float, zero entries with|w| < eps(same rule asthreshold_dag).- Returns:
Square adjacency tensor
[n_media, n_media].
- threshold_dag(eps: float = 0.3) Tensor[source]
Post-training pruning: zero out adjacency entries with |w| < eps.
Returns the thresholded adjacency tensor. For NOTEARS mode this is the recommended way to obtain a clean discrete DAG from the continuous W.
- deepcausalmmm.core.unified_model.create_unified_mmm(n_media: int, n_control: int, hidden_size: int = 64, n_regions: int = 2, dropout: float = 0.1, sparsity_weight: float = 0.1, enable_dag: bool = True, enable_interactions: bool = True, l1_weight: float = 0.01, l2_weight: float = 0.01, coeff_range: float = 2.0) DeepCausalMMM[source]
Factory function to create a DeepCausalMMM model.