deepcausalmmm.utils.device

Device management utilities for DeepCausalMMM.

This module handles: - GPU/CPU device selection - Memory management - Mixed precision training - Multi-GPU support

Functions

clear_gpu_memory()

Clear GPU memory cache.

get_amp_settings(device[, mixed_precision])

Get Automatic Mixed Precision (AMP) settings.

get_device([device])

Get the appropriate device for model training/inference.

move_to_device(data, device)

Recursively move data to specified device.

Classes

DeviceContext([device, mixed_precision])

Context manager for device management.

nullcontext(*args, **kwargs)

Null context manager for CPU fallback.

deepcausalmmm.utils.device.get_device(device: str | None = None) device[source]

Get the appropriate device for model training/inference.

Parameters:

device – Device specification (‘auto’, ‘cpu’, ‘cuda’, ‘cuda:0’, etc.) If None or ‘auto’, will use CUDA if available

Returns:

Selected device

Return type:

torch.device

deepcausalmmm.utils.device.get_amp_settings(device: device, mixed_precision: bool = True) Tuple[GradScaler, bool][source]

Get Automatic Mixed Precision (AMP) settings.

Parameters:
  • device – Current device

  • mixed_precision – Whether to enable mixed precision training

Returns:

Tuple of (gradient scaler, use mixed precision flag)

deepcausalmmm.utils.device.move_to_device(data: Tensor | dict | list | tuple, device: device) Tensor | dict | list | tuple[source]

Recursively move data to specified device.

Parameters:
  • data – Data to move (can be tensor, dict, list, or tuple)

  • device – Target device

Returns:

Data on target device

deepcausalmmm.utils.device.clear_gpu_memory()[source]

Clear GPU memory cache.

class deepcausalmmm.utils.device.DeviceContext(device: str | None = None, mixed_precision: bool = True)[source]

Context manager for device management.

Example

with DeviceContext(device=’auto’, mixed_precision=True) as ctx:

model = model.to(ctx.device) for batch in dataloader:

with ctx.autocast():

output = model(batch)

__init__(device: str | None = None, mixed_precision: bool = True)[source]

Initialize device context.

Parameters:
  • device – Device specification

  • mixed_precision – Whether to use mixed precision

class deepcausalmmm.utils.device.nullcontext(*args, **kwargs)[source]

Null context manager for CPU fallback.

__init__(*args, **kwargs)[source]