deepcausalmmm.core.dag_model

DAG model implementation with Node-to-Edge and Edge-to-Node transformations.

This module implements the DAG-based neural network architecture with: - NodeToEdge: Transform node features to edge features - EdgeToNode: Aggregate edge features back to nodes - DAGConstraint: Enforce acyclicity in the graph structure

Classes

DAGConstraint(n_nodes[, sparsity_weight, ...])

Enforce acyclicity in the graph structure using strict triangular constraint.

DAGModel(n_nodes, node_dim, edge_dim[, ...])

Complete DAG-based model combining NodeToEdge and EdgeToNode transformations.

EdgeToNode(edge_dim, node_dim)

Aggregate edge features back to nodes.

NodeToEdge(node_dim, edge_dim)

Transform node features to edge features using attention mechanism.

class deepcausalmmm.core.dag_model.NodeToEdge(node_dim: int, edge_dim: int)[source]

Transform node features to edge features using attention mechanism.

__init__(node_dim: int, edge_dim: int)[source]

Initialize the node to edge transformation.

Parameters:
  • node_dim – Dimension of node features

  • edge_dim – Dimension of edge features

forward(nodes: Tensor, adj_matrix: Tensor) Tensor[source]

Transform node features to edge features.

Parameters:
  • nodes – Node features [batch_size, n_nodes, 1]

  • adj_matrix – Adjacency matrix [n_nodes, n_nodes]

Returns:

Edge features [batch_size, n_nodes, n_nodes, edge_dim]

class deepcausalmmm.core.dag_model.EdgeToNode(edge_dim: int, node_dim: int)[source]

Aggregate edge features back to nodes.

__init__(edge_dim: int, node_dim: int)[source]

Initialize the edge to node transformation.

Parameters:
  • edge_dim – Dimension of edge features

  • node_dim – Dimension of node features

forward(edges: Tensor, nodes: Tensor, adj_matrix: Tensor) Tensor[source]

Aggregate edge features to update node features.

Parameters:
  • edges – Edge features [batch_size, n_nodes, n_nodes, edge_dim]

  • nodes – Node features [batch_size, n_nodes, 1]

  • adj_matrix – Adjacency matrix [n_nodes, n_nodes]

Returns:

Updated node features [batch_size, n_nodes, 1]

class deepcausalmmm.core.dag_model.DAGConstraint(n_nodes: int, sparsity_weight: float = 0.1, temperature: float = 1.0)[source]

Enforce acyclicity in the graph structure using strict triangular constraint.

__init__(n_nodes: int, sparsity_weight: float = 0.1, temperature: float = 1.0)[source]

Initialize the DAG constraint module.

Parameters:
  • n_nodes – Number of nodes in the graph

  • sparsity_weight – Weight for the sparsity penalty

  • temperature – Initial temperature for Gumbel-Softmax

gumbel_softmax(logits: Tensor, tau: float) Tensor[source]

Gumbel-Softmax sampling with straight-through gradients.

Parameters:
  • logits – Input logits

  • tau – Temperature parameter

Returns:

Sampled probabilities

get_adjacency() Tensor[source]

Get the current adjacency matrix using Gumbel-Softmax sampling. This enforces unidirectional edges and allows learning discrete structure.

update_temperature(epoch: int, total_epochs: int, min_temp: float = 0.1)[source]

Update temperature using exponential decay schedule.

Parameters:
  • epoch – Current epoch

  • total_epochs – Total number of epochs

  • min_temp – Minimum temperature

dag_loss() Tensor[source]

Compute the DAG constraint loss with sparsity penalty. With strictly upper triangular form, we only need sparsity penalty as acyclicity is guaranteed by construction.

Returns:

Loss term combining sparsity and entropy

class deepcausalmmm.core.dag_model.DAGModel(n_nodes: int, node_dim: int, edge_dim: int, n_layers: int = 3, sparsity_weight: float = 0.1)[source]

Complete DAG-based model combining NodeToEdge and EdgeToNode transformations.

__init__(n_nodes: int, node_dim: int, edge_dim: int, n_layers: int = 3, sparsity_weight: float = 0.1)[source]

Initialize the DAG model.

Parameters:
  • n_nodes – Number of nodes in the graph

  • node_dim – Dimension of node features

  • edge_dim – Dimension of edge features

  • n_layers – Number of message passing layers

  • sparsity_weight – Weight for the sparsity penalty

forward(nodes: Tensor) Tuple[Tensor, Tensor][source]

Forward pass through the DAG model.

Parameters:

nodes – Input node features [batch_size, n_nodes, node_dim]

Returns:

Tuple of (output node features, adjacency matrix)

get_dag_loss() Tensor[source]

Get the DAG constraint loss.