"""
Post-processing utilities for DAG visualization and analysis.
.. deprecated:: 1.0.0
This module is deprecated and will be removed in v2.0.0.
Please use the modern VisualizationManager class instead:
from deepcausalmmm.core.visualization import VisualizationManager
viz_manager = VisualizationManager(config)
viz_manager.create_dag_network_plot(...)
viz_manager.create_dag_heatmap_plot(...)
"""
import torch
import numpy as np
import plotly.graph_objects as go
from typing import List, Optional, Dict, Any
import warnings
[docs]
def plot_dag_structure(
adjacency_matrix: torch.Tensor,
channel_names: Optional[List[str]] = None,
threshold: float = 0.1,
title: str = "Media Channel DAG Structure"
) -> go.Figure:
"""
.. deprecated:: 1.0.0
This function is deprecated. Use VisualizationManager.create_dag_network_plot() instead.
"""
warnings.warn(
"plot_dag_structure() is deprecated and will be removed in v2.0.0. "
"Please use VisualizationManager.create_dag_network_plot() instead.",
DeprecationWarning,
stacklevel=2
)
"""
Visualize the DAG structure using Plotly.
Args:
adjacency_matrix: Adjacency matrix from the DAG model [n_nodes, n_nodes]
channel_names: List of channel names. If None, uses indices
threshold: Threshold for edge visibility
title: Plot title
Returns:
Plotly figure object
"""
# Convert to numpy for processing
adj_matrix = adjacency_matrix.detach().cpu().numpy()
n_nodes = adj_matrix.shape[0]
if channel_names is None:
channel_names = [f"Channel {i+1}" for i in range(n_nodes)]
# Create node positions in a circular layout
angles = np.linspace(0, 2*np.pi, n_nodes, endpoint=False)
radius = 1
node_x = radius * np.cos(angles)
node_y = radius * np.sin(angles)
# Create edges (arrows)
edge_x = []
edge_y = []
edge_text = []
for i in range(n_nodes):
for j in range(n_nodes):
if adj_matrix[i, j] > threshold:
# Calculate arrow
start_x, start_y = node_x[i], node_y[i]
end_x, end_y = node_x[j], node_y[j]
# Add edge with arrow
edge_x.extend([start_x, end_x, None])
edge_y.extend([start_y, end_y, None])
# Add edge weight text
edge_text.append(f"{channel_names[i]} → {channel_names[j]}: {adj_matrix[i,j]:.3f}")
# Create figure
fig = go.Figure()
# Add edges
fig.add_trace(go.Scatter(
x=edge_x,
y=edge_y,
mode='lines+text',
line=dict(width=1, color='gray'),
hoverinfo='text',
text=edge_text,
name='Edges'
))
# Add nodes
fig.add_trace(go.Scatter(
x=node_x,
y=node_y,
mode='markers+text',
marker=dict(
size=30,
color='lightblue',
line=dict(width=2, color='darkblue')
),
text=channel_names,
textposition="middle center",
hoverinfo='text',
name='Channels'
))
# Update layout
fig.update_layout(
title=dict(
text=title,
x=0.5,
xanchor='center'
),
showlegend=False,
xaxis=dict(
showgrid=False,
zeroline=False,
showticklabels=False
),
yaxis=dict(
showgrid=False,
zeroline=False,
showticklabels=False
),
plot_bgcolor='white',
width=800,
height=800
)
return fig
[docs]
def analyze_dag_structure(
adjacency_matrix: torch.Tensor,
channel_names: Optional[List[str]] = None,
threshold: float = 0.1
) -> Dict[str, Any]:
"""
.. deprecated:: 1.0.0
This function is deprecated. Use VisualizationManager for DAG analysis instead.
"""
warnings.warn(
"analyze_dag_structure() is deprecated and will be removed in v2.0.0. "
"Please use VisualizationManager for DAG analysis instead.",
DeprecationWarning,
stacklevel=2
)
"""
Analyze the DAG structure and return key metrics.
Args:
adjacency_matrix: Adjacency matrix from the DAG model
channel_names: List of channel names
threshold: Threshold for edge significance
Returns:
Dictionary containing analysis results
"""
adj_matrix = adjacency_matrix.detach().cpu().numpy()
n_nodes = adj_matrix.shape[0]
if channel_names is None:
channel_names = [f"Channel {i+1}" for i in range(n_nodes)]
# Initialize results
results = {
'n_edges': 0,
'avg_edge_weight': 0.0,
'max_edge_weight': 0.0,
'significant_edges': [],
'influential_channels': [],
'influenced_channels': []
}
# Count edges and compute metrics
significant_edges = adj_matrix > threshold
results['n_edges'] = significant_edges.sum()
if results['n_edges'] > 0:
results['avg_edge_weight'] = adj_matrix[significant_edges].mean()
results['max_edge_weight'] = adj_matrix.max()
# Find significant relationships
for i in range(n_nodes):
for j in range(n_nodes):
if adj_matrix[i,j] > threshold:
results['significant_edges'].append({
'from': channel_names[i],
'to': channel_names[j],
'weight': float(adj_matrix[i,j])
})
# Identify influential and influenced channels
out_degree = adj_matrix.sum(axis=1)
in_degree = adj_matrix.sum(axis=0)
# Top influential channels (high out-degree)
influential_idx = np.argsort(-out_degree)
results['influential_channels'] = [
{
'channel': channel_names[i],
'out_degree': float(out_degree[i])
}
for i in influential_idx if out_degree[i] > threshold
]
# Top influenced channels (high in-degree)
influenced_idx = np.argsort(-in_degree)
results['influenced_channels'] = [
{
'channel': channel_names[i],
'in_degree': float(in_degree[i])
}
for i in influenced_idx if in_degree[i] > threshold
]
return results