Skip to content

Commit

Permalink
Fix get_model_covariates() utility function (bambinos#801)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto authored and GStechschulte committed Apr 14, 2024
1 parent e3a4393 commit dd43450
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
### Maintenance and fixes

* Fix bug in predictions with models using HSGP (#780)
* Fix `get_model_covariates()` utility function (#801)

### Documentation

Expand Down
3 changes: 3 additions & 0 deletions bambi/interpret/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ def get_model_covariates(model: Model) -> np.ndarray:

flatten_covariates = [item for sublist in covariates for item in sublist]

# Don't include non-covariate names (#797)
flatten_covariates = [name for name in flatten_covariates if name in model.data]

return np.unique(flatten_covariates)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
"formulae>=0.5.3",
"graphviz",
"pandas>=1.0.0",
"pymc>=5.12.0",
"pymc>=5.12.0,<5.13.0",
]

[project.optional-dependencies]
Expand Down
17 changes: 17 additions & 0 deletions tests/test_interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
This module contains tests for the helper functions of the 'interpret' sub-package.
Tests here do not test any of the plotting functionality.
"""

import numpy as np
import pandas as pd
import pytest

import bambi as bmb
from bambi.interpret.helpers import data_grid, select_draws
from bambi.interpret.utils import get_model_covariates


CHAINS = 4
Expand Down Expand Up @@ -190,3 +192,18 @@ def test_select_draws_no_effect(request, mtcars, condition):
assert draws.shape == (CHAINS, DRAWS, 14)
elif id == "3":
assert draws.shape == (CHAINS, DRAWS, 2)


# ------------------------------------------------------------------------------------------------ #
# Tests for utils #
# ------------------------------------------------------------------------------------------------ #


def test_get_model_covariates():
"""Tests `get_model_covariates()` does not include non-covariate names"""
# See issue 797
df = pd.DataFrame({"y": np.arange(10), "x": np.random.normal(size=10)})
knots = np.linspace(np.min(df["x"]), np.max(df["x"]), 4 + 2)[1:-1]
formula = "y ~ 1 + bs(x, degree=3, knots=knots)"
model = bmb.Model(formula, df)
assert set(get_model_covariates(model)) == {"x"}
16 changes: 8 additions & 8 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,13 @@ def test_cell_means_parameterization(self, crossed_data):
def test_2_factors_saturated(self, crossed_data):
model = bmb.Model("Y ~ threecats*fourcats", crossed_data)
idata = self.fit(model)
assert list(idata.posterior.data_vars) == [
assert set(idata.posterior.data_vars) == {
"Intercept",
"threecats",
"fourcats",
"threecats:fourcats",
"Y_sigma",
]
}
assert list(idata.posterior["threecats_dim"].values) == ["b", "c"]
assert list(idata.posterior["fourcats_dim"].values) == ["b", "c", "d"]
assert list(idata.posterior["threecats:fourcats_dim"].values) == [
Expand All @@ -214,12 +214,12 @@ def test_2_factors_saturated(self, crossed_data):
def test_2_factors_no_intercept(self, crossed_data):
model = bmb.Model("Y ~ 0 + threecats*fourcats", crossed_data)
idata = self.fit(model)
assert list(idata.posterior.data_vars) == [
assert set(idata.posterior.data_vars) == {
"threecats",
"fourcats",
"threecats:fourcats",
"Y_sigma",
]
}
assert list(idata.posterior["threecats_dim"].values) == ["a", "b", "c"]
assert list(idata.posterior["fourcats_dim"].values) == ["b", "c", "d"]
assert list(idata.posterior["threecats:fourcats_dim"].values) == [
Expand All @@ -235,7 +235,7 @@ def test_2_factors_no_intercept(self, crossed_data):
def test_2_factors_cell_means(self, crossed_data):
model = bmb.Model("Y ~ 0 + threecats:fourcats", crossed_data)
idata = self.fit(model)
assert list(idata.posterior.data_vars) == ["threecats:fourcats", "Y_sigma"]
assert set(idata.posterior.data_vars) == {"threecats:fourcats", "Y_sigma"}
assert list(idata.posterior["threecats:fourcats_dim"].values) == [
"a, a",
"a, b",
Expand All @@ -255,7 +255,7 @@ def test_2_factors_cell_means(self, crossed_data):
def test_cell_means_with_covariate(self, crossed_data):
model = bmb.Model("Y ~ 0 + threecats + continuous", crossed_data)
idata = self.fit(model)
assert list(idata.posterior.data_vars) == ["threecats", "continuous", "Y_sigma"]
assert set(idata.posterior.data_vars) == {"threecats", "continuous", "Y_sigma"}
assert list(idata.posterior["threecats_dim"].values) == ["a", "b", "c"]
self.predict_oos(model, idata)

Expand Down Expand Up @@ -477,15 +477,15 @@ def test_group_specific_categorical_interaction(self, crossed_data):
idata = self.fit(model)
self.predict_oos(model, idata)

assert list(idata.posterior.data_vars) == [
assert set(idata.posterior.data_vars) == {
"Intercept",
"continuous",
"Y_sigma",
"1|site_sigma",
"threecats:fourcats|site_sigma",
"1|site",
"threecats:fourcats|site",
]
}
assert list(idata.posterior["threecats:fourcats|site"].coords) == [
"chain",
"draw",
Expand Down

0 comments on commit dd43450

Please sign in to comment.