"""
Modern InferenceManager class for DeepCausalMMM model inference.
Provides a clean, reusable interface for model predictions and analysis.
"""
import torch
import numpy as np
from typing import Dict, Any, List, Optional, Tuple
import warnings
from deepcausalmmm.core.unified_model import DeepCausalMMM
from deepcausalmmm.core.data import UnifiedDataPipeline
from deepcausalmmm.core.scaling import SimpleGlobalScaler
from deepcausalmmm.core.config import get_default_config
from deepcausalmmm.utils.device import get_device
[docs]
class InferenceManager:
"""
Modern class-based interface for DeepCausalMMM model inference.
Handles:
- Model predictions on new data
- Contribution analysis (media, control, baseline)
- Coefficient extraction
- Data preprocessing for inference
- Inverse transformations for interpretable results
"""
[docs]
def __init__(self,
model: DeepCausalMMM,
pipeline: Optional[UnifiedDataPipeline] = None,
scaler: Optional[SimpleGlobalScaler] = None,
config: Optional[Dict[str, Any]] = None,
channel_names: Optional[List[str]] = None,
control_names: Optional[List[str]] = None):
"""
Initialize the inference manager.
Args:
model: Trained DeepCausalMMM model
pipeline: UnifiedDataPipeline used for training (preferred)
scaler: SimpleGlobalScaler used for training (legacy support)
config: Configuration dictionary
channel_names: List of media channel names
control_names: List of control variable names
"""
self.model = model
self.pipeline = pipeline
self.scaler = scaler or (pipeline.scaler if pipeline else None)
self.config = config or get_default_config()
self.channel_names = channel_names or []
self.control_names = control_names or []
# Set device
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = self.model.to(self.device)
self.model.eval()
# Validate inputs
if self.pipeline is None and self.scaler is None:
warnings.warn(
"Neither pipeline nor scaler provided. "
"Inference may not work correctly without proper data preprocessing.",
UserWarning
)
[docs]
def predict(self,
X_media: np.ndarray,
X_control: np.ndarray,
return_contributions: bool = True,
remove_padding: bool = True,
return_media_coefficients: bool = False) -> Dict[str, np.ndarray]:
"""
Make predictions on new data.
Args:
X_media: Media data [n_regions, n_weeks, n_media_channels]
X_control: Control data [n_regions, n_weeks, n_control_vars]
return_contributions: Whether to return contribution breakdowns
remove_padding: Whether to remove burn-in padding from results
return_media_coefficients: If True, include time-varying media coefficients
(second tensor from ``forward()``) as ``media_coefficients``.
Returns:
Dictionary containing predictions and optionally contributions
"""
# Preprocess data
if self.pipeline is not None:
# Use modern pipeline approach
X_media_processed, X_control_processed = self._preprocess_with_pipeline(
X_media, X_control
)
elif self.scaler is not None:
# Use legacy scaler approach
X_media_processed, X_control_processed = self._preprocess_with_scaler(
X_media, X_control
)
else:
# No preprocessing - assume data is already processed
X_media_processed = torch.FloatTensor(X_media).to(self.device)
X_control_processed = torch.FloatTensor(X_control).to(self.device)
# Create region tensor
n_regions = X_media_processed.shape[0]
R = torch.zeros(n_regions, dtype=torch.long).to(self.device)
# Make predictions (single forward; unpack matches DeepCausalMMM.forward)
with torch.no_grad():
predictions, media_coeffs, media_contributions, outputs = self.model(
X_media_processed, X_control_processed, R
)
results: Dict[str, np.ndarray] = {
'predictions': predictions.cpu().numpy()
}
if return_contributions:
results['media_contributions'] = media_contributions.cpu().numpy()
results['control_contributions'] = outputs['control_contributions'].cpu().numpy()
if return_media_coefficients:
results['media_coefficients'] = media_coeffs.cpu().numpy()
# Remove padding if requested
if remove_padding and hasattr(self.model, 'burn_in_weeks'):
burn_in = self.model.burn_in_weeks
if burn_in > 0:
for key in results:
if results[key] is not None and results[key].shape[1] > burn_in:
results[key] = results[key][:, burn_in:, ...]
return results
[docs]
def get_coefficients(self) -> Dict[str, np.ndarray]:
"""
Extract model coefficients.
Returns:
Dictionary containing media and control coefficients
"""
coefficients = {}
# Extract media coefficients
if hasattr(self.model, 'coeff_gen'):
with torch.no_grad():
# Get current coefficients
media_coeffs = self.model.coeff_gen(
torch.zeros(1, 1, self.model.n_media).to(self.device)
)
coefficients['media'] = media_coeffs.cpu().numpy()
# Extract control coefficients
if hasattr(self.model, 'ctrl_coeff_gen'):
with torch.no_grad():
control_coeffs = self.model.ctrl_coeff_gen(
torch.zeros(1, 1, self.model.ctrl_dim).to(self.device)
)
coefficients['control'] = control_coeffs.cpu().numpy()
# Extract stable coefficients if available
if hasattr(self.model, 'stable_media_coeff'):
coefficients['stable_media'] = self.model.stable_media_coeff.detach().cpu().numpy()
if hasattr(self.model, 'stable_ctrl_coeff'):
coefficients['stable_control'] = self.model.stable_ctrl_coeff.detach().cpu().numpy()
return coefficients
[docs]
def get_dag_adjacency(self,
threshold: bool = False,
eps: Optional[float] = None) -> Optional[np.ndarray]:
"""
Extract DAG adjacency matrix if available.
Uses the same mask + ``dag_temperature`` scaling as training. Set
``threshold=True`` (or pass ``eps``) to prune weak edges via
``notears_threshold`` from config by default.
Args:
threshold: If True, zero entries below ``eps``.
eps: Pruning cutoff; defaults to ``config['notears_threshold']``.
Returns:
Adjacency matrix or None if DAG is not enabled
"""
if not (hasattr(self.model, 'get_dag_adjacency_matrix')
and self.model.enable_dag):
return None
with torch.no_grad():
if threshold or eps is not None:
prune_eps = (eps if eps is not None
else float(self.config.get('notears_threshold', 0.3)))
W = self.model.get_dag_adjacency_matrix(eps=prune_eps)
else:
W = self.model.get_dag_adjacency_matrix(eps=None)
return W.cpu().numpy()
[docs]
def analyze_contributions(self,
X_media: np.ndarray,
X_control: np.ndarray,
aggregate_regions: bool = True,
aggregate_time: bool = False) -> Dict[str, Any]:
"""
Comprehensive contribution analysis.
Args:
X_media: Media data
X_control: Control data
aggregate_regions: Whether to aggregate across regions
aggregate_time: Whether to aggregate across time
Returns:
Dictionary with detailed contribution analysis
"""
# Get predictions and contributions
results = self.predict_and_inverse_transform(X_media, X_control, return_contributions=True)
analysis = {
'total_predictions': results['predictions'],
'media_contributions': results.get('media_contributions'),
'control_contributions': results.get('control_contributions'),
'baseline': results.get('baseline', 0)
}
# Aggregate if requested
if aggregate_regions:
for key in ['total_predictions', 'media_contributions', 'control_contributions']:
if analysis[key] is not None:
analysis[f'{key}_by_region'] = np.mean(analysis[key], axis=0)
if aggregate_time:
for key in ['total_predictions', 'media_contributions', 'control_contributions']:
if analysis[key] is not None:
analysis[f'{key}_by_time'] = np.mean(analysis[key], axis=1)
# Calculate contribution percentages
if analysis['media_contributions'] is not None:
total_media = np.sum(analysis['media_contributions'], axis=-1, keepdims=True)
total_pred = analysis['total_predictions']
analysis['media_contribution_pct'] = (
analysis['media_contributions'] / (total_pred[..., None] + 1e-8) * 100
)
return analysis
def _preprocess_with_pipeline(self, X_media: np.ndarray, X_control: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]:
"""Preprocess data using UnifiedDataPipeline."""
# Add seasonality features
X_control_with_seasonality = self.pipeline._add_seasonality_features(
torch.tensor(X_control), start_week=0
)
# Scale data
X_media_scaled, X_control_scaled, _ = self.pipeline.scaler.transform(
X_media, X_control_with_seasonality.numpy(), np.zeros_like(X_media[:, :, 0])
)
# Add padding
X_media_padded, X_control_padded, _ = self.pipeline._add_padding(
X_media_scaled, X_control_scaled, torch.zeros(X_media_scaled.shape[0], X_media_scaled.shape[1])
)
return X_media_padded.to(self.device), X_control_padded.to(self.device)
def _preprocess_with_scaler(self, X_media: np.ndarray, X_control: np.ndarray) -> Tuple[torch.Tensor, torch.Tensor]:
"""Preprocess data using legacy SimpleGlobalScaler."""
# Scale data
X_media_scaled, X_control_scaled, _ = self.scaler.transform(
X_media, X_control, np.zeros_like(X_media[:, :, 0])
)
# Convert to tensors
X_media_tensor = torch.FloatTensor(X_media_scaled).to(self.device)
X_control_tensor = torch.FloatTensor(X_control_scaled).to(self.device)
return X_media_tensor, X_control_tensor
# Legacy compatibility class
[docs]
class ModelInference(InferenceManager):
"""
Legacy compatibility wrapper for InferenceManager.
This class provides backward compatibility with existing code
that uses the old ModelInference interface.
"""
[docs]
def __init__(self, model, scaler, channel_names=None, control_names=None, **kwargs):
"""Initialize with legacy interface."""
warnings.warn(
"ModelInference is deprecated. Please use InferenceManager instead.",
DeprecationWarning,
stacklevel=2
)
super().__init__(
model=model,
scaler=scaler,
channel_names=channel_names,
control_names=control_names,
**kwargs
)