deepcausalmmm.core.trainer
Reusable ModelTrainer class for DeepCausalMMM training. Eliminates code duplication and provides consistent training interface.
Classes
|
Reusable trainer class for DeepCausalMMM models. |
- class deepcausalmmm.core.trainer.ModelTrainer(config: Dict[str, Any] | None = None)[source]
Reusable trainer class for DeepCausalMMM models.
This class provides a complete training pipeline for DeepCausalMMM models with advanced features including early stopping, learning rate scheduling, gradient clipping, and comprehensive logging. It supports both MSE and Huber loss functions with automatic device detection and mixed precision training.
Features: - Config-driven model initialization (zero hardcoding) - Automatic device detection (CPU/CUDA) - Multiple loss functions (MSE, Huber, optional Focal) - Early stopping with patience - Learning rate scheduling (StepLR, Cosine Annealing) - Gradient clipping (global and parameter-specific) - Comprehensive metrics tracking (RMSE, R², MAE) - Progress bars with detailed statistics - Holdout evaluation during training
- Parameters:
config (Dict[str, Any], optional) – Configuration dictionary containing all training parameters. If None, uses default configuration from get_default_config().
- model
The initialized model instance
- Type:
- optimizer
The optimizer (Adam by default)
- Type:
- scheduler
Learning rate scheduler if enabled
- Type:
torch.optim.lr_scheduler._LRScheduler
- device
Training device (CPU or CUDA)
- Type:
Examples
>>> from deepcausalmmm.core.trainer import ModelTrainer >>> from deepcausalmmm.core.config import get_default_config >>> >>> # Initialize trainer with custom config >>> config = get_default_config() >>> config['n_epochs'] = 1000 >>> config['learning_rate'] = 0.01 >>> trainer = ModelTrainer(config) >>> >>> # Train model (assumes processed_data is available) >>> model, results = trainer.train(processed_data) >>> >>> # Access training history >>> print(f"Final RMSE: {results['holdout_rmse']:.0f}") >>> print(f"Final R²: {results['holdout_r2']:.3f}")
- __init__(config: Dict[str, Any] | None = None)[source]
Initialize the trainer with configuration.
- Parameters:
config – Configuration dictionary. If None, uses default config.
- create_model(n_media: int, n_control: int, n_regions: int) DeepCausalMMM[source]
Create and initialize model from config with reproducible initialization.
- Parameters:
n_media – Number of media channels
n_control – Number of control variables
n_regions – Number of regions
- Returns:
Initialized DeepCausalMMM model
- warm_start_training(X_media: Tensor, X_control: Tensor, R: Tensor, y: Tensor, verbose: bool = True) None[source]
Perform warm-start training to stabilize GRU coefficients.
- Parameters:
X_media – Media data tensor
X_control – Control data tensor
R – Region tensor
y – Target tensor
verbose – Whether to show progress
- train_epoch(X_media: Tensor, X_control: Tensor, R: Tensor, y: Tensor) Tuple[float, float, float][source]
Train for one epoch.
- Parameters:
X_media – Media data tensor
X_control – Control data tensor
R – Region tensor
y – Target tensor
- Returns:
Tuple of (loss, rmse, r2)
- evaluate_holdout(X_media: Tensor, X_control: Tensor, R: Tensor, y: Tensor) Tuple[float, float, float][source]
Evaluate model on holdout data.
- Parameters:
X_media – Holdout media data tensor
X_control – Holdout control data tensor
R – Holdout region tensor
y – Holdout target tensor
- Returns:
Tuple of (loss, rmse, r2)
- should_stop_early(current_rmse: float) bool[source]
Check if training should stop early based on RMSE improvement.
- Parameters:
current_rmse – Current epoch’s RMSE
- Returns:
True if training should stop
- train(X_media_train: Tensor, X_control_train: Tensor, R_train: Tensor, y_train: Tensor, X_media_holdout: Tensor | None = None, X_control_holdout: Tensor | None = None, R_holdout: Tensor | None = None, y_holdout: Tensor | None = None, pipeline: Any | None = None, verbose: bool = True) Dict[str, Any][source]
Full training loop with warm-start, main training, and holdout evaluation.
- Parameters:
X_media_train – Training media data
X_control_train – Training control data
R_train – Training region data
y_train – Training target data
X_media_holdout – Optional holdout media data
X_control_holdout – Optional holdout control data
R_holdout – Optional holdout region data
y_holdout – Optional holdout target data
pipeline – Optional UnifiedDataPipeline for accessing scaler
verbose – Whether to show progress
- Returns:
Dictionary with training results