deepcausalmmm.core.unified_model

DeepCausalMMM model implementation combining GRU, DAG, and interaction components.

Functions

create_unified_mmm(n_media, n_control[, ...])

Factory function to create a DeepCausalMMM model.

Classes

DeepCausalMMM([n_media, ctrl_dim, hidden, ...])

Deep Causal Marketing Mix Model with DAG structure and channel interactions.

class deepcausalmmm.core.unified_model.DeepCausalMMM(n_media: int = 10, ctrl_dim: int = 15, hidden: int = 32, n_regions: int = 2, dropout: float = 0.1, sparsity_weight: float = 0.01, enable_dag: bool = True, enable_interactions: bool = True, l1_weight: float = 0.001, l2_weight: float = 0.001, burn_in_weeks: int = 4, use_coefficient_momentum: bool = True, momentum_decay: float = 0.9, use_warm_start: bool = True, warm_start_epochs: int = 50, stabilization_method: str = 'exponential', coeff_l2_weight: float = 0.1, coeff_gen_l2_weight: float = 0.05, gru_layers: int = 1, ctrl_hidden_ratio: float = 0.5, dag_mode: str = 'triangular', notears_lambda1: float = 0.01, notears_rho_init: float = 1.0, notears_alpha_init: float = 0.0, notears_rho_max: float = 1e+16, dag_temperature: float = 1.0, notears_group_l1: float = 0.0)[source]

Deep Causal Marketing Mix Model with DAG structure and channel interactions.

This model combines deep learning with causal inference to understand the impact of marketing channels on business KPIs while learning causal relationships between channels through a Directed Acyclic Graph (DAG).

The model features: - GRU-based temporal modeling for time-varying coefficients - Learnable coefficient bounds for realistic attribution - DAG learning for causal channel interactions (triangular mask or opt-in NOTEARS) - Adstock and saturation transformations - Multi-region support with shared and region-specific parameters - Zero hardcoding philosophy - all parameters are learnable or configurable

Parameters:
  • n_media (int, default=10) – Number of media channels in the dataset

  • ctrl_dim (int, default=15) – Number of control variables (weather, events, etc.)

  • hidden (int, default=32) – Hidden dimension size for GRU and MLP layers

  • n_regions (int, default=2) – Number of geographic regions or DMAs

  • dropout (float, default=0.1) – Dropout rate for regularization during training

  • sparsity_weight (float, default=0.01) – Weight for sparsity regularization on coefficients

  • enable_dag (bool, default=True) – Whether to enable DAG learning for channel interactions

  • enable_interactions (bool, default=True) – Whether to enable channel interaction modeling

  • l1_weight (float, default=0.001) – L1 regularization weight for coefficient sparsity

  • l2_weight (float, default=0.001) – L2 regularization weight for coefficient smoothness

  • burn_in_weeks (int, default=4) – Number of initial weeks for GRU stabilization

  • use_coefficient_momentum (bool, default=True) – Whether to use momentum for coefficient stabilization

  • momentum_decay (float, default=0.9) – Decay rate for coefficient momentum

  • use_warm_start (bool, default=True) – Whether to use warm start training initialization

  • warm_start_epochs (int, default=50) – Number of epochs for warm start phase

  • stabilization_method (str, default="exponential") – Method for coefficient stabilization (“linear”, “exponential”, “sigmoid”)

  • coeff_l2_weight (float, default=0.1) – L2 regularization specifically for media coefficients

  • coeff_gen_l2_weight (float, default=0.05) – L2 regularization for coefficient generation layers

  • gru_layers (int, default=1) – Number of GRU layers (configured, not hardcoded)

  • ctrl_hidden_ratio (float, default=0.5) – Control hidden size as ratio of main hidden dimension

  • dag_mode (str, default="triangular") – DAG acyclicity mode: "triangular" (upper-triangular mask, default) or "notears" (continuous NOTEARS penalty; Zheng et al., 2018)

  • notears_lambda1 (float, default=0.01) – L1 sparsity on the full adjacency in NOTEARS mode

  • notears_rho_init (float, default=1.0) – Initial augmented-Lagrangian penalty rho

  • notears_alpha_init (float, default=0.0) – Initial dual variable alpha for NOTEARS

  • notears_rho_max (float, default=1e16) – Upper cap on rho for numerical safety

  • dag_temperature (float, default=1.0) – Sigmoid temperature for DAG edge weights (< 1 sharpens toward {0, 1})

  • notears_group_l1 (float, default=0.0) – Column-group L1 over adjacency columns (NOTEARS only)

