Skip to content

Commit

Permalink
categorical regression model to test interpret functions with n-dim p…
Browse files Browse the repository at this point in the history
…reds
  • Loading branch information
GStechschulte committed Oct 9, 2023
1 parent 4345e59 commit 2815c0f
Showing 1 changed file with 67 additions and 1 deletion.
68 changes: 67 additions & 1 deletion tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,38 @@ def sleep_study():
idata = model.fit(tune=500, draws=500, random_seed=1234)
return model, idata

@pytest.fixture(scope="module")
def food_choice():
"""
Model a categorical response using the 'categorical' family to test 'interpret'
plotting functions for a model whose predictions have multiple response
dimensions (levels).
"""
length = [
1.3, 1.32, 1.32, 1.4, 1.42, 1.42, 1.47, 1.47, 1.5, 1.52, 1.63, 1.65, 1.65, 1.65, 1.65,
1.68, 1.7, 1.73, 1.78, 1.78, 1.8, 1.85, 1.93, 1.93, 1.98, 2.03, 2.03, 2.31, 2.36, 2.46,
3.25, 3.28, 3.33, 3.56, 3.58, 3.66, 3.68, 3.71, 3.89, 1.24, 1.3, 1.45, 1.45, 1.55, 1.6,
1.6, 1.65, 1.78, 1.78, 1.8, 1.88, 2.16, 2.26, 2.31, 2.36, 2.39, 2.41, 2.44, 2.56, 2.67,
2.72, 2.79, 2.84
]
choice = [
"I", "F", "F", "F", "I", "F", "I", "F", "I", "I", "I", "O", "O", "I", "F", "F",
"I", "O", "F", "O", "F", "F", "I", "F", "I", "F", "F", "F", "F", "F", "O", "O",
"F", "F", "F", "F", "O", "F", "F", "I", "I", "I", "O", "I", "I", "I", "F", "I",
"O", "I", "I", "F", "F", "F", "F", "F", "F", "F", "O", "F", "I", "F", "F"
]
sex = ["Male"] * 32 + ["Female"] * 31
data = pd.DataFrame({"choice": choice, "length": length, "sex": sex})
data["choice"] = pd.Categorical(
data["choice"].map({"I": "Invertebrates", "F": "Fish", "O": "Other"}),
["Other", "Invertebrates", "Fish"],
ordered=True
)

model = bmb.Model("choice ~ length + sex", data, family="categorical")
idata = model.fit(tune=500, draws=500, random_seed=1234)
return model, idata


# Improvement:
# * Test the actual plots are what we are indeed the desired result.
Expand Down Expand Up @@ -99,7 +131,7 @@ def test_ax(self, mtcars, pps):
assert ax is ax_r[0]


class TestCap:
class TestPredictions:
"""
Tests the 'plot_predictions' function for different combinations of main, group,
and panel variables.
Expand Down Expand Up @@ -217,6 +249,18 @@ def test_group_effects(self, sleep_study):
):
# default: sample_new_groups=False
plot_predictions(model, idata, ["Days", "Subject"])

@pytest.mark.parametrize(
"covariates",
(
"length", # Main variable is numeric
"sex", # Main variable is categorical
["length", "sex"] # Using both covariates
),
)
def test_categorical_response(self, food_choice, covariates):
model, idata = food_choice
plot_predictions(model, idata, covariates)


class TestComparison:
Expand Down Expand Up @@ -303,6 +347,17 @@ def test_group_effects(self, sleep_study):
):
# default: sample_new_groups=False
plot_comparisons(model, idata, "Days", "Subject")

@pytest.mark.parametrize(
"contrast, conditional",
[
("sex", "length"), # Categorical & numeric
("length", "sex") # Numeric & categorical
]
)
def test_categorical_response(self, food_choice, contrast, conditional):
model, idata = food_choice
plot_comparisons(model, idata, contrast, conditional)


class TestSlopes:
Expand Down Expand Up @@ -397,3 +452,14 @@ def test_group_effects(self, sleep_study):
):
# default: sample_new_groups=False
plot_slopes(model, idata, "Days", "Subject")

@pytest.mark.parametrize(
"wrt, conditional",
[
("sex", "length"), # Categorical & numeric
("length", "sex") # Numeric & categorical
]
)
def test_categorical_response(self, food_choice, wrt, conditional):
model, idata = food_choice
plot_slopes(model, idata, wrt, conditional)

0 comments on commit 2815c0f

Please sign in to comment.