deepcausalmmm.core.inference

Modern InferenceManager class for DeepCausalMMM model inference. Provides a clean, reusable interface for model predictions and analysis.

Classes

InferenceManager(model[, pipeline, scaler, ...])

Modern class-based interface for DeepCausalMMM model inference.

ModelInference(model, scaler[, ...])

Legacy compatibility wrapper for InferenceManager.

class deepcausalmmm.core.inference.InferenceManager(model: DeepCausalMMM, pipeline: UnifiedDataPipeline | None = None, scaler: SimpleGlobalScaler | None = None, config: Dict[str, Any] | None = None, channel_names: List[str] | None = None, control_names: List[str] | None = None)[source]

Modern class-based interface for DeepCausalMMM model inference.

Handles: - Model predictions on new data - Contribution analysis (media, control, baseline) - Coefficient extraction - Data preprocessing for inference - Inverse transformations for interpretable results

__init__(model: DeepCausalMMM, pipeline: UnifiedDataPipeline | None = None, scaler: SimpleGlobalScaler | None = None, config: Dict[str, Any] | None = None, channel_names: List[str] | None = None, control_names: List[str] | None = None)[source]

Initialize the inference manager.

Parameters:
  • model – Trained DeepCausalMMM model

  • pipeline – UnifiedDataPipeline used for training (preferred)

  • scaler – SimpleGlobalScaler used for training (legacy support)

  • config – Configuration dictionary

  • channel_names – List of media channel names

  • control_names – List of control variable names

predict(X_media: ndarray, X_control: ndarray, return_contributions: bool = True, remove_padding: bool = True, return_media_coefficients: bool = False) Dict[str, ndarray][source]

Make predictions on new data.

Parameters:
  • X_media – Media data [n_regions, n_weeks, n_media_channels]

  • X_control – Control data [n_regions, n_weeks, n_control_vars]

  • return_contributions – Whether to return contribution breakdowns

  • remove_padding – Whether to remove burn-in padding from results

  • return_media_coefficients – If True, include time-varying media coefficients (second tensor from forward()) as media_coefficients.

Returns:

Dictionary containing predictions and optionally contributions

predict_and_inverse_transform(X_media: ndarray, X_control: ndarray, return_contributions: bool = True) Dict[str, ndarray][source]

Make predictions and apply inverse transformations for interpretable results.

Parameters:
  • X_media – Media data [n_regions, n_weeks, n_media_channels]

  • X_control – Control data [n_regions, n_weeks, n_control_vars]

  • return_contributions – Whether to return contribution breakdowns

Returns:

Dictionary containing predictions and contributions in original scale

get_coefficients() Dict[str, ndarray][source]

Extract model coefficients.

Returns:

Dictionary containing media and control coefficients

get_dag_adjacency(threshold: bool = False, eps: float | None = None) ndarray | None[source]

Extract DAG adjacency matrix if available.

Uses the same mask + dag_temperature scaling as training. Set threshold=True (or pass eps) to prune weak edges via notears_threshold from config by default.

Parameters:
  • threshold – If True, zero entries below eps.

  • eps – Pruning cutoff; defaults to config['notears_threshold'].

Returns:

Adjacency matrix or None if DAG is not enabled

analyze_contributions(X_media: ndarray, X_control: ndarray, aggregate_regions: bool = True, aggregate_time: bool = False) Dict[str, Any][source]

Comprehensive contribution analysis.

Parameters:
  • X_media – Media data

  • X_control – Control data

  • aggregate_regions – Whether to aggregate across regions

  • aggregate_time – Whether to aggregate across time

Returns:

Dictionary with detailed contribution analysis

class deepcausalmmm.core.inference.ModelInference(model, scaler, channel_names=None, control_names=None, **kwargs)[source]

Legacy compatibility wrapper for InferenceManager.

This class provides backward compatibility with existing code that uses the old ModelInference interface.

__init__(model, scaler, channel_names=None, control_names=None, **kwargs)[source]

Initialize with legacy interface.