Source code for deepcausalmmm.core.train_model

"""
Training functions for DeepCausalMMM models.
Updated to use UnifiedDataPipeline for consistent data processing.
"""

import logging
import torch
import numpy as np
from typing import Dict, List, Optional, Tuple, Any
from tqdm import tqdm
from sklearn.metrics import r2_score, mean_squared_error

logger = logging.getLogger('deepcausalmmm')

from deepcausalmmm.core.unified_model import DeepCausalMMM, create_unified_mmm
from deepcausalmmm.core.config import get_default_config, update_config
from deepcausalmmm.core.scaling import SimpleGlobalScaler, GlobalScaler
from deepcausalmmm.core.data import UnifiedDataPipeline
from deepcausalmmm.core.trainer import ModelTrainer
# Device utilities available but not needed for this implementation


[docs] def calculate_r2(y_true: torch.Tensor, y_pred: torch.Tensor) -> float: """Calculate R-squared score.""" y_true_np = y_true.detach().cpu().numpy() y_pred_np = y_pred.detach().cpu().numpy() return r2_score(y_true_np.flatten(), y_pred_np.flatten())
# DEPRECATED: Use ModelTrainer class instead
[docs] def train_model_with_config( model: DeepCausalMMM, X_media_padded: torch.Tensor, X_control_padded: torch.Tensor, R: Optional[torch.Tensor], y_padded: torch.Tensor, config: Dict[str, Any], verbose: bool = True, holdout_data: Optional[Dict[str, torch.Tensor]] = None, pipeline: Optional[Any] = None ) -> Tuple[List[float], List[float], List[float], float]: """ Train model with config-driven parameters and warm-start. This matches the proven approach from dashboard_rmse_optimized.py Args: model: DeepCausalMMM model instance X_media_padded: Padded media data [n_regions, n_timesteps, n_channels] X_control_padded: Padded control data [n_regions, n_timesteps, n_controls] R: Region tensor (can be None) y_padded: Padded target data [n_regions, n_timesteps] config: Configuration dictionary verbose: Whether to print progress Returns: Tuple of (train_losses, train_rmses, train_r2s, best_rmse) """ if verbose: logger.info("Training Model with Config Parameters...") logger.info("Training Configuration from Config:") logger.info(f" Epochs: {config['n_epochs']}") logger.info(f" Hidden units: {config['hidden_dim']}") logger.info(f" Warm-start: {config['warm_start_epochs']}") logger.info(f" Learning rate: {config['learning_rate']}") logger.info(f" Optimizer: {config['optimizer']}") logger.info(f" Scheduler: {config['scheduler']}") # 1. Warm-start training for coefficient stabilization if verbose: logger.info(f" Config-driven warm-start training for {config['warm_start_epochs']} epochs...") # Create optimizer for warm-start (reduced learning rate) warm_optimizer = torch.optim.AdamW( model.parameters(), lr=config['learning_rate'] * 0.01, # Reduced LR for warm-start weight_decay=config.get('optimizer', {}).get('weight_decay', 1e-7) ) # Create region tensor if needed if R is None: R = torch.zeros(X_media_padded.shape[0], dtype=torch.long) # CRITICAL FIX: Initialize baseline with training data (was missing!) if hasattr(model, 'initialize_baseline'): if verbose: logger.info(" Initializing model baseline from training data...") model.initialize_baseline(y_padded) if hasattr(model, 'initialize_stable_coefficients_from_data'): if verbose: logger.info(" Initializing stable coefficients from training data...") model.initialize_stable_coefficients_from_data(X_media_padded, X_control_padded, y_padded) model.warm_start_training( X_media_padded, X_control_padded, R, y_padded, warm_optimizer, config.get('warm_start_epochs', 50) ) if verbose: logger.info(f" Warm-start training completed. GRU initialized for stable coefficients.") # 2. Main training with full configuration if verbose: logger.info(f" Main training for {config['n_epochs']} epochs...") # Create main optimizer based on config - EXACT SAME as working dashboard if config.get('optimizer', 'adamw') == 'adamw': optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=1e-5) else: optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate']) # Create scheduler based on config - EXACT SAME as working dashboard if config.get('scheduler', 'reduce_on_plateau') == 'reduce_on_plateau': scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', patience=config.get('patience', 800), factor=0.5, min_lr=1e-6 ) else: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['n_epochs']) # Training tracking train_losses = [] train_rmses = [] train_r2s = [] best_rmse = float('inf') best_holdout_loss = float('inf') # Track best holdout loss for early stopping patience_counter = 0 # Store last holdout metrics for progress bar last_holdout_rmse = None last_holdout_r2 = None last_holdout_loss = None # Training loop with progress bar pbar = tqdm(range(config['n_epochs']), desc="Config-Driven Training") for epoch in pbar: model.train() optimizer.zero_grad() # Forward pass y_pred_full, _, _, _ = model( X_media_padded, X_control_padded, R ) # Calculate loss (MSE in scaled space: y/y_mean) mse_loss = torch.nn.functional.mse_loss(y_pred_full, y_padded) # Add DAG and sparsity losses (CRITICAL FIX) dag_loss = model.get_dag_loss() if hasattr(model, 'get_dag_loss') else 0 sparsity_loss = model.get_sparsity_loss() if hasattr(model, 'get_sparsity_loss') else 0 # Add L1 and L2 regularization l1_reg = sum(torch.sum(torch.abs(param)) for param in model.parameters()) l2_reg = sum(torch.sum(param ** 2) for param in model.parameters()) total_loss = (mse_loss + config.get('dag_weight', 0.1) * dag_loss + config.get('sparsity_weight', 0.1) * sparsity_loss + config['l1_weight'] * l1_reg + config['l2_weight'] * l2_reg) # Backward pass with gradient clipping total_loss.backward() if config.get('max_grad_norm'): torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm']) optimizer.step() # Calculate metrics with torch.no_grad(): # Calculate RMSE in scaled space (y/y_mean, consistent with loss) train_rmse = torch.sqrt(mse_loss).item() # Calculate R² (also in scaled space for consistency) y_eval = y_padded[:, config['burn_in_weeks']:].contiguous() pred_eval = y_pred_full[:, config['burn_in_weeks']:].contiguous() train_r2 = r2_score(y_eval.numpy().flatten(), pred_eval.numpy().flatten()) # Evaluate on holdout if available (every 10 epochs for better monitoring) holdout_rmse = None holdout_r2 = None holdout_loss = None if (holdout_data is not None and pipeline is not None and epoch >= 10): # Every epoch, after epoch 10 try: # Debug info on first holdout evaluation if epoch == 10: logger.debug(f"\n DEBUG: Holdout X_media shape: {holdout_data['X_media'].shape}") logger.debug(f" DEBUG: Holdout X_control shape: {holdout_data['X_control'].shape}") logger.debug(f" DEBUG: Burn-in weeks: {config['burn_in_weeks']}") logger.debug(f" DEBUG: Holdout weeks after burn-in: {holdout_data['X_media'].shape[1] - config['burn_in_weeks']}") # Evaluate on holdout data holdout_pred_full, _, _, _ = model( holdout_data['X_media'], holdout_data['X_control'], holdout_data['R'] ) # Check if we have holdout data after removing padding (burn-in weeks) total_holdout_weeks = holdout_pred_full.shape[1] actual_holdout_weeks = total_holdout_weeks - config['burn_in_weeks'] if actual_holdout_weeks > 0: # Calculate holdout metrics in original scale # CRITICAL FIX: Data already has padding removed, so don't remove it again! holdout_pred_orig = pipeline.inverse_transform_predictions( holdout_pred_full[:, config['burn_in_weeks']:].detach(), remove_padding=False ) holdout_true_orig = pipeline.inverse_transform_predictions( holdout_data['y'][:, config['burn_in_weeks']:], remove_padding=False ) if holdout_pred_orig.numel() > 0 and holdout_true_orig.numel() > 0: holdout_rmse = np.sqrt(mean_squared_error( holdout_true_orig.numpy().flatten(), holdout_pred_orig.numpy().flatten() )) holdout_r2 = r2_score( holdout_true_orig.numpy().flatten(), holdout_pred_orig.numpy().flatten() ) # Cap extremely large RMSE values for display if holdout_rmse > 1e8: # Cap at 100 million holdout_rmse = 1e8 # Calculate holdout loss in scaled space (y/y_mean, consistent with training loss) holdout_pred_scaled = holdout_pred_full[:, config['burn_in_weeks']:].detach() holdout_true_scaled = holdout_data['y'][:, config['burn_in_weeks']:] holdout_loss = torch.nn.functional.mse_loss(holdout_pred_scaled, holdout_true_scaled).item() # Store holdout metrics for progress bar last_holdout_rmse = holdout_rmse last_holdout_r2 = holdout_r2 last_holdout_loss = holdout_loss # Debug success on first holdout evaluation if epoch == 10: logger.debug(f" DEBUG: Holdout evaluation successful - RMSE: {holdout_rmse:,.0f}, R²: {holdout_r2:.3f}") else: if epoch == 10: logger.warning(f" DEBUG: Empty holdout tensors after processing") else: if epoch == 10: logger.warning(f" DEBUG: No actual holdout weeks after removing padding ({actual_holdout_weeks} weeks)") except Exception as e: # Print debug info on errors if epoch == 10: logger.warning(f"\n Holdout evaluation error (epoch {epoch}): {e}") pass # Track metrics train_losses.append(total_loss.item()) train_rmses.append(train_rmse) train_r2s.append(train_r2) # Update best RMSE (use training RMSE for consistency) if train_rmse < best_rmse: best_rmse = train_rmse # Balanced early stopping: consider both train RMSE and holdout loss improvement = False # Check training RMSE improvement if train_rmse < best_rmse: improvement = True # Check holdout loss improvement (if available) if last_holdout_loss is not None: if last_holdout_loss < best_holdout_loss: best_holdout_loss = last_holdout_loss improvement = True # Update patience counter based on ANY improvement if improvement: patience_counter = 0 else: patience_counter += 1 # Update progress bar with both train and holdout metrics (shortened names) progress_dict = { 'TrL': f'{total_loss.item():.1f}', 'TrR': f'{train_rmse:.4f}', 'TrR²': f'{train_r2:.3f}', 'Best': f'{best_rmse:.4f}' } # Use stored holdout metrics (from last evaluation) for progress bar if last_holdout_rmse is not None and last_holdout_loss is not None: # Cap R² display for readability (very negative values are not useful) r2_min = config.get('training_display', {}).get('r2_display_min', -10.0) r2_display = last_holdout_r2 if last_holdout_r2 > r2_min else r2_min progress_dict.update({ 'HoL': f'{last_holdout_loss:.1f}', 'HoR': f'{last_holdout_rmse/1e6:.1f}M' if last_holdout_rmse > 0 else '0.0M', # Show in millions for space 'HoR²': f'{r2_display:.3f}' }) pbar.set_postfix(progress_dict) # Scheduler step - EXACT SAME as working dashboard if config.get('scheduler', 'reduce_on_plateau') == 'reduce_on_plateau': scheduler.step(train_rmse) else: scheduler.step() # Early stopping if (config.get('early_stopping', False) and patience_counter >= config.get('patience', 500)): if verbose: logger.info(f" Early stopping at epoch {epoch}") logger.info(f" Best RMSE: {best_rmse:.2f}") break pbar.close() if verbose: logger.info(f" Config-driven training completed!") logger.info(f" Final Best RMSE: {best_rmse:.2f}") return train_losses, train_rmses, train_r2s, best_rmse
# train_mmm_with_trainer function removed - users should use ModelTrainer class directly
[docs] def train_mmm( X_media: np.ndarray, X_control: np.ndarray, y: np.ndarray, config: Optional[Dict[str, Any]] = None, channel_names: Optional[List[str]] = None, control_names: Optional[List[str]] = None, verbose: bool = True, use_unified_pipeline: bool = False, train_ratio: Optional[float] = None ) -> Tuple[DeepCausalMMM, Dict[str, Any]]: """ Train a DeepCausalMMM model with optional UnifiedDataPipeline. .. deprecated:: 1.0.0 This function-based approach is deprecated. Please use the modern class-based approach with ModelTrainer instead: ```python from deepcausalmmm.core.trainer import ModelTrainer trainer = ModelTrainer(config) model = trainer.create_model(n_media, n_control, n_regions) trainer.create_optimizer_and_scheduler() results = trainer.train(...) ``` Args: X_media: Media variables [n_regions, n_timesteps, n_channels] X_control: Control variables [n_regions, n_timesteps, n_controls] y: Target variable [n_regions, n_timesteps] config: Configuration dictionary (uses default if None) channel_names: List of channel names control_names: List of control variable names verbose: Whether to print progress use_unified_pipeline: Whether to use UnifiedDataPipeline for train/holdout split train_ratio: Deprecated - use config['holdout_weeks'] instead Returns: Tuple of (trained_model, results_dict) """ import warnings warnings.warn( "train_mmm() is deprecated and will be removed in v2.0.0. " "Please use the modern ModelTrainer class instead. " "See documentation for migration guide.", DeprecationWarning, stacklevel=2 ) if verbose: logger.info(" DEEPCAUSALMMM TRAINING") logger.info("=" * 50) if use_unified_pipeline: logger.info(" Config-driven • UnifiedDataPipeline • RMSE Optimized") else: logger.info(" Config-driven • SimpleGlobalScaler • RMSE Optimized") # 1. Configuration setup if config is None: config = get_default_config() if verbose: logger.info(" Using default configuration") else: if verbose: logger.info(" Using provided configuration") if use_unified_pipeline: # Use UnifiedDataPipeline for consistent train/holdout processing return _train_with_unified_pipeline( X_media, X_control, y, config, channel_names, control_names, verbose ) else: # Use simple approach without holdout splitting return _train_simple( X_media, X_control, y, config, channel_names, control_names, verbose )
def _train_with_unified_pipeline( X_media: np.ndarray, X_control: np.ndarray, y: np.ndarray, config: Dict[str, Any], channel_names: Optional[List[str]], control_names: Optional[List[str]], verbose: bool ) -> Tuple[DeepCausalMMM, Dict[str, Any]]: """Train using UnifiedDataPipeline with train/holdout split.""" # Get holdout ratio from config holdout_ratio = config.get('holdout_ratio', 0.27) if verbose: logger.info(f"\n UNIFIED DATA PIPELINE TRAINING") logger.info(f" Consistent train/holdout processing • Holdout ratio: {holdout_ratio:.1%}") # 1. Initialize unified data pipeline pipeline = UnifiedDataPipeline(config) # 2. Temporal split (using ratio-based time series approach) train_data, holdout_data = pipeline.temporal_split( X_media, X_control, y, holdout_ratio=holdout_ratio ) # 3. Process training data (fit scaler + transform + pad) train_tensors = pipeline.fit_and_transform_training(train_data) # 4. Process holdout data (transform + pad using SAME scaler) holdout_tensors = pipeline.transform_holdout(holdout_data) # 5. Create model n_media = X_media.shape[2] n_control = X_control.shape[2] # This is already 7 (original control vars) n_regions = X_media.shape[0] # Get actual dimensions from processed tensors (after seasonality addition) actual_n_media = train_tensors['X_media'].shape[2] actual_n_control = train_tensors['X_control'].shape[2] # This will be 14 (7 original + 7 seasonality) model = DeepCausalMMM( n_media=actual_n_media, ctrl_dim=actual_n_control, n_regions=n_regions, hidden=config.get('hidden_dim', 64), dropout=config.get('dropout', 0.1), l1_weight=config.get('l1_weight', 0.001), l2_weight=config.get('l2_weight', 0.001), coeff_range=config.get('coeff_range', 1.0), burn_in_weeks=config.get('burn_in_weeks', 4), momentum_decay=config.get('momentum_decay', 0.9), warm_start_epochs=config.get('warm_start_epochs', 50), enable_dag=config.get('enable_dag', True), enable_interactions=config.get('enable_interactions', True) ) # 6. Train model with holdout evaluation train_losses, train_rmses, train_r2s, best_rmse = train_model_with_config( model, train_tensors['X_media'], train_tensors['X_control'], train_tensors['R'], train_tensors['y'], config, verbose, holdout_data=holdout_tensors, pipeline=pipeline ) # 7. Final evaluation on both train and holdout if verbose: logger.info("\n Final Evaluation (Train + Holdout)...") model.eval() with torch.no_grad(): # Training evaluation train_pred_full, _, train_media_contrib, train_outputs = model( train_tensors['X_media'], train_tensors['X_control'], train_tensors['R'] ) train_control_contrib = train_outputs['control_contributions'] # Holdout evaluation (if holdout data exists after removing padding) total_holdout_weeks = holdout_tensors['X_media'].shape[1] actual_holdout_weeks = total_holdout_weeks - config['burn_in_weeks'] if actual_holdout_weeks > 0: holdout_pred_full, _, holdout_media_contrib, holdout_outputs = model( holdout_tensors['X_media'], holdout_tensors['X_control'], holdout_tensors['R'] ) holdout_control_contrib = holdout_outputs['control_contributions'] # Convert to original scale using pipeline # CRITICAL FIX: Data already has padding removed, so don't remove it again! holdout_pred_orig = pipeline.inverse_transform_predictions( holdout_pred_full[:, config['burn_in_weeks']:], remove_padding=False ) holdout_true_orig = pipeline.inverse_transform_predictions( holdout_tensors['y'][:, config['burn_in_weeks']:], remove_padding=False ) # Calculate holdout metrics holdout_rmse = np.sqrt(mean_squared_error( holdout_true_orig.numpy().flatten(), holdout_pred_orig.numpy().flatten() )) holdout_r2 = r2_score( holdout_true_orig.numpy().flatten(), holdout_pred_orig.numpy().flatten() ) # Calculate holdout loss in scaled space (y/y_mean, consistent with training) holdout_pred_scaled = holdout_pred_full[:, config['burn_in_weeks']:] holdout_true_scaled = holdout_tensors['y'][:, config['burn_in_weeks']:] holdout_loss = torch.nn.functional.mse_loss(holdout_pred_scaled, holdout_true_scaled).item() if verbose: logger.info(f" HOLDOUT RESULTS:") logger.info(f" Loss: {holdout_loss:.1f}") logger.info(f" RMSE: {holdout_rmse:,.0f}") logger.info(f" R²: {holdout_r2:.3f}") else: holdout_rmse = None holdout_r2 = None holdout_loss = None if verbose: logger.warning(f" Holdout data too small for evaluation") # Training evaluation # CONSISTENCY FIX: Use same process as holdout - no double padding removal train_pred_orig = pipeline.inverse_transform_predictions( train_pred_full[:, config['burn_in_weeks']:], remove_padding=False ) train_true_orig = pipeline.inverse_transform_predictions( train_tensors['y'][:, config['burn_in_weeks']:], remove_padding=False ) final_train_rmse = np.sqrt(mean_squared_error( train_true_orig.numpy().flatten(), train_pred_orig.numpy().flatten() )) final_train_r2 = r2_score( train_true_orig.numpy().flatten(), train_pred_orig.numpy().flatten() ) if verbose: logger.info(f" TRAINING RESULTS:") logger.info(f" RMSE: {final_train_rmse:,.0f}") logger.info(f" R²: {final_train_r2:.3f}") logger.info(f"\n SUMMARY:") logger.info(f" Train: RMSE {final_train_rmse:,.0f} | R² {final_train_r2:.3f}") if holdout_rmse is not None: logger.info(f" Holdout: RMSE {holdout_rmse:,.0f} | R² {holdout_r2:.3f}") generalization_gap = ((holdout_rmse - final_train_rmse) / final_train_rmse) * 100 logger.info(f" Generalization Gap: {generalization_gap:+.1f}%") # 8. Prepare results results = { 'train_losses': train_losses, 'train_rmses': train_rmses, 'train_r2s': train_r2s, 'best_rmse': best_rmse, 'final_train_rmse': final_train_rmse, 'final_train_r2': final_train_r2, 'final_train_loss': train_losses[-1] if train_losses else 0.0, 'final_holdout_rmse': holdout_rmse, 'final_holdout_r2': holdout_r2, 'final_holdout_loss': holdout_loss, 'holdout_predictions_orig': holdout_pred_orig if holdout_rmse is not None else None, # Add holdout predictions 'pipeline': pipeline, 'config': config, 'predictions': train_pred_orig.numpy(), 'media_contributions': pipeline.inverse_transform_contributions( train_media_contrib[:, config['burn_in_weeks']:], train_true_orig ).numpy(), 'control_contributions': train_control_contrib[:, config['burn_in_weeks']:].numpy(), 'channel_names': channel_names, 'control_names': control_names, 'model_params': { 'n_regions': X_media.shape[0], 'n_weeks': X_media.shape[1], 'n_media_channels': X_media.shape[2], 'n_control_channels': X_control.shape[2], 'padding_weeks': config['burn_in_weeks'], 'train_weeks': train_data['X_media'].shape[1], 'holdout_weeks': holdout_data['X_media'].shape[1] if holdout_data['X_media'].shape[1] > 0 else 0 } } if verbose: logger.info(f"\n UNIFIED PIPELINE TRAINING COMPLETE!") logger.info(f" Train RMSE: {final_train_rmse:,.0f} (R²: {final_train_r2:.3f})") if holdout_rmse is not None: logger.info(f" Holdout RMSE: {holdout_rmse:,.0f} (R²: {holdout_r2:.3f})") logger.info("=" * 50) return model, results def _train_simple( X_media: np.ndarray, X_control: np.ndarray, y: np.ndarray, config: Dict[str, Any], channel_names: Optional[List[str]], control_names: Optional[List[str]], verbose: bool ) -> Tuple[DeepCausalMMM, Dict[str, Any]]: """Train using simple approach without holdout splitting.""" # 2. Data preprocessing with SimpleGlobalScaler if verbose: logger.info("\n Data Preprocessing with SimpleGlobalScaler...") scaler = SimpleGlobalScaler() if verbose: logger.info(" Created new SimpleGlobalScaler") # Scale data X_media_scaled, X_control_scaled, y_scaled = scaler.fit_transform( X_media, X_control, y ) if verbose: logger.info(" SimpleGlobalScaler applied successfully") logger.info(f" Data shape: {X_media.shape[0]} regions × {X_media.shape[1]} weeks") logger.info(f" Media channels: {X_media.shape[2]}") logger.info(f" Control variables: {X_control.shape[2]}") # 3. Data padding for burn-in padding_weeks = config['burn_in_weeks'] # Create padding tensors media_padding = torch.zeros(X_media_scaled.shape[0], padding_weeks, X_media_scaled.shape[2]) control_padding = torch.zeros(X_control_scaled.shape[0], padding_weeks, X_control_scaled.shape[2]) y_padding = torch.zeros(y_scaled.shape[0], padding_weeks) # Add padding X_media_padded = torch.cat([media_padding, X_media_scaled], dim=1) X_control_padded = torch.cat([control_padding, X_control_scaled], dim=1) y_padded = torch.cat([y_padding, y_scaled], dim=1) if verbose: logger.info(f" Added {padding_weeks} weeks padding for burn-in") # 4. Model creation if verbose: logger.info("\n Creating Model from Configuration...") model = DeepCausalMMM( n_media=X_media.shape[2], ctrl_dim=X_control.shape[2], n_regions=X_media.shape[0], hidden=config.get('hidden_dim', 64), dropout=config.get('dropout', 0.1), l1_weight=config.get('l1_weight', 0.001), l2_weight=config.get('l2_weight', 0.001), coeff_range=config.get('coeff_range', 1.0), burn_in_weeks=config.get('burn_in_weeks', 4), momentum_decay=config.get('momentum_decay', 0.9), warm_start_epochs=config.get('warm_start_epochs', 50), enable_dag=config.get('enable_dag', True), enable_interactions=config.get('enable_interactions', True) ) if verbose: logger.info(f" Model created with {config.get('hidden_dim', 64)} hidden units") logger.info(f" Config-driven parameters: dropout={config.get('dropout', 0.1)}, l1={config.get('l1_weight', 0.001)}, l2={config.get('l2_weight', 0.001)}") # Create region tensor for training and evaluation R = torch.zeros(X_media_padded.shape[0], dtype=torch.long) # 5. Training with config-driven approach train_losses, train_rmses, train_r2s, best_rmse = train_model_with_config( model, X_media_padded, X_control_padded, R, y_padded, config, verbose ) # 6. Final evaluation if verbose: logger.info("\n Final Evaluation...") model.eval() with torch.no_grad(): predictions_full, _, media_contrib_full, outputs = model( X_media_padded, X_control_padded, R ) control_contrib_full = outputs['control_contributions'] # Remove padding for evaluation predictions = predictions_full[:, padding_weeks:] media_contributions = media_contrib_full[:, padding_weeks:] control_contributions = control_contrib_full[:, padding_weeks:] # Convert to original scale using scaler predictions_orig = scaler.inverse_transform_target(predictions) y_orig = scaler.inverse_transform_target(y_scaled) # Calculate final metrics final_rmse = np.sqrt(mean_squared_error(y_orig.numpy().flatten(), predictions_orig.numpy().flatten())) final_r2 = r2_score(y_orig.numpy().flatten(), predictions_orig.numpy().flatten()) relative_rmse = final_rmse / y_orig.mean().item() * 100 if verbose: logger.info(f" FINAL RESULTS:") logger.info(f" RMSE: {final_rmse:,.0f}") logger.info(f" Relative RMSE: {relative_rmse:.1f}%") logger.info(f" R²: {final_r2:.3f}") logger.info(f" Training Best RMSE: {best_rmse:.4f} (scaled space y/y_mean)") # 7. Prepare results results = { 'train_losses': train_losses, 'train_rmses': train_rmses, 'train_r2s': train_r2s, 'best_rmse': best_rmse, 'final_rmse': final_rmse, 'final_r2': final_r2, 'relative_rmse': relative_rmse, 'scaler': scaler, 'config': config, 'predictions': predictions_orig.numpy(), 'media_contributions': media_contributions.numpy(), 'control_contributions': control_contributions.numpy(), 'channel_names': channel_names, 'control_names': control_names, 'model_params': { 'n_regions': X_media.shape[0], 'n_weeks': X_media.shape[1], 'n_media_channels': X_media.shape[2], 'n_control_channels': X_control.shape[2], 'padding_weeks': padding_weeks } } if verbose: logger.info(f"\n TRAINING COMPLETE!") logger.info(f" Final RMSE: {final_rmse:,.0f} ({relative_rmse:.1f}%)") logger.info(f" R²: {final_r2:.3f}") logger.info("=" * 50) return model, results # Legacy function for backward compatibility
[docs] def train_unified_mmm(*args, **kwargs): """Legacy wrapper for train_mmm with unified pipeline. Use train_mmm instead.""" kwargs['use_unified_pipeline'] = True return train_mmm(*args, **kwargs)