deepcausalmmm.core.trainer

Reusable ModelTrainer class for DeepCausalMMM training. Eliminates code duplication and provides consistent training interface.

Classes

ModelTrainer([config])

Reusable trainer class for DeepCausalMMM models.

class deepcausalmmm.core.trainer.ModelTrainer(config: Dict[str, Any] | None = None)[source]

Reusable trainer class for DeepCausalMMM models.

This class provides a complete training pipeline for DeepCausalMMM models with advanced features including early stopping, learning rate scheduling, gradient clipping, and comprehensive logging. It supports both MSE and Huber loss functions with automatic device detection and mixed precision training.

Features: - Config-driven model initialization (zero hardcoding) - Automatic device detection (CPU/CUDA) - Multiple loss functions (MSE, Huber, optional Focal) - Early stopping with patience - Learning rate scheduling (StepLR, Cosine Annealing) - Gradient clipping (global and parameter-specific) - Comprehensive metrics tracking (RMSE, R², MAE) - Progress bars with detailed statistics - Holdout evaluation during training

Parameters:

config (Dict[str, Any], optional) – Configuration dictionary containing all training parameters. If None, uses default configuration from get_default_config().

model

The initialized model instance

Type:

DeepCausalMMM

optimizer

The optimizer (Adam by default)

Type:

torch.optim.Optimizer

scheduler

Learning rate scheduler if enabled

Type:

torch.optim.lr_scheduler._LRScheduler

device

Training device (CPU or CUDA)

Type:

torch.device

best_rmse

Best holdout RMSE achieved during training

Type:

float

train_losses

Training loss history

Type:

List[float]

train_rmses

Training RMSE history

Type:

List[float]

train_r2s

Training R² history

Type:

List[float]

Examples

>>> from deepcausalmmm.core.trainer import ModelTrainer
>>> from deepcausalmmm.core.config import get_default_config
>>>
>>> # Initialize trainer with custom config
>>> config = get_default_config()
>>> config['n_epochs'] = 1000
>>> config['learning_rate'] = 0.01
>>> trainer = ModelTrainer(config)
>>>
>>> # Train model (assumes processed_data is available)
>>> model, results = trainer.train(processed_data)
>>>
>>> # Access training history
>>> print(f"Final RMSE: {results['holdout_rmse']:.0f}")
>>> print(f"Final R²: {results['holdout_r2']:.3f}")
__init__(config: Dict[str, Any] | None = None)[source]

Initialize the trainer with configuration.

Parameters:

config – Configuration dictionary. If None, uses default config.

create_model(n_media: int, n_control: int, n_regions: int) DeepCausalMMM[source]

Create and initialize model from config with reproducible initialization.

Parameters:
  • n_media – Number of media channels

  • n_control – Number of control variables

  • n_regions – Number of regions

Returns:

Initialized DeepCausalMMM model

create_optimizer_and_scheduler()[source]

Create optimizer and learning rate scheduler from config.

warm_start_training(X_media: Tensor, X_control: Tensor, R: Tensor, y: Tensor, verbose: bool = True) None[source]

Perform warm-start training to stabilize GRU coefficients.

Parameters:
  • X_media – Media data tensor

  • X_control – Control data tensor

  • R – Region tensor

  • y – Target tensor

  • verbose – Whether to show progress

train_epoch(X_media: Tensor, X_control: Tensor, R: Tensor, y: Tensor) Tuple[float, float, float][source]

Train for one epoch.

Parameters:
  • X_media – Media data tensor

  • X_control – Control data tensor

  • R – Region tensor

  • y – Target tensor

Returns:

Tuple of (loss, rmse, r2)

evaluate_holdout(X_media: Tensor, X_control: Tensor, R: Tensor, y: Tensor) Tuple[float, float, float][source]

Evaluate model on holdout data.

Parameters:
  • X_media – Holdout media data tensor

  • X_control – Holdout control data tensor

  • R – Holdout region tensor

  • y – Holdout target tensor

Returns:

Tuple of (loss, rmse, r2)

should_stop_early(current_rmse: float) bool[source]

Check if training should stop early based on RMSE improvement.

Parameters:

current_rmse – Current epoch’s RMSE

Returns:

True if training should stop

train(X_media_train: Tensor, X_control_train: Tensor, R_train: Tensor, y_train: Tensor, X_media_holdout: Tensor | None = None, X_control_holdout: Tensor | None = None, R_holdout: Tensor | None = None, y_holdout: Tensor | None = None, pipeline: Any | None = None, verbose: bool = True) Dict[str, Any][source]

Full training loop with warm-start, main training, and holdout evaluation.

Parameters:
  • X_media_train – Training media data

  • X_control_train – Training control data

  • R_train – Training region data

  • y_train – Training target data

  • X_media_holdout – Optional holdout media data

  • X_control_holdout – Optional holdout control data

  • R_holdout – Optional holdout region data

  • y_holdout – Optional holdout target data

  • pipeline – Optional UnifiedDataPipeline for accessing scaler

  • verbose – Whether to show progress

Returns:

Dictionary with training results