From af6a560f5fc9ef43e054bafb98cd8e2ae647df9b Mon Sep 17 00:00:00 2001 From: Ricardo Date: Wed, 30 Jun 2021 10:14:26 +0200 Subject: [PATCH] Make `TestBoundedContinuous` more readable --- pymc3/tests/test_distributions.py | 76 +++++++++++++++++++------------ 1 file changed, 48 insertions(+), 28 deletions(-) diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index de0fb50323d..e9764215e1b 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -2737,27 +2737,35 @@ def get_dist_params_and_interval_bounds(self, model, rv_name): upper_interval, ) - def test_missing_lower_bound(self): + def test_upper_bounded(self): bounded_rv_name = "lower_bounded" with Model() as model: TruncatedNormal(bounded_rv_name, mu=1, sigma=2, lower=None, upper=3) - dist_params, lower, upper = self.get_dist_params_and_interval_bounds(model, bounded_rv_name) - assert dist_params[2].value == -np.inf - assert dist_params[3].value == 3 - assert lower is None + ( + (_, _, lower, upper), + lower_interval, + upper_interval, + ) = self.get_dist_params_and_interval_bounds(model, bounded_rv_name) + assert lower.value == -np.inf assert upper.value == 3 + assert lower_interval is None + assert upper_interval.value == 3 - def test_missing_upper_bound(self): + def test_lower_bounded(self): bounded_rv_name = "upper_bounded" with Model() as model: TruncatedNormal(bounded_rv_name, mu=1, sigma=2, lower=-2, upper=None) - dist_params, lower, upper = self.get_dist_params_and_interval_bounds(model, bounded_rv_name) - assert dist_params[2].value == -2 - assert dist_params[3].value == np.inf + ( + (_, _, lower, upper), + lower_interval, + upper_interval, + ) = self.get_dist_params_and_interval_bounds(model, bounded_rv_name) assert lower.value == -2 - assert upper is None + assert upper.value == np.inf + assert lower_interval.value == -2 + assert upper_interval is None - def test_missing_upper_bound_array(self): + def test_vector_lower_bounded(self): bounded_rv_name = "upper_bounded" with Model() as model: TruncatedNormal( @@ -2767,40 +2775,52 @@ def test_missing_upper_bound_array(self): lower=np.array([-1.0, 0]), upper=None, ) - dist_params, lower, upper = self.get_dist_params_and_interval_bounds(model, bounded_rv_name) + ( + (_, _, lower, upper), + lower_interval, + upper_interval, + ) = self.get_dist_params_and_interval_bounds(model, bounded_rv_name) - assert np.array_equal(dist_params[2].value, [-1, 0]) - assert dist_params[3].value == np.inf assert np.array_equal(lower.value, [-1, 0]) - assert upper is None + assert upper.value == np.inf + assert np.array_equal(lower_interval.value, [-1, 0]) + assert upper_interval is None - def test_missing_partial_upper_bound_array(self): + def test_lower_bounded_broadcasted(self): bounded_rv_name = "upper_bounded" with Model() as model: TruncatedNormal( bounded_rv_name, mu=np.array([1, 1]), sigma=np.array([2, 3]), - lower=np.array([-1.0, -np.inf]), - upper=None, + lower=-1, + upper=np.array([np.inf, np.inf]), ) - dist_params, lower, upper = self.get_dist_params_and_interval_bounds(model, bounded_rv_name) + ( + (_, _, lower, upper), + lower_interval, + upper_interval, + ) = self.get_dist_params_and_interval_bounds(model, bounded_rv_name) - assert np.array_equal(dist_params[2].value, [-1, -np.inf]) - assert dist_params[3].value == np.inf - assert np.array_equal(lower.value, [-1, -np.inf]) - assert upper is None + assert lower.value == -1 + assert np.array_equal(upper.value, [np.inf, np.inf]) + assert lower_interval.value == -1 + assert upper_interval is None - def test_missing_upper_bound_with_richer_context(self): + def test_hierarchical(self): with Model() as model: sigma = TruncatedNormal("lower_bounded", mu=2, sigma=1.5, lower=0, upper=None) mu = TruncatedNormal("upper_bounded", mu=0, sigma=2, lower=None, upper=3) Normal("normal", mu=mu, sigma=sigma, observed=[1.3, -1.4, 2.0]) - dist_params, lower, upper = self.get_dist_params_and_interval_bounds(model, "upper_bounded") - assert dist_params[2].value == -np.inf - assert dist_params[3].value == 3 - assert lower is None + ( + (_, _, lower, upper), + lower_interval, + upper_interval, + ) = self.get_dist_params_and_interval_bounds(model, "upper_bounded") + assert lower.value == -np.inf assert upper.value == 3 + assert lower_interval is None + assert upper_interval.value == 3 @pytest.mark.xfail(reason="LaTeX repr and str no longer applicable")