Source code for deepcausalmmm.core.dag_model

"""
DAG model implementation with Node-to-Edge and Edge-to-Node transformations.

This module implements the DAG-based neural network architecture with:
- NodeToEdge: Transform node features to edge features
- EdgeToNode: Aggregate edge features back to nodes
- DAGConstraint: Enforce acyclicity in the graph structure
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, Dict, Any
import numpy as np

[docs] class NodeToEdge(nn.Module): """Transform node features to edge features using attention mechanism."""
[docs] def __init__(self, node_dim: int, edge_dim: int): """ Initialize the node to edge transformation. Args: node_dim: Dimension of node features edge_dim: Dimension of edge features """ super().__init__() self.node_dim = node_dim self.edge_dim = edge_dim # Transformations for source and target nodes (wider networks) self.source_transform = nn.Sequential( nn.Linear(1, 64), # Increased width nn.ReLU(), nn.Linear(64, edge_dim) ) self.target_transform = nn.Sequential( nn.Linear(1, 64), # Increased width nn.ReLU(), nn.Linear(64, edge_dim) ) # Edge attention with wider network and stronger initialization self.edge_attention = nn.Sequential( nn.Linear(2 * edge_dim, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU(), nn.Linear(64, 1) ) # Initialize weights with larger values for layer in self.source_transform: if isinstance(layer, nn.Linear): nn.init.xavier_uniform_(layer.weight, gain=1.4) for layer in self.target_transform: if isinstance(layer, nn.Linear): nn.init.xavier_uniform_(layer.weight, gain=1.4) for layer in self.edge_attention: if isinstance(layer, nn.Linear): nn.init.xavier_uniform_(layer.weight, gain=1.4)
[docs] def forward(self, nodes: torch.Tensor, adj_matrix: torch.Tensor) -> torch.Tensor: """ Transform node features to edge features. Args: nodes: Node features [batch_size, n_nodes, 1] adj_matrix: Adjacency matrix [n_nodes, n_nodes] Returns: Edge features [batch_size, n_nodes, n_nodes, edge_dim] """ B, N, _ = nodes.shape # Transform source and target nodes source_h = self.source_transform(nodes) # [B, N, edge_dim] target_h = self.target_transform(nodes) # [B, N, edge_dim] # Compute edge features for all pairs source_e = source_h.unsqueeze(2).expand(-1, -1, N, -1) # [B, N, N, edge_dim] target_e = target_h.unsqueeze(1).expand(-1, N, -1, -1) # [B, N, N, edge_dim] # Concatenate and compute attention edge_input = torch.cat([source_e, target_e], dim=-1) # [B, N, N, 2*edge_dim] edge_attn = self.edge_attention(edge_input) # [B, N, N, 1] # Apply adjacency as multiplicative weight with stronger influence A = adj_matrix.unsqueeze(0).unsqueeze(-1) # [1, N, N, 1] edge_weights = torch.sigmoid(edge_attn) * A # Now magnitude matters # Compute edge features with residual connection edge_features = edge_weights * (source_e + target_e) return edge_features
[docs] class EdgeToNode(nn.Module): """Aggregate edge features back to nodes."""
[docs] def __init__(self, edge_dim: int, node_dim: int): """ Initialize the edge to node transformation. Args: edge_dim: Dimension of edge features node_dim: Dimension of node features """ super().__init__() self.edge_dim = edge_dim self.node_dim = node_dim # Edge aggregation with wider network self.edge_aggregate = nn.Sequential( nn.Linear(edge_dim, 64), nn.ReLU(), nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, 1) ) # Node update with wider network self.node_update = nn.Sequential( nn.Linear(2, 64), # Wider network nn.ReLU(), nn.Linear(64, 32), nn.ReLU(), nn.Linear(32, 1) ) # Skip connection scaling factor (learnable) self.skip_scale = nn.Parameter(torch.ones(1)) # Initialize weights with larger values for layer in self.edge_aggregate: if isinstance(layer, nn.Linear): nn.init.xavier_uniform_(layer.weight, gain=1.4) for layer in self.node_update: if isinstance(layer, nn.Linear): nn.init.xavier_uniform_(layer.weight, gain=1.4)
[docs] def forward(self, edges: torch.Tensor, nodes: torch.Tensor, adj_matrix: torch.Tensor) -> torch.Tensor: """ Aggregate edge features to update node features. Args: edges: Edge features [batch_size, n_nodes, n_nodes, edge_dim] nodes: Node features [batch_size, n_nodes, 1] adj_matrix: Adjacency matrix [n_nodes, n_nodes] Returns: Updated node features [batch_size, n_nodes, 1] """ # Aggregate incoming edges edge_aggr = self.edge_aggregate(edges) # [B, N, N, 1] # Apply adjacency mask with stronger influence mask = adj_matrix.unsqueeze(0).unsqueeze(-1) # [1, N, N, 1] edge_aggr = edge_aggr * mask # Sum over neighbors node_update = edge_aggr.sum(dim=2) # [B, N, 1] # Combine with original node features combined = torch.cat([nodes, node_update], dim=-1) # [B, N, 2] transformed = self.node_update(combined) # [B, N, 1] # Add skip connection with learnable scaling skip_scale = torch.sigmoid(self.skip_scale) # Bound between 0 and 1 updated_nodes = transformed + skip_scale * nodes return updated_nodes
[docs] class DAGConstraint(nn.Module): """Enforce acyclicity in the graph structure using strict triangular constraint."""
[docs] def __init__(self, n_nodes: int, sparsity_weight: float = 0.1, temperature: float = 1.0): """ Initialize the DAG constraint module. Args: n_nodes: Number of nodes in the graph sparsity_weight: Weight for the sparsity penalty temperature: Initial temperature for Gumbel-Softmax """ super().__init__() self.n_nodes = n_nodes self.sparsity_weight = sparsity_weight self.temperature = temperature # Initialize adjacency logits with strong negative bias for sparsity self.adj_logits = nn.Parameter(torch.randn(n_nodes, n_nodes) * 0.1 - 3.0) # Create mask for strictly upper triangular matrix mask = torch.triu(torch.ones(n_nodes, n_nodes), diagonal=1) self.register_buffer('triangular_mask', mask.bool())
[docs] def gumbel_softmax(self, logits: torch.Tensor, tau: float) -> torch.Tensor: """ Gumbel-Softmax sampling with straight-through gradients. Args: logits: Input logits tau: Temperature parameter Returns: Sampled probabilities """ if self.training: # Sample from Gumbel distribution g = -torch.log(-torch.log(torch.rand_like(logits) + 1e-9) + 1e-9) # Gumbel-Softmax with straight-through estimator y_soft = torch.sigmoid((logits + g) / tau) # Straight-through: use hard values in forward pass but soft in backward y_hard = (y_soft > 0.5).float() y = y_hard.detach() - y_soft.detach() + y_soft else: # During evaluation, use deterministic thresholding y = (torch.sigmoid(logits / tau) > 0.5).float() return y
[docs] def get_adjacency(self) -> torch.Tensor: """ Get the current adjacency matrix using Gumbel-Softmax sampling. This enforces unidirectional edges and allows learning discrete structure. """ # Apply Gumbel-Softmax sampling with current temperature adj = self.gumbel_softmax(self.adj_logits, self.temperature) # Apply mask to ensure strictly upper triangular form adj = adj * self.triangular_mask return adj
[docs] def update_temperature(self, epoch: int, total_epochs: int, min_temp: float = 0.1): """ Update temperature using exponential decay schedule. Args: epoch: Current epoch total_epochs: Total number of epochs min_temp: Minimum temperature """ # Exponential decay is more aggressive than cosine progress = epoch / total_epochs self.temperature = max( min_temp, np.exp(-10 * progress) # Even faster decay )
[docs] def dag_loss(self) -> torch.Tensor: """ Compute the DAG constraint loss with sparsity penalty. With strictly upper triangular form, we only need sparsity penalty as acyclicity is guaranteed by construction. Returns: Loss term combining sparsity and entropy """ adj = self.get_adjacency() # L1 sparsity with stronger penalty sparsity_loss = torch.sum(torch.abs(adj)) # Add entropy penalty to encourage binary decisions probs = torch.sigmoid(self.adj_logits) entropy_loss = -torch.mean( probs * torch.log(probs + 1e-9) + (1 - probs) * torch.log(1 - probs + 1e-9) ) # Add edge diversity penalty to encourage different patterns edge_diversity = -torch.std(adj[self.triangular_mask]) return self.sparsity_weight * ( sparsity_loss + 0.1 * entropy_loss + 0.2 * edge_diversity )
[docs] class DAGModel(nn.Module): """ Complete DAG-based model combining NodeToEdge and EdgeToNode transformations. """
[docs] def __init__( self, n_nodes: int, node_dim: int, edge_dim: int, n_layers: int = 3, sparsity_weight: float = 0.1 ): """ Initialize the DAG model. Args: n_nodes: Number of nodes in the graph node_dim: Dimension of node features edge_dim: Dimension of edge features n_layers: Number of message passing layers sparsity_weight: Weight for the sparsity penalty """ super().__init__() self.n_nodes = n_nodes self.node_dim = node_dim self.edge_dim = edge_dim self.n_layers = n_layers # DAG constraint self.dag = DAGConstraint(n_nodes, sparsity_weight) # Node and edge transformations self.node_to_edge = NodeToEdge(node_dim, edge_dim) self.edge_to_node = EdgeToNode(edge_dim, node_dim) # Node embedding self.node_embedding = nn.Linear(node_dim, node_dim) nn.init.xavier_uniform_(self.node_embedding.weight) # Output projection self.output = nn.Sequential( nn.Linear(node_dim, node_dim), nn.ReLU(), nn.Linear(node_dim, node_dim) ) for layer in self.output: if isinstance(layer, nn.Linear): nn.init.xavier_uniform_(layer.weight)
[docs] def forward(self, nodes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass through the DAG model. Args: nodes: Input node features [batch_size, n_nodes, node_dim] Returns: Tuple of (output node features, adjacency matrix) """ adj = self.dag.get_adjacency() # Initial node embedding h = self.node_embedding(nodes) # Message passing layers for _ in range(self.n_layers): # Node to edge edge_features = self.node_to_edge(h, adj) # Edge to node h = self.edge_to_node(edge_features, h, adj) # Output projection out = self.output(h) return out, adj
[docs] def get_dag_loss(self) -> torch.Tensor: """Get the DAG constraint loss.""" return self.dag.dag_loss()