Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add panel covariates to plot_cap() and make it more flexible #596

Merged
merged 7 commits into from
Dec 2, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 199 additions & 94 deletions bambi/plots/plot_cap.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from statistics import mode

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from arviz.plots.backends.matplotlib import create_axes_grid
from arviz.plots.plot_utils import default_grid
from formulae.terms.call import Call
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
Expand Down Expand Up @@ -45,46 +46,45 @@ def create_cap_data(model, covariates, grid_n=200, groups_n=5):
When either the main or the group covariates are not numeric or categoric.
"""
data = model.data
covariates = listify(covariates)

if len(covariates) not in [1, 2]:
raise ValueError(f"The number of covariates must be 1 or 2. It's {len(covariates)}.")

main = covariates[0]

# If available, take the name of the grouping variable
if len(covariates) == 1:
group = None
else:
group = covariates[1]
main = covariates.get("horizontal")
group = covariates.get("color", None)
panel = covariates.get("panel", None)

# Obtain data for main variable
data_main = data[main]
if is_numeric_dtype(data_main):
main_values = np.linspace(np.min(data_main), np.max(data_main), grid_n)
elif is_string_dtype(data_main) or is_categorical_dtype(data_main):
main_values = np.unique(data_main)
else:
raise ValueError("Main covariate must be numeric or categoric.")
main_values = make_main_values(data[main], grid_n)
main_n = len(main_values)

# If available, obtain groups for grouping variable
if group:
group_data = data[group]
if is_string_dtype(group_data) or is_categorical_dtype(group_data):
group_values = np.unique(group_data)
elif is_numeric_dtype(group_data):
group_values = np.quantile(group_data, np.linspace(0, 1, groups_n))
else:
raise ValueError("Group covariate must be numeric or categoric.")

# Reshape accordingly
group_values = make_group_values(data[group], groups_n)
group_n = len(group_values)
main_n = len(main_values)

# If available, obtain groups for panel variable. Same logic than grouping applies
if panel:
panel_values = make_group_values(data[panel], groups_n)
panel_n = len(panel_values)

data_dict = {main: main_values}

if group and not panel:
main_values = np.tile(main_values, group_n)
group_values = np.repeat(group_values, main_n)
data_dict = {main: main_values, group: group_values}
else:
data_dict = {main: main_values}
data_dict.update({main: main_values, group: group_values})
elif not group and panel:
main_values = np.tile(main_values, panel_n)
panel_values = np.repeat(panel_values, main_n)
data_dict.update({main: main_values, panel: panel_values})
elif group and panel:
if group == panel:
main_values = np.tile(main_values, group_n)
group_values = np.repeat(group_values, main_n)
data_dict.update({main: main_values, group: group_values})
else:
main_values = np.tile(np.tile(main_values, group_n), panel_n)
group_values = np.tile(np.repeat(group_values, main_n), panel_n)
panel_values = np.repeat(panel_values, main_n * group_n)
data_dict.update({main: main_values, group: group_values, panel: panel_values})

# Construct dictionary of terms that are in the model
terms = {}
Expand Down Expand Up @@ -122,7 +122,15 @@ def create_cap_data(model, covariates, grid_n=200, groups_n=5):


def plot_cap(
model, idata, covariates, use_hdi=True, hdi_prob=None, transforms=None, legend=True, ax=None
model,
idata,
covariates,
use_hdi=True,
hdi_prob=None,
transforms=None,
legend=True,
ax=None,
fig_kwargs=None,
):
"""Plot Conditional Adjusted Predictions

