Skip to content

Commit

Permalink
interpret support for model predictions with response levels (#732)
Browse files Browse the repository at this point in the history
* interpret predictions support for multiple prediction response levels

* slopes and comparisons support for predictions > 1 response level

* change arg. names and explicit data type cast in enforce_dtypes

* inline comment for if-else statement

* if-else logic for response_dim

* categorical regression model to test interpret functions with n-dim preds

* add docstring and inline comments for explainability

* add plot_predictions to food choice categorical regression

* run black
  • Loading branch information
GStechschulte authored Oct 11, 2023
1 parent 3aaebca commit 77a8fa1
Show file tree
Hide file tree
Showing 5 changed files with 416 additions and 226 deletions.
96 changes: 70 additions & 26 deletions bambi/interpret/effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from bambi.interpret.utils import (
average_over,
ConditionalInfo,
enforce_dtypes,
identity,
merge,
VariableInfo,
)
from bambi.utils import get_aliased_name, listify
Expand Down Expand Up @@ -319,16 +321,28 @@ def get_estimate(

return self

def get_summary_df(self) -> pd.DataFrame:
def get_summary_df(self, response_dim: np.ndarray) -> pd.DataFrame:
"""
Builds the summary dataframe for 'comparisons' and 'slopes' effects. If
the number of values passed for the variable of interest is less then 2
for 'comparisons' and 'slopes', then a subset of the 'preds' data is used
to build the summary. If the effect kind is 'comparisons' and more than
2 values are being compared, then the entire 'preds' data is used. If the
effect kind is 'slopes' and more than 2 values are being compared, then
only a subset of the 'preds' data is used to build the summary.
Builds the summary dataframe for 'comparisons' and 'slopes' effects.
There are four scenarios to consider:
1.) If the effect kind is 'comparisons' and more than 2 values are being
compared, then the entire 'preds' data is used.
2.) If the model predictions have multiple response levels, then 'preds' data
needs to be duplicated to match the number of response levels. E.g., 'preds'
data has 100 rows and 3 response levels, then the summary dataframe will have
300 rows since the model made a prediction for each response level for each
sample in 'preds'.
3.) If the effect kind is 'slopes' and more than 2 values are being compared, then
only a subset of the 'preds' data is used to build the summary.
4.) If the number of values passed for the variable of interest is less then 2
for 'comparisons' and 'slopes', then a subset of the 'preds' data is used
to build the summary.
"""
# Scenario 1
if len(self.variable.values) > 2 and self.kind == "comparisons":
summary_df = self.preds_data.drop(columns=self.variable.name).drop_duplicates()
covariates_cols = summary_df.columns
Expand All @@ -339,6 +353,18 @@ def get_summary_df(self) -> pd.DataFrame:
contrast_values, summary_df.shape[0] // len(contrast_values), axis=0
)
contrast_values = [tuple(elem) for elem in contrast_values]
# Scenario 2
elif len(response_dim) > 1:
summary_df = self.preds_data.drop(columns=self.variable.name).drop_duplicates()
covariates_cols = summary_df.columns
contrast_values = self.variable.values.flatten()
covariate_vals = np.repeat(summary_df.T, len(response_dim))
summary_df = pd.DataFrame(data=covariate_vals.T, columns=covariates_cols)
summary_df["estimate_dim"] = np.tile(
response_dim, summary_df.shape[0] // len(response_dim)
)
contrast_values = [tuple(contrast_values)] * summary_df.shape[0]
# Scenario 3 & 4
else:
wrt = {}
for idx, _ in enumerate(self.variable.values):
Expand Down Expand Up @@ -473,12 +499,10 @@ def predictions(

assert 1 <= len(covariates) <= 3

if transforms is None:
transforms = {}
transforms = transforms if transforms is not None else {}

if prob is None:
prob = az.rcParams["stats.hdi_prob"]

if not 0 < prob < 1:
raise ValueError(f"'prob' must be greater than 0 and smaller than 1. It is {prob}.")

Expand Down Expand Up @@ -525,9 +549,20 @@ def predictions(
upper_bound = 1 - lower_bound
response.lower_bound, response.upper_bound = lower_bound, upper_bound

cap_data["estimate"] = y_hat_mean
cap_data[response.lower_bound_name] = y_hat_bounds[0]
cap_data[response.upper_bound_name] = y_hat_bounds[1]
if y_hat_mean.ndim > 1:
cap_data = merge(y_hat_mean, y_hat_bounds, cap_data)
cap_data = cap_data.rename(
columns={
f"{response.name}_dim": "estimate_dim",
f"{response.name_target}": "estimate",
f"{response.name_target}_x": response.lower_bound_name,
f"{response.name_target}_y": response.upper_bound_name,
}
)
else:
cap_data["estimate"] = y_hat_mean
cap_data[response.lower_bound_name] = y_hat_bounds[0]
cap_data[response.upper_bound_name] = y_hat_bounds[1]

return cap_data

Expand Down Expand Up @@ -630,8 +665,7 @@ def comparisons(
contrast_info = VariableInfo(model, contrast, "comparisons", eps=0.5)
conditional_info = ConditionalInfo(model, conditional)

if transforms is None:
transforms = {}
transforms = transforms if transforms is not None else {}

response_name = get_aliased_name(model.response_component.response_term)
response = ResponseInfo(
Expand All @@ -647,6 +681,13 @@ def comparisons(
idata, data=comparisons_data, sample_new_groups=sample_new_groups, inplace=False
)

# returns empty array if model predictions do not have multiple dimensions
response_dim_key = response.name + "_dim"
if response_dim_key in idata.posterior.coords:
response_dim = idata.posterior.coords[response_dim_key].values
else:
response_dim = np.empty(0)

predictive_difference = PredictiveDifferences(
model,
comparisons_data,
Expand All @@ -658,12 +699,12 @@ def comparisons(
)
comparisons_summary = predictive_difference.get_estimate(
idata, response_transform, comparison_type, prob=prob
).get_summary_df()
).get_summary_df(response_dim)

if average_by:
comparisons_summary = predictive_difference.average_by(variable=average_by)

return comparisons_summary
return enforce_dtypes(comparisons_data, comparisons_summary)


def slopes(
Expand Down Expand Up @@ -769,10 +810,7 @@ def slopes(
# 'slopes' should not be limited to ("main", "group", "panel")
conditional_info = ConditionalInfo(model, conditional)

grid = False
if conditional_info.covariates:
grid = True

grid = bool(conditional_info.covariates)
# if wrt is categorical or string dtype, call 'comparisons' to compute the
# difference between group means as the slope
effect_type = "slopes"
Expand All @@ -784,8 +822,7 @@ def slopes(
lower_bound = round((1 - prob) / 2, 4)
upper_bound = 1 - lower_bound

if transforms is None:
transforms = {}
transforms = transforms if transforms is not None else {}

response_name = get_aliased_name(model.response_component.response_term)
response = ResponseInfo(response_name, "mean", lower_bound, upper_bound)
Expand All @@ -798,14 +835,21 @@ def slopes(
idata, data=slopes_data, sample_new_groups=sample_new_groups, inplace=False
)

# returns empty array if model predictions do not have multiple dimensions
response_dim_key = response.name + "_dim"
if response_dim_key in idata.posterior.coords:
response_dim = idata.posterior.coords[response_dim_key].values
else:
response_dim = np.empty(0)

predictive_difference = PredictiveDifferences(
model, slopes_data, wrt_info, conditional_info, response, use_hdi, effect_type
)
slopes_summary = predictive_difference.get_estimate(
idata, response_transform, "diff", slope, eps
).get_summary_df()
).get_summary_df(response_dim)

if average_by:
slopes_summary = predictive_difference.average_by(variable=average_by)

return slopes_summary
return enforce_dtypes(slopes_data, slopes_summary)
87 changes: 82 additions & 5 deletions bambi/interpret/plot_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,33 @@ def plot_numeric(
y_hat_mean = plot_data["estimate"]
y_hat_bounds = np.transpose(plot_data[plot_data.columns[-2:]].values)

# if "estimate_dim" column exists, then model predictions has multiple dimensions
if "estimate_dim" in plot_data.columns:
y_hat_dims = plot_data["estimate_dim"].unique()
y_hat_ndim = len(y_hat_dims)
else:
y_hat_ndim = 1

if len(covariates) == 1:
ax = axes[0]
values_main = transform_main(plot_data[main])
ax.plot(values_main, y_hat_mean, solid_capstyle="butt", color="C0")
ax.fill_between(values_main, y_hat_bounds[0], y_hat_bounds[1], alpha=0.4)
if y_hat_ndim > 1:
for i, clr in enumerate(y_hat_dims):
idx = plot_data["estimate_dim"] == clr
values_main = transform_main(plot_data.loc[idx, main])
ax.plot(
values_main, y_hat_mean[idx], color=f"C{i}", label=clr, solid_capstyle="butt"
)
ax.fill_between(
values_main,
y_hat_bounds[0][idx],
y_hat_bounds[1][idx],
alpha=0.4,
color=f"C{i}",
)
else:
values_main = transform_main(plot_data[main])
ax.plot(values_main, y_hat_mean, solid_capstyle="butt", color="C0")
ax.fill_between(values_main, y_hat_bounds[0], y_hat_bounds[1], alpha=0.4)
elif "group" in covariates and not "panel" in covariates:
ax = axes[0]
colors = np.unique(plot_data[color])
Expand Down Expand Up @@ -100,6 +122,25 @@ def plot_numeric(
)
ax.set(title=f"{panel} = {pnl}")

if y_hat_ndim > 1:
if "group" not in covariates and legend:
handles = [
(
Line2D([], [], color=f"C{i}", solid_capstyle="butt"),
Patch(color=f"C{i}", alpha=0.4, lw=1),
)
for i in range(len(y_hat_dims))
]
for ax in axes.ravel():
ax.legend(
handles,
tuple(y_hat_dims),
title="estimate_dim",
handlelength=1.3,
handleheight=1,
loc="best",
)

if "group" in covariates and legend:
handles = [
(
Expand All @@ -112,6 +153,7 @@ def plot_numeric(
ax.legend(
handles, tuple(colors), title=color, handlelength=1.3, handleheight=1, loc="best"
)

return axes


Expand Down Expand Up @@ -144,6 +186,13 @@ def plot_categoric(covariates: Covariates, plot_data: pd.DataFrame, legend: bool
y_hat_mean = plot_data["estimate"]
y_hat_bounds = np.transpose(plot_data[plot_data.columns[-2:]].values)

# if "estimate_dim" column exists, then model predictions has multiple dimensions
if "estimate_dim" in plot_data.columns:
y_hat_dims = plot_data["estimate_dim"].unique()
y_hat_ndim = len(y_hat_dims)
else:
y_hat_ndim = 1

if "group" in covariates:
colors = np.unique(plot_data[color])
colors_n = len(colors)
Expand All @@ -155,8 +204,17 @@ def plot_categoric(covariates: Covariates, plot_data: pd.DataFrame, legend: bool

if len(covariates) == 1:
ax = axes[0]
ax.scatter(idxs_main, y_hat_mean, color="C0")
ax.vlines(idxs_main, y_hat_bounds[0], y_hat_bounds[1], color="C0")
if y_hat_ndim > 1:
offset_bounds = get_group_offset(y_hat_ndim)
colors_offset = np.linspace(-offset_bounds, offset_bounds, y_hat_ndim)
for i, clr in enumerate(y_hat_dims):
idx = plot_data["estimate_dim"] == clr
idxs = idxs_main + colors_offset[i]
ax.scatter(idxs, y_hat_mean[idx], color=f"C{i}")
ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")
else:
ax.scatter(idxs_main, y_hat_mean, color="C0")
ax.vlines(idxs_main, y_hat_bounds[0], y_hat_bounds[1], color="C0")
elif "group" in covariates and not "panel" in covariates:
ax = axes[0]
for i, clr in enumerate(colors):
Expand Down Expand Up @@ -187,6 +245,25 @@ def plot_categoric(covariates: Covariates, plot_data: pd.DataFrame, legend: bool
ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")
ax.set(title=f"{panel} = {pnl}")

if y_hat_ndim > 1:
if "group" not in covariates and legend:
handles = [
(
Line2D([], [], color=f"C{i}", solid_capstyle="butt"),
Patch(color=f"C{i}", alpha=0.4, lw=1),
)
for i in range(len(y_hat_dims))
]
for ax in axes.ravel():
ax.legend(
handles,
tuple(y_hat_dims),
title="estimate_dim",
handlelength=1.3,
handleheight=1,
loc="best",
)

if "group" in covariates and legend:
handles = [
Line2D([], [], c=f"C{i}", marker="o", label=level) for i, level in enumerate(colors)
Expand Down
43 changes: 36 additions & 7 deletions bambi/interpret/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from formulae.terms.call import Call
import pandas as pd
from pandas.api.types import is_categorical_dtype, is_numeric_dtype, is_string_dtype
import xarray as xr

from bambi import Model
from bambi.utils import listify
Expand Down Expand Up @@ -257,20 +258,23 @@ def get_covariates(covariates: dict) -> Covariates:
return Covariates(main, group, panel)


def enforce_dtypes(data: pd.DataFrame, df: pd.DataFrame, except_col=None) -> pd.DataFrame:
def enforce_dtypes(
observed_df: pd.DataFrame, new_df: pd.DataFrame, except_col=None
) -> pd.DataFrame:
"""
Enforce dtypes of the original data to the new data.
Enforce dtypes of the observed data to the new data.
"""
observed_dtypes = data.dtypes
for col in df.columns:
observed_dtypes = observed_df.dtypes
for col in new_df.columns:
if col in observed_dtypes.index and not except_col:
if observed_dtypes[col] == "category":
# explicitly converts to category dtype
df[col] = df[col].astype("category")
new_df[col] = new_df[col].astype("category")
else:
# casts the original dtype to the new data
df[col] = df[col].astype(observed_dtypes[col])
return df
new_df[col] = new_df[col].astype(observed_dtypes[col])

return new_df


def make_group_panel_values(
Expand Down Expand Up @@ -399,3 +403,28 @@ def get_group_offset(n, lower: float = 0.05, upper: float = 0.4) -> np.ndarray:

def identity(x):
return x


def merge(y_hat_mean: xr.DataArray, y_hat_bounds: xr.DataArray, data: pd.DataFrame) -> pd.DataFrame:
"""
Convert predictions ('y_hat_mean' and 'y_hat_bounds') into dataframes and join
with the original data used to perform predictions. This will "duplicate" the
data to ensure that the original data is aligned with each response dimension
(level).
"""

idx_names = y_hat_mean.to_dataframe().index.names

yhat_df = y_hat_mean.to_dataframe().reset_index().set_index(idx_names)
lower_df = y_hat_bounds.sel(hdi="lower").to_dataframe().reset_index().set_index(idx_names)
higher_df = y_hat_bounds.sel(hdi="higher").to_dataframe().reset_index().set_index(idx_names)
bounds_df = pd.merge(left=lower_df, right=higher_df, left_index=True, right_index=True)
preds_df = (
pd.merge(left=yhat_df, right=bounds_df, left_index=True, right_index=True)
.reset_index()
.set_index(idx_names[0])
)

summary_df = pd.merge(left=data, right=preds_df, left_index=True, right_index=True)

return summary_df.drop(columns=["hdi_x", "hdi_y"])
Loading

0 comments on commit 77a8fa1

Please sign in to comment.