Model Inference
Inference Manager
Modern InferenceManager class for DeepCausalMMM model inference. Provides a clean, reusable interface for model predictions and analysis.
- class deepcausalmmm.core.inference.InferenceManager(model: DeepCausalMMM, pipeline: UnifiedDataPipeline | None = None, scaler: SimpleGlobalScaler | None = None, config: Dict[str, Any] | None = None, channel_names: List[str] | None = None, control_names: List[str] | None = None)[source]
Bases:
objectModern class-based interface for DeepCausalMMM model inference.
Handles: - Model predictions on new data - Contribution analysis (media, control, baseline) - Coefficient extraction - Data preprocessing for inference - Inverse transformations for interpretable results
- __init__(model: DeepCausalMMM, pipeline: UnifiedDataPipeline | None = None, scaler: SimpleGlobalScaler | None = None, config: Dict[str, Any] | None = None, channel_names: List[str] | None = None, control_names: List[str] | None = None)[source]
Initialize the inference manager.
- Parameters:
model – Trained DeepCausalMMM model
pipeline – UnifiedDataPipeline used for training (preferred)
scaler – SimpleGlobalScaler used for training (legacy support)
config – Configuration dictionary
channel_names – List of media channel names
control_names – List of control variable names
- predict(X_media: ndarray, X_control: ndarray, return_contributions: bool = True, remove_padding: bool = True, return_media_coefficients: bool = False) Dict[str, ndarray][source]
Make predictions on new data.
- Parameters:
X_media – Media data [n_regions, n_weeks, n_media_channels]
X_control – Control data [n_regions, n_weeks, n_control_vars]
return_contributions – Whether to return contribution breakdowns
remove_padding – Whether to remove burn-in padding from results
return_media_coefficients – If True, include time-varying media coefficients (second tensor from
forward()) asmedia_coefficients.
- Returns:
Dictionary containing predictions and optionally contributions
- predict_and_inverse_transform(X_media: ndarray, X_control: ndarray, return_contributions: bool = True) Dict[str, ndarray][source]
Make predictions and apply inverse transformations for interpretable results.
- Parameters:
X_media – Media data [n_regions, n_weeks, n_media_channels]
X_control – Control data [n_regions, n_weeks, n_control_vars]
return_contributions – Whether to return contribution breakdowns
- Returns:
Dictionary containing predictions and contributions in original scale
- get_coefficients() Dict[str, ndarray][source]
Extract model coefficients.
- Returns:
Dictionary containing media and control coefficients
- get_dag_adjacency(threshold: bool = False, eps: float | None = None) ndarray | None[source]
Extract DAG adjacency matrix if available.
Uses the same mask +
dag_temperaturescaling as training. Setthreshold=True(or passeps) to prune weak edges vianotears_thresholdfrom config by default.- Parameters:
threshold – If True, zero entries below
eps.eps – Pruning cutoff; defaults to
config['notears_threshold'].
- Returns:
Adjacency matrix or None if DAG is not enabled
- analyze_contributions(X_media: ndarray, X_control: ndarray, aggregate_regions: bool = True, aggregate_time: bool = False) Dict[str, Any][source]
Comprehensive contribution analysis.
- Parameters:
X_media – Media data
X_control – Control data
aggregate_regions – Whether to aggregate across regions
aggregate_time – Whether to aggregate across time
- Returns:
Dictionary with detailed contribution analysis
- class deepcausalmmm.core.inference.ModelInference(model, scaler, channel_names=None, control_names=None, **kwargs)[source]
Bases:
InferenceManagerLegacy compatibility wrapper for InferenceManager.
This class provides backward compatibility with existing code that uses the old ModelInference interface.