Core Model Components
DeepCausalMMM Model
DeepCausalMMM supports two DAG modes via dag_mode (set in
get_default_config() or passed to the
constructor): triangular (default) and notears. See
DAG and NOTEARS structure learning for configuration and inspection.
DeepCausalMMM model implementation combining GRU, DAG, and interaction components.
- class deepcausalmmm.core.unified_model.DeepCausalMMM(n_media: int = 10, ctrl_dim: int = 15, hidden: int = 32, n_regions: int = 2, dropout: float = 0.1, sparsity_weight: float = 0.01, enable_dag: bool = True, enable_interactions: bool = True, l1_weight: float = 0.001, l2_weight: float = 0.001, burn_in_weeks: int = 4, use_coefficient_momentum: bool = True, momentum_decay: float = 0.9, use_warm_start: bool = True, warm_start_epochs: int = 50, stabilization_method: str = 'exponential', coeff_l2_weight: float = 0.1, coeff_gen_l2_weight: float = 0.05, gru_layers: int = 1, ctrl_hidden_ratio: float = 0.5, dag_mode: str = 'triangular', notears_lambda1: float = 0.01, notears_rho_init: float = 1.0, notears_alpha_init: float = 0.0, notears_rho_max: float = 1e+16, dag_temperature: float = 1.0, notears_group_l1: float = 0.0)[source]
Bases:
ModuleDeep Causal Marketing Mix Model with DAG structure and channel interactions.
This model combines deep learning with causal inference to understand the impact of marketing channels on business KPIs while learning causal relationships between channels through a Directed Acyclic Graph (DAG).
The model features: - GRU-based temporal modeling for time-varying coefficients - Learnable coefficient bounds for realistic attribution - DAG learning for causal channel interactions (triangular mask or opt-in NOTEARS) - Adstock and saturation transformations - Multi-region support with shared and region-specific parameters - Zero hardcoding philosophy - all parameters are learnable or configurable
- Parameters:
n_media (int, default=10) – Number of media channels in the dataset
ctrl_dim (int, default=15) – Number of control variables (weather, events, etc.)
hidden (int, default=32) – Hidden dimension size for GRU and MLP layers
n_regions (int, default=2) – Number of geographic regions or DMAs
dropout (float, default=0.1) – Dropout rate for regularization during training
sparsity_weight (float, default=0.01) – Weight for sparsity regularization on coefficients
enable_dag (bool, default=True) – Whether to enable DAG learning for channel interactions
enable_interactions (bool, default=True) – Whether to enable channel interaction modeling
l1_weight (float, default=0.001) – L1 regularization weight for coefficient sparsity
l2_weight (float, default=0.001) – L2 regularization weight for coefficient smoothness
burn_in_weeks (int, default=4) – Number of initial weeks for GRU stabilization
use_coefficient_momentum (bool, default=True) – Whether to use momentum for coefficient stabilization
momentum_decay (float, default=0.9) – Decay rate for coefficient momentum
use_warm_start (bool, default=True) – Whether to use warm start training initialization
warm_start_epochs (int, default=50) – Number of epochs for warm start phase
stabilization_method (str, default="exponential") – Method for coefficient stabilization (“linear”, “exponential”, “sigmoid”)
coeff_l2_weight (float, default=0.1) – L2 regularization specifically for media coefficients
coeff_gen_l2_weight (float, default=0.05) – L2 regularization for coefficient generation layers
gru_layers (int, default=1) – Number of GRU layers (configured, not hardcoded)
ctrl_hidden_ratio (float, default=0.5) – Control hidden size as ratio of main hidden dimension
dag_mode (str, default="triangular") – DAG acyclicity mode:
"triangular"(upper-triangular mask, default) or"notears"(continuous NOTEARS penalty; Zheng et al., 2018)notears_lambda1 (float, default=0.01) – L1 sparsity on the full adjacency in NOTEARS mode
notears_rho_init (float, default=1.0) – Initial augmented-Lagrangian penalty
rhonotears_alpha_init (float, default=0.0) – Initial dual variable
alphafor NOTEARSnotears_rho_max (float, default=1e16) – Upper cap on
rhofor numerical safetydag_temperature (float, default=1.0) – Sigmoid temperature for DAG edge weights (
< 1sharpens toward {0, 1})notears_group_l1 (float, default=0.0) – Column-group L1 over adjacency columns (NOTEARS only)
- media_coeffs
Time-varying coefficients for media channels
- Type:
torch.nn.Parameter
- ctrl_coeffs
Coefficients for control variables
- Type:
torch.nn.Parameter
- dag_matrix
Learnable DAG adjacency matrix for channel interactions
- Type:
torch.nn.Parameter
- region_baseline
Region-specific baseline contributions
- Type:
torch.nn.Parameter
- seasonal_coeff
Learnable coefficient for seasonal component
- Type:
torch.nn.Parameter
Examples
>>> import torch >>> from deepcausalmmm import DeepCausalMMM >>> >>> # Initialize model >>> model = DeepCausalMMM( ... n_media=5, ... ctrl_dim=3, ... n_regions=2, ... hidden=64 ... ) >>> >>> # Prepare data tensors >>> media_data = torch.randn(2, 104, 5) # [regions, weeks, channels] >>> control_data = torch.randn(2, 104, 3) # [regions, weeks, controls] >>> regions = torch.arange(2).unsqueeze(1).repeat(1, 104) >>> >>> # Forward pass >>> predictions, media_coeffs, media_contributions, outputs = model( ... media_data, control_data, regions ... ) >>> >>> print(f"Predictions shape: {predictions.shape}") >>> print(f"Media contributions: {outputs['contributions'].shape}")
- __init__(n_media: int = 10, ctrl_dim: int = 15, hidden: int = 32, n_regions: int = 2, dropout: float = 0.1, sparsity_weight: float = 0.01, enable_dag: bool = True, enable_interactions: bool = True, l1_weight: float = 0.001, l2_weight: float = 0.001, burn_in_weeks: int = 4, use_coefficient_momentum: bool = True, momentum_decay: float = 0.9, use_warm_start: bool = True, warm_start_epochs: int = 50, stabilization_method: str = 'exponential', coeff_l2_weight: float = 0.1, coeff_gen_l2_weight: float = 0.05, gru_layers: int = 1, ctrl_hidden_ratio: float = 0.5, dag_mode: str = 'triangular', notears_lambda1: float = 0.01, notears_rho_init: float = 1.0, notears_alpha_init: float = 0.0, notears_rho_max: float = 1e+16, dag_temperature: float = 1.0, notears_group_l1: float = 0.0)[source]
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- initialize_baseline(y_data: Tensor)[source]
Initialize baseline to match target data statistics.
CRITICAL: y_data is ALREADY in scaled space (y/y_mean per region)! Extract ALL parameters directly from the actual data distribution. IMPORTANT: Skip padding weeks to avoid baseline bias!
- initialize_hill_from_data(Xm: Tensor)[source]
Initialize hill_g based on per-channel SOV (Share of Voice) distribution.
For each channel, hill_g is set to the 60th percentile of its SOV values, ensuring the inflection point matches where the channel typically operates.
- Parameters:
Xm – Media data [n_regions, n_timesteps, n_channels] in SOV-scaled space [0, 1]
- initialize_stable_coefficients_from_data(Xm: Tensor, Xc: Tensor, y: Tensor)[source]
Initialize stable coefficients based on simple linear regression on the data. This provides domain-informed starting points for coefficient stabilization.
- warm_start_training(Xm: Tensor, Xc: Tensor, R: Tensor, y: Tensor, optimizer: Optimizer, epochs: int = None)[source]
Warm-start training phase to stabilize GRU coefficients before main training. Uses only stable coefficients and focuses on learning good hidden state initialization.
- dag_interaction(x: Tensor) Tensor[source]
Load-bearing DAG-driven channel interactions.
Each target channel j’s effective input becomes a learned blend of itself and a DAG-driven aggregation of its causal parents:
parents_j = sum_i adj[i, j] * x_i (column-j of adj * x) x_j_new = (1 - mix_j) * x_j + mix_j * parents_j
Two design choices make NOTEARS actually structure-learning here rather than decorative:
mix is a per-channel learnable scalar in (0, 1). The model can turn the DAG up where it helps prediction (strong causal parents) and down where it doesn’t, so each adj column receives genuine per-channel gradient signal rather than a globally-averaged one.
adj uses a temperature-controlled sigmoid (
dag_temperature), so the model is encouraged toward {near-0, near-1} edges instead of a soft cluster around 0.5 / sigmoid floor — this is the standard trick for making sigmoid gates bi-modal.
With the previous additive form
x + scalar * matmul(x, adj)the loss was satisfied even with a uniform adjacency at the L1 floor, which is why the learned graph collapsed to ~equal edges. The blended form forces the model to either pick informative parents or keepmix_jclose to 0 (effectively ignoring the DAG for that channel).
- process_media(X: Tensor) Tuple[Tensor, Dict[str, Tensor]][source]
Process media variables through transformations.
- apply_burn_in_stabilization(coeffs: Tensor, stable_coeff: Tensor) Tensor[source]
Advanced burn-in stabilization with multiple transition methods.
- Parameters:
coeffs – Time-varying coefficients [B, T, dim]
stable_coeff – Stable reference coefficients [dim]
- Returns:
Stabilized coefficients with smooth burn-in transition
- forward(Xm: Tensor, Xc: Tensor, R: Tensor) Tuple[Tensor, Tensor, Tensor, Dict[str, Any]][source]
Forward pass through the DeepCausalMMM model.
Processes media and control variables through the neural network to generate predictions, time-varying coefficients, per-channel media contributions, and a detailed
outputsdict (baseline, seasonality, control contributions, DAG, etc.).- Parameters:
Xm (torch.Tensor) – Media data tensor of shape [batch_size, time_steps, n_media] Should be SOV-scaled (Share of Voice) normalized to [0, 1] range
Xc (torch.Tensor) – Control variables tensor of shape [batch_size, time_steps, ctrl_dim] Should be standardized (z-score normalized)
R (torch.Tensor) – Region indicators tensor of shape [batch_size, time_steps] Integer values representing region/DMA IDs
- Returns:
predictions (torch.Tensor) – Model predictions (scaled KPI), shape typically
[batch_size, time_steps](broadcast with components beforeprediction_scale).media_coefficients (torch.Tensor) – Time-varying media coefficients, shape
[batch_size, time_steps, n_media].media_contributions (torch.Tensor) – Per-channel media contributions
X_processed * media_coefficients, shape[batch_size, time_steps, n_media].outputs (Dict[str, Any]) – Detailed tensors, including: -
contributions: same asmedia_contributions(media channel breakdown) -coefficients: same as returnedmedia_coefficients-control_contributions: control variable contributions[batch, time, ctrl_dim]-control_coefficients: control coefficients[batch, time, ctrl_dim]-baseline: baseline without seasonality (for waterfall-style splits) -seasonal_contribution: seasonal term[batch, time]-dag_matrix(when DAG enabled): adjacency[n_media, n_media]-adstocked_media,media_hill,media_dag(when applicable): media pipeline stages
Examples
>>> import torch >>> model = DeepCausalMMM(n_media=3, ctrl_dim=2, n_regions=2) >>> >>> # Prepare input tensors >>> media = torch.rand(2, 52, 3) # 2 regions, 52 weeks, 3 channels >>> control = torch.randn(2, 52, 2) # 2 control variables >>> regions = torch.tensor([[0]*52, [1]*52]) # Region indicators >>> >>> # Forward pass >>> pred, media_coeffs, media_contrib, outputs = model(media, control, regions) >>> >>> # Access detailed outputs >>> media_contrib_out = outputs['contributions'] >>> ctrl_contrib = outputs['control_contributions'] >>> dag_matrix = outputs.get('dag_matrix')
Notes
The forward pass applies the following transformations in order: 1. Media processing: Adstock -> Hill saturation -> DAG interactions 2. Feature processing: Media features + Control features -> GRU 3. Coefficient generation: Time-varying coefficients from GRU states 4. Contribution calculation: Features * Coefficients 5. Final prediction: Baseline + Seasonality + Media + Control contributions
The model enforces several constraints: - DAG acyclicity: upper-triangular mask (default) or NOTEARS penalty
when
dag_mode='notears'Non-negative baseline and seasonal contributions
Learnable coefficient bounds to prevent explosion
Burn-in period stabilization for initial weeks
- h_acyclicity(W: Tensor) Tensor[source]
NOTEARS acyclicity scalar: h(W) = tr(exp(W ⊙ W)) − d.
Equals zero iff W is the adjacency of a DAG; smooth and differentiable elsewhere. See Zheng et al., 2018 (https://arxiv.org/abs/1803.01422).
- Parameters:
W – Square adjacency matrix (n_media × n_media).
- Returns:
Scalar tensor; minimised toward 0 during training.
- get_dag_loss() Tensor[source]
DAG regularisation. Mode-aware: triangular uses sparsity/confidence penalties only (acyclicity is structural); NOTEARS additionally adds the augmented-Lagrangian acyclicity term 0.5·rho·h(W)² + alpha·h(W) and an L1 penalty on the full adjacency.
- notears_update_duals(factor: float = 10.0, progress: float = 0.25) Dict[str, float][source]
Augmented-Lagrangian dual update (NOTEARS outer loop).
Call once every K epochs from the trainer. Returns diagnostic dict ({“h”, “rho”, “alpha”}) for logging; returns empty dict in triangular mode.
- Parameters:
factor – Multiplicative growth applied to rho when h(W) stalls.
progress – Required relative shrinkage of h between outer iterations. If h_new > progress * h_prev, rho is grown by factor.
- get_dag_adjacency_matrix(eps: float | None = None) Tensor[source]
Learned adjacency with
dag_temperatureandtri_maskapplied.- Parameters:
eps – If
None, return continuous edge weights. If a float, zero entries with|w| < eps(same rule asthreshold_dag).- Returns:
Square adjacency tensor
[n_media, n_media].
- threshold_dag(eps: float = 0.3) Tensor[source]
Post-training pruning: zero out adjacency entries with |w| < eps.
Returns the thresholded adjacency tensor. For NOTEARS mode this is the recommended way to obtain a clean discrete DAG from the continuous W.
- deepcausalmmm.core.unified_model.create_unified_mmm(n_media: int, n_control: int, hidden_size: int = 64, n_regions: int = 2, dropout: float = 0.1, sparsity_weight: float = 0.1, enable_dag: bool = True, enable_interactions: bool = True, l1_weight: float = 0.01, l2_weight: float = 0.01, coeff_range: float = 2.0) DeepCausalMMM[source]
Factory function to create a DeepCausalMMM model.
Configuration
Configuration settings for DeepCausalMMM model.
- deepcausalmmm.core.config.get_default_config() Dict[str, Any][source]
Get default configuration settings for the model.
Includes DAG structure-learning keys under
dag_mode:'triangular'(default) — upper-triangular acyclicity mask'notears'— NOTEARS augmented-Lagrangian mode; see alsonotears_warmup_epochs,notears_lambda1,notears_dual_*,dag_temperature, andnotears_group_l1
- Returns:
Dict containing all configuration parameters
- deepcausalmmm.core.config.update_config(base_config: Dict[str, Any], updates: Dict[str, Any]) Dict[str, Any][source]
Update base configuration with new values.
- Parameters:
base_config – Base configuration dictionary
updates – Dictionary containing updates to apply
- Returns:
Updated configuration dictionary
DAG Model
DAG model implementation with Node-to-Edge and Edge-to-Node transformations.
This module implements the DAG-based neural network architecture with: - NodeToEdge: Transform node features to edge features - EdgeToNode: Aggregate edge features back to nodes - DAGConstraint: Enforce acyclicity in the graph structure
- class deepcausalmmm.core.dag_model.NodeToEdge(node_dim: int, edge_dim: int)[source]
Bases:
ModuleTransform node features to edge features using attention mechanism.
- class deepcausalmmm.core.dag_model.EdgeToNode(edge_dim: int, node_dim: int)[source]
Bases:
ModuleAggregate edge features back to nodes.
- __init__(edge_dim: int, node_dim: int)[source]
Initialize the edge to node transformation.
- Parameters:
edge_dim – Dimension of edge features
node_dim – Dimension of node features
- forward(edges: Tensor, nodes: Tensor, adj_matrix: Tensor) Tensor[source]
Aggregate edge features to update node features.
- Parameters:
edges – Edge features [batch_size, n_nodes, n_nodes, edge_dim]
nodes – Node features [batch_size, n_nodes, 1]
adj_matrix – Adjacency matrix [n_nodes, n_nodes]
- Returns:
Updated node features [batch_size, n_nodes, 1]
- class deepcausalmmm.core.dag_model.DAGConstraint(n_nodes: int, sparsity_weight: float = 0.1, temperature: float = 1.0)[source]
Bases:
ModuleEnforce acyclicity in the graph structure using strict triangular constraint.
- __init__(n_nodes: int, sparsity_weight: float = 0.1, temperature: float = 1.0)[source]
Initialize the DAG constraint module.
- Parameters:
n_nodes – Number of nodes in the graph
sparsity_weight – Weight for the sparsity penalty
temperature – Initial temperature for Gumbel-Softmax
- gumbel_softmax(logits: Tensor, tau: float) Tensor[source]
Gumbel-Softmax sampling with straight-through gradients.
- Parameters:
logits – Input logits
tau – Temperature parameter
- Returns:
Sampled probabilities
- get_adjacency() Tensor[source]
Get the current adjacency matrix using Gumbel-Softmax sampling. This enforces unidirectional edges and allows learning discrete structure.
- class deepcausalmmm.core.dag_model.DAGModel(n_nodes: int, node_dim: int, edge_dim: int, n_layers: int = 3, sparsity_weight: float = 0.1)[source]
Bases:
ModuleComplete DAG-based model combining NodeToEdge and EdgeToNode transformations.
- __init__(n_nodes: int, node_dim: int, edge_dim: int, n_layers: int = 3, sparsity_weight: float = 0.1)[source]
Initialize the DAG model.
- Parameters:
n_nodes – Number of nodes in the graph
node_dim – Dimension of node features
edge_dim – Dimension of edge features
n_layers – Number of message passing layers
sparsity_weight – Weight for the sparsity penalty
Scaling
Simple, proven scaling implementation that works reliably. Based on the successful approach from dashboard_rmse_optimized.py.
- class deepcausalmmm.core.scaling.SimpleScalingParams(control_mean: Tensor, control_std: Tensor, total_impressions: Tensor | None = None)[source]
Bases:
objectStore simple global scaling parameters.
- class deepcausalmmm.core.scaling.SimpleGlobalScaler(config: Dict[str, Any] | None = None)[source]
Bases:
objectLinear scaling approach (y/y_mean) for additive attribution.
Scaling features: - Media: Share-of-voice scaling with outlier smoothing - Control: Robust standardization with adaptive clipping - Target: Linear scaling by region mean (y/y_mean) for additive decomposition - Adaptive normalization with distribution-aware clipping - Advanced outlier handling for extreme value stability
- __init__(config: Dict[str, Any] | None = None)[source]
Initialize the scaler with optional config parameters.
- fit(X_media: ndarray, X_control: ndarray, y: ndarray) None[source]
Fit the scaler using simple global statistics.
- Parameters:
X_media – Media variables [n_regions, n_timesteps, n_channels]
X_control – Control variables [n_regions, n_timesteps, n_controls]
y – Target variable [n_regions, n_timesteps]
- transform(X_media: ndarray, X_control: ndarray, y: ndarray) Tuple[Tensor, Tensor, Tensor][source]
Transform data using fitted parameters.
- Parameters:
X_media – Media variables [n_regions, n_timesteps, n_channels]
X_control – Control variables [n_regions, n_timesteps, n_controls]
y – Target variable [n_regions, n_timesteps]
- Returns:
Tuple of (X_media_scaled, X_control_scaled, y_scaled)
- inverse_transform_target(y_scaled: Tensor) Tensor[source]
Inverse transform target variable.
- Parameters:
y_scaled – Scaled target [n_regions, n_timesteps]
- Returns:
Original scale target
- inverse_transform_contributions(media_contributions: Tensor, baseline: Tensor = None, control_contributions: Tensor = None, seasonal_contributions: Tensor = None, trend_contributions: Tensor = None, prediction_scale: Tensor = None) dict[source]
Inverse transform ALL contributions to original scale using simple multiplication.
With linear scaling (y/y_mean), the inverse transform is straightforward: component_orig = component_scaled * prediction_scale * y_mean_per_region
This preserves additivity: sum(components_orig) = prediction_orig
- Parameters:
media_contributions – Media contributions in scaled space [regions, timesteps, channels]
baseline – Baseline in scaled space [regions, timesteps]
control_contributions – Control contributions in scaled space [regions, timesteps, controls]
seasonal_contributions – Seasonal contributions in scaled space [regions, timesteps]
trend_contributions – Trend contributions in scaled space [regions, timesteps]
prediction_scale – Model’s prediction_scale factor (from F.softplus(self.prediction_scale))
- Returns:
Dictionary with all contributions in original scale
- fit_transform(X_media: ndarray, X_control: ndarray, y: ndarray) Tuple[Tensor, Tensor, Tensor][source]
Fit the scaler and transform data in one step.
- Parameters:
X_media – Media variables [n_regions, n_timesteps, n_channels]
X_control – Control variables [n_regions, n_timesteps, n_controls]
y – Target variable [n_regions, n_timesteps]
- Returns:
Tuple of (X_media_scaled, X_control_scaled, y_scaled)
- deepcausalmmm.core.scaling.GlobalScaler
alias of
SimpleGlobalScaler