Skip to content

Commit

Permalink
plot_cap() can have different model parameters as output (#627)
Browse files Browse the repository at this point in the history
* plot_cap gains a new argument 'target'

* Add seed to test and update changelog

* Update tests/test_plots.py

Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com>

* Update tests/test_plots.py

Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com>

* Update tests/test_plots.py

Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com>

* remove old 'print'

* fix indentation

Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com>
  • Loading branch information
tomicapretto and canyon289 authored Jan 14, 2023
1 parent 6fec04f commit 9ba92e1
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 5 deletions.
1 change: 1 addition & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

* Refactored the codebase to support distributional models (#607)
* Added a default method to handle posterior predictive sampling for custom families (#625)
* `plot_cap()` gains a new argument `target` that allows to plot different parameters of the response distribution (#627)

### Maintenance and fixes

Expand Down
11 changes: 7 additions & 4 deletions bambi/plots/plot_cap.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def plot_cap(
model,
idata,
covariates,
target="mean",
use_hdi=True,
hdi_prob=None,
transforms=None,
Expand All @@ -152,6 +153,8 @@ def plot_cap(
and the third is mapped to different plot panels.
If a dictionary, keys must be taken from ("horizontal", "color", "panel") and the values
are the names of the variables.
target : str
Which model parameter to plot. Defaults to 'mean'.
use_hdi : bool, optional
Whether to compute the highest density interval (defaults to True) or the quantiles.
hdi_prob : float, optional
Expand Down Expand Up @@ -203,11 +206,11 @@ def plot_cap(
response_name = get_aliased_name(model.response_component.response_term)
response_transform = transforms.get(response_name, identity)

y_hat = response_transform(idata.posterior[f"{response_name}_mean"])
y_hat = response_transform(idata.posterior[f"{response_name}_{target}"])
y_hat_mean = y_hat.mean(("chain", "draw"))

if use_hdi:
y_hat_bounds = az.hdi(y_hat, hdi_prob)[f"{response_name}_mean"].T
y_hat_bounds = az.hdi(y_hat, hdi_prob)[f"{response_name}_{target}"].T
else:
lower_bound = round((1 - hdi_prob) / 2, 4)
upper_bound = 1 - lower_bound
Expand All @@ -222,7 +225,6 @@ def plot_cap(
axes = np.atleast_1d(axes)
else:
axes = np.atleast_1d(ax)
print(axes)
if isinstance(axes[0], np.ndarray):
fig = axes[0][0].get_figure()
else:
Expand All @@ -238,8 +240,9 @@ def plot_cap(
else:
raise ValueError("Main covariate must be numeric or categoric.")

ylabel = response_name if target == "mean" else target
for ax in axes.ravel(): # pylint: disable = redefined-argument-from-local
ax.set(xlabel=main, ylabel=response_name)
ax.set(xlabel=main, ylabel=ylabel)

return fig, axes

Expand Down
22 changes: 21 additions & 1 deletion tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import matplotlib.pyplot as plt
import pytest

from bambi.models import Model
from bambi.models import Model, Formula
from bambi.plots import plot_cap


Expand Down Expand Up @@ -150,3 +150,23 @@ def test_transforms(mtcars):

transforms = {"mpg": np.log, "hp": np.log}
plot_cap(model, idata, ["hp"], transforms=transforms)


def test_multiple_outputs():
"""Test plot cap default and specified values for target argument"""
rng = np.random.default_rng(121195)
N = 200
a, b = 0.5, 1.1
x = rng.uniform(-1.5, 1.5, N)
shape = np.exp(0.3 + x * 0.5 + rng.normal(scale=0.1, size=N))
y = rng.gamma(shape, np.exp(a + b * x) / shape, N)
data_gamma = pd.DataFrame({"x": x, "y": y})


formula = Formula("y ~ x", "alpha ~ x")
model = Model(formula, data_gamma, family="gamma")
idata = model.fit(tune=100, draws=100, random_seed=1234)
# Test default target
plot_cap(model, idata, "x")
# Test user supplied target argument
plot_cap(model, idata, "x", "alpha")

0 comments on commit 9ba92e1

Please sign in to comment.