Source code for deepcausalmmm.core.trainer

"""
Reusable ModelTrainer class for DeepCausalMMM training.
Eliminates code duplication and provides consistent training interface.
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts, LambdaLR
import numpy as np
from typing import Dict, Any, Tuple, List, Optional
from tqdm import tqdm
from sklearn.metrics import mean_squared_error, r2_score

from deepcausalmmm.core.unified_model import DeepCausalMMM
from deepcausalmmm.core.config import get_default_config
from deepcausalmmm.utils.device import get_device

import logging

logger = logging.getLogger('deepcausalmmm')

[docs] class ModelTrainer: """ Reusable trainer class for DeepCausalMMM models. This class provides a complete training pipeline for DeepCausalMMM models with advanced features including early stopping, learning rate scheduling, gradient clipping, and comprehensive logging. It supports both MSE and Huber loss functions with automatic device detection and mixed precision training. Features: - Config-driven model initialization (zero hardcoding) - Automatic device detection (CPU/CUDA) - Multiple loss functions (MSE, Huber, optional Focal) - Early stopping with patience - Learning rate scheduling (StepLR, Cosine Annealing) - Gradient clipping (global and parameter-specific) - Comprehensive metrics tracking (RMSE, R², MAE) - Progress bars with detailed statistics - Holdout evaluation during training Parameters ---------- config : Dict[str, Any], optional Configuration dictionary containing all training parameters. If None, uses default configuration from get_default_config(). Attributes ---------- model : DeepCausalMMM The initialized model instance optimizer : torch.optim.Optimizer The optimizer (Adam by default) scheduler : torch.optim.lr_scheduler._LRScheduler Learning rate scheduler if enabled device : torch.device Training device (CPU or CUDA) best_rmse : float Best holdout RMSE achieved during training train_losses : List[float] Training loss history train_rmses : List[float] Training RMSE history train_r2s : List[float] Training R² history Examples -------- >>> from deepcausalmmm.core.trainer import ModelTrainer >>> from deepcausalmmm.core.config import get_default_config >>> >>> # Initialize trainer with custom config >>> config = get_default_config() >>> config['n_epochs'] = 1000 >>> config['learning_rate'] = 0.01 >>> trainer = ModelTrainer(config) >>> >>> # Train model (assumes processed_data is available) >>> model, results = trainer.train(processed_data) >>> >>> # Access training history >>> print(f"Final RMSE: {results['holdout_rmse']:.0f}") >>> print(f"Final R²: {results['holdout_r2']:.3f}") """
[docs] def __init__(self, config: Optional[Dict[str, Any]] = None): """ Initialize the trainer with configuration. Args: config: Configuration dictionary. If None, uses default config. """ self.config = config or get_default_config() self.model = None self.optimizer = None self.scheduler = None self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Revert to original device handling # Note: Seeds should be set by main script for global reproducibility # Model creation will use current RNG state # Training state self.best_rmse = float('inf') self.patience_counter = 0 self.train_losses = [] self.train_rmses = [] self.train_r2s = []
@staticmethod def _trim_burn_in_time(tensor: torch.Tensor, burn_in_weeks: int) -> torch.Tensor: """Drop leading burn-in timesteps from ``[batch, time, ...]`` tensors.""" if tensor is None or burn_in_weeks <= 0: return tensor if tensor.dim() < 2 or tensor.shape[1] <= burn_in_weeks: return tensor return tensor[:, burn_in_weeks:]
[docs] def create_model(self, n_media: int, n_control: int, n_regions: int) -> DeepCausalMMM: """ Create and initialize model from config with reproducible initialization. Args: n_media: Number of media channels n_control: Number of control variables n_regions: Number of regions Returns: Initialized DeepCausalMMM model """ self.model = DeepCausalMMM( n_media=n_media, ctrl_dim=n_control, n_regions=n_regions, hidden=self.config.get('hidden_dim', 64), dropout=self.config.get('dropout', 0.1), l1_weight=self.config.get('l1_weight', 0.001), l2_weight=self.config.get('l2_weight', 0.001), burn_in_weeks=self.config.get('burn_in_weeks', 4), momentum_decay=self.config.get('momentum_decay', 0.9), warm_start_epochs=self.config.get('warm_start_epochs', 50), enable_dag=self.config.get('enable_dag', True), enable_interactions=self.config.get('enable_interactions', True), # COEFFICIENT REGULARIZATION: Pass parameters to prevent explosion coeff_l2_weight=self.config.get('coeff_l2_weight', 0.1), coeff_gen_l2_weight=self.config.get('coeff_gen_l2_weight', 0.05), # NOTEARS DAG learning (default 'triangular' preserves prior behaviour) dag_mode=self.config.get('dag_mode', 'triangular'), notears_lambda1=self.config.get('notears_lambda1', 0.005), notears_rho_init=self.config.get('notears_rho_init', 1.0), notears_alpha_init=self.config.get('notears_alpha_init', 0.0), notears_rho_max=self.config.get('notears_rho_max', 1e16), dag_temperature=self.config.get('dag_temperature', 1.0), notears_group_l1=self.config.get('notears_group_l1', 0.0), ).to(self.device) return self.model
[docs] def create_optimizer_and_scheduler(self): """Create optimizer and learning rate scheduler from config.""" if self.model is None: raise ValueError("Model must be created before optimizer") # Get optimizer config opt_config = self.config.get('optimizer', {}) # Create optimizer self.optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.config.get('learning_rate', 0.001), betas=opt_config.get('betas', (0.9, 0.999)), eps=opt_config.get('eps', 1e-8), weight_decay=opt_config.get('weight_decay', 1e-5) ) # Create advanced scheduler based on config if self.config.get('use_cosine_annealing', False): # Cosine Annealing with Warm Restarts for better convergence self.scheduler = CosineAnnealingWarmRestarts( self.optimizer, T_0=self.config.get('cosine_t_initial', 500), T_mult=int(self.config.get('cosine_t_mult', 1.2)), eta_min=self.config.get('cosine_eta_min', 1e-6) ) # Add warmup scheduler if specified warmup_epochs = self.config.get('warmup_epochs', 0) if warmup_epochs > 0: def lr_lambda(epoch): if epoch < warmup_epochs: return epoch / warmup_epochs return 1.0 self.warmup_scheduler = LambdaLR(self.optimizer, lr_lambda) else: self.warmup_scheduler = None else: # Default ReduceLROnPlateau scheduler scheduler_config = self.config.get('scheduler', {}) self.scheduler = ReduceLROnPlateau( self.optimizer, mode='min', patience=scheduler_config.get('patience', 300), factor=scheduler_config.get('factor', 0.8), min_lr=scheduler_config.get('min_lr', 1e-8) ) self.warmup_scheduler = None
[docs] def warm_start_training(self, X_media: torch.Tensor, X_control: torch.Tensor, R: torch.Tensor, y: torch.Tensor, verbose: bool = True) -> None: """ Perform warm-start training to stabilize GRU coefficients. Args: X_media: Media data tensor X_control: Control data tensor R: Region tensor y: Target tensor verbose: Whether to show progress """ if self.model is None or self.optimizer is None: raise ValueError("Model and optimizer must be created before training") warm_epochs = self.config.get('warm_start_epochs', 50) if warm_epochs <= 0: return if verbose: logger.info(f"\nWarm-start Training ({warm_epochs} epochs)...") # Create separate optimizer for warm-start with lower learning rate warm_optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.config['learning_rate'] * 0.01, # Reduced LR for warm-start weight_decay=self.config.get('optimizer', {}).get('weight_decay', 1e-7) ) self.model.train() pbar = tqdm(range(warm_epochs), desc="Warm-start", leave=False) if verbose else range(warm_epochs) for epoch in pbar: warm_optimizer.zero_grad() # Forward pass predictions, media_coeffs, media_contributions, outputs = self.model(X_media, X_control, R) # Compute loss mse_loss = nn.MSELoss()(predictions, y) dag_loss = self.model.get_dag_loss() if hasattr(self.model, 'get_dag_loss') else 0 sparsity_loss = self.model.get_sparsity_loss() if hasattr(self.model, 'get_sparsity_loss') else 0 # NEW: Attribution prior regularization with DYNAMIC SCALING attribution_reg_raw = outputs.get('attribution_reg_loss_raw', torch.tensor(0.0, device=predictions.device)) if attribution_reg_raw.item() > 0: # Scale attribution loss to match MSE magnitude (multi-task learning approach) mse_scale = mse_loss.detach() # Don't backprop through scaling factor attribution_reg_scaled = attribution_reg_raw * (mse_scale / (attribution_reg_raw.detach() + 1e-8)) # Apply weight (0.5 = equal priority to prediction and attribution) attribution_reg = self.model.attribution_reg_weight * attribution_reg_scaled else: attribution_reg = torch.tensor(0.0, device=predictions.device) # NEW: Seasonal regularization with DYNAMIC SCALING seasonal_reg_raw = outputs.get('seasonal_reg_loss_raw', torch.tensor(0.0, device=predictions.device)) if seasonal_reg_raw.item() > 0: # Scale seasonal loss to match MSE magnitude (same approach as attribution) mse_scale = mse_loss.detach() seasonal_reg_scaled = seasonal_reg_raw * (mse_scale / (seasonal_reg_raw.detach() + 1e-8)) # Apply weight (0.1 = lower priority than attribution) seasonal_reg = self.model.seasonal_reg_weight * seasonal_reg_scaled else: seasonal_reg = torch.tensor(0.0, device=predictions.device) # Add L1 and L2 regularization l1_reg = sum(torch.sum(torch.abs(param)) for param in self.model.parameters()) l2_reg = sum(torch.sum(param ** 2) for param in self.model.parameters()) total_loss = (mse_loss + self.config.get('dag_weight', 1.0) * dag_loss + self.config.get('sparsity_weight', 0.1) * sparsity_loss + self.config.get('l1_weight', 0.0) * l1_reg + self.config.get('l2_weight', 0.0) * l2_reg + attribution_reg + # Media attribution prior (dynamically scaled) seasonal_reg) # Seasonal regularization (prevent suppression) # Backward pass total_loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.get('max_grad_norm', 1.0)) warm_optimizer.step() if verbose and isinstance(pbar, tqdm): pbar.set_postfix({'Loss': f'{total_loss.item():.4f}'})
[docs] def train_epoch(self, X_media: torch.Tensor, X_control: torch.Tensor, R: torch.Tensor, y: torch.Tensor) -> Tuple[float, float, float]: """ Train for one epoch. Args: X_media: Media data tensor X_control: Control data tensor R: Region tensor y: Target tensor Returns: Tuple of (loss, rmse, r2) """ self.model.train() self.optimizer.zero_grad() # Forward pass predictions, media_coeffs, media_contributions, outputs = self.model(X_media, X_control, R) # Use same loss function as today's working version (Huber Loss) if self.config.get('use_huber_loss', True): # Default to Huber loss (today's working version) huber_delta = self.config.get('huber_delta', 0.3) base_loss = nn.HuberLoss(delta=huber_delta)(predictions, y) else: base_loss = nn.MSELoss()(predictions, y) # Add focal loss component for hard examples if self.config.get('use_focal_loss', False): alpha = self.config.get('focal_alpha', 0.25) gamma = self.config.get('focal_gamma', 1.5) # Calculate focal weight based on prediction error abs_error = torch.abs(predictions - y) normalized_error = abs_error / (abs_error.mean() + 1e-8) focal_weight = alpha * torch.pow(normalized_error, gamma) focal_loss = focal_weight * base_loss focal_weight_config = self.config.get('focal_loss_weight', 0.1) mse_loss = base_loss + focal_weight_config * focal_loss.mean() # CONFIGURABLE focal loss contribution else: mse_loss = base_loss dag_loss = self.model.get_dag_loss() if hasattr(self.model, 'get_dag_loss') else 0 sparsity_loss = self.model.get_sparsity_loss() if hasattr(self.model, 'get_sparsity_loss') else 0 # NEW: Attribution prior regularization with DYNAMIC SCALING attribution_reg_raw = outputs.get('attribution_reg_loss_raw', torch.tensor(0.0, device=predictions.device)) if attribution_reg_raw.item() > 0: # Scale attribution loss to match MSE magnitude (multi-task learning approach) mse_scale = mse_loss.detach() # Don't backprop through scaling factor attribution_reg_scaled = attribution_reg_raw * (mse_scale / (attribution_reg_raw.detach() + 1e-8)) # Apply weight (0.5 = equal priority to prediction and attribution) attribution_reg = self.model.attribution_reg_weight * attribution_reg_scaled else: attribution_reg = torch.tensor(0.0, device=predictions.device) # NEW: Seasonal regularization with DYNAMIC SCALING seasonal_reg_raw = outputs.get('seasonal_reg_loss_raw', torch.tensor(0.0, device=predictions.device)) if seasonal_reg_raw.item() > 0: # Scale seasonal loss to match MSE magnitude (same approach as attribution) mse_scale = mse_loss.detach() seasonal_reg_scaled = seasonal_reg_raw * (mse_scale / (seasonal_reg_raw.detach() + 1e-8)) # Apply weight (0.1 = lower priority than attribution) seasonal_reg = self.model.seasonal_reg_weight * seasonal_reg_scaled else: seasonal_reg = torch.tensor(0.0, device=predictions.device) # HYBRID APPROACH: Fixed regularization weights for stable training # Only core model parameters (coefficients, ranges) are learnable - not loss balancing l1_reg = sum(torch.sum(torch.abs(param)) for param in self.model.parameters()) l2_reg = sum(torch.sum(param ** 2) for param in self.model.parameters()) # Use minimal fixed regularization for maximum learning capability (from config) total_loss = (mse_loss + self.config.get('dag_weight', 0.005) * dag_loss + # Minimal DAG regularization self.config.get('sparsity_weight', 0.001) * sparsity_loss + # Minimal sparsity regularization self.config.get('l1_weight', 1e-5) * l1_reg + # Ultra-light L1 self.config.get('l2_weight', 5e-5) * l2_reg + # Ultra-light L2 attribution_reg + # Media attribution prior (dynamically scaled) seasonal_reg) # Seasonal regularization (prevent suppression) # Backward pass total_loss.backward() # COEFFICIENT-SPECIFIC GRADIENT CLIPPING: Prevent coefficient explosion # Stronger clipping for coefficient-related parameters coeff_params = [] other_params = [] for name, param in self.model.named_parameters(): if param.requires_grad: if 'coeff' in name.lower() or 'range_raw' in name: coeff_params.append(param) else: other_params.append(param) # Advanced gradient clipping for stability if coeff_params: coeff_grad_clip = self.config.get('coeff_grad_clip', 1.0) # Updated from config torch.nn.utils.clip_grad_norm_(coeff_params, max_norm=coeff_grad_clip) # Global gradient clipping for all parameters gradient_clip_norm = self.config.get('gradient_clip_norm', 2.0) if other_params: torch.nn.utils.clip_grad_norm_(other_params, max_norm=gradient_clip_norm) self.optimizer.step() # Calculate metrics IN SCALED SPACE (y/y_mean) for training stability # Original-space conversion should ONLY be done for final reporting with torch.no_grad(): y_np = y.detach().cpu().numpy().flatten() pred_np = predictions.detach().cpu().numpy().flatten() rmse = np.sqrt(mean_squared_error(y_np, pred_np)) r2 = r2_score(y_np, pred_np) if len(np.unique(y_np)) > 1 else 0.0 return total_loss.item(), rmse, r2
[docs] def evaluate_holdout(self, X_media: torch.Tensor, X_control: torch.Tensor, R: torch.Tensor, y: torch.Tensor) -> Tuple[float, float, float]: """ Evaluate model on holdout data. Args: X_media: Holdout media data tensor X_control: Holdout control data tensor R: Holdout region tensor y: Holdout target tensor Returns: Tuple of (loss, rmse, r2) """ self.model.eval() with torch.no_grad(): # Forward pass predictions, _, _, _ = self.model(X_media, X_control, R) # Enhanced validation loss matching training loss if self.config.get('use_huber_loss', True): # Default to Huber loss huber_delta = self.config.get('huber_delta', 0.25) # Tighter delta for precision base_loss = nn.HuberLoss(delta=huber_delta)(predictions, y) else: base_loss = nn.MSELoss()(predictions, y) # Add focal loss component for validation too if self.config.get('use_focal_loss', False): alpha = self.config.get('focal_alpha', 0.25) gamma = self.config.get('focal_gamma', 1.5) # Calculate focal weight based on prediction error abs_error = torch.abs(predictions - y) normalized_error = abs_error / (abs_error.mean() + 1e-8) focal_weight = alpha * torch.pow(normalized_error, gamma) focal_loss = focal_weight * base_loss focal_weight_config = self.config.get('focal_loss_weight', 0.1) mse_loss = base_loss + focal_weight_config * focal_loss.mean() # CONFIGURABLE focal loss contribution else: mse_loss = base_loss dag_loss = self.model.get_dag_loss() if hasattr(self.model, 'get_dag_loss') else 0 sparsity_loss = self.model.get_sparsity_loss() if hasattr(self.model, 'get_sparsity_loss') else 0 # HYBRID APPROACH: Fixed regularization weights for consistent evaluation (from config) total_loss = (mse_loss + self.config.get('dag_weight', 0.005) * dag_loss + # Minimal DAG regularization self.config.get('sparsity_weight', 0.001) * sparsity_loss) # Minimal sparsity regularization # Calculate metrics IN SCALED SPACE (y/y_mean) for training stability # Original-space conversion should ONLY be done for final reporting y_np = y.detach().cpu().numpy().flatten() pred_np = predictions.detach().cpu().numpy().flatten() rmse = np.sqrt(mean_squared_error(y_np, pred_np)) r2 = r2_score(y_np, pred_np) if len(np.unique(y_np)) > 1 else 0.0 return total_loss.item(), rmse, r2
[docs] def should_stop_early(self, current_rmse: float) -> bool: """ Check if training should stop early based on RMSE improvement. Args: current_rmse: Current epoch's RMSE Returns: True if training should stop """ if not self.config.get('early_stopping', False): return False if current_rmse < self.best_rmse: self.best_rmse = current_rmse self.patience_counter = 0 return False else: self.patience_counter += 1 patience = self.config.get('patience', 600) return self.patience_counter >= patience
[docs] def train(self, X_media_train: torch.Tensor, X_control_train: torch.Tensor, R_train: torch.Tensor, y_train: torch.Tensor, X_media_holdout: Optional[torch.Tensor] = None, X_control_holdout: Optional[torch.Tensor] = None, R_holdout: Optional[torch.Tensor] = None, y_holdout: Optional[torch.Tensor] = None, pipeline: Optional[Any] = None, verbose: bool = True) -> Dict[str, Any]: """ Full training loop with warm-start, main training, and holdout evaluation. Args: X_media_train: Training media data X_control_train: Training control data R_train: Training region data y_train: Training target data X_media_holdout: Optional holdout media data X_control_holdout: Optional holdout control data R_holdout: Optional holdout region data y_holdout: Optional holdout target data pipeline: Optional UnifiedDataPipeline for accessing scaler verbose: Whether to show progress Returns: Dictionary with training results """ # Store pipeline if provided if pipeline is not None: self.pipeline = pipeline if self.model is None or self.optimizer is None: raise ValueError("Model and optimizer must be created before training") # Move data to device X_media_train = X_media_train.to(self.device) X_control_train = X_control_train.to(self.device) R_train = R_train.to(self.device) y_train = y_train.to(self.device) if X_media_holdout is not None: X_media_holdout = X_media_holdout.to(self.device) X_control_holdout = X_control_holdout.to(self.device) R_holdout = R_holdout.to(self.device) y_holdout = y_holdout.to(self.device) # Initialize model with data # CRITICAL: Pass scaling_constants to model for proper initialization scaler = self.pipeline.get_scaler() self.model.scaling_constants = scaler.scaling_constants if hasattr(self.model, 'initialize_baseline'): self.model.initialize_baseline(y_train) if hasattr(self.model, 'initialize_stable_coefficients_from_data'): self.model.initialize_stable_coefficients_from_data(X_media_train, X_control_train, y_train) # Initialize Hill parameters from data distribution (data-driven, not hard-coded) if hasattr(self.model, 'initialize_hill_from_data'): logger.info("\n Initializing Hill parameters from channel-specific SOV distributions...") self.model.initialize_hill_from_data(X_media_train) logger.info(" Hill initialization complete - each channel has its own saturation curve") # Warm-start training self.warm_start_training(X_media_train, X_control_train, R_train, y_train, verbose) # Main training n_epochs = self.config.get('n_epochs', 1000) if verbose: logger.info(f"\n Main Training ({n_epochs} epochs)...") # Storage for holdout metrics last_holdout_loss = None last_holdout_rmse = None last_holdout_r2 = None pbar = tqdm(range(n_epochs), desc="Training") if verbose else range(n_epochs) notears_update_every = self.config.get('notears_dual_update_every', 100) is_notears = (self.config.get('dag_mode', 'triangular') == 'notears') notears_warmup = int(self.config.get('notears_warmup_epochs', 0)) notears_factor = float(self.config.get('notears_dual_factor', 3.0)) # During warmup the NOTEARS penalty and dual updates are disabled so the # model can establish a Huber-only data fit. The model exposes a # `notears_active` buffer that get_dag_loss()/notears_update_duals() # consult; flip it off here, then back on once warmup completes. if is_notears and notears_warmup > 0 and hasattr(self.model, 'notears_active'): self.model.notears_active.fill_(False) if verbose: logger.info( f"[NOTEARS] warmup enabled: Huber-only for first {notears_warmup} epochs" ) for epoch in pbar: # Training step train_loss, train_rmse, train_r2 = self.train_epoch( X_media_train, X_control_train, R_train, y_train ) # End of warmup: enable NOTEARS penalty + dual updates from now on. if (is_notears and notears_warmup > 0 and epoch == notears_warmup and hasattr(self.model, 'notears_active')): self.model.notears_active.fill_(True) if verbose: logger.info(f"[NOTEARS] warmup complete at epoch {epoch}; activating penalty") # NOTEARS augmented-Lagrangian outer step. Runs only in notears mode # and on the configured cadence; orthogonal to Huber loss and the # rest of the inner-loop training. if (is_notears and epoch > 0 and epoch % notears_update_every == 0 and hasattr(self.model, 'notears_update_duals')): info = self.model.notears_update_duals(factor=notears_factor) if verbose and info: logger.info( f"[NOTEARS] epoch={epoch} h={info['h']:.2e} " f"rho={info['rho']:.2e} alpha={info['alpha']:.2e}" ) # Store metrics self.train_losses.append(train_loss) self.train_rmses.append(train_rmse) self.train_r2s.append(train_r2) # Holdout evaluation (every 10 epochs to save time) if X_media_holdout is not None and epoch % 10 == 0: holdout_loss, holdout_rmse, holdout_r2 = self.evaluate_holdout( X_media_holdout, X_control_holdout, R_holdout, y_holdout ) last_holdout_loss = holdout_loss last_holdout_rmse = holdout_rmse last_holdout_r2 = holdout_r2 # Advanced learning rate scheduling if self.config.get('use_cosine_annealing', False): # For Cosine Annealing, step every epoch (not based on metrics) if self.warmup_scheduler and epoch < self.config.get('warmup_epochs', 0): self.warmup_scheduler.step() else: self.scheduler.step() else: # Traditional ReduceLROnPlateau scheduling self.scheduler.step(holdout_rmse) # Update best RMSE for monitoring (regardless of early stopping) if holdout_rmse < self.best_rmse: self.best_rmse = holdout_rmse # Early stopping check if self.should_stop_early(holdout_rmse): if verbose: logger.info(f"\n Early stopping at epoch {epoch}") break else: # Advanced learning rate scheduling (no holdout case) if self.config.get('use_cosine_annealing', False): # For Cosine Annealing, step every epoch if self.warmup_scheduler and epoch < self.config.get('warmup_epochs', 0): self.warmup_scheduler.step() else: self.scheduler.step() else: # Use training RMSE for scheduling if no holdout self.scheduler.step(train_rmse) # Update best RMSE using training RMSE when no holdout available if train_rmse < self.best_rmse: self.best_rmse = train_rmse # Update progress bar if verbose and isinstance(pbar, tqdm): progress_dict = { 'TrL': f'{train_loss:.2f}', 'TrR': f'{train_rmse:.4f}', 'TrR²': f'{train_r2:.3f}', 'Best': f'{self.best_rmse:.4f}' } # Add holdout metrics if available if last_holdout_rmse is not None: r2_min = self.config.get('training_display', {}).get('r2_display_min', -10.0) r2_display = last_holdout_r2 if last_holdout_r2 > r2_min else r2_min progress_dict.update({ 'HoL': f'{last_holdout_loss:.1f}', 'HoR': f'{last_holdout_rmse:.4f}', 'HoR²': f'{r2_display:.3f}' }) pbar.set_postfix(progress_dict) # Final evaluation: original-scale reporting on post-burn-in timesteps only. # Trim padded burn-in before inverse transform and scoring so RMSE/R² match # operational weeks (aligned with pipeline padding / GRU stabilization). self.model.eval() with torch.no_grad(): train_pred_scaled, _, _, _ = self.model(X_media_train, X_control_train, R_train) scaler = self.pipeline.get_scaler() burn_in_weeks = int( self.config.get('burn_in_weeks', getattr(self.pipeline, 'padding_weeks', 0)) ) train_pred_eval = ModelTrainer._trim_burn_in_time(train_pred_scaled, burn_in_weeks) train_true_eval = ModelTrainer._trim_burn_in_time(y_train, burn_in_weeks) train_pred_orig = scaler.inverse_transform_target(train_pred_eval) train_true_orig = scaler.inverse_transform_target(train_true_eval) final_train_rmse_orig = np.sqrt(mean_squared_error( train_true_orig.detach().cpu().numpy().flatten(), train_pred_orig.detach().cpu().numpy().flatten() )) final_train_r2_orig = r2_score( train_true_orig.detach().cpu().numpy().flatten(), train_pred_orig.detach().cpu().numpy().flatten() ) # Final holdout evaluation in original scale (if available) if X_media_holdout is not None and y_holdout is not None: holdout_pred_scaled, _, _, _ = self.model(X_media_holdout, X_control_holdout, R_holdout) holdout_pred_eval = ModelTrainer._trim_burn_in_time(holdout_pred_scaled, burn_in_weeks) holdout_true_eval = ModelTrainer._trim_burn_in_time(y_holdout, burn_in_weeks) holdout_pred_orig = scaler.inverse_transform_target(holdout_pred_eval) holdout_true_orig = scaler.inverse_transform_target(holdout_true_eval) # Convert to numpy for robust metrics y_true_np = holdout_true_orig.detach().cpu().numpy().flatten() y_pred_np = holdout_pred_orig.detach().cpu().numpy().flatten() y_true_scaled_np = holdout_true_eval.detach().cpu().numpy().flatten() y_pred_scaled_np = holdout_pred_eval.detach().cpu().numpy().flatten() # Standard metrics (original scale) final_holdout_rmse_orig = np.sqrt(mean_squared_error(y_true_np, y_pred_np)) final_holdout_r2_orig = r2_score(y_true_np, y_pred_np) # ROBUST METRICS - Option 1 from sklearn.metrics import mean_absolute_error # 1. Median Absolute Error (original scale) holdout_mae_orig = mean_absolute_error(y_true_np, y_pred_np) holdout_median_ae = np.median(np.abs(y_true_np - y_pred_np)) # 2. Trimmed RMSE (remove top 5% outliers) abs_errors = np.abs(y_true_np - y_pred_np) trimmed_threshold = np.percentile(abs_errors, 95) trimmed_mask = abs_errors <= trimmed_threshold if np.sum(trimmed_mask) > 10: # Need enough data points holdout_trimmed_rmse = np.sqrt(mean_squared_error( y_true_np[trimmed_mask], y_pred_np[trimmed_mask] )) else: holdout_trimmed_rmse = final_holdout_rmse_orig # 3. Scaled-space R² (y/y_mean) holdout_r2_scaled = r2_score(y_true_scaled_np, y_pred_scaled_np) # 4. Scaled-space RMSE holdout_rmse_scaled = np.sqrt(mean_squared_error(y_true_scaled_np, y_pred_scaled_np)) final_holdout_loss_orig = last_holdout_loss # Keep scaled-space loss else: final_holdout_rmse_orig = None final_holdout_r2_orig = None final_holdout_loss_orig = None # Initialize robust metrics as None holdout_mae_orig = None holdout_median_ae = None holdout_trimmed_rmse = None holdout_r2_scaled = None holdout_rmse_scaled = None final_results = { 'model': self.model, 'train_losses': self.train_losses, 'train_rmses': self.train_rmses, 'train_r2s': self.train_r2s, 'best_rmse': self.best_rmse, 'final_train_loss': self.train_losses[-1] if self.train_losses else 0.0, 'final_train_rmse': final_train_rmse_orig, # ORIGINAL SCALE (for final reporting) 'final_train_r2': final_train_r2_orig, # ORIGINAL SCALE (for final reporting) } if final_holdout_rmse_orig is not None: final_results.update({ 'final_holdout_loss': final_holdout_loss_orig, 'final_holdout_rmse': final_holdout_rmse_orig, # ORIGINAL SCALE (for final reporting) 'final_holdout_r2': final_holdout_r2_orig, # ORIGINAL SCALE (for final reporting) # ROBUST METRICS - More reliable evaluation 'holdout_mae_orig': holdout_mae_orig, # Mean Absolute Error (original scale) 'holdout_median_ae': holdout_median_ae, # Median Absolute Error (original scale) 'holdout_trimmed_rmse': holdout_trimmed_rmse, # Trimmed RMSE (95% of data, removes outliers) 'holdout_r2_scaled': holdout_r2_scaled, # R² in scaled space (y/mean) 'holdout_rmse_scaled': holdout_rmse_scaled, # RMSE in scaled space (training metric) }) return final_results
# REMOVED: _set_random_seeds method # Seeds should be set once by the main script and not interfered with during training. # Multiple seed resets can disrupt the intended random number sequence.