deepcausalmmm.core.dag_model
DAG model implementation with Node-to-Edge and Edge-to-Node transformations.
This module implements the DAG-based neural network architecture with: - NodeToEdge: Transform node features to edge features - EdgeToNode: Aggregate edge features back to nodes - DAGConstraint: Enforce acyclicity in the graph structure
Classes
|
Enforce acyclicity in the graph structure using strict triangular constraint. |
|
Complete DAG-based model combining NodeToEdge and EdgeToNode transformations. |
|
Aggregate edge features back to nodes. |
|
Transform node features to edge features using attention mechanism. |
- class deepcausalmmm.core.dag_model.NodeToEdge(node_dim: int, edge_dim: int)[source]
Transform node features to edge features using attention mechanism.
- class deepcausalmmm.core.dag_model.EdgeToNode(edge_dim: int, node_dim: int)[source]
Aggregate edge features back to nodes.
- __init__(edge_dim: int, node_dim: int)[source]
Initialize the edge to node transformation.
- Parameters:
edge_dim – Dimension of edge features
node_dim – Dimension of node features
- forward(edges: Tensor, nodes: Tensor, adj_matrix: Tensor) Tensor[source]
Aggregate edge features to update node features.
- Parameters:
edges – Edge features [batch_size, n_nodes, n_nodes, edge_dim]
nodes – Node features [batch_size, n_nodes, 1]
adj_matrix – Adjacency matrix [n_nodes, n_nodes]
- Returns:
Updated node features [batch_size, n_nodes, 1]
- class deepcausalmmm.core.dag_model.DAGConstraint(n_nodes: int, sparsity_weight: float = 0.1, temperature: float = 1.0)[source]
Enforce acyclicity in the graph structure using strict triangular constraint.
- __init__(n_nodes: int, sparsity_weight: float = 0.1, temperature: float = 1.0)[source]
Initialize the DAG constraint module.
- Parameters:
n_nodes – Number of nodes in the graph
sparsity_weight – Weight for the sparsity penalty
temperature – Initial temperature for Gumbel-Softmax
- gumbel_softmax(logits: Tensor, tau: float) Tensor[source]
Gumbel-Softmax sampling with straight-through gradients.
- Parameters:
logits – Input logits
tau – Temperature parameter
- Returns:
Sampled probabilities
- get_adjacency() Tensor[source]
Get the current adjacency matrix using Gumbel-Softmax sampling. This enforces unidirectional edges and allows learning discrete structure.
- class deepcausalmmm.core.dag_model.DAGModel(n_nodes: int, node_dim: int, edge_dim: int, n_layers: int = 3, sparsity_weight: float = 0.1)[source]
Complete DAG-based model combining NodeToEdge and EdgeToNode transformations.
- __init__(n_nodes: int, node_dim: int, edge_dim: int, n_layers: int = 3, sparsity_weight: float = 0.1)[source]
Initialize the DAG model.
- Parameters:
n_nodes – Number of nodes in the graph
node_dim – Dimension of node features
edge_dim – Dimension of edge features
n_layers – Number of message passing layers
sparsity_weight – Weight for the sparsity penalty