deepcausalmmm.utils
Utility functions for DeepCausalMMM.
- deepcausalmmm.utils.get_device(device: str | None = None) device[source]
Get the appropriate device for model training/inference.
- Parameters:
device – Device specification (‘auto’, ‘cpu’, ‘cuda’, ‘cuda:0’, etc.) If None or ‘auto’, will use CUDA if available
- Returns:
Selected device
- Return type:
- deepcausalmmm.utils.get_amp_settings(device: device, mixed_precision: bool = True) Tuple[GradScaler, bool][source]
Get Automatic Mixed Precision (AMP) settings.
- Parameters:
device – Current device
mixed_precision – Whether to enable mixed precision training
- Returns:
Tuple of (gradient scaler, use mixed precision flag)
- deepcausalmmm.utils.move_to_device(data: Tensor | dict | list | tuple, device: device) Tensor | dict | list | tuple[source]
Recursively move data to specified device.
- Parameters:
data – Data to move (can be tensor, dict, list, or tuple)
device – Target device
- Returns:
Data on target device
- class deepcausalmmm.utils.DeviceContext(device: str | None = None, mixed_precision: bool = True)[source]
Context manager for device management.
Example
- with DeviceContext(device=’auto’, mixed_precision=True) as ctx:
model = model.to(ctx.device) for batch in dataloader:
- with ctx.autocast():
output = model(batch)
- deepcausalmmm.utils.generate_synthetic_mmm_data(n_regions: int = 10, n_weeks: int = 52, n_media: int = 5, n_controls: int = 3, seed: int = 42)[source]
Simple wrapper to generate synthetic MMM data as a DataFrame.
- Parameters:
n_regions – Number of regions/DMAs
n_weeks – Number of weeks
n_media – Number of media channels
n_controls – Number of control variables
seed – Random seed for reproducibility
- Returns:
pandas DataFrame with synthetic MMM data
- class deepcausalmmm.utils.ConfigurableDataGenerator(config: Dict[str, Any] | None = None)[source]
Generate synthetic MMM data using configuration parameters.
All data generation parameters are driven by configuration to ensure consistency and reproducibility across examples and tests.
- __init__(config: Dict[str, Any] | None = None)[source]
Initialize the data generator.
- Parameters:
config – Configuration dictionary. If None, uses default config.
- generate_mmm_dataset(n_regions: int = 2, n_weeks: int = 104, n_media_channels: int = 5, n_control_channels: int = 3) Tuple[ndarray, ndarray, ndarray][source]
Generate a complete MMM dataset with realistic patterns.
- Parameters:
n_regions – Number of regions
n_weeks – Number of weeks
n_media_channels – Number of media channels
n_control_channels – Number of control variables
- Returns:
Tuple of (X_media, X_control, y) arrays
Modules
Config-driven synthetic data generator for DeepCausalMMM. |
|
Device management utilities for DeepCausalMMM. |