"""
Response curve fitting for Marketing Mix Modeling using Hill equation.
This module provides the ResponseCurveFit class for fitting saturation curves
to the relationship between media spend/impressions and predicted outcomes.
The implementation follows modern Python standards:
- Type hints for all parameters and return values
- Private methods prefixed with underscore (_)
- Keyword-only arguments for better API clarity
- Comprehensive docstrings in NumPy style
- Backward compatibility maintained for legacy code
"""
from typing import Optional, Literal, Tuple, List
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from scipy.optimize import curve_fit
from sklearn.metrics import r2_score
from tqdm import tqdm
import logging
logger = logging.getLogger('deepcausalmmm')
[docs]
class ResponseCurveFit:
"""
Fit Hill equation response curves to marketing mix model predictions.
The Hill equation models saturation effects:
y = bottom + (top - bottom) * x^slope / (saturation^slope + x^slope)
Parameters
----------
data : pd.DataFrame
DataFrame with columns: 'week_monday', 'spend', 'impressions', 'predicted'
For DMA-level: also needs 'dmacode' column
bottom_param : bool, default=False
Whether to fit a non-zero intercept (bottom parameter)
For MMM, typically False (response at zero spend = 0)
Modellevel : str, default='Overall'
'Overall': Single aggregated curve across all regions
'DMA': Separate curves for each DMA
Datecol : str, default='week_monday'
Name of the date column
Attributes
----------
top : float
Maximum response (saturation level)
bottom : float
Minimum response (typically 0)
saturation : float
Spend level at half-maximum response
slope : float
Steepness of the curve
r_2 : float
R-squared score of the fitted curve
equation : str
String representation of the fitted equation
figure : go.Figure
Plotly figure object (if generate_figure=True)
Examples
--------
>>> # Prepare data
>>> data = pd.DataFrame({
... 'week_monday': dates,
... 'spend': spend_values,
... 'impressions': impression_values,
... 'predicted': model_predictions
... })
>>>
>>> # Fit overall response curve
>>> fitter = ResponseCurveFit(data, Modellevel='Overall')
>>> fitter.fit_model(
... title="Response Curve",
... x_label="Impressions",
... y_label="Predicted Visits",
... generate_figure=True,
... save_figure=True,
... output_path='response_curve.html'
... )
>>> print(f"R²: {fitter.r_2:.3f}")
>>> print(f"Slope: {fitter.slope:.3f}")
"""
[docs]
def __init__(
self,
data: pd.DataFrame,
*,
bottom_param: bool = False,
model_level: Literal['Overall', 'DMA'] = 'Overall',
date_col: str = 'week_monday'
) -> None:
"""
Initialize ResponseCurveFit.
Parameters
----------
data : pd.DataFrame
Input data with required columns
bottom_param : bool, default=False
Whether to fit non-zero intercept
model_level : {'Overall', 'DMA'}, default='Overall'
Aggregation level for fitting
date_col : str, default='week_monday'
Name of date column
"""
self.data = data
self.bottom_param = bottom_param
self.model_level = model_level
self.date_col = date_col
# Backward compatibility
self.Modellevel = model_level
self.Datecol = date_col
def _hill_equation(self, X: np.ndarray, *params) -> np.ndarray:
"""
Hill equation for saturation modeling.
Parameters
----------
X : np.ndarray
Input values (spend/impressions)
params : tuple
(top, bottom, saturation, slope)
Returns
-------
np.ndarray
Predicted response values
"""
self.top = params[0]
self.bottom = params[1] if self.bottom_param else 0
self.saturation = params[2]
self.slope = params[3]
return self.bottom + (self.top - self.bottom) * X**self.slope / (
self.saturation**self.slope + X**self.slope
)
# Backward compatibility
[docs]
def Hill(self, X: np.ndarray, *params) -> np.ndarray:
"""Backward compatibility wrapper for _hill_equation."""
return self._hill_equation(X, *params)
def _get_initial_params(self, curve_fit_kws: dict) -> dict:
"""
Get initial parameter guesses and bounds.
Parameters
----------
curve_fit_kws : dict
Additional keyword arguments for scipy.optimize.curve_fit
Returns
-------
dict
Updated curve_fit_kws with p0 and bounds
"""
min_data = np.amin(self._y_data)
max_data = np.amax(self._y_data)
h = abs(max_data - min_data)
param_initial = [max_data, min_data, 0.5 * (self._X_data[-1] - self._X_data[0]), 1]
param_bounds = (
[max_data - 0.5 * h, min_data - 0.5 * h, self._X_data[0] * 0.1, 0.01],
[max_data + 0.5 * h, min_data + 0.5 * h, self._X_data[-1] * 10, 100],
)
curve_fit_kws.setdefault("p0", param_initial)
curve_fit_kws.setdefault("bounds", param_bounds)
return curve_fit_kws
def _fit_curve(self, curve_fit_kws: dict) -> List[float]:
"""
Fit the Hill curve to data.
Parameters
----------
curve_fit_kws : dict
Keyword arguments for scipy.optimize.curve_fit
Returns
-------
list
Fitted parameters [top, bottom, saturation, slope]
"""
curve_fit_kws = self._get_initial_params(curve_fit_kws)
popt, _ = curve_fit(self._hill_equation, self._X_data, self._y_data, **curve_fit_kws)
if not self.bottom_param:
popt[1] = 0
return [float(param) for param in popt]
# Backward compatibility
[docs]
def get_param(self, curve_fit_kws: dict) -> List[float]:
"""Backward compatibility wrapper for _fit_curve."""
return self._fit_curve(curve_fit_kws)
def _calculate_r2_and_plot(
self,
x_fit: np.ndarray,
y_fit: np.ndarray,
x_label: str,
y_label: str,
title: str,
sigfigs: int,
log_x: bool,
print_r_sqr: bool,
generate_figure: bool,
view_figure: bool,
*params,
) -> None:
"""
Calculate R² and generate visualization.
Parameters
----------
x_fit : np.ndarray
X values for fitted curve
y_fit : np.ndarray
Y values for fitted curve
x_label : str
X-axis label
y_label : str
Y-axis label
title : str
Plot title
sigfigs : int
Significant figures for equation display
log_x : bool
Whether to use log scale for x-axis
print_r_sqr : bool
Whether to print R² score
generate_figure : bool
Whether to generate visualization
view_figure : bool
Whether to display the figure
params : tuple
Fitted parameters
"""
corrected_y_data = self._hill_equation(self._X_data, *params)
self.r_2 = r2_score(self._y_data, corrected_y_data)
if generate_figure:
self.figure = go.Figure()
self.figure.add_trace(go.Scatter(
x=self._X_data, y=self._y_data,
name='Observed Data',
mode='markers'
))
self.figure.add_trace(go.Scatter(
x=x_fit, y=y_fit,
name='Fitted Model',
mode='lines'
))
self.figure.update_layout(
title=title,
xaxis_title=x_label,
yaxis_title=y_label,
legend_title="Legend",
width=1500,
height=900,
yaxis_zeroline=False,
xaxis_zeroline=False
)
if print_r_sqr:
logger.info(f" R² Score: {self.r_2:.4f}")
if view_figure:
self.figure.show()
# Backward compatibility
[docs]
def regression(self, x_fit, y_fit, x_label, y_label, title, sigfigs, log_x, print_r_sqr, generate_figure, view_figure, *params) -> None:
"""Backward compatibility wrapper for _calculate_r2_and_plot."""
return self._calculate_r2_and_plot(x_fit, y_fit, x_label, y_label, title, sigfigs, log_x, print_r_sqr, generate_figure, view_figure, *params)
[docs]
def fit(
self,
*,
x_label: str = "x",
y_label: str = "y",
title: str = "Fitted Hill equation",
sigfigs: int = 6,
log_x: bool = False,
print_r_sqr: bool = True,
generate_figure: bool = True,
view_figure: bool = False,
save_figure: bool = False,
output_path: Optional[str] = None,
curve_fit_kws: Optional[dict] = None,
) -> Optional[pd.DataFrame]:
"""
Fit Hill equation to the data.
Parameters
----------
x_label : str, default='x'
X-axis label
y_label : str, default='y'
Y-axis label
title : str, default='Fitted Hill equation'
Plot title
sigfigs : int, default=6
Significant figures for equation display
log_x : bool, default=False
Whether to use log scale for x-axis
print_r_sqr : bool, default=True
Whether to print R² score
generate_figure : bool, default=True
Whether to generate visualization
view_figure : bool, default=False
Whether to display the figure
save_figure : bool, default=False
Whether to save the figure
output_path : str, optional
Path to save the figure (if save_figure=True)
curve_fit_kws : dict, optional
Additional keyword arguments for scipy.optimize.curve_fit
Returns
-------
pd.DataFrame or None
For DMA-level: DataFrame with parameters for each DMA
For Overall: None (parameters stored as attributes)
"""
if self.model_level == 'Overall':
cpi = self.data['spend'].sum() / self.data['impressions'].sum()
self.data_agg = self.data[[self.date_col, 'impressions', 'predicted']].groupby(self.date_col).sum()
self.data_agg['spend'] = self.data_agg['impressions'] * cpi
self.data_agg.sort_values(by='spend', inplace=True)
self._X_data = np.array(self.data_agg['spend'])
self._y_data = np.array(self.data_agg['predicted'])
if self._X_data[0] > self._X_data[-1]:
raise ValueError(
f"The first point {self._X_data[0]} and the last point {self._X_data[-1]} are not amenable with the scipy.curvefit function."
)
if curve_fit_kws is None:
curve_fit_kws = {}
self.generate_figure = generate_figure
self.x_fit = np.logspace(
np.log10(self._X_data[0]), np.log10(self._X_data[-1]), len(self._y_data)
)
self.fit_flag = False
try:
params = self._fit_curve(curve_fit_kws)
self.y_fit = self._hill_equation(self.x_fit, *params)
self.equation = f"{np.round(self.bottom, sigfigs)} + ({np.round(self.top, sigfigs)}-{np.round(self.bottom, sigfigs)})*x**{(np.round(self.slope, sigfigs))} / ({np.round(self.saturation, sigfigs)}**{(np.round(self.slope, sigfigs))} + x**{(np.round(self.slope, sigfigs))})"
self._calculate_r2_and_plot(
self.x_fit,
self.y_fit,
x_label,
y_label,
title,
sigfigs,
log_x,
print_r_sqr,
generate_figure,
view_figure,
*params,
)
if save_figure and output_path and generate_figure:
self.figure.write_html(output_path)
logger.info(f" Figure saved to: {output_path}")
self.fit_flag = True
except RuntimeError as re:
self.r_2 = 0
logger.warning(f" Fitting failed: {re}")
elif self.model_level == 'DMA':
self.generate_figure = False
self.view_figure = False
self.print_r_sqr = False
cpi = self.data[['dmacode', 'impressions', 'spend']].groupby('dmacode').sum().reset_index()
cpi['cpi'] = cpi['spend'] / cpi['impressions']
self.data_agg = self.data[['dmacode', self.date_col, 'impressions', 'predicted']].groupby(['dmacode', self.date_col]).sum()
self.data_agg = self.data_agg.merge(cpi[['dmacode', 'cpi']], on='dmacode')
self.data_agg['spend'] = self.data_agg['impressions'] * self.data_agg['cpi']
self.data_agg.sort_values(by=['dmacode', 'spend'], inplace=True)
dmas = self.data_agg['dmacode'].unique()
bottom = []
top = []
slope = []
saturation = []
r_2 = []
for i in tqdm(dmas, desc='Running...'):
self._X_data = np.array(self.data_agg[self.data_agg['dmacode'] == i]['spend'])
self._y_data = np.array(self.data_agg[self.data_agg['dmacode'] == i]['predicted'])
if self._X_data[0] > self._X_data[-1]:
raise ValueError(
f"The first point {self._X_data[0]} and the last point {self._X_data[-1]} are not amenable with the scipy.curvefit function."
)
curve_fit_kws = {}
try:
params = self._fit_curve(curve_fit_kws)
corrected_y_data = self._hill_equation(self._X_data, *params)
self.r_2 = r2_score(self._y_data, corrected_y_data)
bottom.append(self.bottom)
top.append(self.top)
slope.append(self.slope)
saturation.append(self.saturation)
r_2.append(self.r_2)
except ValueError as e:
bottom.append(0)
top.append(0)
slope.append(0)
saturation.append(0)
r_2.append(0)
continue
except RuntimeError as re:
bottom.append(0)
top.append(0)
slope.append(0)
saturation.append(0)
r_2.append(0)
re = "for dmacode: " + str(i) + " " + str(re)
logger.warning(re)
continue
return pd.DataFrame({
'dmacode': list(dmas),
'bottom': bottom,
'top': top,
'slope': slope,
'saturation': saturation,
'r_2': r_2
})
# Backward compatibility
[docs]
def fit_model(self, **kwargs) -> Optional[pd.DataFrame]:
"""Backward compatibility wrapper for fit()."""
return self.fit(**kwargs)
[docs]
def predict(self, X: np.ndarray) -> np.ndarray:
"""
Predict response for new spend levels (Overall level only).
Parameters
----------
X : np.ndarray
New spend/impression values
Returns
-------
np.ndarray
Predicted response values
"""
if self.fit_flag:
return self.bottom + (self.top - self.bottom) * X**self.slope / (
self.saturation**self.slope + X**self.slope
)
[docs]
def get_summary(self):
"""
Get summary of fitted parameters.
Returns
-------
dict
Dictionary with 'params', 'r2', and 'equation'
"""
return {
'params': {
'top': self.top,
'bottom': self.bottom,
'saturation': self.saturation,
'slope': self.slope
},
'r2': self.r_2,
'equation': self.equation if hasattr(self, 'equation') else None
}
# Backward compatibility alias
ResponseCurveFitter = ResponseCurveFit