Skip to content

Commit

Permalink
Add comparisons and plot comparisons (#684)
Browse files Browse the repository at this point in the history
  • Loading branch information
GStechschulte authored Jul 9, 2023
1 parent 6268ccf commit 0fa0b6f
Show file tree
Hide file tree
Showing 10 changed files with 1,838 additions and 668 deletions.
3 changes: 1 addition & 2 deletions bambi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,6 @@ def set_alias(self, aliases):
assert component_name in self.distributional_components
component = self.distributional_components[component_name]
for name, alias in component_aliases.items():

is_used = False

if name in component.terms:
Expand Down Expand Up @@ -661,7 +660,7 @@ def plot_priors(
unobserved_rvs_names = []
flat_rvs = []
for unobserved in self.backend.model.unobserved_RVs:
if "Flat" in unobserved.__str__():
if "Flat" in str(unobserved):
flat_rvs.append(unobserved.name)
else:
unobserved_rvs_names.append(unobserved.name)
Expand Down
6 changes: 4 additions & 2 deletions bambi/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .plot_cap import create_cap_data, plot_cap
from bambi.plots.effects import comparisons, predictions
from bambi.plots.plotting import plot_cap, plot_comparison

__all__ = ["create_cap_data", "plot_cap"]

__all__ = ["comparisons", "predictions", "plot_cap", "plot_comparison"]
129 changes: 129 additions & 0 deletions bambi/plots/create_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import itertools

import numpy as np
import pandas as pd

from bambi.models import Model
from bambi.plots.utils import (
ConditionalInfo,
ContrastInfo,
enforce_dtypes,
get_covariates,
get_model_covariates,
make_group_panel_values,
make_main_values,
set_default_values,
)


def create_cap_data(model: Model, covariates: dict) -> pd.DataFrame:
"""Create data for a Conditional Adjusted Predictions
Parameters
----------
model : bambi.Model
An instance of a Bambi model
covariates : dict
A dictionary of length between one and three.
Keys must be taken from ("horizontal", "color", "panel").
The values indicate the names of variables.
Returns
-------
pandas.DataFrame
The data for the Conditional Adjusted Predictions dataframe and or
plotting.
"""
data = model.data
covariates = get_covariates(covariates)
main, group, panel = covariates.main, covariates.group, covariates.panel

# Obtain data for main variable
main_values = make_main_values(data[main])
data_dict = {main: main_values}

# Obtain data for group and panel variables if not None
data_dict = make_group_panel_values(data, data_dict, main, group, panel, kind="predictions")
data_dict = set_default_values(model, data_dict, kind="predictions")
return enforce_dtypes(data, pd.DataFrame(data_dict))


def create_comparisons_data(
condition: ConditionalInfo, contrast: ContrastInfo, user_passed: bool = False
) -> pd.DataFrame:
"""Create data for a Conditional Adjusted Comparisons
Parameters
----------
condition: ConditionalInfo
A dataclass instance containing the model, contrast, and conditional
covariates to be used in the comparisons.
contrast: ContrastInfo
A dataclass instance containing the model, and contrast name and values.
user_passed: bool, optional
Whether the user passed their own 'conditional' data. Defaults to False.
Returns
-------
pd.DataFrame
The data for the Conditional Adjusted Comparisons dataframe and or
plotting.
"""

def _grid_level(condition: ConditionalInfo, contrast: ContrastInfo):
"""
Creates the data for grid-level contrasts by using the covariates passed
into the `conditional` arg. Values for the grid are either: (1) computed
using a equally spaced grid, mean, and or mode (depending on the covariate
dtype), and (2) a user specified value or range of values.
"""
covariates = get_covariates(condition.covariates)

if user_passed:
data_dict = {**condition.conditional}
else:
main_values = make_main_values(condition.model.data[covariates.main])
data_dict = {covariates.main: main_values}
data_dict = make_group_panel_values(
condition.model.data,
data_dict,
covariates.main,
covariates.group,
covariates.panel,
kind="comparison",
)

data_dict[contrast.name] = contrast.values
comparison_data = set_default_values(condition.model, data_dict, kind="comparison")
# use cartesian product (cross join) to create contrasts
keys, values = zip(*comparison_data.items())
contrast_dict = [dict(zip(keys, v)) for v in itertools.product(*values)]

return enforce_dtypes(condition.model.data, pd.DataFrame(contrast_dict))

def _unit_level(contrast: ContrastInfo):
"""
Creates the data for unit-level contrasts by using the observed (empirical)
data. All covariates in the model are included in the data, except for the
contrast predictor. The contrast predictor is replaced with either: (1) the
default contrast value, or (2) the user specified contrast value.
"""
covariates = get_model_covariates(contrast.model)
df = contrast.model.data[covariates].drop(labels=contrast.name, axis=1)

contrast_vals = np.array(contrast.values)[..., None]
contrast_vals = np.repeat(contrast_vals, contrast.model.data.shape[0], axis=1)

contrast_df_dict = {}
for idx, value in enumerate(contrast_vals):
contrast_df_dict[f"contrast_{idx}"] = df.copy()
contrast_df_dict[f"contrast_{idx}"][contrast.name] = value

return pd.concat(contrast_df_dict.values())

if not condition.conditional:
df = _unit_level(contrast)
else:
df = _grid_level(condition, contrast)

return df
Loading

0 comments on commit 0fa0b6f

Please sign in to comment.