media_coeffs

Time-varying coefficients for media channels

Type:

torch.nn.Parameter

ctrl_coeffs

Coefficients for control variables

Type:

torch.nn.Parameter

dag_matrix

Learnable DAG adjacency matrix for channel interactions

Type:

torch.nn.Parameter

region_baseline

Region-specific baseline contributions

Type:

torch.nn.Parameter

seasonal_coeff

Learnable coefficient for seasonal component

Type:

torch.nn.Parameter

Examples

>>> import torch
>>> from deepcausalmmm import DeepCausalMMM
>>>
>>> # Initialize model
>>> model = DeepCausalMMM(
...     n_media=5,
...     ctrl_dim=3,
...     n_regions=2,
...     hidden=64
... )
>>>
>>> # Prepare data tensors
>>> media_data = torch.randn(2, 104, 5)    # [regions, weeks, channels]
>>> control_data = torch.randn(2, 104, 3)  # [regions, weeks, controls]
>>> regions = torch.arange(2).unsqueeze(1).repeat(1, 104)
>>>
>>> # Forward pass
>>> predictions, media_coeffs, media_contributions, outputs = model(
...     media_data, control_data, regions
... )
>>>
>>> print(f"Predictions shape: {predictions.shape}")
>>> print(f"Media contributions: {outputs['contributions'].shape}")
__init__(n_media: int = 10, ctrl_dim: int = 15, hidden: int = 32, n_regions: int = 2, dropout: float = 0.1, sparsity_weight: float = 0.01, enable_dag: bool = True, enable_interactions: bool = True, l1_weight: float = 0.001, l2_weight: float = 0.001, burn_in_weeks: int = 4, use_coefficient_momentum: bool = True, momentum_decay: float = 0.9, use_warm_start: bool = True, warm_start_epochs: int = 50, stabilization_method: str = 'exponential', coeff_l2_weight: float = 0.1, coeff_gen_l2_weight: float = 0.05, gru_layers: int = 1, ctrl_hidden_ratio: float = 0.5, dag_mode: str = 'triangular', notears_lambda1: float = 0.01, notears_rho_init: float = 1.0, notears_alpha_init: float = 0.0, notears_rho_max: float = 1e+16, dag_temperature: float = 1.0, notears_group_l1: float = 0.0)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

initialize_baseline(y_data: Tensor)[source]

Initialize baseline to match target data statistics.

CRITICAL: y_data is ALREADY in scaled space (y/y_mean per region)! Extract ALL parameters directly from the actual data distribution. IMPORTANT: Skip padding weeks to avoid baseline bias!

initialize_hill_from_data(Xm: Tensor)[source]

Initialize hill_g based on per-channel SOV (Share of Voice) distribution.

For each channel, hill_g is set to the 60th percentile of its SOV values, ensuring the inflection point matches where the channel typically operates.

Parameters:

Xm – Media data [n_regions, n_timesteps, n_channels] in SOV-scaled space [0, 1]

initialize_stable_coefficients_from_data(Xm: Tensor, Xc: Tensor, y: Tensor)[source]

Initialize stable coefficients based on simple linear regression on the data. This provides domain-informed starting points for coefficient stabilization.

warm_start_training(Xm: Tensor, Xc: Tensor, R: Tensor, y: Tensor, optimizer: Optimizer, epochs: int = None)[source]

