Skip to content

Commit

Permalink
Make TestBoundedContinuous more readable
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 30, 2021
1 parent 91a90e7 commit af6a560
Showing 1 changed file with 48 additions and 28 deletions.
76 changes: 48 additions & 28 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2737,27 +2737,35 @@ def get_dist_params_and_interval_bounds(self, model, rv_name):
upper_interval,
)

def test_missing_lower_bound(self):
def test_upper_bounded(self):
bounded_rv_name = "lower_bounded"
with Model() as model:
TruncatedNormal(bounded_rv_name, mu=1, sigma=2, lower=None, upper=3)
dist_params, lower, upper = self.get_dist_params_and_interval_bounds(model, bounded_rv_name)
assert dist_params[2].value == -np.inf
assert dist_params[3].value == 3
assert lower is None
(
(_, _, lower, upper),
lower_interval,
upper_interval,
) = self.get_dist_params_and_interval_bounds(model, bounded_rv_name)
assert lower.value == -np.inf
assert upper.value == 3
assert lower_interval is None
assert upper_interval.value == 3

def test_missing_upper_bound(self):
def test_lower_bounded(self):
bounded_rv_name = "upper_bounded"
with Model() as model:
TruncatedNormal(bounded_rv_name, mu=1, sigma=2, lower=-2, upper=None)
dist_params, lower, upper = self.get_dist_params_and_interval_bounds(model, bounded_rv_name)
assert dist_params[2].value == -2
assert dist_params[3].value == np.inf
(
(_, _, lower, upper),
lower_interval,
upper_interval,
) = self.get_dist_params_and_interval_bounds(model, bounded_rv_name)
assert lower.value == -2
assert upper is None
assert upper.value == np.inf
assert lower_interval.value == -2
assert upper_interval is None

def test_missing_upper_bound_array(self):
def test_vector_lower_bounded(self):
bounded_rv_name = "upper_bounded"
with Model() as model:
TruncatedNormal(
Expand All @@ -2767,40 +2775,52 @@ def test_missing_upper_bound_array(self):
lower=np.array([-1.0, 0]),
upper=None,
)
dist_params, lower, upper = self.get_dist_params_and_interval_bounds(model, bounded_rv_name)
(
(_, _, lower, upper),
lower_interval,
upper_interval,
) = self.get_dist_params_and_interval_bounds(model, bounded_rv_name)

assert np.array_equal(dist_params[2].value, [-1, 0])
assert dist_params[3].value == np.inf
assert np.array_equal(lower.value, [-1, 0])
assert upper is None
assert upper.value == np.inf
assert np.array_equal(lower_interval.value, [-1, 0])
assert upper_interval is None

def test_missing_partial_upper_bound_array(self):
def test_lower_bounded_broadcasted(self):
bounded_rv_name = "upper_bounded"
with Model() as model:
TruncatedNormal(
bounded_rv_name,
mu=np.array([1, 1]),
sigma=np.array([2, 3]),
lower=np.array([-1.0, -np.inf]),
upper=None,
lower=-1,
upper=np.array([np.inf, np.inf]),
)
dist_params, lower, upper = self.get_dist_params_and_interval_bounds(model, bounded_rv_name)
(
(_, _, lower, upper),
lower_interval,
upper_interval,
) = self.get_dist_params_and_interval_bounds(model, bounded_rv_name)

assert np.array_equal(dist_params[2].value, [-1, -np.inf])
assert dist_params[3].value == np.inf
assert np.array_equal(lower.value, [-1, -np.inf])
assert upper is None
assert lower.value == -1
assert np.array_equal(upper.value, [np.inf, np.inf])
assert lower_interval.value == -1
assert upper_interval is None

def test_missing_upper_bound_with_richer_context(self):
def test_hierarchical(self):
with Model() as model:
sigma = TruncatedNormal("lower_bounded", mu=2, sigma=1.5, lower=0, upper=None)
mu = TruncatedNormal("upper_bounded", mu=0, sigma=2, lower=None, upper=3)
Normal("normal", mu=mu, sigma=sigma, observed=[1.3, -1.4, 2.0])
dist_params, lower, upper = self.get_dist_params_and_interval_bounds(model, "upper_bounded")
assert dist_params[2].value == -np.inf
assert dist_params[3].value == 3
assert lower is None
(
(_, _, lower, upper),
lower_interval,
upper_interval,
) = self.get_dist_params_and_interval_bounds(model, "upper_bounded")
assert lower.value == -np.inf
assert upper.value == 3
assert lower_interval is None
assert upper_interval.value == 3


@pytest.mark.xfail(reason="LaTeX repr and str no longer applicable")
Expand Down

0 comments on commit af6a560

Please sign in to comment.