deepcausalmmm.core.visualization

Reusable VisualizationManager class for creating consistent plots. Eliminates code duplication and provides config-driven visualization.

Classes

VisualizationManager([config])

Visualization manager for creating consistent plots in DeepCausalMMM analysis.

class deepcausalmmm.core.visualization.VisualizationManager(config: Dict[str, Any] | None = None)[source]

Visualization manager for creating consistent plots in DeepCausalMMM analysis.

Provides a unified interface for creating training progress, coefficient analysis, contribution plots, DAG visualizations, and other MMM-related charts. All plot parameters are driven by configuration for consistency.

Parameters:

config (Dict[str, Any], optional) – Configuration dictionary. If None, uses default configuration.

__init__(config: Dict[str, Any] | None = None)[source]

Initialize the visualization manager.

Parameters:

config – Configuration dictionary. If None, uses default config.

create_training_progress_plot(train_losses: List[float], train_rmses: List[float], train_r2s: List[float], title: str = 'Training Progress') Figure[source]

Create a training progress plot with loss, RMSE, and R².

Parameters:
  • train_losses – Training losses over epochs

  • train_rmses – Training RMSEs over epochs

  • train_r2s – Training R² scores over epochs

  • title – Plot title

Returns:

Plotly figure

create_actual_vs_predicted_plot(y_actual: ndarray, y_predicted: ndarray, title: str = 'Actual vs Predicted', weeks: List[int] | None = None) Figure[source]

Create an actual vs predicted time series plot.

Parameters:
  • y_actual – Actual values

  • y_predicted – Predicted values

  • title – Plot title

  • weeks – Optional week indices for x-axis

Returns:

Plotly figure

create_scatter_plot(x: ndarray, y: ndarray, title: str = 'Scatter Plot', x_label: str = 'X', y_label: str = 'Y', color: str = 'blue') Figure[source]

Create a scatter plot with perfect correlation line.

Parameters:
  • x – X values

  • y – Y values

  • title – Plot title

  • x_label – X-axis label

  • y_label – Y-axis label

  • color – Marker color

Returns:

Plotly figure

create_waterfall_chart(categories: List[str], values: List[float], title: str = 'Waterfall Chart') Figure[source]

Create a proper waterfall chart using Plotly’s go.Waterfall.

Parameters:
  • categories – Category names

  • values – Values for each category

  • title – Chart title

Returns:

Plotly figure

create_contribution_stacked_bar(media_contributions: ndarray, control_contributions: ndarray, baseline: ndarray, media_names: List[str], control_names: List[str], weeks: List[int] | None = None, title: str = 'Contributions Over Time') Figure[source]

Create a stacked bar chart of contributions over time.

Parameters:
  • media_contributions – Media contributions [n_weeks, n_media]

  • control_contributions – Control contributions [n_weeks, n_controls]

  • baseline – Baseline values [n_weeks]

  • media_names – Media channel names

  • control_names – Control variable names

  • weeks – Optional week indices

  • title – Chart title

Returns:

Plotly figure

create_dag_network_plot(adjacency_matrix: ndarray, node_names: List[str], title: str = 'DAG Network') Figure[source]

Create a DAG network visualization.

Parameters:
  • adjacency_matrix – Adjacency matrix [n_nodes, n_nodes]

  • node_names – Node names

  • title – Plot title

Returns:

Plotly figure

create_dag_heatmap(adjacency_matrix: ndarray, node_names: List[str], title: str = 'DAG Adjacency Matrix') Figure[source]

Create a DAG adjacency matrix heatmap.

Parameters:
  • adjacency_matrix – Adjacency matrix [n_nodes, n_nodes]

  • node_names – Node names

  • title – Plot title

Returns:

Plotly figure

save_plot(fig: Figure, filepath: str, include_plotlyjs: str = 'cdn') bool[source]

Save a Plotly figure to HTML file.

Parameters:
  • fig – Plotly figure to save

  • filepath – Output file path

  • include_plotlyjs – How to include Plotly.js (‘cdn’, ‘inline’, etc.)

Returns:

True if successful, False otherwise

create_comprehensive_dashboard(results: Dict[str, Any], output_dir: str = 'dashboard_comprehensive') List[Tuple[str, str]][source]

Create a comprehensive dashboard with multiple plots.

Parameters:
  • results – Training results dictionary

  • output_dir – Output directory for plots

Returns:

List of (plot_name, filepath) tuples for created plots