deepcausalmmm.postprocess.analysis
Post-processing utilities for analyzing and visualizing DeepCausalMMM results.
Classes
|
Analyze and visualize DeepCausalMMM model results with modern class-based architecture. |
- class deepcausalmmm.postprocess.analysis.ModelAnalyzer(inference: InferenceManager | None = None, legacy_inference: ModelInference | None = None, scaler: SimpleGlobalScaler | None = None, pipeline=None, config: Dict | None = None, output_dir: str | None = None)[source]
Analyze and visualize DeepCausalMMM model results with modern class-based architecture.
- __init__(inference: InferenceManager | None = None, legacy_inference: ModelInference | None = None, scaler: SimpleGlobalScaler | None = None, pipeline=None, config: Dict | None = None, output_dir: str | None = None)[source]
Initialize the enhanced analyzer with unified pipeline support.
- Parameters:
inference – Modern InferenceManager instance (preferred)
legacy_inference – Legacy ModelInference instance (for compatibility)
scaler – SimpleGlobalScaler for proper inverse transformations (legacy)
pipeline – UnifiedDataPipeline instance (preferred)
config – Model configuration dictionary
output_dir – Directory to save outputs
- analyze_with_unified_pipeline(model, X_media: ndarray, X_control: ndarray, y_true: ndarray, channel_names: List[str], control_names: List[str]) Dict[str, Any][source]
Analyze model results using the unified pipeline.
- Parameters:
model – Trained model
X_media – Media data
X_control – Control data
y_true – True target values
channel_names – Media channel names
control_names – Control variable names
- Returns:
Analysis results dictionary
- analyze_predictions(X_m: ndarray, X_c: ndarray, R: ndarray, y_true: ndarray | None = None, generate_plots: bool = True) Dict[str, Any][source]
Analyze model predictions and generate visualizations.
- Parameters:
X_m – Media variables [n_regions, n_weeks, n_channels]
X_c – Control variables [n_regions, n_weeks, n_controls]
R – Region indices [n_regions] (reserved;
InferenceManagercurrently builds region indices internally)y_true – Optional ground truth values
generate_plots – Whether to generate and save plots
- Returns:
Dictionary containing analysis results and metrics
- static plot_coefficients_over_time(coefficients: ndarray, channel_names: List[str]) Figure[source]
Plot mean coefficients over time for each channel.
- static plot_contribution_comparison(predictions: ndarray, actuals: ndarray | None, burn_in_weeks: int) Figure[source]
Plot actual vs predicted revenue comparison.