"""
Post-processing utilities for analyzing and visualizing DeepCausalMMM results.
"""
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from typing import Dict, List, Optional, Any
import os
import logging
logger = logging.getLogger('deepcausalmmm')
from deepcausalmmm.core.inference import InferenceManager, ModelInference # ModelInference for legacy compatibility
from deepcausalmmm.core.scaling import SimpleGlobalScaler
from deepcausalmmm.core.config import get_default_config
import torch
[docs]
class ModelAnalyzer:
"""Analyze and visualize DeepCausalMMM model results with modern class-based architecture."""
[docs]
def __init__(
self,
inference: Optional[InferenceManager] = None, # Modern InferenceManager
legacy_inference: Optional[ModelInference] = None, # Legacy compatibility
scaler: Optional[SimpleGlobalScaler] = None, # Legacy compatibility
pipeline = None, # UnifiedDataPipeline instance
config: Optional[Dict] = None,
output_dir: Optional[str] = None
):
"""
Initialize the enhanced analyzer with unified pipeline support.
Args:
inference: Modern InferenceManager instance (preferred)
legacy_inference: Legacy ModelInference instance (for compatibility)
scaler: SimpleGlobalScaler for proper inverse transformations (legacy)
pipeline: UnifiedDataPipeline instance (preferred)
config: Model configuration dictionary
output_dir: Directory to save outputs
"""
# Use modern InferenceManager if available, otherwise fall back to legacy
self.inference = inference or legacy_inference
self.scaler = scaler # Legacy support
self.pipeline = pipeline # Unified pipeline support
self.config = config or get_default_config()
self.output_dir = output_dir
if output_dir:
os.makedirs(output_dir, exist_ok=True)
[docs]
def analyze_with_unified_pipeline(
self,
model,
X_media: np.ndarray,
X_control: np.ndarray,
y_true: np.ndarray,
channel_names: List[str],
control_names: List[str]
) -> Dict[str, Any]:
"""
Analyze model results using the unified pipeline.
Args:
model: Trained model
X_media: Media data
X_control: Control data
y_true: True target values
channel_names: Media channel names
control_names: Control variable names
Returns:
Analysis results dictionary
"""
if self.pipeline is None:
raise ValueError("UnifiedDataPipeline is required for this method")
logger.info(f"\n ModelAnalyzer: Unified Pipeline Analysis")
# Get predictions and contributions
results = self.pipeline.predict_and_postprocess(
model=model,
X_media=X_media,
X_control=X_control,
channel_names=channel_names,
control_names=control_names,
combine_with_holdout=True
)
# Calculate metrics
metrics = self.pipeline.calculate_metrics(
y_true, results['predictions'], prefix='pipeline_'
)
# Combine results
analysis_results = {
**results,
**metrics,
'y_true': y_true
}
logger.info(f" Pipeline analysis complete")
logger.info(f" R²: {metrics['pipeline_r2']:.3f}")
logger.info(f" RMSE: {metrics['pipeline_rmse']:.0f}")
return analysis_results
[docs]
def analyze_predictions(
self,
X_m: np.ndarray,
X_c: np.ndarray,
R: np.ndarray,
y_true: Optional[np.ndarray] = None,
generate_plots: bool = True
) -> Dict[str, Any]:
"""
Analyze model predictions and generate visualizations.
Args:
X_m: Media variables [n_regions, n_weeks, n_channels]
X_c: Control variables [n_regions, n_weeks, n_controls]
R: Region indices [n_regions] (reserved; ``InferenceManager`` currently
builds region indices internally)
y_true: Optional ground truth values
generate_plots: Whether to generate and save plots
Returns:
Dictionary containing analysis results and metrics
"""
if self.inference is None:
raise ValueError("InferenceManager (or legacy ModelInference) is required for analyze_predictions")
_ = R # API compatibility; not yet passed through to InferenceManager.predict
results = self.inference.predict(
np.asarray(X_m, dtype=np.float32),
np.asarray(X_c, dtype=np.float32),
return_contributions=True,
remove_padding=True,
return_media_coefficients=True,
)
if y_true is not None:
results['actual_revenue'] = np.asarray(y_true, dtype=np.float32)
# Generate plots if requested
if generate_plots:
self._generate_plots(results)
return results
def _generate_plots(self, results: Dict[str, Any]) -> None:
"""Generate all visualization plots."""
burn_in_weeks = int(getattr(self.inference.model, 'burn_in_weeks', 0))
media_coefficients = results.get('media_coefficients') or results.get('coefficients')
if media_coefficients is None:
raise ValueError(
"Expected 'media_coefficients' in results; ensure predict(..., return_media_coefficients=True)"
)
# 1. Coefficients over time
coeff_fig = self.plot_coefficients_over_time(
media_coefficients,
self.inference.channel_names
)
# 2. Actual vs Predicted
comparison_fig = self.plot_contribution_comparison(
results['predictions'],
results.get('actual_revenue'), # May be None
burn_in_weeks
)
# 3. Waterfall chart
waterfall_fig = self.plot_waterfall_chart(
results['media_contributions'],
results.get('control_contributions'),
results.get('baseline'),
self.inference.channel_names,
self.inference.control_names
)
# 4. Contribution donut
donut_fig = self.plot_contribution_donut(
results['media_contributions'],
results.get('control_contributions'),
results.get('baseline')
)
# Save plots if output directory is specified
if self.output_dir:
coeff_fig.write_html(os.path.join(self.output_dir, 'coefficients.html'))
comparison_fig.write_html(os.path.join(self.output_dir, 'comparison.html'))
waterfall_fig.write_html(os.path.join(self.output_dir, 'waterfall.html'))
donut_fig.write_html(os.path.join(self.output_dir, 'donut.html'))
# Display plots - Commented out to prevent browser popups
# coeff_fig.show()
# comparison_fig.show()
# waterfall_fig.show()
# donut_fig.show()
[docs]
@staticmethod
def plot_coefficients_over_time(
coefficients: np.ndarray,
channel_names: List[str]
) -> go.Figure:
"""Plot mean coefficients over time for each channel."""
mean_coeffs = coefficients.mean(axis=0) # Average across regions
fig = go.Figure()
for i, name in enumerate(channel_names):
fig.add_trace(
go.Scatter(
y=mean_coeffs[:, i],
name=name,
mode='lines'
)
)
fig.update_layout(
title="Channel Coefficients Over Time",
xaxis_title="Week",
yaxis_title="Coefficient Value",
showlegend=True
)
return fig
[docs]
@staticmethod
def plot_contribution_comparison(
predictions: np.ndarray,
actuals: Optional[np.ndarray],
burn_in_weeks: int
) -> go.Figure:
"""Plot actual vs predicted revenue comparison."""
fig = make_subplots(
rows=1, cols=2,
subplot_titles=('Time Series Comparison', 'Scatter Plot'),
specs=[[{'type': 'scatter'}, {'type': 'scatter'}]]
)
# Sum across regions
total_pred = predictions.sum(axis=0)[burn_in_weeks:]
if actuals is not None:
total_actual = actuals.sum(axis=0)[burn_in_weeks:]
# Time series plot
fig.add_trace(
go.Scatter(
y=total_actual,
name='Actual Revenue',
mode='lines',
line=dict(color='blue')
),
row=1, col=1
)
# Scatter plot
fig.add_trace(
go.Scatter(
x=total_actual,
y=total_pred,
mode='markers',
name='Actual vs Predicted',
marker=dict(color='green')
),
row=1, col=2
)
# Add 45-degree line
min_val = min(total_actual.min(), total_pred.min())
max_val = max(total_actual.max(), total_pred.max())
fig.add_trace(
go.Scatter(
x=[min_val, max_val],
y=[min_val, max_val],
mode='lines',
name='45° line',
line=dict(dash='dash', color='gray')
),
row=1, col=2
)
# Add predicted line
fig.add_trace(
go.Scatter(
y=total_pred,
name='Predicted Revenue',
mode='lines',
line=dict(color='red')
),
row=1, col=1
)
fig.update_layout(
title="Actual vs Predicted Revenue Comparison",
showlegend=True,
height=500
)
return fig
[docs]
@staticmethod
def plot_waterfall_chart(
marketing_contribs: np.ndarray,
control_contribs: Optional[np.ndarray],
baseline: Optional[np.ndarray],
channel_names: List[str],
control_names: List[str]
) -> go.Figure:
"""Create waterfall chart of contributions."""
# Calculate mean contributions
mean_marketing = marketing_contribs.mean(axis=(0, 1)) # Average across regions and time
measures = ['relative'] * len(channel_names)
values = list(mean_marketing)
labels = channel_names.copy()
if control_contribs is not None:
mean_control = control_contribs.mean(axis=(0, 1))
measures.extend(['relative'] * len(control_names))
values.extend(mean_control)
labels.extend(control_names)
if baseline is not None:
mean_baseline = baseline.mean()
measures.append('total')
values.append(mean_baseline)
labels.append('Baseline')
fig = go.Figure(go.Waterfall(
name="Contribution Breakdown",
orientation="v",
measure=measures,
x=labels,
y=values,
connector={"line": {"color": "rgb(63, 63, 63)"}},
))
fig.update_layout(
title="Contribution Waterfall Chart",
showlegend=True,
xaxis_title="Components",
yaxis_title="Contribution"
)
return fig
[docs]
@staticmethod
def plot_contribution_donut(
marketing_contribs: np.ndarray,
control_contribs: Optional[np.ndarray],
baseline: Optional[np.ndarray]
) -> go.Figure:
"""Create donut chart of contribution percentages."""
# Calculate total contributions
total_marketing = marketing_contribs.sum()
total_control = control_contribs.sum() if control_contribs is not None else 0
total_baseline = baseline.sum() if baseline is not None else 0
total = total_marketing + total_control + total_baseline
values = [
(total_marketing / total) * 100,
(total_control / total) * 100 if control_contribs is not None else 0,
(total_baseline / total) * 100 if baseline is not None else 0
]
labels = ['Marketing', 'Control', 'Baseline']
fig = go.Figure(data=[go.Pie(
labels=labels,
values=values,
hole=.4,
textinfo='label+percent',
textposition='outside'
)])
fig.update_layout(
title="Contribution Split",
annotations=[dict(text='Total', x=0.5, y=0.5, font_size=20, showarrow=False)]
)
return fig