Warm-start training phase to stabilize GRU coefficients before main training. Uses only stable coefficients and focuses on learning good hidden state initialization.

adstock(x: Tensor) Tensor[source]

STABILIZED adstock transformation.

hill(x: Tensor) Tensor[source]

STABILIZED Hill saturation transformation.

dag_interaction(x: Tensor) Tensor[source]

Load-bearing DAG-driven channel interactions.

Each target channel j’s effective input becomes a learned blend of itself and a DAG-driven aggregation of its causal parents:

parents_j = sum_i adj[i, j] * x_i (column-j of adj * x) x_j_new = (1 - mix_j) * x_j + mix_j * parents_j

Two design choices make NOTEARS actually structure-learning here rather than decorative:

  • mix is a per-channel learnable scalar in (0, 1). The model can turn the DAG up where it helps prediction (strong causal parents) and down where it doesn’t, so each adj column receives genuine per-channel gradient signal rather than a globally-averaged one.

  • adj uses a temperature-controlled sigmoid (dag_temperature), so the model is encouraged toward {near-0, near-1} edges instead of a soft cluster around 0.5 / sigmoid floor — this is the standard trick for making sigmoid gates bi-modal.

With the previous additive form x + scalar * matmul(x, adj) the loss was satisfied even with a uniform adjacency at the L1 floor, which is why the learned graph collapsed to ~equal edges. The blended form forces the model to either pick informative parents or keep mix_j close to 0 (effectively ignoring the DAG for that channel).

process_media(X: Tensor) Tuple[Tensor, Dict[str, Tensor]][source]

Process media variables through transformations.

apply_burn_in_stabilization(coeffs: Tensor, stable_coeff: Tensor) Tensor[source]

Advanced burn-in stabilization with multiple transition methods.

Parameters:
  • coeffs – Time-varying coefficients [B, T, dim]

  • stable_coeff – Stable reference coefficients [dim]

Returns:

Stabilized coefficients with smooth burn-in transition

forward(Xm: Tensor, Xc: Tensor, R: Tensor) Tuple[Tensor, Tensor, Tensor, Dict[str, Any]][source]

Forward pass through the DeepCausalMMM model.

Processes media and control variables through the neural network to generate predictions, time-varying coefficients, per-channel media contributions, and a detailed outputs dict (baseline, seasonality, control contributions, DAG, etc.).

Parameters:
  • Xm (torch.Tensor) – Media data tensor of shape [batch_size, time_steps, n_media] Should be SOV-scaled (Share of Voice) normalized to [0, 1] range

  • Xc (torch.Tensor) – Control variables tensor of shape [batch_size, time_steps, ctrl_dim] Should be standardized (z-score normalized)

  • R (torch.Tensor) – Region indicators tensor of shape [batch_size, time_steps] Integer values representing region/DMA IDs

Returns:

  • predictions (torch.Tensor) – Model predictions (scaled KPI), shape typically [batch_size, time_steps] (broadcast with components before prediction_scale).

  • media_coefficients (torch.Tensor) – Time-varying media coefficients, shape [batch_size, time_steps, n_media].

  • media_contributions (torch.Tensor) – Per-channel media contributions X_processed * media_coefficients, shape [batch_size, time_steps, n_media].

  • outputs (Dict[str, Any]) – Detailed tensors, including: - contributions: same as media_contributions (media channel breakdown) - coefficients: same as returned media_coefficients - control_contributions: control variable contributions [batch, time, ctrl_dim] - control_coefficients: control coefficients [batch, time, ctrl_dim] - baseline: baseline without seasonality (for waterfall-style splits) - seasonal_contribution: seasonal term [batch, time] - dag_matrix (when DAG enabled): adjacency [n_media, n_media] - adstocked_media, media_hill, media_dag (when applicable): media pipeline stages

Examples

