Skip to content

Commit

Permalink
Replace lists with sets when asserting .data_vars
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto committed Apr 9, 2024
1 parent a8a4931 commit 1e7010c
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,13 @@ def test_cell_means_parameterization(self, crossed_data):
def test_2_factors_saturated(self, crossed_data):
model = bmb.Model("Y ~ threecats*fourcats", crossed_data)
idata = self.fit(model)
assert list(idata.posterior.data_vars) == [
assert set(idata.posterior.data_vars) == {
"Intercept",
"threecats",
"fourcats",
"threecats:fourcats",
"Y_sigma",
]
}
assert list(idata.posterior["threecats_dim"].values) == ["b", "c"]
assert list(idata.posterior["fourcats_dim"].values) == ["b", "c", "d"]
assert list(idata.posterior["threecats:fourcats_dim"].values) == [
Expand All @@ -214,12 +214,12 @@ def test_2_factors_saturated(self, crossed_data):
def test_2_factors_no_intercept(self, crossed_data):
model = bmb.Model("Y ~ 0 + threecats*fourcats", crossed_data)
idata = self.fit(model)
assert list(idata.posterior.data_vars) == [
assert set(idata.posterior.data_vars) == {
"threecats",
"fourcats",
"threecats:fourcats",
"Y_sigma",
]
}
assert list(idata.posterior["threecats_dim"].values) == ["a", "b", "c"]
assert list(idata.posterior["fourcats_dim"].values) == ["b", "c", "d"]
assert list(idata.posterior["threecats:fourcats_dim"].values) == [
Expand All @@ -235,7 +235,7 @@ def test_2_factors_no_intercept(self, crossed_data):
def test_2_factors_cell_means(self, crossed_data):
model = bmb.Model("Y ~ 0 + threecats:fourcats", crossed_data)
idata = self.fit(model)
assert list(idata.posterior.data_vars) == ["threecats:fourcats", "Y_sigma"]
assert set(idata.posterior.data_vars) == {"threecats:fourcats", "Y_sigma"}
assert list(idata.posterior["threecats:fourcats_dim"].values) == [
"a, a",
"a, b",
Expand All @@ -255,7 +255,7 @@ def test_2_factors_cell_means(self, crossed_data):
def test_cell_means_with_covariate(self, crossed_data):
model = bmb.Model("Y ~ 0 + threecats + continuous", crossed_data)
idata = self.fit(model)
assert list(idata.posterior.data_vars) == ["threecats", "continuous", "Y_sigma"]
assert set(idata.posterior.data_vars) == {"threecats", "continuous", "Y_sigma"}
assert list(idata.posterior["threecats_dim"].values) == ["a", "b", "c"]
self.predict_oos(model, idata)

Expand Down Expand Up @@ -477,15 +477,15 @@ def test_group_specific_categorical_interaction(self, crossed_data):
idata = self.fit(model)
self.predict_oos(model, idata)

assert list(idata.posterior.data_vars) == [
assert set(idata.posterior.data_vars) == {
"Intercept",
"continuous",
"Y_sigma",
"1|site_sigma",
"threecats:fourcats|site_sigma",
"1|site",
"threecats:fourcats|site",
]
}
assert list(idata.posterior["threecats:fourcats|site"].coords) == [
"chain",
"draw",
Expand Down

0 comments on commit 1e7010c

Please sign in to comment.