Source code for deepcausalmmm.utils.device

"""
Device management utilities for DeepCausalMMM.

This module handles:
- GPU/CPU device selection
- Memory management
- Mixed precision training
- Multi-GPU support
"""

import torch
import logging
from typing import Union, Tuple, Optional

logger = logging.getLogger(__name__)


[docs] def get_device(device: Optional[str] = None) -> torch.device: """ Get the appropriate device for model training/inference. Args: device: Device specification ('auto', 'cpu', 'cuda', 'cuda:0', etc.) If None or 'auto', will use CUDA if available Returns: torch.device: Selected device """ if device is None or device == 'auto': device = 'cuda' if torch.cuda.is_available() else 'cpu' if device.startswith('cuda') and not torch.cuda.is_available(): logger.warning("CUDA requested but not available. Falling back to CPU.") device = 'cpu' device = torch.device(device) if device.type == 'cuda': # Log GPU info gpu_name = torch.cuda.get_device_name(device.index or 0) memory_allocated = torch.cuda.memory_allocated(device.index or 0) / 1024**3 memory_total = torch.cuda.get_device_properties(device.index or 0).total_memory / 1024**3 logger.info(f"Using GPU: {gpu_name}") logger.info(f"GPU Memory: {memory_allocated:.2f}GB used / {memory_total:.2f}GB total") else: logger.info("Using CPU") return device
[docs] def get_amp_settings( device: torch.device, mixed_precision: bool = True ) -> Tuple[torch.cuda.amp.GradScaler, bool]: """ Get Automatic Mixed Precision (AMP) settings. Args: device: Current device mixed_precision: Whether to enable mixed precision training Returns: Tuple of (gradient scaler, use mixed precision flag) """ use_amp = mixed_precision and device.type == 'cuda' scaler = torch.cuda.amp.GradScaler() if use_amp else None if use_amp: logger.info("Using Automatic Mixed Precision (AMP)") return scaler, use_amp
[docs] def move_to_device( data: Union[torch.Tensor, dict, list, tuple], device: torch.device ) -> Union[torch.Tensor, dict, list, tuple]: """ Recursively move data to specified device. Args: data: Data to move (can be tensor, dict, list, or tuple) device: Target device Returns: Data on target device """ if isinstance(data, torch.Tensor): return data.to(device) elif isinstance(data, dict): return {k: move_to_device(v, device) for k, v in data.items()} elif isinstance(data, (list, tuple)): return type(data)(move_to_device(x, device) for x in data) return data
[docs] def clear_gpu_memory(): """Clear GPU memory cache.""" if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("Cleared GPU memory cache")
[docs] class DeviceContext: """ Context manager for device management. Example: with DeviceContext(device='auto', mixed_precision=True) as ctx: model = model.to(ctx.device) for batch in dataloader: with ctx.autocast(): output = model(batch) """
[docs] def __init__( self, device: Optional[str] = None, mixed_precision: bool = True ): """ Initialize device context. Args: device: Device specification mixed_precision: Whether to use mixed precision """ self.device = get_device(device) self.scaler, self.use_amp = get_amp_settings(self.device, mixed_precision) self.autocast = torch.cuda.amp.autocast if self.use_amp else nullcontext
def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): clear_gpu_memory()
[docs] class nullcontext: """Null context manager for CPU fallback."""
[docs] def __init__(self, *args, **kwargs): pass
def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): pass