Source code for deepcausalmmm.core.scaling
"""
Simple, proven scaling implementation that works reliably.
Based on the successful approach from dashboard_rmse_optimized.py.
"""
import torch
import numpy as np
from typing import Tuple, Dict, Optional, Any
from dataclasses import dataclass
from deepcausalmmm.core.config import get_default_config
[docs]
@dataclass
class SimpleScalingParams:
"""Store simple global scaling parameters."""
# Control scaling (standardization)
control_mean: torch.Tensor
control_std: torch.Tensor
# Target scaling (linear: y/y_mean)
# No additional params needed - y_mean_per_region stored in scaling_constants
# Media scaling (share-of-voice) - Optional, defaults to None
total_impressions: Optional[torch.Tensor] = None # For inverse transformation (not needed for per-timestep scaling)
[docs]
class SimpleGlobalScaler:
"""
Linear scaling approach (y/y_mean) for additive attribution.
Scaling features:
- Media: Share-of-voice scaling with outlier smoothing
- Control: Robust standardization with adaptive clipping
- Target: Linear scaling by region mean (y/y_mean) for additive decomposition
- Adaptive normalization with distribution-aware clipping
- Advanced outlier handling for extreme value stability
"""
[docs]
def __init__(self, config: Optional[Dict[str, Any]] = None):
"""Initialize the scaler with optional config parameters."""
self.fitted = False
self.params = None
self.config = config or get_default_config()
self.scaling_constants = self.config.get('scaling_constants', {})
[docs]
def fit(
self,
X_media: np.ndarray, # [n_regions, n_timesteps, n_channels]
X_control: np.ndarray, # [n_regions, n_timesteps, n_controls]
y: np.ndarray, # [n_regions, n_timesteps]
) -> None:
"""
Fit the scaler using simple global statistics.
Args:
X_media: Media variables [n_regions, n_timesteps, n_channels]
X_control: Control variables [n_regions, n_timesteps, n_controls]
y: Target variable [n_regions, n_timesteps]
"""
# Convert to tensors for consistent processing
X_media_tensor = torch.FloatTensor(X_media) if not isinstance(X_media, torch.Tensor) else X_media.float()
X_control_tensor = torch.FloatTensor(X_control) if not isinstance(X_control, torch.Tensor) else X_control.float()
# Store target statistics for precision (use Float64)
if isinstance(y, torch.Tensor):
self.target_mean = float(y.mean().item())
self.target_std = float(y.std().item())
else:
self.target_mean = float(np.mean(y))
self.target_std = float(np.std(y))
# Media: Share-of-voice scaling should be per-timestep (no stored parameters needed)
# Each timestep gets normalized by its own total impressions
# Control: Use robust statistics (median/IQR) to handle distribution shifts
control_flat = X_control_tensor.reshape(-1, X_control_tensor.shape[-1])
# Use median instead of mean (more robust to outliers)
control_median = torch.median(control_flat, dim=0)[0]
# Use IQR instead of std (more robust to distribution shifts)
q75 = torch.quantile(control_flat, 0.75, dim=0)
q25 = torch.quantile(control_flat, 0.25, dim=0)
control_iqr = q75 - q25
# Convert IQR to std-equivalent for compatibility (IQR ≈ 1.349 * std for normal dist)
control_mean = control_median.unsqueeze(0).unsqueeze(0)
iqr_factor = self.scaling_constants.get('iqr_to_std_factor', 1.349)
control_std = control_iqr.unsqueeze(0).unsqueeze(0) * iqr_factor
# Target scaling: Linear scaling by region mean - preserves additivity
y_tensor = torch.DoubleTensor(y) if not isinstance(y, torch.Tensor) else y.double()
# Linear scaling by region mean - preserves additive decomposition
# y_scaled = y / y_mean ensures components can be easily attributed in original space
# CRITICAL: Calculate and STORE y_mean_per_region from TRAINING data ONLY
# This prevents data leakage when transforming holdout data
y_mean_per_region = y_tensor.mean(dim=1, keepdim=True) # [n_regions, 1]
self.scaling_constants['y_mean_per_region'] = y_mean_per_region
# Store parameters (no total_impressions needed for share-of-voice)
self.params = SimpleScalingParams(
total_impressions=None, # Not needed for per-timestep share-of-voice
control_mean=control_mean,
control_std=control_std
)
self.fitted = True
[docs]
def transform(
self,
X_media: np.ndarray,
X_control: np.ndarray,
y: np.ndarray,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Transform data using fitted parameters.
Args:
X_media: Media variables [n_regions, n_timesteps, n_channels]
X_control: Control variables [n_regions, n_timesteps, n_controls]
y: Target variable [n_regions, n_timesteps]
Returns:
Tuple of (X_media_scaled, X_control_scaled, y_scaled)
"""
if not self.fitted:
raise ValueError("Scaler must be fitted before transform")
# Convert to tensors (use Float64 for target to avoid precision loss)
X_media_tensor = torch.FloatTensor(X_media) if not isinstance(X_media, torch.Tensor) else X_media.float()
X_control_tensor = torch.FloatTensor(X_control) if not isinstance(X_control, torch.Tensor) else X_control.float()
y_tensor = torch.DoubleTensor(y) if not isinstance(y, torch.Tensor) else y.double() # Use Float64 for target precision
# Media scaling: Share-of-voice approach (per-timestep normalization)
total_impressions = torch.sum(X_media_tensor, dim=2, keepdim=True) # [regions, timesteps, 1]
# Handle zero-total-impressions case more robustly
# If total impressions is 0, set all channels to equal share
zero_threshold = self.scaling_constants.get('zero_threshold', 1e-8)
zero_mask = (total_impressions <= zero_threshold)
n_channels = X_media_tensor.shape[2]
# Normal share-of-voice calculation
X_media_scaled = X_media_tensor / (total_impressions + zero_threshold)
# For zero total impressions, distribute equally
X_media_scaled = torch.where(
zero_mask.expand_as(X_media_tensor),
torch.ones_like(X_media_tensor) / n_channels,
X_media_scaled
)
# Control scaling: Robust standardization with clipping
X_control_scaled = (X_control_tensor - self.params.control_mean) / (self.params.control_std + 1e-8)
# Adaptive clipping based on data characteristics
# More aggressive clipping for extreme distribution shifts
control_std_ratio = torch.std(X_control_scaled) / torch.std(X_control_tensor)
extreme_threshold = self.scaling_constants.get('extreme_clip_threshold', 2.0)
if control_std_ratio > extreme_threshold: # Detected extreme distribution shift
clip_range = self.scaling_constants.get('aggressive_clip_range', 3.0)
else:
clip_range = self.scaling_constants.get('standard_clip_range', 5.0)
X_control_scaled = torch.clamp(X_control_scaled, min=-clip_range, max=clip_range)
# Target scaling: Linear scaling by region mean - preserves additivity
# CRITICAL FIX: Use STORED y_mean_per_region from training data (no data leakage!)
y_mean_per_region = self.scaling_constants.get('y_mean_per_region')
if y_mean_per_region is None:
raise ValueError("y_mean_per_region not found in scaling_constants. Must call fit() before transform().")
y_scaled = y_tensor / (y_mean_per_region + 1e-8) # Normalize by region mean FROM TRAINING
# Convert back to Float32 for model compatibility
y_scaled_float32 = y_scaled.float()
return X_media_scaled, X_control_scaled, y_scaled_float32
[docs]
def inverse_transform_target(
self,
y_scaled: torch.Tensor,
) -> torch.Tensor:
"""
Inverse transform target variable.
Args:
y_scaled: Scaled target [n_regions, n_timesteps]
Returns:
Original scale target
"""
if not self.fitted:
raise ValueError("Scaler must be fitted before inverse transform")
# Ensure input is Float64 for precision
if y_scaled.dtype != torch.float64:
y_scaled = y_scaled.double()
# Inverse linear scaling: multiply by region mean
y_mean_per_region = self.scaling_constants.get('y_mean_per_region')
if y_mean_per_region is None:
raise ValueError("y_mean_per_region not found in scaling_constants. Was the scaler fitted properly?")
y_orig = y_scaled * y_mean_per_region
# Ensure non-negative (visits can't be negative)
y_orig = torch.clamp(y_orig, min=0)
# Convert back to Float32 for consistency with model
return y_orig.float()
[docs]
def inverse_transform_contributions(
self,
media_contributions: torch.Tensor, # [n_regions, n_timesteps, n_channels]
baseline: torch.Tensor = None, # [n_regions, n_timesteps] - baseline in scaled space
control_contributions: torch.Tensor = None, # [n_regions, n_timesteps, n_controls] - in scaled space
seasonal_contributions: torch.Tensor = None, # [n_regions, n_timesteps] - in scaled space
trend_contributions: torch.Tensor = None, # [n_regions, n_timesteps] - in scaled space
prediction_scale: torch.Tensor = None, # Scalar or [1] - model's prediction_scale factor
) -> dict:
"""
Inverse transform ALL contributions to original scale using simple multiplication.
With linear scaling (y/y_mean), the inverse transform is straightforward:
component_orig = component_scaled * prediction_scale * y_mean_per_region
This preserves additivity: sum(components_orig) = prediction_orig
Args:
media_contributions: Media contributions in scaled space [regions, timesteps, channels]
baseline: Baseline in scaled space [regions, timesteps]
control_contributions: Control contributions in scaled space [regions, timesteps, controls]
seasonal_contributions: Seasonal contributions in scaled space [regions, timesteps]
trend_contributions: Trend contributions in scaled space [regions, timesteps]
prediction_scale: Model's prediction_scale factor (from F.softplus(self.prediction_scale))
Returns:
Dictionary with all contributions in original scale
"""
if not self.fitted:
raise ValueError("Scaler must be fitted before inverse transform")
# Get region means
y_mean_per_region = self.scaling_constants.get('y_mean_per_region')
if y_mean_per_region is None:
raise ValueError("y_mean_per_region not found in scaling_constants")
# If no prediction_scale provided, default to 1.0
if prediction_scale is None:
prediction_scale = torch.tensor(1.0)
# Ensure prediction_scale is scalar
if prediction_scale.numel() > 1:
prediction_scale = prediction_scale.squeeze()
results = {}
# Simple multiplication for each component
# component_orig = component_scaled * prediction_scale * y_mean
if baseline is not None:
results['baseline'] = baseline * prediction_scale * y_mean_per_region
# Media contributions (per channel)
results['media'] = media_contributions * prediction_scale * y_mean_per_region.unsqueeze(-1)
# Control contributions (per control)
if control_contributions is not None:
results['control'] = control_contributions * prediction_scale * y_mean_per_region.unsqueeze(-1)
# Seasonal contributions
if seasonal_contributions is not None:
results['seasonal'] = seasonal_contributions * prediction_scale * y_mean_per_region
# Trend contributions
if trend_contributions is not None:
results['trend'] = trend_contributions * prediction_scale * y_mean_per_region
return results
[docs]
def fit_transform(
self,
X_media: np.ndarray,
X_control: np.ndarray,
y: np.ndarray,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Fit the scaler and transform data in one step.
Args:
X_media: Media variables [n_regions, n_timesteps, n_channels]
X_control: Control variables [n_regions, n_timesteps, n_controls]
y: Target variable [n_regions, n_timesteps]
Returns:
Tuple of (X_media_scaled, X_control_scaled, y_scaled)
"""
self.fit(X_media, X_control, y)
return self.transform(X_media, X_control, y)
# Alias for backward compatibility and consistency
GlobalScaler = SimpleGlobalScaler