From 9ba92e1b8fa2833370468de46d55bffac3fda101 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Capretto?= Date: Sat, 14 Jan 2023 16:39:12 -0300 Subject: [PATCH] `plot_cap()` can have different model parameters as output (#627) * plot_cap gains a new argument 'target' * Add seed to test and update changelog * Update tests/test_plots.py Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com> * Update tests/test_plots.py Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com> * Update tests/test_plots.py Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com> * remove old 'print' * fix indentation Co-authored-by: Ravin Kumar <7213793+canyon289@users.noreply.github.com> --- Changelog.md | 1 + bambi/plots/plot_cap.py | 11 +++++++---- tests/test_plots.py | 22 +++++++++++++++++++++- 3 files changed, 29 insertions(+), 5 deletions(-) diff --git a/Changelog.md b/Changelog.md index 60f0e41e3..f8f1f39fd 100644 --- a/Changelog.md +++ b/Changelog.md @@ -6,6 +6,7 @@ * Refactored the codebase to support distributional models (#607) * Added a default method to handle posterior predictive sampling for custom families (#625) +* `plot_cap()` gains a new argument `target` that allows to plot different parameters of the response distribution (#627) ### Maintenance and fixes diff --git a/bambi/plots/plot_cap.py b/bambi/plots/plot_cap.py index 8d35aa2bc..58ffd3bb5 100644 --- a/bambi/plots/plot_cap.py +++ b/bambi/plots/plot_cap.py @@ -128,6 +128,7 @@ def plot_cap( model, idata, covariates, + target="mean", use_hdi=True, hdi_prob=None, transforms=None, @@ -152,6 +153,8 @@ def plot_cap( and the third is mapped to different plot panels. If a dictionary, keys must be taken from ("horizontal", "color", "panel") and the values are the names of the variables. + target : str + Which model parameter to plot. Defaults to 'mean'. use_hdi : bool, optional Whether to compute the highest density interval (defaults to True) or the quantiles. hdi_prob : float, optional @@ -203,11 +206,11 @@ def plot_cap( response_name = get_aliased_name(model.response_component.response_term) response_transform = transforms.get(response_name, identity) - y_hat = response_transform(idata.posterior[f"{response_name}_mean"]) + y_hat = response_transform(idata.posterior[f"{response_name}_{target}"]) y_hat_mean = y_hat.mean(("chain", "draw")) if use_hdi: - y_hat_bounds = az.hdi(y_hat, hdi_prob)[f"{response_name}_mean"].T + y_hat_bounds = az.hdi(y_hat, hdi_prob)[f"{response_name}_{target}"].T else: lower_bound = round((1 - hdi_prob) / 2, 4) upper_bound = 1 - lower_bound @@ -222,7 +225,6 @@ def plot_cap( axes = np.atleast_1d(axes) else: axes = np.atleast_1d(ax) - print(axes) if isinstance(axes[0], np.ndarray): fig = axes[0][0].get_figure() else: @@ -238,8 +240,9 @@ def plot_cap( else: raise ValueError("Main covariate must be numeric or categoric.") + ylabel = response_name if target == "mean" else target for ax in axes.ravel(): # pylint: disable = redefined-argument-from-local - ax.set(xlabel=main, ylabel=response_name) + ax.set(xlabel=main, ylabel=ylabel) return fig, axes diff --git a/tests/test_plots.py b/tests/test_plots.py index ec4a56366..f50b3656c 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt import pytest -from bambi.models import Model +from bambi.models import Model, Formula from bambi.plots import plot_cap @@ -150,3 +150,23 @@ def test_transforms(mtcars): transforms = {"mpg": np.log, "hp": np.log} plot_cap(model, idata, ["hp"], transforms=transforms) + + +def test_multiple_outputs(): + """Test plot cap default and specified values for target argument""" + rng = np.random.default_rng(121195) + N = 200 + a, b = 0.5, 1.1 + x = rng.uniform(-1.5, 1.5, N) + shape = np.exp(0.3 + x * 0.5 + rng.normal(scale=0.1, size=N)) + y = rng.gamma(shape, np.exp(a + b * x) / shape, N) + data_gamma = pd.DataFrame({"x": x, "y": y}) + + + formula = Formula("y ~ x", "alpha ~ x") + model = Model(formula, data_gamma, family="gamma") + idata = model.fit(tune=100, draws=100, random_seed=1234) + # Test default target + plot_cap(model, idata, "x") + # Test user supplied target argument + plot_cap(model, idata, "x", "alpha") \ No newline at end of file