>>> import torch
>>> model = DeepCausalMMM(n_media=3, ctrl_dim=2, n_regions=2)
>>>
>>> # Prepare input tensors
>>> media = torch.rand(2, 52, 3)  # 2 regions, 52 weeks, 3 channels
>>> control = torch.randn(2, 52, 2)  # 2 control variables
>>> regions = torch.tensor([[0]*52, [1]*52])  # Region indicators
>>>
>>> # Forward pass
>>> pred, media_coeffs, media_contrib, outputs = model(media, control, regions)
>>>
>>> # Access detailed outputs
>>> media_contrib_out = outputs['contributions']
>>> ctrl_contrib = outputs['control_contributions']
>>> dag_matrix = outputs.get('dag_matrix')

Notes

The forward pass applies the following transformations in order: 1. Media processing: Adstock -> Hill saturation -> DAG interactions 2. Feature processing: Media features + Control features -> GRU 3. Coefficient generation: Time-varying coefficients from GRU states 4. Contribution calculation: Features * Coefficients 5. Final prediction: Baseline + Seasonality + Media + Control contributions

The model enforces several constraints: - DAG acyclicity: upper-triangular mask (default) or NOTEARS penalty

when dag_mode='notears'

  • Non-negative baseline and seasonal contributions

  • Learnable coefficient bounds to prevent explosion

  • Burn-in period stabilization for initial weeks

h_acyclicity(W: Tensor) Tensor[source]

NOTEARS acyclicity scalar: h(W) = tr(exp(W ⊙ W)) − d.

Equals zero iff W is the adjacency of a DAG; smooth and differentiable elsewhere. See Zheng et al., 2018 (https://arxiv.org/abs/1803.01422).

Parameters:

W – Square adjacency matrix (n_media × n_media).

Returns:

Scalar tensor; minimised toward 0 during training.

get_dag_loss() Tensor[source]

DAG regularisation. Mode-aware: triangular uses sparsity/confidence penalties only (acyclicity is structural); NOTEARS additionally adds the augmented-Lagrangian acyclicity term 0.5·rho·h(W)² + alpha·h(W) and an L1 penalty on the full adjacency.

notears_update_duals(factor: float = 10.0, progress: float = 0.25) Dict[str, float][source]

Augmented-Lagrangian dual update (NOTEARS outer loop).

Call once every K epochs from the trainer. Returns diagnostic dict ({“h”, “rho”, “alpha”}) for logging; returns empty dict in triangular mode.

Parameters:
  • factor – Multiplicative growth applied to rho when h(W) stalls.

  • progress – Required relative shrinkage of h between outer iterations. If h_new > progress * h_prev, rho is grown by factor.

get_dag_adjacency_matrix(eps: float | None = None) Tensor[source]

Learned adjacency with dag_temperature and tri_mask applied.

Parameters:

eps – If None, return continuous edge weights. If a float, zero entries with |w| < eps (same rule as threshold_dag).

Returns:

Square adjacency tensor [n_media, n_media].

threshold_dag(eps: float = 0.3) Tensor[source]

Post-training pruning: zero out adjacency entries with |w| < eps.

Returns the thresholded adjacency tensor. For NOTEARS mode this is the recommended way to obtain a clean discrete DAG from the continuous W.

get_sparsity_loss() Tensor[source]

Sparsity loss to encourage sparse coefficients.

get_regularization_loss() Tensor[source]

Calculate combined regularization loss including DAG penalty and coefficient regularization.

get_parameters() Dict[str, Tensor][source]

Get model parameters for analysis.

deepcausalmmm.core.unified_model.create_unified_mmm(n_media: int, n_control: int, hidden_size: int = 64, n_regions: int = 2, dropout: float = 0.1, sparsity_weight: float = 0.1, enable_dag: bool = True, enable_interactions: bool = True, l1_weight: float = 0.01, l2_weight: float = 0.01, coeff_range: float = 2.0) DeepCausalMMM[source]

Factory function to create a DeepCausalMMM model.