Skip to content

Commit

Permalink
Allow predictions on new groups (#693)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto authored Jun 29, 2023
1 parent e1b0e8c commit c753266
Show file tree
Hide file tree
Showing 9 changed files with 2,247 additions and 755 deletions.
362 changes: 268 additions & 94 deletions bambi/model_components.py

Large diffs are not rendered by default.

28 changes: 22 additions & 6 deletions bambi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,15 @@ def prior_predictive(self, draws=500, var_names=None, omit_offsets=True, random_

return idata

def predict(self, idata, kind="mean", data=None, inplace=True, include_group_specific=True):
def predict(
self,
idata,
kind="mean",
data=None,
inplace=True,
include_group_specific=True,
sample_new_groups=False,
):
"""Predict method for Bambi models
Obtains in-sample and out-of-sample predictions from a fitted Bambi model.
Expand All @@ -769,16 +777,22 @@ def predict(self, idata, kind="mean", data=None, inplace=True, include_group_spe
data : pandas.DataFrame or None
An optional data frame with values for the predictors that are used to obtain
out-of-sample predictions. If omitted, the original dataset is used.
include_group_specific : bool
If ``True`` make predictions including the group specific effects. Otherwise,
predictions are made with common effects only (i.e. group specific are set
to zero).
inplace : bool
If ``True`` it will modify ``idata`` in-place. Otherwise, it will return a copy of
``idata`` with the predictions added. If ``kind="mean"``, a new variable ending in
``"_mean"`` is added to the ``posterior`` group. If ``kind="pps"``, it appends a
``posterior_predictive`` group to ``idata``. If any of these already exist, it will be
overwritten.
include_group_specific : bool
Determines if predictions incorporate group-specific effects. If ``False``, predictions
are made with common effects only (i.e. group specific are set to zero). Defaults to
``True``.
sample_new_groups : bool
Specifies if it is allowed to obtain predictions for new groups of group-specific terms.
When ``True``, each posterior sample for the new groups is drawn from the posterior
draws of a randomly selected existing group. Since different groups may be selected at
each draw, the end result represents the variation across existing groups.
The method implemented is quivalent to `sample_new_levels="uncertainty"` in brms.
Returns
-------
Expand Down Expand Up @@ -806,7 +820,9 @@ def predict(self, idata, kind="mean", data=None, inplace=True, include_group_spe
else:
var_name = f"{response_aliased_name}_{name}"

means_dict[var_name] = component.predict(idata, data, include_group_specific, hsgp_dict)
means_dict[var_name] = component.predict(
idata, data, include_group_specific, hsgp_dict, sample_new_groups
)

# Drop var/dim if already present. Needed for out-of-sample predictions.
if var_name in idata.posterior.data_vars:
Expand Down
1 change: 1 addition & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

* Implement new families `"ordinal"` and `"sratio"` for modeling of ordinal responses (#678)
* Allow families to implement a custom `create_extra_pps_coord()` (#688)
* Allow predictions on new groups (#693)

### Maintenance and fixes

Expand Down
268 changes: 143 additions & 125 deletions docs/notebooks/categorical_regression.ipynb

Large diffs are not rendered by default.

275 changes: 225 additions & 50 deletions docs/notebooks/distributional_models.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit c753266

Please sign in to comment.