Visualization
Reusable VisualizationManager class for creating consistent plots. Eliminates code duplication and provides config-driven visualization.
- class deepcausalmmm.core.visualization.VisualizationManager(config: Dict[str, Any] | None = None)[source]
Bases:
objectVisualization manager for creating consistent plots in DeepCausalMMM analysis.
Provides a unified interface for creating training progress, coefficient analysis, contribution plots, DAG visualizations, and other MMM-related charts. All plot parameters are driven by configuration for consistency.
- Parameters:
config (Dict[str, Any], optional) – Configuration dictionary. If None, uses default configuration.
- __init__(config: Dict[str, Any] | None = None)[source]
Initialize the visualization manager.
- Parameters:
config – Configuration dictionary. If None, uses default config.
- create_training_progress_plot(train_losses: List[float], train_rmses: List[float], train_r2s: List[float], title: str = 'Training Progress') Figure[source]
Create a training progress plot with loss, RMSE, and R².
- Parameters:
train_losses – Training losses over epochs
train_rmses – Training RMSEs over epochs
train_r2s – Training R² scores over epochs
title – Plot title
- Returns:
Plotly figure
- create_actual_vs_predicted_plot(y_actual: ndarray, y_predicted: ndarray, title: str = 'Actual vs Predicted', weeks: List[int] | None = None) Figure[source]
Create an actual vs predicted time series plot.
- Parameters:
y_actual – Actual values
y_predicted – Predicted values
title – Plot title
weeks – Optional week indices for x-axis
- Returns:
Plotly figure
- create_scatter_plot(x: ndarray, y: ndarray, title: str = 'Scatter Plot', x_label: str = 'X', y_label: str = 'Y', color: str = 'blue') Figure[source]
Create a scatter plot with perfect correlation line.
- Parameters:
x – X values
y – Y values
title – Plot title
x_label – X-axis label
y_label – Y-axis label
color – Marker color
- Returns:
Plotly figure
- create_waterfall_chart(categories: List[str], values: List[float], title: str = 'Waterfall Chart') Figure[source]
Create a proper waterfall chart using Plotly’s go.Waterfall.
- Parameters:
categories – Category names
values – Values for each category
title – Chart title
- Returns:
Plotly figure
- create_contribution_stacked_bar(media_contributions: ndarray, control_contributions: ndarray, baseline: ndarray, media_names: List[str], control_names: List[str], weeks: List[int] | None = None, title: str = 'Contributions Over Time') Figure[source]
Create a stacked bar chart of contributions over time.
- Parameters:
media_contributions – Media contributions [n_weeks, n_media]
control_contributions – Control contributions [n_weeks, n_controls]
baseline – Baseline values [n_weeks]
media_names – Media channel names
control_names – Control variable names
weeks – Optional week indices
title – Chart title
- Returns:
Plotly figure
- create_dag_network_plot(adjacency_matrix: ndarray, node_names: List[str], title: str = 'DAG Network') Figure[source]
Create a DAG network visualization.
- Parameters:
adjacency_matrix – Adjacency matrix [n_nodes, n_nodes]
node_names – Node names
title – Plot title
- Returns:
Plotly figure
- create_dag_heatmap(adjacency_matrix: ndarray, node_names: List[str], title: str = 'DAG Adjacency Matrix') Figure[source]
Create a DAG adjacency matrix heatmap.
- Parameters:
adjacency_matrix – Adjacency matrix [n_nodes, n_nodes]
node_names – Node names
title – Plot title
- Returns:
Plotly figure
- save_plot(fig: Figure, filepath: str, include_plotlyjs: str = 'cdn') bool[source]
Save a Plotly figure to HTML file.
- Parameters:
fig – Plotly figure to save
filepath – Output file path
include_plotlyjs – How to include Plotly.js (‘cdn’, ‘inline’, etc.)
- Returns:
True if successful, False otherwise
- create_comprehensive_dashboard(results: Dict[str, Any], output_dir: str = 'dashboard_comprehensive') List[Tuple[str, str]][source]
Create a comprehensive dashboard with multiple plots.
- Parameters:
results – Training results dictionary
output_dir – Output directory for plots
- Returns:
List of (plot_name, filepath) tuples for created plots