Expand All @@ -147,8 +155,8 @@ def plot_cap(
Transformations that are applied to each of the variables being plotted. The keys are the
name of the variables, and the values are functions to be applied. Defaults to ``None``.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
name of the variables, and the values are functions to be applied. Defaults to ``None``.
name of the variables, and the values are functions to be applied. Defaults to `None`.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, just curious why are the double ticks needed? I see it in other docstrings too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember correctly it's because it doesn't render properly if it only has a single backtick. Maybe something changed and now it works with a single backtick, but previously I remember trying a single backtick without working

ax : matplotlib.axes._subplots.AxesSubplot, optional
A matplotlib axes object. If None, this function instantiates a new axes object.
Defaults to ``None``.
A matplotlib axes object or a sequence of them. If None, this function instantiates a
new axes object. Defaults to ``None``.

Returns
-------
Expand All @@ -162,8 +170,15 @@ def plot_cap(
When the main covariate is not numeric or categoric.
"""

covariates = listify(covariates)
assert len(covariates) in [1, 2]
covariate_kinds = ("horizontal", "color", "panel")
if not isinstance(covariates, dict):
covariates = listify(covariates)
covariates = dict(zip(covariate_kinds, covariates))
else:
assert covariate_kinds[0] in covariates
assert set(covariates).issubset(set(covariate_kinds))

assert 1 <= len(covariates) <= 3

cap_data = create_cap_data(model, covariates)
idata = model.predict(idata, data=cap_data, inplace=False)
Expand All @@ -177,7 +192,6 @@ def plot_cap(
if transforms is None:
transforms = {}

# If passed, use transformation
response_transform = transforms.get(model.response.name, identity)

y_hat = response_transform(idata.posterior[f"{model.response.name}_mean"])
Expand All @@ -191,102 +205,193 @@ def plot_cap(
y_hat_bounds = y_hat.quantile(q=(lower_bound, upper_bound), dim=("chain", "draw"))

if ax is None:
fig, ax = plt.subplots()
fig_kwargs = {} if fig_kwargs is None else fig_kwargs
panel = covariates.get("panel", None)
panels_n = len(np.unique(cap_data[panel])) if panel else 1
rows, cols = default_grid(panels_n)
fig, axes = create_axes_grid(panels_n, rows, cols, backend_kwargs=fig_kwargs)
axes = np.atleast_1d(axes)
else:
fig = ax.get_figure()
axes = np.atleast_1d(ax)
fig = axes[0].get_figure()

main = covariates[0]
main = covariates.get("horizontal")
if is_numeric_dtype(cap_data[main]):
ax = _plot_cap_numeric(
covariates, cap_data, y_hat_mean, y_hat_bounds, transforms, legend, ax
axes = _plot_cap_numeric(
covariates, cap_data, y_hat_mean, y_hat_bounds, transforms, legend, axes
)
elif is_categorical_dtype(cap_data[main]) or is_string_dtype(cap_data[main]):
ax = _plot_cap_categoric(covariates, cap_data, y_hat_mean, y_hat_bounds, legend, ax)
axes = _plot_cap_categoric(covariates, cap_data, y_hat_mean, y_hat_bounds, legend, axes)
else:
raise ValueError("Main covariate must be numeric or categoric.")

ax.set(xlabel=main, ylabel=model.response.name)
return fig, ax
for ax in axes.ravel(): # pylint: disable = redefined-argument-from-local
ax.set(xlabel=main, ylabel=model.response.name)

return fig, axes

def _plot_cap_numeric(covariates, cap_data, y_hat_mean, y_hat_bounds, transforms, legend, ax):
main = covariates[0]
# Extract transform

def _plot_cap_numeric(covariates, cap_data, y_hat_mean, y_hat_bounds, transforms, legend, axes):
main = covariates.get("horizontal")
transform_main = transforms.get(main, identity)

if len(covariates) == 1:
ax = axes[0]
values_main = transform_main(cap_data[main])
ax.plot(values_main, y_hat_mean, solid_capstyle="butt")
ax.fill_between(values_main, y_hat_bounds[0], y_hat_bounds[1], alpha=0.5)
else:
group = covariates[1]
groups = get_unique_levels(cap_data[group])

for i, grp in enumerate(groups):
idx = (cap_data[group] == grp).values
elif "color" in covariates and not "panel" in covariates:
ax = axes[0]
color = covariates.get("color")
colors = get_unique_levels(cap_data[color])
for i, clr in enumerate(colors):
idx = (cap_data[color] == clr).to_numpy()
values_main = transform_main(cap_data.loc[idx, main])
ax.plot(values_main, y_hat_mean[idx], color=f"C{i}", solid_capstyle="butt")
ax.fill_between(
values_main,
y_hat_bounds[0][idx],
y_hat_bounds[1][idx],
alpha=0.3,
alpha=0.5,
color=f"C{i}",
)

if legend:
handles = [
(
Line2D([], [], color=f"C{i}", solid_capstyle="butt"),
Patch(color=f"C{i}", alpha=0.3, lw=1),
elif not "color" in covariates and "panel" in covariates:
panel = covariates.get("panel")
panels = get_unique_levels(cap_data[panel])
for ax, pnl in zip(axes.ravel(), panels):
idx = (cap_data[panel] == pnl).to_numpy()
values_main = transform_main(cap_data.loc[idx, main])
ax.plot(values_main, y_hat_mean[idx], solid_capstyle="butt")
ax.fill_between(values_main, y_hat_bounds[0][idx], y_hat_bounds[1][idx], alpha=0.5)
ax.set(title=f"{panel} = {pnl}")
elif "color" in covariates and "panel" in covariates:
color = covariates.get("color")
panel = covariates.get("panel")
colors = get_unique_levels(cap_data[color])
panels = get_unique_levels(cap_data[panel])
if color == panel:
for i, (ax, pnl) in enumerate(zip(axes.ravel(), panels)):
idx = (cap_data[panel] == pnl).to_numpy()
values_main = transform_main(cap_data.loc[idx, main])
ax.plot(values_main, y_hat_mean[idx], color=f"C{i}", solid_capstyle="butt")
ax.fill_between(
values_main,
y_hat_bounds[0][idx],
y_hat_bounds[1][idx],
alpha=0.5,
color=f"C{i}",
)
for i in range(len(groups))
]
ax.set(title=f"{panel} = {pnl}")
else:
for ax, pnl in zip(axes.ravel(), panels):
for i, clr in enumerate(colors):
idx = ((cap_data[panel] == pnl) & (cap_data[color] == clr)).to_numpy()
values_main = transform_main(cap_data.loc[idx, main])
ax.plot(values_main, y_hat_mean[idx], color=f"C{i}", solid_capstyle="butt")
ax.fill_between(
values_main,
y_hat_bounds[0][idx],
y_hat_bounds[1][idx],
alpha=0.5,
color=f"C{i}",
)
ax.set(title=f"{panel} = {pnl}")

if "color" in covariates and legend:
handles = [
(
Line2D([], [], color=f"C{i}", solid_capstyle="butt"),
Patch(color=f"C{i}", alpha=0.5, lw=1),
)
for i in range(len(colors))
]
for ax in axes.ravel():
ax.legend(
handles,
tuple(groups),
title=group,
handlelength=1.3,
handleheight=1,
bbox_to_anchor=(1.03, 0.5),
loc="center left",
handles, tuple(colors), title=color, handlelength=1.3, handleheight=1, loc="best"
)
return axes

return ax


def _plot_cap_categoric(covariates, cap_data, y_hat_mean, y_hat_bounds, legend, ax):
main = covariates[0]
def _plot_cap_categoric(covariates, cap_data, y_hat_mean, y_hat_bounds, legend, axes):
main = covariates.get("horizontal")
main_levels = get_unique_levels(cap_data[main])
main_levels_n = len(main_levels)
idxs_main = np.arange(main_levels_n)

if "color" in covariates:
color = covariates.get("color")
colors = get_unique_levels(cap_data[color])
colors_n = len(colors)
offset_bounds = get_group_offset(colors_n)
colors_offset = np.linspace(-offset_bounds, offset_bounds, colors_n)

if "panel" in covariates:
panel = covariates.get("panel")
panels = get_unique_levels(cap_data[panel])

if len(covariates) == 1:
ax = axes[0]
ax.scatter(idxs_main, y_hat_mean)
ax.vlines(idxs_main, y_hat_bounds[0], y_hat_bounds[1])
else:
group = covariates[1]
group_levels = get_unique_levels(cap_data[group])
group_levels_n = len(group_levels)
offset_bounds = get_group_offset(group_levels_n)
offset_groups = np.linspace(-offset_bounds, offset_bounds, group_levels_n)

for i, grp in enumerate(group_levels):
idx = (cap_data[group] == grp).values
idxs = idxs_main + offset_groups[i]
elif "color" in covariates and not "panel" in covariates:
ax = axes[0]
for i, clr in enumerate(colors):
idx = (cap_data[color] == clr).to_numpy()
idxs = idxs_main + colors_offset[i]
ax.scatter(idxs, y_hat_mean[idx], color=f"C{i}")
ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")
elif not "color" in covariates and "panel" in covariates:
for ax, pnl in zip(axes.ravel(), panels):
idx = (cap_data[panel] == pnl).to_numpy()
ax.scatter(idxs_main, y_hat_mean[idx])
ax.vlines(idxs_main, y_hat_bounds[0][idx], y_hat_bounds[1][idx])
ax.set(title=f"{panel} = {pnl}")
elif "color" in covariates and "panel" in covariates:
if color == panel:
for i, (ax, pnl) in enumerate(zip(axes.ravel(), panels)):
idx = (cap_data[panel] == pnl).to_numpy()
idxs = idxs_main + colors_offset[i]
ax.scatter(idxs, y_hat_mean[idx], color=f"C{i}")
ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")
ax.set(title=f"{panel} = {pnl}")
else:
for ax, pnl in zip(axes.ravel(), panels):
for i, clr in enumerate(colors):
idx = ((cap_data[panel] == pnl) & (cap_data[color] == clr)).to_numpy()
idxs = idxs_main + colors_offset[i]
ax.scatter(idxs, y_hat_mean[idx], color=f"C{i}")
ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")
ax.set(title=f"{panel} = {pnl}")

if "color" in covariates and legend:
handles = [
Line2D([], [], c=f"C{i}", marker="o", label=level) for i, level in enumerate(colors)
]
for ax in axes.ravel():
ax.legend(handles=handles, title=color, loc="best")

if legend:
handles = [
Line2D([], [], c=f"C{i}", marker="o", label=level)
for i, level in enumerate(group_levels)
]
ax.legend(handles=handles, title=group, bbox_to_anchor=(1.03, 0.5), loc="center left")
for ax in axes.ravel():
ax.set_xticks(idxs_main)
ax.set_xticklabels(main_levels)

ax.set_xticks(idxs_main)
ax.set_xticklabels(main_levels)
return ax
return axes


def identity(x):
return x


def make_main_values(x, grid_n):
if is_numeric_dtype(x):
return np.linspace(np.min(x), np.max(x), grid_n)
elif is_string_dtype(x) or is_categorical_dtype(x):
return np.unique(x)
raise ValueError("Main covariate must be numeric or categoric.")


def make_group_values(x, groups_n):
if is_string_dtype(x) or is_categorical_dtype(x):
return np.unique(x)
elif is_numeric_dtype(x):
return np.quantile(x, np.linspace(0, 1, groups_n))
raise ValueError("Group covariate must be numeric or categoric.")