Skip to content

Commit

Permalink
interpret predictions enhancements (#736)
Browse files Browse the repository at this point in the history
* unit level and average by for

* black formatting

* add tests to TestPredictions for new functionality

* common functions for create difference and predictions data

* add args. and error handling for new predictions functionality

* update docs to reflect API change

* update to reflect new functionality

* add inline comment
  • Loading branch information
GStechschulte authored Oct 20, 2023
1 parent 77a8fa1 commit 9b6bec4
Show file tree
Hide file tree
Showing 8 changed files with 2,030 additions and 1,199 deletions.
160 changes: 105 additions & 55 deletions bambi/interpret/create_data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import itertools
from typing import Union

import numpy as np
import pandas as pd

from bambi.models import Model
from bambi.interpret.utils import (
ConditionalInfo,
enforce_dtypes,
Expand All @@ -16,8 +16,32 @@
)


def _pairwise_grid(data_dict: dict) -> pd.DataFrame:
"""Creates a pairwise grid (cartesian product) of data by using the
key-values of the dictionary.
Parameters
----------
data_dict : dict
A dictionary containing the covariates as keys and their values as the
values.
Returns
-------
pd.DataFrame
A dataframe containing values used as input to the fitted Bambi model to
generate predictions.
"""
keys, values = zip(*data_dict.items())
data_grid = pd.DataFrame([dict(zip(keys, v)) for v in itertools.product(*values)])
return data_grid


def _grid_level(
condition_info: ConditionalInfo, variable_info: VariableInfo, user_passed: bool, kind: str
condition_info: ConditionalInfo,
variable_info: Union[VariableInfo, None],
user_passed: bool,
kind: str,
) -> pd.DataFrame:
"""Creates a "grid" of data by using the covariates passed into the
`conditional` argument. Values for the grid are either: (1) computed
Expand All @@ -29,56 +53,80 @@ def _grid_level(
condition_info : ConditionalInfo
Information about the conditional argument passed into the plot
function.
variable_info : VariableInfo
variable_info : VariableInfo, optional
Information about the variable of interest. This is `contrast` for
'comparisons' and `wrt` for 'slopes'.
'comparisons', `wrt` for 'slopes', and `None` for 'predictions'.
user_passed : bool
Whether the user passed a value(s) for the `conditional` argument.
kind : str
The kind of effect being computed. Either "comparisons" or "slopes".
The kind of effect being computed. Either "comparisons", "predictions",
or "slopes".
Returns
-------
pd.DataFrame
A dataframe containing a pairwise grid of values used as input to the
fitted Bambi model to generate predictions.
A dataframe containing values used as input to the fitted Bambi model to
generate predictions.
"""
covariates = get_covariates(condition_info.covariates)

if user_passed:
data_dict = {**condition_info.conditional}
if kind == "predictions":
# Compute pairwise grid of values if the user passed a dict.
if user_passed:
data_dict = {**condition_info.conditional}
data_dict = set_default_values(condition_info.model, data_dict, kind=kind)
for key, value in data_dict.items():
if not isinstance(value, (list, np.ndarray)):
data_dict[key] = [value]
data_grid = _pairwise_grid(data_dict)
else:
# Compute a grid of values
main_values = make_main_values(condition_info.model.data[covariates.main])
data_dict = {covariates.main: main_values}
data_dict = make_group_panel_values(
condition_info.model.data,
data_dict,
covariates.main,
covariates.group,
covariates.panel,
kind=kind,
)
data_dict = set_default_values(condition_info.model, data_dict, kind=kind)
data_grid = pd.DataFrame(data_dict)
else:
main_values = make_main_values(condition_info.model.data[covariates.main])
data_dict = {covariates.main: main_values}
data_dict = make_group_panel_values(
condition_info.model.data,
data_dict,
covariates.main,
covariates.group,
covariates.panel,
kind=kind,
)

data_dict[variable_info.name] = variable_info.values
comparison_data = set_default_values(condition_info.model, data_dict, kind=kind)
# use cartesian product (cross join) to create pairwise grid
keys, values = zip(*comparison_data.items())
pairwise_grid = pd.DataFrame([dict(zip(keys, v)) for v in itertools.product(*values)])
# can't enforce dtype on numeric 'wrt' as it may remove floating point epsilons
if kind == "comparisons":
pairwise_grid = enforce_dtypes(condition_info.model.data, pairwise_grid)
elif kind == "slopes":
pairwise_grid = enforce_dtypes(condition_info.model.data, pairwise_grid, variable_info.name)
# Compute pairwise grid of values if the user passed a dict.
if user_passed:
data_dict = {**condition_info.conditional}
else:
# Compute a grid of values
main_values = make_main_values(condition_info.model.data[covariates.main])
data_dict = {covariates.main: main_values}
data_dict = make_group_panel_values(
condition_info.model.data,
data_dict,
covariates.main,
covariates.group,
covariates.panel,
kind=kind,
)

data_dict[variable_info.name] = variable_info.values
data_dict = set_default_values(condition_info.model, data_dict, kind=kind)
data_grid = _pairwise_grid(data_dict)

# Can't enforce dtype on numeric 'wrt' for 'slopes 'as it may remove floating point epsilons
except_col = None if kind in ("comparisons", "predictions") else {variable_info.name}
data_grid = enforce_dtypes(condition_info.model.data, data_grid, except_col)

# After computing default values, fractional values may have been computed.
# Enforcing the dtype of "int" may create duplicate rows as it will round
# the fractional values.
pairwise_grid = pairwise_grid.drop_duplicates()
data_grid = data_grid.drop_duplicates()

return pairwise_grid
return data_grid.reset_index(drop=True)


def _unit_level(variable_info: VariableInfo, kind: str) -> pd.DataFrame:
def _differences_unit_level(variable_info: VariableInfo, kind: str) -> pd.DataFrame:
"""Creates the data for unit-level contrasts by using the observed (empirical)
data. All covariates in the model are included in the data, except for the
contrast predictor. The contrast predictor is replaced with either: (1) the
Expand Down Expand Up @@ -138,39 +186,41 @@ def create_differences_data(
Returns
-------
pd.DataFrame
A dataframe containing the data used to generate predictions.
A dataframe containing the data used to generate predictions. If no
covariates were passed, then the original data used to fit the model
is returned. Otherwise, a grid of values is created using the covariates
passed into the `conditional` argument.
"""

if not condition_info.covariates:
return _unit_level(variable_info, kind)
else:
return _grid_level(condition_info, variable_info, user_passed, kind)
return _differences_unit_level(variable_info, kind)

return _grid_level(condition_info, variable_info, user_passed, kind)

def create_predictions_data(model: Model, covariates: dict) -> pd.DataFrame:
"""Creates a data grid for conditional adjusted predictions using the covariates
passed by the user.

def create_predictions_data(condition_info: ConditionalInfo, user_passed: bool) -> pd.DataFrame:
"""Creates either unit level or grid level data for 'predictions' depending
if the user passed covariates.
Parameters
----------
model : Model
A fitted Bambi model.
covariates : dict
A dictionary of covariates passed by the user.
condition_info : ConditionalInfo
Information about the conditional argument passed into the plot
function.
user_passed : bool
Whether the user passed a value(s) for the `conditional` argument.
Returns
-------
pd.DataFrame
A dataframe containing the data used to generate predictions.
A dataframe containing the data used to generate predictions. If no
covariates were passed, then the original data used to fit the model
is returned. Otherwise, a grid of values is created using the covariates
passed into the `conditional` argument.
"""
data = model.data
covariates = get_covariates(covariates)
main, group, panel = covariates.main, covariates.group, covariates.panel

main_values = make_main_values(data[main])
data_dict = {main: main_values}

data_dict = make_group_panel_values(data, data_dict, main, group, panel, kind="predictions")
data_dict = set_default_values(model, data_dict, kind="predictions")
# Unit level data used the observed (empirical) data
if not condition_info.covariates:
covariates = get_model_covariates(condition_info.model)
return condition_info.model.data[covariates]

return enforce_dtypes(data, pd.DataFrame(data_dict))
return _grid_level(condition_info, None, user_passed, "predictions")
Loading

0 comments on commit 9b6bec4

Please sign in to comment.