From 0fa0b6f270ee33ca118c53dcbc2834222435d4c7 Mon Sep 17 00:00:00 2001
From: Gabriel Stechschulte <63432018+GStechschulte@users.noreply.github.com>
Date: Sun, 9 Jul 2023 14:54:21 +0200
Subject: [PATCH] Add comparisons and plot comparisons (#684)
---
bambi/models.py | 3 +-
bambi/plots/__init__.py | 6 +-
bambi/plots/create_data.py | 129 ++++++++++
bambi/plots/effects.py | 449 ++++++++++++++++++++++++++++++++++
bambi/plots/plot_cap.py | 425 --------------------------------
bambi/plots/plot_types.py | 201 +++++++++++++++
bambi/plots/plotting.py | 306 +++++++++++++++++++++++
bambi/plots/utils.py | 322 +++++++++++++++++++++++-
docs/notebooks/plot_cap.ipynb | 270 ++++++++++++++------
tests/test_plots.py | 395 ++++++++++++++++++------------
10 files changed, 1838 insertions(+), 668 deletions(-)
create mode 100644 bambi/plots/create_data.py
create mode 100644 bambi/plots/effects.py
delete mode 100644 bambi/plots/plot_cap.py
create mode 100644 bambi/plots/plot_types.py
create mode 100644 bambi/plots/plotting.py
diff --git a/bambi/models.py b/bambi/models.py
index 2299adab4..2d184ef99 100644
--- a/bambi/models.py
+++ b/bambi/models.py
@@ -542,7 +542,6 @@ def set_alias(self, aliases):
assert component_name in self.distributional_components
component = self.distributional_components[component_name]
for name, alias in component_aliases.items():
-
is_used = False
if name in component.terms:
@@ -661,7 +660,7 @@ def plot_priors(
unobserved_rvs_names = []
flat_rvs = []
for unobserved in self.backend.model.unobserved_RVs:
- if "Flat" in unobserved.__str__():
+ if "Flat" in str(unobserved):
flat_rvs.append(unobserved.name)
else:
unobserved_rvs_names.append(unobserved.name)
diff --git a/bambi/plots/__init__.py b/bambi/plots/__init__.py
index 92da1d8f1..fa53374db 100644
--- a/bambi/plots/__init__.py
+++ b/bambi/plots/__init__.py
@@ -1,3 +1,5 @@
-from .plot_cap import create_cap_data, plot_cap
+from bambi.plots.effects import comparisons, predictions
+from bambi.plots.plotting import plot_cap, plot_comparison
-__all__ = ["create_cap_data", "plot_cap"]
+
+__all__ = ["comparisons", "predictions", "plot_cap", "plot_comparison"]
diff --git a/bambi/plots/create_data.py b/bambi/plots/create_data.py
new file mode 100644
index 000000000..0d6e6899d
--- /dev/null
+++ b/bambi/plots/create_data.py
@@ -0,0 +1,129 @@
+import itertools
+
+import numpy as np
+import pandas as pd
+
+from bambi.models import Model
+from bambi.plots.utils import (
+ ConditionalInfo,
+ ContrastInfo,
+ enforce_dtypes,
+ get_covariates,
+ get_model_covariates,
+ make_group_panel_values,
+ make_main_values,
+ set_default_values,
+)
+
+
+def create_cap_data(model: Model, covariates: dict) -> pd.DataFrame:
+ """Create data for a Conditional Adjusted Predictions
+
+ Parameters
+ ----------
+ model : bambi.Model
+ An instance of a Bambi model
+ covariates : dict
+ A dictionary of length between one and three.
+ Keys must be taken from ("horizontal", "color", "panel").
+ The values indicate the names of variables.
+
+ Returns
+ -------
+ pandas.DataFrame
+ The data for the Conditional Adjusted Predictions dataframe and or
+ plotting.
+ """
+ data = model.data
+ covariates = get_covariates(covariates)
+ main, group, panel = covariates.main, covariates.group, covariates.panel
+
+ # Obtain data for main variable
+ main_values = make_main_values(data[main])
+ data_dict = {main: main_values}
+
+ # Obtain data for group and panel variables if not None
+ data_dict = make_group_panel_values(data, data_dict, main, group, panel, kind="predictions")
+ data_dict = set_default_values(model, data_dict, kind="predictions")
+ return enforce_dtypes(data, pd.DataFrame(data_dict))
+
+
+def create_comparisons_data(
+ condition: ConditionalInfo, contrast: ContrastInfo, user_passed: bool = False
+) -> pd.DataFrame:
+ """Create data for a Conditional Adjusted Comparisons
+
+ Parameters
+ ----------
+ condition: ConditionalInfo
+ A dataclass instance containing the model, contrast, and conditional
+ covariates to be used in the comparisons.
+ contrast: ContrastInfo
+ A dataclass instance containing the model, and contrast name and values.
+ user_passed: bool, optional
+ Whether the user passed their own 'conditional' data. Defaults to False.
+
+ Returns
+ -------
+ pd.DataFrame
+ The data for the Conditional Adjusted Comparisons dataframe and or
+ plotting.
+ """
+
+ def _grid_level(condition: ConditionalInfo, contrast: ContrastInfo):
+ """
+ Creates the data for grid-level contrasts by using the covariates passed
+ into the `conditional` arg. Values for the grid are either: (1) computed
+ using a equally spaced grid, mean, and or mode (depending on the covariate
+ dtype), and (2) a user specified value or range of values.
+ """
+ covariates = get_covariates(condition.covariates)
+
+ if user_passed:
+ data_dict = {**condition.conditional}
+ else:
+ main_values = make_main_values(condition.model.data[covariates.main])
+ data_dict = {covariates.main: main_values}
+ data_dict = make_group_panel_values(
+ condition.model.data,
+ data_dict,
+ covariates.main,
+ covariates.group,
+ covariates.panel,
+ kind="comparison",
+ )
+
+ data_dict[contrast.name] = contrast.values
+ comparison_data = set_default_values(condition.model, data_dict, kind="comparison")
+ # use cartesian product (cross join) to create contrasts
+ keys, values = zip(*comparison_data.items())
+ contrast_dict = [dict(zip(keys, v)) for v in itertools.product(*values)]
+
+ return enforce_dtypes(condition.model.data, pd.DataFrame(contrast_dict))
+
+ def _unit_level(contrast: ContrastInfo):
+ """
+ Creates the data for unit-level contrasts by using the observed (empirical)
+ data. All covariates in the model are included in the data, except for the
+ contrast predictor. The contrast predictor is replaced with either: (1) the
+ default contrast value, or (2) the user specified contrast value.
+ """
+ covariates = get_model_covariates(contrast.model)
+ df = contrast.model.data[covariates].drop(labels=contrast.name, axis=1)
+
+ contrast_vals = np.array(contrast.values)[..., None]
+ contrast_vals = np.repeat(contrast_vals, contrast.model.data.shape[0], axis=1)
+
+ contrast_df_dict = {}
+ for idx, value in enumerate(contrast_vals):
+ contrast_df_dict[f"contrast_{idx}"] = df.copy()
+ contrast_df_dict[f"contrast_{idx}"][contrast.name] = value
+
+ return pd.concat(contrast_df_dict.values())
+
+ if not condition.conditional:
+ df = _unit_level(contrast)
+ else:
+ df = _grid_level(condition, contrast)
+
+ return df
diff --git a/bambi/plots/effects.py b/bambi/plots/effects.py
new file mode 100644
index 000000000..e6005814a
--- /dev/null
+++ b/bambi/plots/effects.py
@@ -0,0 +1,449 @@
+# pylint: disable = protected-access
+# pylint: disable = too-many-function-args
+# pylint: disable = too-many-nested-blocks
+from dataclasses import dataclass, field
+import itertools
+from typing import Dict, Union
+
+import arviz as az
+import numpy as np
+import pandas as pd
+import xarray as xr
+
+from bambi.models import Model
+from bambi.plots.create_data import create_cap_data, create_comparisons_data
+from bambi.plots.utils import average_over, ConditionalInfo, ContrastInfo, enforce_dtypes, identity
+from bambi.utils import get_aliased_name, listify
+
+
+@dataclass
+class ResponseInfo:
+ name: str
+ target: Union[str, None] = None
+ lower_bound: float = 0.03
+ upper_bound: float = 0.97
+ name_target: str = field(init=False)
+ name_obs: str = field(init=False)
+ lower_bound_name: str = field(init=False)
+ upper_bound_name: str = field(init=False)
+
+ def __post_init__(self):
+ """
+ Assigns commonly used f-strings for indexing and column names as attributes.
+ """
+ if self.target is None:
+ self.name_target = self.name
+ else:
+ self.name_target = f"{self.name}_{self.target}"
+
+ self.name_obs = f"{self.name}_obs"
+ self.lower_bound_name = f"lower_{self.lower_bound * 100}%"
+ self.upper_bound_name = f"upper_{self.upper_bound * 100}%"
+
+
+def predictions(
+ model: Model,
+ idata: az.InferenceData,
+ covariates: Union[str, dict, list],
+ target: str = "mean",
+ pps: bool = False,
+ use_hdi: bool = True,
+ prob=None,
+ transforms=None,
+) -> pd.DataFrame:
+ """Compute Conditional Adjusted Predictions
+
+ Parameters
+ ----------
+ model : bambi.Model
+ The model for which we want to plot the predictions.
+ idata : arviz.InferenceData
+ The InferenceData object that contains the samples from the posterior distribution of
+ the model.
+ covariates : list or dict
+ A sequence of between one and three names of variables or a dict of length between one
+ and three.
+ If a sequence, the first variable is taken as the main variable and is mapped to the
+ horizontal axis. If present, the second name is a coloring/grouping variable,
+ and the third is mapped to different plot panels.
+ If a dictionary, keys must be taken from ("main", "group", "panel") and the values
+ are the names of the variables.
+ target : str
+ Which model parameter to plot. Defaults to 'mean'. Passing a parameter into target only
+ works when pps is False as the target may not be available in the posterior predictive
+ distribution.
+ pps: bool, optional
+ Whether to plot the posterior predictive samples. Defaults to ``False``.
+ use_hdi : bool, optional
+ Whether to compute the highest density interval (defaults to True) or the quantiles.
+ prob : float, optional
+ The probability for the credibility intervals. Must be between 0 and 1. Defaults to 0.94.
+ Changing the global variable ``az.rcParam["stats.hdi_prob"]`` affects this default.
+ transforms : dict, optional
+ Transformations that are applied to each of the variables being plotted. The keys are the
+ name of the variables, and the values are functions to be applied. Defaults to ``None``.
+
+ Returns
+ -------
+ cap_data : pandas.DataFrame
+ A DataFrame with the ``create_cap_data`` and model predictions.
+
+ Raises
+ ------
+ ValueError
+ If ``pps`` is ``True`` and ``target`` is not ``"mean"``.
+ If passed ``covariates`` is not in correct key, value format.
+ If length of ``covariates`` is not between 1 and 3.
+ """
+
+ if pps and target != "mean":
+ raise ValueError("When passing 'pps=True', target must be 'mean'")
+
+ covariate_kinds = ("main", "group", "panel")
+ if not isinstance(covariates, dict):
+ covariates = listify(covariates)
+ covariates = dict(zip(covariate_kinds, covariates))
+ else:
+ assert covariate_kinds[0] in covariates
+ assert set(covariates).issubset(set(covariate_kinds))
+
+ assert 1 <= len(covariates) <= 3
+
+ if transforms is None:
+ transforms = {}
+
+ if prob is None:
+ prob = az.rcParams["stats.hdi_prob"]
+
+ if not 0 < prob < 1:
+ raise ValueError(f"'prob' must be greater than 0 and smaller than 1. It is {prob}.")
+
+ cap_data = create_cap_data(model, covariates)
+
+ if target != "mean":
+ component = model.components[target]
+ if component.alias:
+ # use only the aliased name (without appended target)
+ response_name = get_aliased_name(component)
+ target = None
+ else:
+ # use the default response "y" and append target
+ response_name = get_aliased_name(model.response_component.response_term)
+ else:
+ response_name = get_aliased_name(model.response_component.response_term)
+
+ response = ResponseInfo(response_name, target)
+ response_transform = transforms.get(response_name, identity)
+
+ if pps:
+ idata = model.predict(idata, data=cap_data, inplace=False, kind="pps")
+ y_hat = response_transform(idata.posterior_predictive[response.name])
+ y_hat_mean = y_hat.mean(("chain", "draw"))
+ else:
+ idata = model.predict(idata, data=cap_data, inplace=False)
+ y_hat = response_transform(idata.posterior[response.name_target])
+ y_hat_mean = y_hat.mean(("chain", "draw"))
+
+ if use_hdi and pps:
+ y_hat_bounds = az.hdi(y_hat, prob)[response.name].T
+ elif use_hdi:
+ y_hat_bounds = az.hdi(y_hat, prob)[response.name_target].T
+ else:
+ lower_bound = round((1 - prob) / 2, 4)
+ upper_bound = 1 - lower_bound
+ y_hat_bounds = y_hat.quantile(q=(lower_bound, upper_bound), dim=("chain", "draw"))
+
+ lower_bound = round((1 - prob) / 2, 4)
+ upper_bound = 1 - lower_bound
+ response.lower_bound, response.upper_bound = lower_bound, upper_bound
+
+ cap_data["estimate"] = y_hat_mean
+ cap_data[response.lower_bound_name] = y_hat_bounds[0]
+ cap_data[response.upper_bound_name] = y_hat_bounds[1]
+
+ return cap_data
+
+
+@dataclass
+class ContrastEstimate:
+ comparison: Dict[str, xr.DataArray]
+ hdi: Dict[str, xr.Dataset]
+
+
+def comparisons(
+ model: Model,
+ idata: az.InferenceData,
+ contrast: Union[str, dict, list],
+ conditional: Union[str, dict, list, None] = None,
+ average_by: Union[str, list, bool, None] = None,
+ comparison_type: str = "diff",
+ use_hdi: bool = True,
+ prob=None,
+ transforms=None,
+) -> pd.DataFrame:
+ """Compute Conditional Adjusted Comparisons
+
+ Parameters
+ ----------
+ model : bambi.Model
+ The model for which we want to plot the predictions.
+ idata : arviz.InferenceData
+ The InferenceData object that contains the samples from the posterior distribution of
+ the model.
+ contrast : str, dict, list
+ The predictor name whose contrast we would like to compare.
+ conditional : str, dict, list
+ The covariates we would like to condition on.
+ average_by: str, list, bool, optional
+ The covariates we would like to average by. The passed covariate(s) will marginalize
+ over the other covariates in the model. If True, it averages over all covariates
+ in the model to obtain the average estimate. Defaults to ``None``.
+ comparison_type : str, optional
+ The type of comparison to plot. Defaults to 'diff'.
+ use_hdi : bool, optional
+ Whether to compute the highest density interval (defaults to True) or the quantiles.
+ prob : float, optional
+ The probability for the credibility intervals. Must be between 0 and 1. Defaults to 0.94.
+ Changing the global variable ``az.rcParam["stats.hdi_prob"]`` affects this default.
+ transforms : dict, optional
+ Transformations that are applied to each of the variables being plotted. The keys are the
+ name of the variables, and the values are functions to be applied. Defaults to ``None``.
+
+ Returns
+ -------
+ pandas.DataFrame
+ A dataframe with the comparison values, highest density interval, contrast name,
+ contrast value, and conditional values.
+
+ Raises
+ ------
+ ValueError
+ If length of ``contrast`` is greater than 1.
+ If ``contrast`` is not a string, dictionary, or list.
+ If ``comparison_type`` is not 'diff' or 'ratio'.
+ If ``prob`` is not > 0 and < 1.
+ """
+
+ if not isinstance(contrast, (dict, list, str)):
+ raise ValueError("'contrast' must be a string, dictionary, or list.")
+ if isinstance(contrast, (dict, list)):
+ if len(contrast) > 1:
+ raise ValueError(
+ f"Only one contrast predictor can be passed. {len(contrast)} were passed."
+ )
+
+ if comparison_type not in ("diff", "ratio"):
+ raise ValueError("'comparison_type' must be 'diff' or 'ratio'")
+
+ if prob is None:
+ prob = az.rcParams["stats.hdi_prob"]
+ if not 0 < prob < 1:
+ raise ValueError(f"'prob' must be greater than 0 and smaller than 1. It is {prob}.")
+
+ comparison_functions = {"diff": lambda x, y: x - y, "ratio": lambda x, y: x / y}
+ lower_bound = round((1 - prob) / 2, 4)
+ upper_bound = 1 - lower_bound
+
+ contrast_info = ContrastInfo(model, contrast)
+ conditional_info = ConditionalInfo(model, conditional)
+
+ # 'comparisons' should not be restricted to ("main", "group", "panel")
+ comparisons_df = create_comparisons_data(
+ conditional_info, contrast_info, user_passed=conditional_info.user_passed
+ )
+
+ if transforms is None:
+ transforms = {}
+
+ response_name = get_aliased_name(model.response_component.response_term)
+ response = ResponseInfo(
+ response_name, target="mean", lower_bound=lower_bound, upper_bound=upper_bound
+ )
+
+ # perform predictions on new data
+ idata = model.predict(idata, data=comparisons_df, inplace=False)
+
+ def _compute_contrast_estimate(
+ contrast: ContrastInfo,
+ response: ResponseInfo,
+ comparisons_df: pd.DataFrame,
+ idata: az.InferenceData,
+ ) -> ContrastEstimate:
+ """
+ Computes the contrast comparison estimate and highest density interval
+ for a given contrast and response by first subsetting posterior draws
+ using a contrast mask. Then, pairwise comparisons are computed for the
+ contrast values. Finally, the mean comparison and lower/upper bounds
+ are computed for each pairwise comparison.
+ """
+ function = comparison_functions[comparison_type]
+
+ draws = {}
+ for idx, val in enumerate(contrast.values):
+ mask = np.array(comparisons_df[contrast.name] == contrast.values[idx])
+ select_draw = idata.posterior[response.name_target].sel({response.name_obs: mask})
+ select_draw = select_draw.assign_coords(
+ {response.name_obs: np.arange(len(select_draw.coords[response.name_obs]))}
+ )
+ draws[val] = select_draw
+
+ pairwise_contrasts = list(itertools.combinations(contrast.values, 2))
+
+ comparison_mean = {}
+ comparison_bounds = {}
+ for idx, pair in enumerate(pairwise_contrasts):
+ comparison_estimate = function(draws[pair[1]], draws[pair[0]])
+ comparison_mean[pair] = comparison_estimate.mean(("chain", "draw"))
+ if use_hdi:
+ comparison_bounds[pair] = az.hdi(comparison_estimate, prob)
+ else:
+ comparison_bounds[pair] = comparison_estimate.quantile(
+ q=(response.lower_bound, response.upper_bound), dim=("chain", "draw")
+ )
+
+ return ContrastEstimate(comparison_mean, comparison_bounds)
+
+ def _build_contrasts_df(
+ contrast: ContrastInfo,
+ condition: ConditionalInfo,
+ response: ResponseInfo,
+ comparisons_df: pd.DataFrame,
+ idata: az.InferenceData,
+ average_by,
+ ) -> pd.DataFrame:
+ """
+ Builds a dataframe with the comparison values and lower / upper bounds from
+ ``_compute_contrast_estimate`` along with the contrast name, contrast value,
+ and conditional values.
+ """
+ contrast_estimate = _compute_contrast_estimate(contrast, response, comparisons_df, idata)
+
+ # if two contrast values, then can drop duplicates to build contrast_df
+ if len(contrast.values) < 3:
+ if not any(condition.covariates.values()):
+ contrast_df = model.data[comparisons_df.columns].drop(columns=contrast.name)
+ num_rows = contrast_df.shape[0]
+ contrast_df.insert(0, "term", contrast.name)
+ contrast_df.insert(
+ 1, "contrast", list(np.tile(contrast.values, num_rows).reshape(num_rows, 2))
+ )
+ contrast_df["estimate"] = contrast_estimate.comparison[
+ tuple(contrast.values)
+ ].to_numpy()
+ else:
+ contrast_df = comparisons_df.drop_duplicates(
+ list(condition.covariates.values())
+ ).reset_index(drop=True)
+ contrast_df = contrast_df.drop(columns=contrast.name)
+ num_rows = contrast_df.shape[0]
+ contrast_df.insert(0, "term", contrast.name)
+ contrast_df.insert(
+ 1, "contrast", list(np.tile(contrast.values, num_rows).reshape(num_rows, 2))
+ )
+ contrast_df["estimate"] = contrast_estimate.comparison[
+ tuple(contrast.values)
+ ].to_numpy()
+
+ if use_hdi:
+ contrast_df[response.lower_bound_name] = (
+ contrast_estimate.hdi[tuple(contrast.values)][response.name_target]
+ .sel(hdi="lower")
+ .values
+ )
+ contrast_df[response.upper_bound_name] = (
+ contrast_estimate.hdi[tuple(contrast.values)][response.name_target]
+ .sel(hdi="higher")
+ .values
+ )
+ else:
+ contrast_df[response.lower_bound_name] = contrast_estimate.hdi[
+ tuple(contrast.values)
+ ].sel(quantile=lower_bound)
+ contrast_df[response.upper_bound_name] = contrast_estimate.hdi[
+ tuple(contrast.values)
+ ].sel(quantile=upper_bound)
+
+ # if > 2 contrast values, then need the full dataframe to build contrast_df
+ elif len(contrast.values) >= 3:
+ contrast_keys = [list(elem) for elem in list(contrast_estimate.comparison.keys())]
+ covariate_cols = comparisons_df.drop(columns=contrast.name).columns
+ covariate_vals = (
+ comparisons_df.drop(columns=contrast.name).drop_duplicates().reset_index(drop=True)
+ ).values
+ covariate_vals = np.tile(np.transpose(covariate_vals), len(contrast.values))
+
+ contrast_df = (
+ pd.DataFrame(contrast_estimate.comparison)
+ .unstack()
+ .reset_index()
+ .rename(columns={0: "estimate"})
+ )
+
+ # this hardcoded subset will not work for cross-contrasts
+ contrast_df.insert(0, "term", contrast.name)
+ contrast_df.insert(
+ 1, "contrast", tuple(zip(contrast_df["level_0"], contrast_df["level_1"]))
+ )
+ contrast_df = contrast_df.drop(["level_0", "level_1", "level_2"], axis=1)
+
+ lower = []
+ upper = []
+ for pair in contrast_keys:
+ if use_hdi:
+ lower.append(
+ (
+ contrast_estimate.hdi[tuple(pair)][response.name_target]
+ .sel(hdi="lower")
+ .values
+ )
+ )
+ upper.append(
+ (
+ contrast_estimate.hdi[tuple(pair)][response.name_target]
+ .sel(hdi="higher")
+ .values
+ )
+ )
+ else:
+ lower.append(contrast_estimate.hdi[tuple(pair)].sel(quantile=lower_bound))
+ upper.append(contrast_estimate.hdi[tuple(pair)].sel(quantile=upper_bound))
+
+ contrast_df[covariate_cols] = np.transpose(covariate_vals)
+ contrast_df[response.lower_bound_name] = np.array(lower).flatten()
+ contrast_df[response.upper_bound_name] = np.array(upper).flatten()
+ contrast_df.insert(
+ len(contrast_df.columns) - 3, "estimate", contrast_df.pop("estimate")
+ )
+ contrast_df = enforce_dtypes(model.data, contrast_df)
+
+ contrast_df["contrast"] = contrast_df["contrast"].apply(tuple)
+
+ if average_by:
+ if average_by is True:
+ contrast_df_avg = average_over(contrast_df, None)
+ contrast_df_avg.insert(0, "term", contrast.name)
+ contrast_df_avg.insert(
+ 1,
+ "contrast",
+ np.tile(contrast_df["contrast"].drop_duplicates(), len(contrast_df_avg)),
+ )
+ else:
+ contrast_df_avg = average_over(contrast_df, average_by)
+ contrast_df_avg.insert(0, "term", contrast.name)
+ contrast_df_avg.insert(
+ 1,
+ "contrast",
+ np.tile(contrast_df["contrast"].drop_duplicates(), len(contrast_df_avg)),
+ )
+ return contrast_df_avg.reset_index(drop=True)
+ else:
+ return contrast_df.reset_index(drop=True)
+
+ return _build_contrasts_df(
+ contrast_info,
+ conditional_info,
+ response,
+ comparisons_df,
+ idata,
+ average_by,
+ )
diff --git a/bambi/plots/plot_cap.py b/bambi/plots/plot_cap.py
deleted file mode 100644
index afa515a2c..000000000
--- a/bambi/plots/plot_cap.py
+++ /dev/null
@@ -1,425 +0,0 @@
-# pylint: disable = protected-access
-# pylint: disable = too-many-function-args
-# pylint: disable = too-many-nested-blocks
-from statistics import mode
-
-import arviz as az
-import numpy as np
-import pandas as pd
-
-from arviz.plots.backends.matplotlib import create_axes_grid
-from arviz.plots.plot_utils import default_grid
-from formulae.terms.call import Call
-from matplotlib.lines import Line2D
-from matplotlib.patches import Patch
-from pandas.api.types import is_categorical_dtype, is_numeric_dtype, is_string_dtype
-
-from bambi.utils import listify, get_aliased_name
-from bambi.plots.utils import get_group_offset, get_unique_levels
-
-
-def create_cap_data(model, covariates, grid_n=200, groups_n=5):
- """Create data for a Conditional Adjusted Predictions plot
-
- Parameters
- ----------
- model : bambi.Model
- An instance of a Bambi model
- covariates : dict
- A dictionary of length between one and three.
- Keys must be taken from ("horizontal", "color", "panel").
- The values indicate the names of variables.
- grid_n : int, optional
- The number of points used to evaluate the main covariate. Defaults to 200.
- groups_n : int, optional
- The number of groups to create when the grouping variable is numeric. Groups are based on
- equally spaced points. Defaults to 5.
-
- Returns
- -------
- pandas.DataFrame
- The data for the Conditional Adjusted Predictions plot.
-
- Raises
- ------
- ValueError
- When the number of covariates is larger than 2.
- When either the main or the group covariates are not numeric or categoric.
- """
- data = model.data
-
- main = covariates.get("horizontal")
- group = covariates.get("color", None)
- panel = covariates.get("panel", None)
-
- # Obtain data for main variable
- main_values = make_main_values(data[main], grid_n)
- main_n = len(main_values)
-
- # If available, obtain groups for grouping variable
- if group:
- group_values = make_group_values(data[group], groups_n)
- group_n = len(group_values)
-
- # If available, obtain groups for panel variable. Same logic than grouping applies
- if panel:
- panel_values = make_group_values(data[panel], groups_n)
- panel_n = len(panel_values)
-
- data_dict = {main: main_values}
-
- if group and not panel:
- main_values = np.tile(main_values, group_n)
- group_values = np.repeat(group_values, main_n)
- data_dict.update({main: main_values, group: group_values})
- elif not group and panel:
- main_values = np.tile(main_values, panel_n)
- panel_values = np.repeat(panel_values, main_n)
- data_dict.update({main: main_values, panel: panel_values})
- elif group and panel:
- if group == panel:
- main_values = np.tile(main_values, group_n)
- group_values = np.repeat(group_values, main_n)
- data_dict.update({main: main_values, group: group_values})
- else:
- main_values = np.tile(np.tile(main_values, group_n), panel_n)
- group_values = np.tile(np.repeat(group_values, main_n), panel_n)
- panel_values = np.repeat(panel_values, main_n * group_n)
- data_dict.update({main: main_values, group: group_values, panel: panel_values})
-
- # Construct dictionary of terms that are in the model.
- # See it includes the terms for _all_ the distributional components, not just the response
- terms = {}
- for component in model.distributional_components.values():
- if component.design.common:
- terms.update(component.design.common.terms)
-
- if component.design.group:
- terms.update(component.design.group.terms)
-
- # Get default values for each variable in the model
- for term in terms.values():
- if hasattr(term, "components"):
- for component in term.components:
- # If the component is a function call, use the argument names
- if isinstance(component, Call):
- names = [arg.name for arg in component.call.args]
- else:
- names = [component.name]
-
- for name in names:
- if name not in data_dict:
- # For numeric predictors, select the mean.
- if component.kind == "numeric":
- data_dict[name] = np.mean(data[name])
- # For categoric predictors, select the most frequent level.
- elif component.kind == "categoric":
- data_dict[name] = mode(data[name])
-
- cap_data = pd.DataFrame(data_dict)
-
- # Make sure new types are same types than the original columns
- for column in cap_data:
- cap_data[column] = cap_data[column].astype(data[column].dtype)
- return cap_data
-
-
-def plot_cap(
- model,
- idata,
- covariates,
- target="mean",
- pps=False,
- use_hdi=True,
- hdi_prob=None,
- transforms=None,
- legend=True,
- ax=None,
- fig_kwargs=None,
-):
- """Plot Conditional Adjusted Predictions
-
- Parameters
- ----------
- model : bambi.Model
- The model for which we want to plot the predictions.
- idata : arviz.InferenceData
- The InferenceData object that contains the samples from the posterior distribution of
- the model.
- covariates : list or dict
- A sequence of between one and three names of variables or a dict of length between one
- and three.
- If a sequence, the first variable is taken as the main variable,
- mapped to the horizontal axis. If present, the second name is a coloring/grouping variable,
- and the third is mapped to different plot panels.
- If a dictionary, keys must be taken from ("horizontal", "color", "panel") and the values
- are the names of the variables.
- target : str
- Which model parameter to plot. Defaults to 'mean'. Passing a parameter into target only
- works when pps is False as the target may not be available in the posterior predictive
- distribution.
- pps: bool, optional
- Whether to plot the posterior predictive samples. Defaults to ``False``.
- use_hdi : bool, optional
- Whether to compute the highest density interval (defaults to True) or the quantiles.
- hdi_prob : float, optional
- The probability for the credibility intervals. Must be between 0 and 1. Defaults to 0.94.
- Changing the global variable ``az.rcParam["stats.hdi_prob"]`` affects this default.
- legend : bool, optional
- Whether to automatically include a legend in the plot. Defaults to ``True``.
- transforms : dict, optional
- Transformations that are applied to each of the variables being plotted. The keys are the
- name of the variables, and the values are functions to be applied. Defaults to ``None``.
- ax : matplotlib.axes._subplots.AxesSubplot, optional
- A matplotlib axes object or a sequence of them. If None, this function instantiates a
- new axes object. Defaults to ``None``.
-
- Returns
- -------
- matplotlib.figure.Figure, matplotlib.axes._subplots.AxesSubplot
- A tuple with the figure and the axes.
-
- Raises
- ------
- ValueError
- When ``level`` is not within 0 and 1.
- When the main covariate is not numeric or categoric.
- """
-
- covariate_kinds = ("horizontal", "color", "panel")
- if not isinstance(covariates, dict):
- covariates = listify(covariates)
- covariates = dict(zip(covariate_kinds, covariates))
- else:
- assert covariate_kinds[0] in covariates
- assert set(covariates).issubset(set(covariate_kinds))
-
- assert 1 <= len(covariates) <= 3
-
- if hdi_prob is None:
- hdi_prob = az.rcParams["stats.hdi_prob"]
-
- if not 0 < hdi_prob < 1:
- raise ValueError(f"'hdi_prob' must be greater than 0 and smaller than 1. It is {hdi_prob}.")
-
- cap_data = create_cap_data(model, covariates)
-
- if transforms is None:
- transforms = {}
-
- response_name = get_aliased_name(model.response_component.response_term)
- response_transform = transforms.get(response_name, identity)
-
- if pps:
- idata = model.predict(idata, data=cap_data, inplace=False, kind="pps")
- y_hat = response_transform(idata.posterior_predictive[response_name])
- y_hat_mean = y_hat.mean(("chain", "draw"))
- else:
- idata = model.predict(idata, data=cap_data, inplace=False)
- y_hat = response_transform(idata.posterior[f"{response_name}_{target}"])
- y_hat_mean = y_hat.mean(("chain", "draw"))
-
- if use_hdi and pps:
- y_hat_bounds = az.hdi(y_hat, hdi_prob)[response_name].T
- elif use_hdi:
- y_hat_bounds = az.hdi(y_hat, hdi_prob)[f"{response_name}_{target}"].T
- else:
- lower_bound = round((1 - hdi_prob) / 2, 4)
- upper_bound = 1 - lower_bound
- y_hat_bounds = y_hat.quantile(q=(lower_bound, upper_bound), dim=("chain", "draw"))
-
- if ax is None:
- fig_kwargs = {} if fig_kwargs is None else fig_kwargs
- panel = covariates.get("panel", None)
- panels_n = len(np.unique(cap_data[panel])) if panel else 1
- rows, cols = default_grid(panels_n)
- fig, axes = create_axes_grid(panels_n, rows, cols, backend_kwargs=fig_kwargs)
- axes = np.atleast_1d(axes)
- else:
- axes = np.atleast_1d(ax)
- if isinstance(axes[0], np.ndarray):
- fig = axes[0][0].get_figure()
- else:
- fig = axes[0].get_figure()
-
- main = covariates.get("horizontal")
- if is_numeric_dtype(cap_data[main]):
- axes = _plot_cap_numeric(
- covariates, cap_data, y_hat_mean, y_hat_bounds, transforms, legend, axes
- )
- elif is_categorical_dtype(cap_data[main]) or is_string_dtype(cap_data[main]):
- axes = _plot_cap_categoric(covariates, cap_data, y_hat_mean, y_hat_bounds, legend, axes)
- else:
- raise ValueError("Main covariate must be numeric or categoric.")
-
- ylabel = response_name if target == "mean" else target
- for ax in axes.ravel(): # pylint: disable = redefined-argument-from-local
- ax.set(xlabel=main, ylabel=ylabel)
-
- return fig, axes
-
-
-def _plot_cap_numeric(covariates, cap_data, y_hat_mean, y_hat_bounds, transforms, legend, axes):
- main = covariates.get("horizontal")
- transform_main = transforms.get(main, identity)
-
- if len(covariates) == 1:
- ax = axes[0]
- values_main = transform_main(cap_data[main])
- ax.plot(values_main, y_hat_mean, solid_capstyle="butt")
- ax.fill_between(values_main, y_hat_bounds[0], y_hat_bounds[1], alpha=0.4)
- elif "color" in covariates and not "panel" in covariates:
- ax = axes[0]
- color = covariates.get("color")
- colors = get_unique_levels(cap_data[color])
- for i, clr in enumerate(colors):
- idx = (cap_data[color] == clr).to_numpy()
- values_main = transform_main(cap_data.loc[idx, main])
- ax.plot(values_main, y_hat_mean[idx], color=f"C{i}", solid_capstyle="butt")
- ax.fill_between(
- values_main,
- y_hat_bounds[0][idx],
- y_hat_bounds[1][idx],
- alpha=0.4,
- color=f"C{i}",
- )
- elif not "color" in covariates and "panel" in covariates:
- panel = covariates.get("panel")
- panels = get_unique_levels(cap_data[panel])
- for ax, pnl in zip(axes.ravel(), panels):
- idx = (cap_data[panel] == pnl).to_numpy()
- values_main = transform_main(cap_data.loc[idx, main])
- ax.plot(values_main, y_hat_mean[idx], solid_capstyle="butt")
- ax.fill_between(values_main, y_hat_bounds[0][idx], y_hat_bounds[1][idx], alpha=0.4)
- ax.set(title=f"{panel} = {pnl}")
- elif "color" in covariates and "panel" in covariates:
- color = covariates.get("color")
- panel = covariates.get("panel")
- colors = get_unique_levels(cap_data[color])
- panels = get_unique_levels(cap_data[panel])
- if color == panel:
- for i, (ax, pnl) in enumerate(zip(axes.ravel(), panels)):
- idx = (cap_data[panel] == pnl).to_numpy()
- values_main = transform_main(cap_data.loc[idx, main])
- ax.plot(values_main, y_hat_mean[idx], color=f"C{i}", solid_capstyle="butt")
- ax.fill_between(
- values_main,
- y_hat_bounds[0][idx],
- y_hat_bounds[1][idx],
- alpha=0.4,
- color=f"C{i}",
- )
- ax.set(title=f"{panel} = {pnl}")
- else:
- for ax, pnl in zip(axes.ravel(), panels):
- for i, clr in enumerate(colors):
- idx = ((cap_data[panel] == pnl) & (cap_data[color] == clr)).to_numpy()
- values_main = transform_main(cap_data.loc[idx, main])
- ax.plot(values_main, y_hat_mean[idx], color=f"C{i}", solid_capstyle="butt")
- ax.fill_between(
- values_main,
- y_hat_bounds[0][idx],
- y_hat_bounds[1][idx],
- alpha=0.4,
- color=f"C{i}",
- )
- ax.set(title=f"{panel} = {pnl}")
-
- if "color" in covariates and legend:
- handles = [
- (
- Line2D([], [], color=f"C{i}", solid_capstyle="butt"),
- Patch(color=f"C{i}", alpha=0.4, lw=1),
- )
- for i in range(len(colors))
- ]
- for ax in axes.ravel():
- ax.legend(
- handles, tuple(colors), title=color, handlelength=1.3, handleheight=1, loc="best"
- )
- return axes
-
-
-def _plot_cap_categoric(covariates, cap_data, y_hat_mean, y_hat_bounds, legend, axes):
- main = covariates.get("horizontal")
- main_levels = get_unique_levels(cap_data[main])
- main_levels_n = len(main_levels)
- idxs_main = np.arange(main_levels_n)
-
- if "color" in covariates:
- color = covariates.get("color")
- colors = get_unique_levels(cap_data[color])
- colors_n = len(colors)
- offset_bounds = get_group_offset(colors_n)
- colors_offset = np.linspace(-offset_bounds, offset_bounds, colors_n)
-
- if "panel" in covariates:
- panel = covariates.get("panel")
- panels = get_unique_levels(cap_data[panel])
-
- if len(covariates) == 1:
- ax = axes[0]
- ax.scatter(idxs_main, y_hat_mean)
- ax.vlines(idxs_main, y_hat_bounds[0], y_hat_bounds[1])
- elif "color" in covariates and not "panel" in covariates:
- ax = axes[0]
- for i, clr in enumerate(colors):
- idx = (cap_data[color] == clr).to_numpy()
- idxs = idxs_main + colors_offset[i]
- ax.scatter(idxs, y_hat_mean[idx], color=f"C{i}")
- ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")
- elif not "color" in covariates and "panel" in covariates:
- for ax, pnl in zip(axes.ravel(), panels):
- idx = (cap_data[panel] == pnl).to_numpy()
- ax.scatter(idxs_main, y_hat_mean[idx])
- ax.vlines(idxs_main, y_hat_bounds[0][idx], y_hat_bounds[1][idx])
- ax.set(title=f"{panel} = {pnl}")
- elif "color" in covariates and "panel" in covariates:
- if color == panel:
- for i, (ax, pnl) in enumerate(zip(axes.ravel(), panels)):
- idx = (cap_data[panel] == pnl).to_numpy()
- idxs = idxs_main + colors_offset[i]
- ax.scatter(idxs, y_hat_mean[idx], color=f"C{i}")
- ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")
- ax.set(title=f"{panel} = {pnl}")
- else:
- for ax, pnl in zip(axes.ravel(), panels):
- for i, clr in enumerate(colors):
- idx = ((cap_data[panel] == pnl) & (cap_data[color] == clr)).to_numpy()
- idxs = idxs_main + colors_offset[i]
- ax.scatter(idxs, y_hat_mean[idx], color=f"C{i}")
- ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")
- ax.set(title=f"{panel} = {pnl}")
-
- if "color" in covariates and legend:
- handles = [
- Line2D([], [], c=f"C{i}", marker="o", label=level) for i, level in enumerate(colors)
- ]
- for ax in axes.ravel():
- ax.legend(handles=handles, title=color, loc="best")
-
- for ax in axes.ravel():
- ax.set_xticks(idxs_main)
- ax.set_xticklabels(main_levels)
-
- return axes
-
-
-def identity(x):
- return x
-
-
-def make_main_values(x, grid_n):
- if is_numeric_dtype(x):
- return np.linspace(np.min(x), np.max(x), grid_n)
- elif is_string_dtype(x) or is_categorical_dtype(x):
- return np.unique(x)
- raise ValueError("Main covariate must be numeric or categoric.")
-
-
-def make_group_values(x, groups_n):
- if is_string_dtype(x) or is_categorical_dtype(x):
- return np.unique(x)
- elif is_numeric_dtype(x):
- return np.quantile(x, np.linspace(0, 1, groups_n))
- raise ValueError("Group covariate must be numeric or categoric.")
diff --git a/bambi/plots/plot_types.py b/bambi/plots/plot_types.py
new file mode 100644
index 000000000..11526b2d1
--- /dev/null
+++ b/bambi/plots/plot_types.py
@@ -0,0 +1,201 @@
+from matplotlib.lines import Line2D
+from matplotlib.patches import Patch
+import numpy as np
+import pandas as pd
+
+from bambi.plots.utils import Covariates, get_unique_levels, get_group_offset, identity
+
+
+def plot_numeric(
+ covariates: Covariates,
+ plot_data: pd.DataFrame,
+ transforms: dict,
+ legend: bool = True,
+ axes=None,
+):
+ """Plotting of numeric data types.
+
+ Parameters
+ ----------
+ covariates : Covariates
+ Covariates callable with attributes main, group, panel.
+ plot_data : pd.DataFrame
+ The data created by the `create_cap_data` or `create_comparisons_data`
+ function.
+ transforms : dict
+ Transformations that are applied to each of the variables being plotted. The keys are the
+ name of the variables, and the values are functions to be applied. Defaults to `None`.
+ legend : bool, optional
+ Whether to include a legend in the plot. Default to `True`.
+ axes : np.ndarray, optional
+ Array of axes. Defaults to `None`.
+
+ Returns
+ -------
+ axes : np.ndarray
+ Array of axes.
+ """
+
+ main, color, panel = covariates.main, covariates.group, covariates.panel
+ covariates = {k: v for k, v in vars(covariates).items() if v is not None}
+ transform_main = transforms.get(main, identity)
+ y_hat_mean = plot_data["estimate"]
+ y_hat_bounds = np.transpose(plot_data[plot_data.columns[-2:]].values)
+
+ if len(covariates) == 1:
+ ax = axes[0]
+ values_main = transform_main(plot_data[main])
+ ax.plot(values_main, y_hat_mean, solid_capstyle="butt")
+ ax.fill_between(values_main, y_hat_bounds[0], y_hat_bounds[1], alpha=0.4)
+ elif "group" in covariates and not "panel" in covariates:
+ ax = axes[0]
+ colors = get_unique_levels(plot_data[color])
+ for i, clr in enumerate(colors):
+ idx = (plot_data[color] == clr).to_numpy()
+ values_main = transform_main(plot_data.loc[idx, main])
+ ax.plot(values_main, y_hat_mean[idx], color=f"C{i}", solid_capstyle="butt")
+ ax.fill_between(
+ values_main,
+ y_hat_bounds[0][idx],
+ y_hat_bounds[1][idx],
+ alpha=0.4,
+ color=f"C{i}",
+ )
+ elif not "group" in covariates and "panel" in covariates:
+ panels = get_unique_levels(plot_data[panel])
+ for ax, pnl in zip(axes.ravel(), panels):
+ idx = (plot_data[panel] == pnl).to_numpy()
+ values_main = transform_main(plot_data.loc[idx, main])
+ ax.plot(values_main, y_hat_mean[idx], solid_capstyle="butt")
+ ax.fill_between(values_main, y_hat_bounds[0][idx], y_hat_bounds[1][idx], alpha=0.4)
+ ax.set(title=f"{panel} = {pnl}")
+ elif "group" in covariates and "panel" in covariates:
+ colors = get_unique_levels(plot_data[color])
+ panels = get_unique_levels(plot_data[panel])
+ if color == panel:
+ for i, (ax, pnl) in enumerate(zip(axes.ravel(), panels)):
+ idx = (plot_data[panel] == pnl).to_numpy()
+ values_main = transform_main(plot_data.loc[idx, main])
+ ax.plot(values_main, y_hat_mean[idx], color=f"C{i}", solid_capstyle="butt")
+ ax.fill_between(
+ values_main,
+ y_hat_bounds[0][idx],
+ y_hat_bounds[1][idx],
+ alpha=0.4,
+ color=f"C{i}",
+ )
+ ax.set(title=f"{panel} = {pnl}")
+ else:
+ for ax, pnl in zip(axes.ravel(), panels):
+ for i, clr in enumerate(colors):
+ idx = ((plot_data[panel] == pnl) & (plot_data[color] == clr)).to_numpy()
+ values_main = transform_main(plot_data.loc[idx, main])
+ ax.plot(values_main, y_hat_mean[idx], color=f"C{i}", solid_capstyle="butt")
+ ax.fill_between(
+ values_main,
+ y_hat_bounds[0][idx],
+ y_hat_bounds[1][idx],
+ alpha=0.4,
+ color=f"C{i}",
+ )
+ ax.set(title=f"{panel} = {pnl}")
+
+ if "group" in covariates and legend:
+ handles = [
+ (
+ Line2D([], [], color=f"C{i}", solid_capstyle="butt"),
+ Patch(color=f"C{i}", alpha=0.4, lw=1),
+ )
+ for i in range(len(colors))
+ ]
+ for ax in axes.ravel():
+ ax.legend(
+ handles, tuple(colors), title=color, handlelength=1.3, handleheight=1, loc="best"
+ )
+ return axes
+
+
+def plot_categoric(covariates: Covariates, plot_data: pd.DataFrame, legend: bool = True, axes=None):
+ """Plotting of categorical data types.
+
+ Parameters
+ ----------
+ covariates : Covariates
+ Covariates callable with attributes main, gro up, panel.
+ plot_data : pd.DataFrame
+ The data created by the `create_cap_data` or `create_comparisons_data`
+ function.
+ legend : bool, optional
+ Whether to include a legend in the plot. Default to `True`.
+ axes : np.ndarray, optional
+ Array of axes. Defaults to `None`.
+
+ Returns
+ -------
+ axes : np.ndarray
+ Array of axes.
+ """
+
+ main, color, panel = covariates.main, covariates.group, covariates.panel
+ covariates = {k: v for k, v in vars(covariates).items() if v is not None}
+ main_levels = get_unique_levels(plot_data[main])
+ main_levels_n = len(main_levels)
+ idxs_main = np.arange(main_levels_n)
+ y_hat_mean = plot_data["estimate"]
+ y_hat_bounds = np.transpose(plot_data[plot_data.columns[-2:]].values)
+
+ if "group" in covariates:
+ colors = get_unique_levels(plot_data[color])
+ colors_n = len(colors)
+ offset_bounds = get_group_offset(colors_n)
+ colors_offset = np.linspace(-offset_bounds, offset_bounds, colors_n)
+
+ if "panel" in covariates:
+ panels = get_unique_levels(plot_data[panel])
+
+ if len(covariates) == 1:
+ ax = axes[0]
+ ax.scatter(idxs_main, y_hat_mean)
+ ax.vlines(idxs_main, y_hat_bounds[0], y_hat_bounds[1])
+ elif "group" in covariates and not "panel" in covariates:
+ ax = axes[0]
+ for i, clr in enumerate(colors):
+ idx = (plot_data[color] == clr).to_numpy()
+ idxs = idxs_main + colors_offset[i]
+ ax.scatter(idxs, y_hat_mean[idx], color=f"C{i}")
+ ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")
+ elif not "group" in covariates and "panel" in covariates:
+ for ax, pnl in zip(axes.ravel(), panels):
+ idx = (plot_data[panel] == pnl).to_numpy()
+ ax.scatter(idxs_main, y_hat_mean[idx])
+ ax.vlines(idxs_main, y_hat_bounds[0][idx], y_hat_bounds[1][idx])
+ ax.set(title=f"{panel} = {pnl}")
+ elif "group" in covariates and "panel" in covariates:
+ if color == panel:
+ for i, (ax, pnl) in enumerate(zip(axes.ravel(), panels)):
+ idx = (plot_data[panel] == pnl).to_numpy()
+ idxs = idxs_main + colors_offset[i]
+ ax.scatter(idxs, y_hat_mean[idx], color=f"C{i}")
+ ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")
+ ax.set(title=f"{panel} = {pnl}")
+ else:
+ for ax, pnl in zip(axes.ravel(), panels):
+ for i, clr in enumerate(colors):
+ idx = ((plot_data[panel] == pnl) & (plot_data[color] == clr)).to_numpy()
+ idxs = idxs_main + colors_offset[i]
+ ax.scatter(idxs, y_hat_mean[idx], color=f"C{i}")
+ ax.vlines(idxs, y_hat_bounds[0][idx], y_hat_bounds[1][idx], color=f"C{i}")
+ ax.set(title=f"{panel} = {pnl}")
+
+ if "group" in covariates and legend:
+ handles = [
+ Line2D([], [], c=f"C{i}", marker="o", label=level) for i, level in enumerate(colors)
+ ]
+ for ax in axes.ravel():
+ ax.legend(handles=handles, title=color, loc="best")
+
+ for ax in axes.ravel():
+ ax.set_xticks(idxs_main)
+ ax.set_xticklabels(main_levels)
+
+ return axes
diff --git a/bambi/plots/plotting.py b/bambi/plots/plotting.py
new file mode 100644
index 000000000..a5f9376e6
--- /dev/null
+++ b/bambi/plots/plotting.py
@@ -0,0 +1,306 @@
+# pylint: disable = protected-access
+# pylint: disable = too-many-function-args
+# pylint: disable = too-many-nested-blocks
+from typing import Union
+
+import arviz as az
+from arviz.plots.backends.matplotlib import create_axes_grid
+from arviz.plots.plot_utils import default_grid
+import numpy as np
+from pandas.api.types import is_categorical_dtype, is_numeric_dtype, is_string_dtype
+
+from bambi.models import Model
+from bambi.plots.effects import comparisons, predictions
+from bambi.plots.plot_types import plot_categoric, plot_numeric
+from bambi.plots.utils import get_covariates, ConditionalInfo
+from bambi.utils import get_aliased_name, listify
+
+
+def plot_cap(
+ model: Model,
+ idata: az.InferenceData,
+ covariates: Union[str, list],
+ target: str = "mean",
+ pps: bool = False,
+ use_hdi: bool = True,
+ prob=None,
+ transforms=None,
+ legend: bool = True,
+ ax=None,
+ fig_kwargs=None,
+ subplot_kwargs=None,
+):
+ """Plot Conditional Adjusted Predictions
+
+ Parameters
+ ----------
+ model : bambi.Model
+ The model for which we want to plot the predictions.
+ idata : arviz.InferenceData
+ The InferenceData object that contains the samples from the posterior distribution of
+ the model.
+ covariates : list or dict
+ A sequence of between one and three names of variables in the model.
+ target : str
+ Which model parameter to plot. Defaults to 'mean'. Passing a parameter into target only
+ works when pps is False as the target may not be available in the posterior predictive
+ distribution.
+ pps: bool, optional
+ Whether to plot the posterior predictive samples. Defaults to ``False``.
+ use_hdi : bool, optional
+ Whether to compute the highest density interval (defaults to True) or the quantiles.
+ prob : float, optional
+ The probability for the credibility intervals. Must be between 0 and 1. Defaults to 0.94.
+ Changing the global variable ``az.rcParam["stats.hdi_prob"]`` affects this default.
+ legend : bool, optional
+ Whether to automatically include a legend in the plot. Defaults to ``True``.
+ transforms : dict, optional
+ Transformations that are applied to each of the variables being plotted. The keys are the
+ name of the variables, and the values are functions to be applied. Defaults to ``None``.
+ ax : matplotlib.axes._subplots.AxesSubplot, optional
+ A matplotlib axes object or a sequence of them. If None, this function instantiates a
+ new axes object. Defaults to ``None``.
+ fig_kwargs : optional
+ Keyword arguments passed to the matplotlib figure function as a dict. For example,
+ ``fig_kwargs=dict(figsize=(11, 8)), sharey=True`` would make the figure 11 inches wide
+ by 8 inches high and would share the y-axis values.
+ subplot_kwargs : optional
+ Keyword arguments used to determine the covariates used for the horizontal, group,
+ and panel axes. For example, ``subplot_kwargs=dict(main="x", group="y", panel="z")`` would
+ plot the horizontal axis as ``x``, the color (hue) as ``y``, and the panel axis as ``z``.
+
+ Returns
+ -------
+ matplotlib.figure.Figure, matplotlib.axes._subplots.AxesSubplot
+ A tuple with the figure and the axes.
+
+ Raises
+ ------
+ ValueError
+ When ``level`` is not within 0 and 1.
+ When the main covariate is not numeric or categoric.
+
+ TypeError
+ When ``covariates`` is not a string or a list of strings.
+ """
+
+ covariate_kinds = ("main", "group", "panel")
+ if isinstance(covariates, dict):
+ raise TypeError("covariates must be a string or a list of strings.")
+
+ if not isinstance(covariates, dict):
+ covariates = listify(covariates)
+ covariates = dict(zip(covariate_kinds, covariates))
+ else:
+ assert covariate_kinds[0] in covariates
+ assert set(covariates).issubset(set(covariate_kinds))
+
+ assert 1 <= len(covariates) <= 3
+
+ if transforms is None:
+ transforms = {}
+
+ cap_data = predictions(
+ model,
+ idata,
+ covariates,
+ target=target,
+ pps=pps,
+ use_hdi=use_hdi,
+ prob=prob,
+ transforms=transforms,
+ )
+
+ response_name = get_aliased_name(model.response_component.response_term)
+ covariates = get_covariates(covariates)
+
+ if subplot_kwargs:
+ for key, value in subplot_kwargs.items():
+ setattr(covariates, key, value)
+
+ if ax is None:
+ fig_kwargs = {} if fig_kwargs is None else fig_kwargs
+ panels_n = len(np.unique(cap_data[covariates.panel])) if covariates.panel else 1
+ rows, cols = default_grid(panels_n)
+ fig, axes = create_axes_grid(panels_n, rows, cols, backend_kwargs=fig_kwargs)
+ axes = np.atleast_1d(axes)
+ else:
+ axes = np.atleast_1d(ax)
+ if isinstance(axes[0], np.ndarray):
+ fig = axes[0][0].get_figure()
+ else:
+ fig = axes[0].get_figure()
+
+ if is_numeric_dtype(cap_data[covariates.main]):
+ axes = plot_numeric(covariates, cap_data, transforms, legend, axes)
+ elif is_categorical_dtype(cap_data[covariates.main]) or is_string_dtype(
+ cap_data[covariates.main]
+ ):
+ axes = plot_categoric(covariates, cap_data, legend, axes)
+ else:
+ raise ValueError("Main covariate must be numeric or categoric.")
+
+ ylabel = response_name if target == "mean" else target
+ for ax in axes.ravel(): # pylint: disable = redefined-argument-from-local
+ ax.set(xlabel=covariates.main, ylabel=ylabel)
+
+ return fig, axes
+
+
+def plot_comparison(
+ model: Model,
+ idata: az.InferenceData,
+ contrast: Union[str, dict, list],
+ conditional: Union[str, dict, list, None] = None,
+ average_by: Union[str, list] = None,
+ comparison_type: str = "diff",
+ use_hdi: bool = True,
+ prob=None,
+ legend: bool = True,
+ transforms=None,
+ ax=None,
+ fig_kwargs=None,
+ subplot_kwargs=None,
+):
+ """Plot Conditional Adjusted Comparisons
+
+ Parameters
+ ----------
+ model : bambi.Model
+ The model for which we want to plot the predictions.
+ idata : arviz.InferenceData
+ The InferenceData object that contains the samples from the posterior distribution of
+ the model.
+ contrast : str, dict, list
+ The predictor name whose contrast we would like to compare.
+ conditional : str, dict, list
+ The covariates we would like to condition on.
+ average_by: str, list, optional
+ The covariates we would like to average by. The passed covariate(s) will marginalize
+ over the other covariates in the model. Defaults to ``None``.
+ comparison_type : str, optional
+ The type of comparison to plot. Defaults to 'diff'.
+ use_hdi : bool, optional
+ Whether to compute the highest density interval (defaults to True) or the quantiles.
+ prob : float, optional
+ The probability for the credibility intervals. Must be between 0 and 1. Defaults to 0.94.
+ Changing the global variable ``az.rcParam["stats.hdi_prob"]`` affects this default.
+ legend : bool, optional
+ Whether to automatically include a legend in the plot. Defaults to ``True``.
+ transforms : dict, optional
+ Transformations that are applied to each of the variables being plotted. The keys are the
+ name of the variables, and the values are functions to be applied. Defaults to ``None``.
+ ax : matplotlib.axes._subplots.AxesSubplot, optional
+ A matplotlib axes object or a sequence of them. If None, this function instantiates a
+ new axes object. Defaults to ``None``.
+ fig_kwargs : optional
+ Keyword arguments passed to the matplotlib figure function as a dict. For example,
+ ``fig_kwargs=dict(figsize=(11, 8)), sharey=True`` would make the figure 11 inches wide
+ by 8 inches high and would share the y-axis values.
+ subplot_kwargs : optional
+ Keyword arguments used to determine the covariates used for the horizontal, group,
+ and panel axes. For example, ``subplot_kwargs=dict(main="x", group="y", panel="z")`` would
+ plot the horizontal axis as ``x``, the color (hue) as ``y``, and the panel axis as ``z``.
+
+ Returns
+ -------
+ matplotlib.figure.Figure, matplotlib.axes._subplots.AxesSubplot
+ A tuple with the figure and the axes.
+
+ Raises
+ ------
+ ValueError
+ If ``conditional`` and ``average_by`` are both ``None``.
+ If length of ``conditional`` is greater than 3 and ``average_by`` is ``None``.
+
+ Warning
+ If length of ``contrast`` is greater than 2.
+ """
+ if conditional is None and average_by is None:
+ raise ValueError("Must specify at least one of 'conditional' or 'average_by'.")
+ if conditional is not None:
+ if not isinstance(conditional, str):
+ if len(conditional) > 3 and average_by is None:
+ raise ValueError(
+ "Must specify a covariate to 'average_by' when number of covariates"
+ "passed to 'conditional' is greater than 3."
+ )
+ if average_by is True:
+ raise ValueError(
+ "Plotting when 'average_by = True' is not possible as 'True' marginalizes "
+ "over all covariates resulting in a single comparison estimate. "
+ "Please specify a covariate(s) to 'average_by'."
+ )
+
+ if isinstance(contrast, dict):
+ contrast_name, contrast_level = next(iter(contrast.items()))
+ if len(contrast_level) > 2:
+ raise ValueError(
+ f"Plotting when 'contrast' has > 2 values is not supported. "
+ f"{contrast_name} has {len(contrast_level)} values."
+ )
+
+ contrast_df = comparisons(
+ model=model,
+ idata=idata,
+ contrast=contrast,
+ conditional=conditional,
+ average_by=average_by,
+ comparison_type=comparison_type,
+ use_hdi=use_hdi,
+ prob=prob,
+ transforms=transforms,
+ )
+
+ conditional_info = ConditionalInfo(model, conditional)
+
+ if (subplot_kwargs and not average_by) or (subplot_kwargs and average_by):
+ for key, value in subplot_kwargs.items():
+ conditional_info.covariates.update({key: value})
+ covariates = get_covariates(conditional_info.covariates)
+ elif average_by and not subplot_kwargs:
+ if not isinstance(average_by, list):
+ average_by = listify(average_by)
+ covariate_kinds = ("main", "group", "panel")
+ average_by = dict(zip(covariate_kinds, average_by))
+ covariates = get_covariates(average_by)
+ else:
+ covariates = get_covariates(conditional_info.covariates)
+
+ if transforms is None:
+ transforms = {}
+
+ response_name = get_aliased_name(model.response_component.response_term)
+
+ if ax is None:
+ fig_kwargs = {} if fig_kwargs is None else fig_kwargs
+ panels_n = len(np.unique(contrast_df[covariates.panel])) if covariates.panel else 1
+ rows, cols = default_grid(panels_n)
+ fig, axes = create_axes_grid(panels_n, rows, cols, backend_kwargs=fig_kwargs)
+ axes = np.atleast_1d(axes)
+ else:
+ axes = np.atleast_1d(ax)
+ if isinstance(axes[0], np.ndarray):
+ fig = axes[0][0].get_figure()
+ else:
+ fig = axes[0].get_figure()
+
+ if is_numeric_dtype(contrast_df[covariates.main]):
+ # main condition variable can be numeric but at the same time only
+ # a few values, so it is treated as categoric
+ if np.unique(contrast_df[covariates.main]).shape[0] <= 5:
+ axes = plot_categoric(covariates, contrast_df, legend, axes)
+ else:
+ axes = plot_numeric(covariates, contrast_df, transforms, legend, axes)
+ elif is_categorical_dtype(contrast_df[covariates.main]) or is_string_dtype(
+ contrast_df[covariates.main]
+ ):
+ axes = plot_categoric(covariates, contrast_df, legend, axes)
+ else:
+ raise TypeError("Main covariate must be numeric or categoric.")
+
+ response_name = get_aliased_name(model.response_component.response_term)
+ for ax in axes.ravel(): # pylint: disable = redefined-argument-from-local
+ ax.set(xlabel=covariates.main, ylabel=response_name)
+
+ return fig, axes
diff --git a/bambi/plots/utils.py b/bambi/plots/utils.py
index 9f54039ed..0d5deb8fa 100644
--- a/bambi/plots/utils.py
+++ b/bambi/plots/utils.py
@@ -1,7 +1,321 @@
+from dataclasses import dataclass, field
+from statistics import mode
+from typing import Union
+
import numpy as np
+from formulae.terms.call import Call
+import pandas as pd
+from pandas.api.types import is_categorical_dtype, is_numeric_dtype, is_string_dtype
+
+from bambi import Model
+from bambi.utils import listify
+
+
+@dataclass
+class ContrastInfo:
+ model: Model
+ contrast: Union[str, dict, list]
+ name: str = field(init=False)
+ values: Union[int, float] = field(init=False)
+
+ def __post_init__(self):
+ """ """
+ if isinstance(self.contrast, dict):
+ self.values = list(self.contrast.values())[0]
+ self.name = list(self.contrast.keys())[0]
+ elif isinstance(self.contrast, (list, str)):
+ if isinstance(self.contrast, list):
+ self.name = " ".join(self.contrast)
+ else:
+ self.name = self.contrast
+ self.values = set_default_contrast_values(self.model, self.name)
+ elif not isinstance(self.contrast, (list, dict, str)):
+ raise TypeError("`contrast` must be a list, dict, or string")
+
+
+@dataclass
+class ConditionalInfo:
+ model: Model
+ conditional: Union[str, dict, list]
+ covariates: dict = field(init=False)
+ user_passed: bool = field(init=False)
+
+ def __post_init__(self):
+ """
+ Sets the covariates attributes based on if the user passed a dictionary
+ or not.
+ """
+ covariate_kinds = ("main", "group", "panel")
+
+ if not isinstance(self.conditional, dict):
+ self.covariates = listify(self.conditional)
+ self.covariates = dict(zip(covariate_kinds, self.covariates))
+ self.user_passed = False
+ elif isinstance(self.conditional, dict):
+ self.covariates = {k: listify(v) for k, v in self.conditional.items()}
+ self.covariates = dict(zip(covariate_kinds, self.conditional))
+ self.user_passed = True
+
+
+@dataclass
+class Covariates:
+ main: str
+ group: Union[str, None]
+ panel: Union[str, None]
+
+
+def average_over(data: pd.DataFrame, covariate: Union[str, list, None]) -> pd.DataFrame:
+ """
+ Average estimates by specified covariate in the model. data.columns[-3:] are
+ the columns: 'estimate', 'lower', and 'upper'.
+ """
+ if covariate is None:
+ return pd.DataFrame(data[data.columns[-3:]].mean()).T
+ else:
+ return data.groupby(covariate, as_index=False)[data.columns[-3:]].mean()
+
+
+def get_model_terms(model: Model) -> dict:
+ """
+ Loops through the distributional components of a bambi model and
+ returns a dictionary of terms.
+ """
+ terms = {}
+ for component in model.distributional_components.values():
+ if component.design.common:
+ terms.update(component.design.common.terms)
+
+ if component.design.group:
+ terms.update(component.design.group.terms)
+
+ return terms
+
+
+def get_model_covariates(model: Model):
+ """
+ Return covariates specified in the model.
+ """
+
+ terms = get_model_terms(model)
+ names = []
+ for term in terms.values():
+ if hasattr(term, "components"):
+ for component in term.components:
+ # If the component is a function call, use the argument names
+ if isinstance(component, Call):
+ names.append([arg.name for arg in component.call.args])
+ else:
+ names.append([component.name])
+
+ return np.unique(names)
+
+
+def get_covariates(covariates: dict) -> Covariates:
+ """
+ Obtain the main, group, and panel covariates from the user's
+ conditional dict.
+ """
+ covariate_kinds = ("main", "group", "panel")
+ if any(key in covariate_kinds for key in covariates.keys()):
+ # default if user did not pass their own conditional dict
+ main = covariates.get("main")
+ group = covariates.get("group", None)
+ panel = covariates.get("panel", None)
+ else:
+ # assign main, group, panel based on the number of variables
+ # passed by the user in their conditional dict
+ length = len(covariates.keys())
+ if length == 1:
+ main = covariates.keys()
+ group = None
+ panel = None
+ elif length == 2:
+ main, group = covariates.keys()
+ panel = None
+ elif length == 3:
+ main, group, panel = covariates.keys()
+
+ return Covariates(main, group, panel)
+
+def enforce_dtypes(data: pd.DataFrame, df: pd.DataFrame) -> pd.DataFrame:
+ """
+ Enforce dtypes of the original data to the new data.
+ """
+ observed_dtypes = data.dtypes
+ for col in df.columns:
+ if col in observed_dtypes.index:
+ df[col] = df[col].astype(observed_dtypes[col])
+ return df
-def get_unique_levels(x):
+
+def make_group_panel_values(
+ data: pd.DataFrame,
+ data_dict: dict,
+ main: str,
+ group: Union[str, None],
+ panel: Union[str, None],
+ kind: str,
+ groups_n: int = 5,
+) -> dict:
+ """
+ Compute group and panel values based on original data.
+ """
+
+ # If available, obtain groups for grouping variable
+ if group:
+ group_values = make_group_values(data[group], groups_n)
+ group_n = len(group_values)
+
+ # If available, obtain groups for panel variable. Same logic than grouping applies
+ if panel:
+ panel_values = make_group_values(data[panel], groups_n)
+ panel_n = len(panel_values)
+
+ main_values = data_dict[main]
+ main_n = len(main_values)
+
+ if kind == "predictions":
+ if group and not panel:
+ main_values = np.tile(main_values, group_n)
+ group_values = np.repeat(group_values, main_n)
+ data_dict.update({main: main_values, group: group_values})
+ elif not group and panel:
+ main_values = np.tile(main_values, panel_n)
+ panel_values = np.repeat(panel_values, main_n)
+ data_dict.update({main: main_values, panel: panel_values})
+ elif group and panel:
+ if group == panel:
+ main_values = np.tile(main_values, group_n)
+ group_values = np.repeat(group_values, main_n)
+ data_dict.update({main: main_values, group: group_values})
+ else:
+ main_values = np.tile(np.tile(main_values, group_n), panel_n)
+ group_values = np.tile(np.repeat(group_values, main_n), panel_n)
+ panel_values = np.repeat(panel_values, main_n * group_n)
+ data_dict.update({main: main_values, group: group_values, panel: panel_values})
+ elif kind == "comparison":
+ # for comparisons, we need unique values for numeric and categorical
+ # group/panel covariates since we iterate over pairwise combinations of values
+ if group and not panel:
+ data_dict.update({group: np.unique(group_values)})
+ elif group and panel:
+ data_dict.update({group: np.unique(group_values), panel: np.unique(panel_values)})
+
+ return data_dict
+
+
+def set_default_values(model: Model, data_dict: dict, kind: str):
+ """
+ Set default values for each variable in the model if the user did not
+ pass them in the data_dict.
+ """
+ assert kind in [
+ "comparison",
+ "predictions",
+ ], "kind must be either 'comparison' or 'predictions'"
+
+ terms = get_model_terms(model)
+
+ # Get default values for each variable in the model
+ # pylint: disable=R1702
+ for term in terms.values():
+ if hasattr(term, "components"):
+ for component in term.components:
+ # If the component is a function call, use the argument names
+ if isinstance(component, Call):
+ names = [arg.name for arg in component.call.args]
+ else:
+ names = [component.name]
+ for name in names:
+ if name not in data_dict:
+ # For numeric predictors, select the mean.
+ if component.kind == "numeric":
+ data_dict[name] = np.mean(model.data[name])
+ # For categoric predictors, select the most frequent level.
+ elif component.kind == "categoric":
+ data_dict[name] = mode(model.data[name])
+
+ if kind == "comparison":
+ # if value in dict is not a list then convert to a list
+ for key, value in data_dict.items():
+ if not isinstance(value, (list, np.ndarray)):
+ data_dict[key] = [value]
+ return data_dict
+ elif kind == "predictions":
+ return data_dict
+ else:
+ return None
+
+
+def set_default_contrast_values(model: Model, contrast_predictor: str) -> Union[list, np.ndarray]:
+ """
+ Set the default contrast value for the contrast predictor based on the
+ contrast predictor dtype.
+ """
+
+ def _numeric_difference(x):
+ """
+ Centered difference for numeric predictors results in a default contrast
+ of a 1 unit increase
+ """
+ return np.array([x - 0.5, x + 0.5])
+
+ terms = get_model_terms(model)
+ contrast_dtype = model.data[contrast_predictor].dtype
+
+ # Get default values for each variable in the model
+ # pylint: disable=R1702
+ for term in terms.values():
+ if hasattr(term, "components"):
+ for component in term.components:
+ # If the component is a function call, use the argument names
+ if isinstance(component, Call):
+ names = [arg.name for arg in component.call.args]
+ else:
+ names = [component.name]
+ for name in names:
+ if name == contrast_predictor:
+ # For numeric predictors, select the mean.
+ if component.kind == "numeric":
+ contrast = _numeric_difference(np.mean(model.data[name])).astype(
+ contrast_dtype
+ )
+ # For categoric predictors, select the most frequent level.
+ elif component.kind == "categoric":
+ contrast = get_unique_levels(model.data[name])
+
+ return contrast
+
+
+def make_main_values(x: np.ndarray, grid_n: int = 50) -> np.ndarray:
+ """
+ Compuet main values based on original data using a grid of evenly spaced
+ values for numeric predictors and unique levels for categoric predictors.
+ """
+ if is_numeric_dtype(x):
+ return np.linspace(np.min(x), np.max(x), grid_n)
+ elif is_string_dtype(x) or is_categorical_dtype(x):
+ return np.unique(x)
+ raise ValueError("Main covariate must be numeric or categoric.")
+
+
+def make_group_values(x: np.ndarray, groups_n: int = 5) -> np.ndarray:
+ """
+ Compute group values based on original data using unique levels for
+ categoric predictors and quantiles for numeric predictors.
+ """
+ if is_string_dtype(x) or is_categorical_dtype(x):
+ return np.unique(x)
+ elif is_numeric_dtype(x):
+ return np.quantile(x, np.linspace(0, 1, groups_n))
+ raise ValueError("Group covariate must be numeric or categoric.")
+
+
+def get_unique_levels(x: np.ndarray) -> Union[list, np.ndarray]:
+ """
+ Get unique levels of a categoric variable.
+ """
if hasattr(x, "dtype") and hasattr(x.dtype, "categories"):
levels = list(x.dtype.categories)
else:
@@ -9,7 +323,7 @@ def get_unique_levels(x):
return levels
-def get_group_offset(n, lower=0.05, upper=0.4):
+def get_group_offset(n, lower: float = 0.05, upper: float = 0.4) -> np.ndarray:
# Complementary log log function, scaled.
# See following code to have an idea of how this function looks like
# lower, upper = 0.05, 0.4
@@ -22,3 +336,7 @@ def get_group_offset(n, lower=0.05, upper=0.4):
# ax.axhline(upper, color="k", ls="--")
intercept, slope = 3.25, 1
return lower + np.exp(-np.exp(intercept - slope * n)) * (upper - lower)
+
+
+def identity(x):
+ return x
diff --git a/docs/notebooks/plot_cap.ipynb b/docs/notebooks/plot_cap.ipynb
index 403d9b3f3..2a44fbb64 100644
--- a/docs/notebooks/plot_cap.ipynb
+++ b/docs/notebooks/plot_cap.ipynb
@@ -58,17 +58,29 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 9,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ }
+ ],
"source": [
"import arviz as az\n",
"import bambi as bmb\n",
- "import numpy as np\n",
"import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
"import pandas as pd\n",
"\n",
- "from bambi.plots import plot_cap"
+ "from bambi.plots import plot_cap\n",
+ "\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2"
]
},
{
@@ -94,7 +106,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 46,
"metadata": {},
"outputs": [
{
@@ -103,11 +115,54 @@
"text": [
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
- "Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
- "NUTS: [mpg_sigma, hp, wt, hp:wt, cyl, gear]\n"
+ "NUTS: [response_sigma, hp, wt, hp:wt, cyl, gear]\n"
]
},
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "