Source code for deepcausalmmm.core.visualization

"""
Reusable VisualizationManager class for creating consistent plots.
Eliminates code duplication and provides config-driven visualization.
"""

import numpy as np
import pandas as pd
import torch
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import networkx as nx
from typing import Dict, Any, List, Optional, Tuple
import os

import logging

logger = logging.getLogger('deepcausalmmm')

from deepcausalmmm.core.config import get_default_config


[docs] class VisualizationManager: """ Visualization manager for creating consistent plots in DeepCausalMMM analysis. Provides a unified interface for creating training progress, coefficient analysis, contribution plots, DAG visualizations, and other MMM-related charts. All plot parameters are driven by configuration for consistency. Parameters ---------- config : Dict[str, Any], optional Configuration dictionary. If None, uses default configuration. """
[docs] def __init__(self, config: Optional[Dict[str, Any]] = None): """ Initialize the visualization manager. Args: config: Configuration dictionary. If None, uses default config. """ self.config = config or get_default_config() self.viz_params = self._get_viz_params()
def _get_viz_params(self) -> Dict[str, Any]: """Get visualization parameters from config with defaults""" viz_config = self.config.get('visualization', {}) return { 'node_opacity': viz_config.get('node_opacity', 0.7), 'line_opacity': viz_config.get('line_opacity', 0.6), 'fill_opacity': viz_config.get('fill_opacity', 0.1), 'marker_size': viz_config.get('marker_size', 8), 'correlation_threshold': viz_config.get('correlation_threshold', 0.2), 'edge_width_multiplier': viz_config.get('edge_width_multiplier', 8), 'subplot_vertical_spacing': viz_config.get('subplot_vertical_spacing', 0.08), 'subplot_horizontal_spacing': viz_config.get('subplot_horizontal_spacing', 0.06), }
[docs] def create_training_progress_plot(self, train_losses: List[float], train_rmses: List[float], train_r2s: List[float], title: str = "Training Progress") -> go.Figure: """ Create a training progress plot with loss, RMSE, and R². Args: train_losses: Training losses over epochs train_rmses: Training RMSEs over epochs train_r2s: Training R² scores over epochs title: Plot title Returns: Plotly figure """ epochs = list(range(1, len(train_losses) + 1)) fig = make_subplots( rows=1, cols=3, subplot_titles=['Loss', 'RMSE', 'R²'], horizontal_spacing=self.viz_params['subplot_horizontal_spacing'] ) # Loss plot fig.add_trace( go.Scatter(x=epochs, y=train_losses, name='Loss', line=dict(color='blue'), opacity=self.viz_params['line_opacity']), row=1, col=1 ) # RMSE plot fig.add_trace( go.Scatter(x=epochs, y=train_rmses, name='RMSE', line=dict(color='red'), opacity=self.viz_params['line_opacity']), row=1, col=2 ) # R² plot fig.add_trace( go.Scatter(x=epochs, y=train_r2s, name='R²', line=dict(color='green'), opacity=self.viz_params['line_opacity']), row=1, col=3 ) fig.update_layout( title=title, showlegend=False, height=400 ) return fig
[docs] def create_actual_vs_predicted_plot(self, y_actual: np.ndarray, y_predicted: np.ndarray, title: str = "Actual vs Predicted", weeks: Optional[List[int]] = None) -> go.Figure: """ Create an actual vs predicted time series plot. Args: y_actual: Actual values y_predicted: Predicted values title: Plot title weeks: Optional week indices for x-axis Returns: Plotly figure """ if weeks is None: weeks = list(range(len(y_actual))) fig = go.Figure() # Actual values fig.add_trace(go.Scatter( x=weeks, y=y_actual, mode='lines', name='Actual', line=dict(color='blue', width=2), opacity=self.viz_params['line_opacity'] )) # Predicted values fig.add_trace(go.Scatter( x=weeks, y=y_predicted, mode='lines', name='Predicted', line=dict(color='red', width=2, dash='dot'), opacity=self.viz_params['line_opacity'] )) fig.update_layout( title=title, xaxis_title='Week', yaxis_title='Value', height=500, hovermode='x unified' ) return fig
[docs] def create_scatter_plot(self, x: np.ndarray, y: np.ndarray, title: str = "Scatter Plot", x_label: str = "X", y_label: str = "Y", color: str = 'blue') -> go.Figure: """ Create a scatter plot with perfect correlation line. Args: x: X values y: Y values title: Plot title x_label: X-axis label y_label: Y-axis label color: Marker color Returns: Plotly figure """ # Calculate R² from sklearn.metrics import r2_score r2 = r2_score(x, y) if len(np.unique(x)) > 1 else 0.0 fig = go.Figure() # Scatter plot fig.add_trace(go.Scatter( x=x, y=y, mode='markers', name=f'Data (R²={r2:.3f})', marker=dict( size=self.viz_params['marker_size'], color=color, opacity=self.viz_params['node_opacity'] ) )) # Perfect correlation line min_val, max_val = min(x.min(), y.min()), max(x.max(), y.max()) fig.add_trace(go.Scatter( x=[min_val, max_val], y=[min_val, max_val], mode='lines', name='Perfect Correlation', line=dict(color='gray', dash='dash', width=1), opacity=0.5 )) fig.update_layout( title=title, xaxis_title=x_label, yaxis_title=y_label, height=500 ) return fig
[docs] def create_waterfall_chart(self, categories: List[str], values: List[float], title: str = "Waterfall Chart") -> go.Figure: """ Create a proper waterfall chart using Plotly's go.Waterfall. Args: categories: Category names values: Values for each category title: Chart title Returns: Plotly figure """ fig = go.Figure(go.Waterfall( name="Contributions", orientation="v", measure=["relative"] * (len(categories) - 1) + ["total"], x=categories, textposition="outside", text=[f"{v:,.0f}" for v in values], y=values, connector={"line": {"color": "rgb(63, 63, 63)"}}, )) fig.update_layout( title=title, showlegend=False, height=600 ) return fig
[docs] def create_contribution_stacked_bar(self, media_contributions: np.ndarray, control_contributions: np.ndarray, baseline: np.ndarray, media_names: List[str], control_names: List[str], weeks: Optional[List[int]] = None, title: str = "Contributions Over Time") -> go.Figure: """ Create a stacked bar chart of contributions over time. Args: media_contributions: Media contributions [n_weeks, n_media] control_contributions: Control contributions [n_weeks, n_controls] baseline: Baseline values [n_weeks] media_names: Media channel names control_names: Control variable names weeks: Optional week indices title: Chart title Returns: Plotly figure """ if weeks is None: weeks = list(range(len(baseline))) fig = go.Figure() # Add baseline fig.add_trace(go.Bar( x=weeks, y=baseline, name='Baseline', marker_color='lightgray', opacity=self.viz_params['node_opacity'] )) # Add media contributions colors = px.colors.qualitative.Set3 for i, name in enumerate(media_names): fig.add_trace(go.Bar( x=weeks, y=media_contributions[:, i], name=f'Media: {name}', marker_color=colors[i % len(colors)], opacity=self.viz_params['node_opacity'] )) # Add control contributions for i, name in enumerate(control_names): fig.add_trace(go.Bar( x=weeks, y=control_contributions[:, i], name=f'Control: {name}', marker_color=colors[(len(media_names) + i) % len(colors)], opacity=self.viz_params['node_opacity'] )) fig.update_layout( title=title, xaxis_title='Week', yaxis_title='Contribution', barmode='stack', height=600, legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1) ) return fig
[docs] def create_dag_network_plot(self, adjacency_matrix: np.ndarray, node_names: List[str], title: str = "DAG Network") -> go.Figure: """ Create a DAG network visualization. Args: adjacency_matrix: Adjacency matrix [n_nodes, n_nodes] node_names: Node names title: Plot title Returns: Plotly figure """ # Create network graph G = nx.DiGraph() # Add nodes for i, name in enumerate(node_names): G.add_node(i, name=name, label=name) # Add edges based on adjacency matrix for i in range(len(node_names)): for j in range(len(node_names)): if adjacency_matrix[i, j] > self.viz_params['correlation_threshold']: G.add_edge(i, j, weight=adjacency_matrix[i, j]) # Create layout pos = nx.spring_layout(G, k=3, iterations=50) # Create Plotly figure fig = go.Figure() # Add edges for edge in G.edges(): x0, y0 = pos[edge[0]] x1, y1 = pos[edge[1]] weight = G[edge[0]][edge[1]]['weight'] fig.add_trace(go.Scatter( x=[x0, x1, None], y=[y0, y1, None], mode='lines', line=dict( width=weight * self.viz_params['edge_width_multiplier'], color=f'rgba(125,125,125,{self.viz_params["line_opacity"]})' ), hoverinfo='none', showlegend=False )) # Add nodes node_x = [pos[node][0] for node in G.nodes()] node_y = [pos[node][1] for node in G.nodes()] node_text = [node_names[node] for node in G.nodes()] fig.add_trace(go.Scatter( x=node_x, y=node_y, mode='markers+text', marker=dict( size=self.viz_params['marker_size'] * 5, # Scale up for nodes color='lightblue', line=dict(width=2, color='darkblue') ), text=node_text, textposition="middle center", textfont=dict(size=10, color='black'), hovertemplate='<b>%{text}</b><extra></extra>', name='Nodes' )) fig.update_layout( title=title, showlegend=False, xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), height=600 ) return fig
[docs] def create_dag_heatmap(self, adjacency_matrix: np.ndarray, node_names: List[str], title: str = "DAG Adjacency Matrix") -> go.Figure: """ Create a DAG adjacency matrix heatmap. Args: adjacency_matrix: Adjacency matrix [n_nodes, n_nodes] node_names: Node names title: Plot title Returns: Plotly figure """ fig = go.Figure(data=go.Heatmap( z=adjacency_matrix, x=node_names, y=node_names, colorscale='RdYlBu_r', hoverongaps=False, hovertemplate='From: %{y}<br>To: %{x}<br>Strength: %{z:.3f}<extra></extra>' )) fig.update_layout( title=title, xaxis_title='Influenced Channel', yaxis_title='Influencing Channel', height=600 ) return fig
[docs] def save_plot(self, fig: go.Figure, filepath: str, include_plotlyjs: str = 'cdn') -> bool: """ Save a Plotly figure to HTML file. Args: fig: Plotly figure to save filepath: Output file path include_plotlyjs: How to include Plotly.js ('cdn', 'inline', etc.) Returns: True if successful, False otherwise """ try: # Ensure directory exists os.makedirs(os.path.dirname(filepath), exist_ok=True) # Save figure fig.write_html( filepath, include_plotlyjs=include_plotlyjs, config={'displayModeBar': True, 'responsive': True} ) return True except Exception as e: logger.warning(f" Failed to save plot to {filepath}: {e}") return False
[docs] def create_comprehensive_dashboard(self, results: Dict[str, Any], output_dir: str = "dashboard_comprehensive") -> List[Tuple[str, str]]: """ Create a comprehensive dashboard with multiple plots. Args: results: Training results dictionary output_dir: Output directory for plots Returns: List of (plot_name, filepath) tuples for created plots """ plots_created = [] # Ensure output directory exists os.makedirs(output_dir, exist_ok=True) # Extract data from results model = results['model'] train_losses = results.get('train_losses', []) train_rmses = results.get('train_rmses', []) train_r2s = results.get('train_r2s', []) # 1. Training Progress if train_losses: fig = self.create_training_progress_plot( train_losses, train_rmses, train_r2s, "Training Progress" ) filepath = os.path.join(output_dir, "training_progress.html") if self.save_plot(fig, filepath): plots_created.append(("Training Progress", filepath)) # Add more plots as needed... return plots_created