Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plot comparisons #684

Merged
merged 85 commits into from
Jul 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
07533f8
plot_cap draft outline for docs example
GStechschulte Apr 24, 2023
4780d36
intro. to GLMs and Negative Binomial model
GStechschulte May 8, 2023
dcd3ceb
added logistic regression and other model params. demo
GStechschulte May 9, 2023
8274c53
basic linear model demo
GStechschulte May 10, 2023
b844396
comparisons learning from marginaleffects
GStechschulte May 14, 2023
227d78d
comparison contrasts using make_cap_data code
GStechschulte May 16, 2023
2769969
CreateData class added to __init__.py
GStechschulte May 18, 2023
7753a2b
CreateData class for all plotting functions
GStechschulte May 18, 2023
e1952e3
functions for computing and plotting comparisons
GStechschulte May 18, 2023
9c3c78d
plot_comparisons demo on categorical data
GStechschulte May 18, 2023
95d1ec1
logic of main, group, panel for building contrasts df
GStechschulte May 19, 2023
e4ec36d
add make_group_panel_values and enforce_dtypes functions
GStechschulte May 19, 2023
0371be9
plot_comparisons demo
GStechschulte May 19, 2023
096958b
cleanup demo notebook
GStechschulte May 19, 2023
c2533fe
cleanup demo notebook
GStechschulte May 20, 2023
0115451
move util functions to utils.py and renaming of modules and functions
GStechschulte May 22, 2023
deef5f3
re-run demo.
GStechschulte May 22, 2023
4360e79
use dataclass for returning covariates instead of dict
GStechschulte May 23, 2023
d3334ce
remove unused variables in plot_comparison
GStechschulte May 23, 2023
f0d2f71
type hinting and added dataclass for covariates
GStechschulte May 23, 2023
bde5470
module for plot kinds based on response
GStechschulte May 23, 2023
7489f84
plot_comparisons demo notebook update
GStechschulte May 23, 2023
1d2287d
changed and added modules to be treated as packages
GStechschulte May 23, 2023
288914c
modularize create cap, comparisons, and slopes data functions
GStechschulte May 23, 2023
6624f17
deleted and moved into plotting.py
GStechschulte May 23, 2023
ef168de
plot cap and comparisons and type of plots in separate modules
GStechschulte May 23, 2023
fb40672
commonly used functions for create_data.py
GStechschulte May 23, 2023
8566014
delete print statement
GStechschulte May 23, 2023
40b3e87
re-run demo notebook
GStechschulte May 23, 2023
27d03f5
comparisons numerical default
GStechschulte May 28, 2023
ed94906
default contrast level for numeric and char. variables
GStechschulte Jun 7, 2023
4bec48f
replace np.repeat with np.tile
GStechschulte Jun 7, 2023
2422e68
default contrast level for numeric and char. variables
GStechschulte Jun 7, 2023
002da59
plot_comparisons default numeric demo
GStechschulte Jun 7, 2023
e4ae726
re-run plot_cap.ipynb
GStechschulte Jun 7, 2023
07eeba4
modularize cap, comparisons, and slopes
GStechschulte Jun 7, 2023
98cce84
add Comparisons class to reduce redundant passing of args.
GStechschulte Jun 7, 2023
4b4ba43
re-run notebook to ensure everything still works
GStechschulte Jun 7, 2023
65c0c28
added class objects for attribute lookup, more informative contrast d…
GStechschulte Jun 8, 2023
0536412
re-run notebook
GStechschulte Jun 8, 2023
a1c460f
re-run notebook with default numeric contrast value
GStechschulte Jun 8, 2023
bd6173f
reduce number of args. in plotting functions and use dataclasses inst…
GStechschulte Jun 9, 2023
282ecb8
new examples with both numeric and categoric variables
GStechschulte Jun 9, 2023
3cf90ad
add comments for review
GStechschulte Jun 9, 2023
4501af9
Show working version and some ideas of plot_comparisons with xarray
tomicapretto Jun 10, 2023
29643e2
comparisons computed using entire posterior and better error handling
GStechschulte Jun 12, 2023
f369291
re-run comparisons notebook w/updated comparisons code
GStechschulte Jun 12, 2023
bdc3fb6
delete notebook
GStechschulte Jun 12, 2023
3a7b0b9
refactor plot_cap to work and pass Covariates class instead of dict t…
GStechschulte Jun 13, 2023
cc14c35
UserWarning if level > 2
GStechschulte Jun 13, 2023
2d1e876
re-run plot_cap and plot_comparisons notebooks
GStechschulte Jun 13, 2023
eefde59
assertions, docstrings, and type hints
GStechschulte Jun 15, 2023
c81bfe7
re-run to make sure ValueErrors work
GStechschulte Jun 15, 2023
23253d0
comparisons for > 1 contrast level
GStechschulte Jun 16, 2023
17b1c89
raise ValueError if user passes > 1 contrast level when plotting
GStechschulte Jun 16, 2023
64b817b
add predictions as sub-package
GStechschulte Jun 19, 2023
ff16e83
type hints, doctrings, and run black
GStechschulte Jun 19, 2023
a6a31e3
added comparison tests, and move cap and comparisons tests into classes
GStechschulte Jun 19, 2023
82396e7
GSoC code review 22.06 and added arg. for
GStechschulte Jun 24, 2023
218abb0
added documentation on new arg.
GStechschulte Jun 24, 2023
de784cd
added classes for organizing , and added tests for
GStechschulte Jun 24, 2023
f8b1dc5
delete print statement
GStechschulte Jun 24, 2023
9ad464c
add test_hdi_prob in class TestCommon
GStechschulte Jun 24, 2023
d122ac5
pylint C2801 use str() instead of __str__()
GStechschulte Jun 26, 2023
68e11af
resolve pylint error messages
GStechschulte Jun 26, 2023
5ae78cb
remove lambda expression as it is not needed
GStechschulte Jun 26, 2023
3ccb833
add support for unit-level contrasts and 'average_by=True'
GStechschulte Jun 27, 2023
40ecc50
improved OOP with dataclasses, error handling, and added unit-level c…
GStechschulte Jun 28, 2023
57f1a75
Allow predictions on new groups (#693)
tomicapretto Jun 29, 2023
dde435c
ran black
GStechschulte Jun 29, 2023
c303748
move isinstance logic to dataclass, improved error handling, and remo…
GStechschulte Jun 29, 2023
503d5c1
resolve pylint message codes
GStechschulte Jun 29, 2023
879b05f
remove imports that users should not have access to
GStechschulte Jul 2, 2023
a6fcdcc
fix/add docstrings
GStechschulte Jul 2, 2023
ba1918c
fix/add docstrings and f-string attributes to ResponseInfo class
GStechschulte Jul 2, 2023
ebd9d4d
fix/add docstrings
GStechschulte Jul 2, 2023
646ed7e
Prepare 0.12.0 release (#694)
tomicapretto Jul 2, 2023
9861601
dev version
tomicapretto Jul 2, 2023
ee4baf7
bug fix for building contrast_df when len(contrast values) > 3
GStechschulte Jul 5, 2023
b074df6
raise ValueError if user tries to plot with > 2 contrast values
GStechschulte Jul 6, 2023
614b511
pylinter error solved, make contrast_df column ordering consistent
GStechschulte Jul 6, 2023
0b1b7a3
logic added for subsetting InferenceData when non-parent param. is al…
GStechschulte Jul 8, 2023
9ccd588
raise ValueError if only average_by=True when plotting comparisons
GStechschulte Jul 8, 2023
a616e60
docstring describing indexing of last 3 columns
GStechschulte Jul 8, 2023
84f32b3
added test for plot_cap non-parent parameter when there is an alias
GStechschulte Jul 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
362 changes: 268 additions & 94 deletions bambi/model_components.py

Large diffs are not rendered by default.

31 changes: 23 additions & 8 deletions bambi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,6 @@ def set_alias(self, aliases):
assert component_name in self.distributional_components
component = self.distributional_components[component_name]
for name, alias in component_aliases.items():

is_used = False

if name in component.terms:
Expand Down Expand Up @@ -660,7 +659,7 @@ def plot_priors(
unobserved_rvs_names = []
flat_rvs = []
for unobserved in self.backend.model.unobserved_RVs:
if "Flat" in unobserved.__str__():
if "Flat" in str(unobserved):
flat_rvs.append(unobserved.name)
else:
unobserved_rvs_names.append(unobserved.name)
Expand Down Expand Up @@ -751,7 +750,15 @@ def prior_predictive(self, draws=500, var_names=None, omit_offsets=True, random_

return idata

def predict(self, idata, kind="mean", data=None, inplace=True, include_group_specific=True):
def predict(
self,
idata,
kind="mean",
data=None,
inplace=True,
include_group_specific=True,
sample_new_groups=False,
):
"""Predict method for Bambi models

Obtains in-sample and out-of-sample predictions from a fitted Bambi model.
Expand All @@ -769,16 +776,22 @@ def predict(self, idata, kind="mean", data=None, inplace=True, include_group_spe
data : pandas.DataFrame or None
An optional data frame with values for the predictors that are used to obtain
out-of-sample predictions. If omitted, the original dataset is used.
include_group_specific : bool
If ``True`` make predictions including the group specific effects. Otherwise,
predictions are made with common effects only (i.e. group specific are set
to zero).
inplace : bool
If ``True`` it will modify ``idata`` in-place. Otherwise, it will return a copy of
``idata`` with the predictions added. If ``kind="mean"``, a new variable ending in
``"_mean"`` is added to the ``posterior`` group. If ``kind="pps"``, it appends a
``posterior_predictive`` group to ``idata``. If any of these already exist, it will be
overwritten.
include_group_specific : bool
Determines if predictions incorporate group-specific effects. If ``False``, predictions
are made with common effects only (i.e. group specific are set to zero). Defaults to
``True``.
sample_new_groups : bool
Specifies if it is allowed to obtain predictions for new groups of group-specific terms.
When ``True``, each posterior sample for the new groups is drawn from the posterior
draws of a randomly selected existing group. Since different groups may be selected at
each draw, the end result represents the variation across existing groups.
The method implemented is quivalent to `sample_new_levels="uncertainty"` in brms.

Returns
-------
Expand Down Expand Up @@ -806,7 +819,9 @@ def predict(self, idata, kind="mean", data=None, inplace=True, include_group_spe
else:
var_name = f"{response_aliased_name}_{name}"

means_dict[var_name] = component.predict(idata, data, include_group_specific, hsgp_dict)
means_dict[var_name] = component.predict(
idata, data, include_group_specific, hsgp_dict, sample_new_groups
)

# Drop var/dim if already present. Needed for out-of-sample predictions.
if var_name in idata.posterior.data_vars:
Expand Down
6 changes: 4 additions & 2 deletions bambi/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .plot_cap import create_cap_data, plot_cap
from bambi.plots.effects import comparisons, predictions
from bambi.plots.plotting import plot_cap, plot_comparison

__all__ = ["create_cap_data", "plot_cap"]

__all__ = ["comparisons", "predictions", "plot_cap", "plot_comparison"]
129 changes: 129 additions & 0 deletions bambi/plots/create_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import itertools

import numpy as np
import pandas as pd

from bambi.models import Model
from bambi.plots.utils import (
ConditionalInfo,
ContrastInfo,
enforce_dtypes,
get_covariates,
get_model_covariates,
make_group_panel_values,
make_main_values,
set_default_values,
)


def create_cap_data(model: Model, covariates: dict) -> pd.DataFrame:
"""Create data for a Conditional Adjusted Predictions

Parameters
----------
model : bambi.Model
An instance of a Bambi model
covariates : dict
A dictionary of length between one and three.
Keys must be taken from ("horizontal", "color", "panel").
The values indicate the names of variables.

Returns
-------
pandas.DataFrame
The data for the Conditional Adjusted Predictions dataframe and or
plotting.
"""
data = model.data
covariates = get_covariates(covariates)
main, group, panel = covariates.main, covariates.group, covariates.panel

# Obtain data for main variable
main_values = make_main_values(data[main])
data_dict = {main: main_values}

# Obtain data for group and panel variables if not None
data_dict = make_group_panel_values(data, data_dict, main, group, panel, kind="predictions")
data_dict = set_default_values(model, data_dict, kind="predictions")
return enforce_dtypes(data, pd.DataFrame(data_dict))


def create_comparisons_data(
condition: ConditionalInfo, contrast: ContrastInfo, user_passed: bool = False
) -> pd.DataFrame:
"""Create data for a Conditional Adjusted Comparisons

Parameters
GStechschulte marked this conversation as resolved.
Show resolved Hide resolved
----------
condition: ConditionalInfo
A dataclass instance containing the model, contrast, and conditional
covariates to be used in the comparisons.
contrast: ContrastInfo
A dataclass instance containing the model, and contrast name and values.
user_passed: bool, optional
Whether the user passed their own 'conditional' data. Defaults to False.

Returns
-------
pd.DataFrame
The data for the Conditional Adjusted Comparisons dataframe and or
plotting.
"""

def _grid_level(condition: ConditionalInfo, contrast: ContrastInfo):
"""
Creates the data for grid-level contrasts by using the covariates passed
into the `conditional` arg. Values for the grid are either: (1) computed
using a equally spaced grid, mean, and or mode (depending on the covariate
dtype), and (2) a user specified value or range of values.
"""
covariates = get_covariates(condition.covariates)

if user_passed:
data_dict = {**condition.conditional}
else:
main_values = make_main_values(condition.model.data[covariates.main])
data_dict = {covariates.main: main_values}
data_dict = make_group_panel_values(
condition.model.data,
data_dict,
covariates.main,
covariates.group,
covariates.panel,
kind="comparison",
)

data_dict[contrast.name] = contrast.values
comparison_data = set_default_values(condition.model, data_dict, kind="comparison")
# use cartesian product (cross join) to create contrasts
keys, values = zip(*comparison_data.items())
contrast_dict = [dict(zip(keys, v)) for v in itertools.product(*values)]

return enforce_dtypes(condition.model.data, pd.DataFrame(contrast_dict))

def _unit_level(contrast: ContrastInfo):
"""
Creates the data for unit-level contrasts by using the observed (empirical)
data. All covariates in the model are included in the data, except for the
contrast predictor. The contrast predictor is replaced with either: (1) the
default contrast value, or (2) the user specified contrast value.
"""
covariates = get_model_covariates(contrast.model)
df = contrast.model.data[covariates].drop(labels=contrast.name, axis=1)

contrast_vals = np.array(contrast.values)[..., None]
contrast_vals = np.repeat(contrast_vals, contrast.model.data.shape[0], axis=1)

contrast_df_dict = {}
for idx, value in enumerate(contrast_vals):
contrast_df_dict[f"contrast_{idx}"] = df.copy()
contrast_df_dict[f"contrast_{idx}"][contrast.name] = value

return pd.concat(contrast_df_dict.values())

if not condition.conditional:
df = _unit_level(contrast)
else:
df = _grid_level(condition, contrast)

return df
Loading