From 1e7010c7b69084598aa13e5de471cf6a7231e9e0 Mon Sep 17 00:00:00 2001 From: Tomas Capretto Date: Tue, 9 Apr 2024 11:30:58 -0300 Subject: [PATCH] Replace lists with sets when asserting .data_vars --- tests/test_models.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 717abe7e3..3bdcf4a39 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -192,13 +192,13 @@ def test_cell_means_parameterization(self, crossed_data): def test_2_factors_saturated(self, crossed_data): model = bmb.Model("Y ~ threecats*fourcats", crossed_data) idata = self.fit(model) - assert list(idata.posterior.data_vars) == [ + assert set(idata.posterior.data_vars) == { "Intercept", "threecats", "fourcats", "threecats:fourcats", "Y_sigma", - ] + } assert list(idata.posterior["threecats_dim"].values) == ["b", "c"] assert list(idata.posterior["fourcats_dim"].values) == ["b", "c", "d"] assert list(idata.posterior["threecats:fourcats_dim"].values) == [ @@ -214,12 +214,12 @@ def test_2_factors_saturated(self, crossed_data): def test_2_factors_no_intercept(self, crossed_data): model = bmb.Model("Y ~ 0 + threecats*fourcats", crossed_data) idata = self.fit(model) - assert list(idata.posterior.data_vars) == [ + assert set(idata.posterior.data_vars) == { "threecats", "fourcats", "threecats:fourcats", "Y_sigma", - ] + } assert list(idata.posterior["threecats_dim"].values) == ["a", "b", "c"] assert list(idata.posterior["fourcats_dim"].values) == ["b", "c", "d"] assert list(idata.posterior["threecats:fourcats_dim"].values) == [ @@ -235,7 +235,7 @@ def test_2_factors_no_intercept(self, crossed_data): def test_2_factors_cell_means(self, crossed_data): model = bmb.Model("Y ~ 0 + threecats:fourcats", crossed_data) idata = self.fit(model) - assert list(idata.posterior.data_vars) == ["threecats:fourcats", "Y_sigma"] + assert set(idata.posterior.data_vars) == {"threecats:fourcats", "Y_sigma"} assert list(idata.posterior["threecats:fourcats_dim"].values) == [ "a, a", "a, b", @@ -255,7 +255,7 @@ def test_2_factors_cell_means(self, crossed_data): def test_cell_means_with_covariate(self, crossed_data): model = bmb.Model("Y ~ 0 + threecats + continuous", crossed_data) idata = self.fit(model) - assert list(idata.posterior.data_vars) == ["threecats", "continuous", "Y_sigma"] + assert set(idata.posterior.data_vars) == {"threecats", "continuous", "Y_sigma"} assert list(idata.posterior["threecats_dim"].values) == ["a", "b", "c"] self.predict_oos(model, idata) @@ -477,7 +477,7 @@ def test_group_specific_categorical_interaction(self, crossed_data): idata = self.fit(model) self.predict_oos(model, idata) - assert list(idata.posterior.data_vars) == [ + assert set(idata.posterior.data_vars) == { "Intercept", "continuous", "Y_sigma", @@ -485,7 +485,7 @@ def test_group_specific_categorical_interaction(self, crossed_data): "threecats:fourcats|site_sigma", "1|site", "threecats:fourcats|site", - ] + } assert list(idata.posterior["threecats:fourcats|site"].coords) == [ "chain", "draw",