deepcausalmmm.postprocess.analysis

Post-processing utilities for analyzing and visualizing DeepCausalMMM results.

Classes

ModelAnalyzer([inference, legacy_inference, ...])

Analyze and visualize DeepCausalMMM model results with modern class-based architecture.

class deepcausalmmm.postprocess.analysis.ModelAnalyzer(inference: InferenceManager | None = None, legacy_inference: ModelInference | None = None, scaler: SimpleGlobalScaler | None = None, pipeline=None, config: Dict | None = None, output_dir: str | None = None)[source]

Analyze and visualize DeepCausalMMM model results with modern class-based architecture.

__init__(inference: InferenceManager | None = None, legacy_inference: ModelInference | None = None, scaler: SimpleGlobalScaler | None = None, pipeline=None, config: Dict | None = None, output_dir: str | None = None)[source]

Initialize the enhanced analyzer with unified pipeline support.

Parameters:
  • inference – Modern InferenceManager instance (preferred)

  • legacy_inference – Legacy ModelInference instance (for compatibility)

  • scaler – SimpleGlobalScaler for proper inverse transformations (legacy)

  • pipeline – UnifiedDataPipeline instance (preferred)

  • config – Model configuration dictionary

  • output_dir – Directory to save outputs

analyze_with_unified_pipeline(model, X_media: ndarray, X_control: ndarray, y_true: ndarray, channel_names: List[str], control_names: List[str]) Dict[str, Any][source]

Analyze model results using the unified pipeline.

Parameters:
  • model – Trained model

  • X_media – Media data

  • X_control – Control data

  • y_true – True target values

  • channel_names – Media channel names

  • control_names – Control variable names

Returns:

Analysis results dictionary

analyze_predictions(X_m: ndarray, X_c: ndarray, R: ndarray, y_true: ndarray | None = None, generate_plots: bool = True) Dict[str, Any][source]

Analyze model predictions and generate visualizations.

Parameters:
  • X_m – Media variables [n_regions, n_weeks, n_channels]

  • X_c – Control variables [n_regions, n_weeks, n_controls]

  • R – Region indices [n_regions] (reserved; InferenceManager currently builds region indices internally)

  • y_true – Optional ground truth values

  • generate_plots – Whether to generate and save plots

Returns:

Dictionary containing analysis results and metrics

static plot_coefficients_over_time(coefficients: ndarray, channel_names: List[str]) Figure[source]

Plot mean coefficients over time for each channel.

static plot_contribution_comparison(predictions: ndarray, actuals: ndarray | None, burn_in_weeks: int) Figure[source]

Plot actual vs predicted revenue comparison.

static plot_waterfall_chart(marketing_contribs: ndarray, control_contribs: ndarray | None, baseline: ndarray | None, channel_names: List[str], control_names: List[str]) Figure[source]

Create waterfall chart of contributions.

static plot_contribution_donut(marketing_contribs: ndarray, control_contribs: ndarray | None, baseline: ndarray | None) Figure[source]

Create donut chart of contribution percentages.