Skip to content

Commit

Permalink
Fix typos in checks naming and add sanity check
Browse files Browse the repository at this point in the history
  • Loading branch information
matteo-pallini committed Apr 24, 2021
1 parent 3d28087 commit 56253dc
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ class BaseTestDistribution(SeededTest):
repeated_params_shape = 5

def test_distribution(self):
self.validate_tests_list()
self._instantiate_pymc_rv()
if self.reference_dist is not None:
self.reference_dist_draws = self.reference_dist()(
Expand All @@ -439,7 +440,7 @@ def check_pymc_draws_match_reference(self):
self.pymc_rv.eval(), self.reference_dist_draws, decimal=self.decimal
)

def check_pymc_params_match_rv_op(self) -> None:
def check_pymc_params_match_rv_op(self):
aesera_dist_inputs = self.pymc_rv.get_parents()[0].inputs[3:]
assert len(self.expected_rv_op_params) == len(aesera_dist_inputs)
for (expected_name, expected_value), actual_variable in zip(
Expand Down Expand Up @@ -476,6 +477,11 @@ def check_rv_size(self):
actual = change_rv_size(self.pymc_rv, size).eval().shape
assert actual == expected

def validate_tests_list(self):
assert len(self.tests_to_run) == len(
set(self.tests_to_run)
), "There are duplicates in the list of tests_to_run"


def seeded_scipy_distribution_builder(dist_name: str) -> Callable:
return lambda self: functools.partial(
Expand All @@ -490,24 +496,24 @@ def seeded_numpy_distribution_builder(dist_name: str) -> Callable:


class TestDiscreteWeibull(BaseTestDistribution):
def discrete_weibul_rng_fn(self):
p = seeded_numpy_distribution_builder("uniform")
return (
lambda size, q, beta: np.ceil(
np.power(np.log(1 - p(self)(size=size)) / np.log(q), 1.0 / beta)
)
- 1
def discrete_weibul_rng_fn(self, size, q, beta, uniform_rgn_fct):
return np.ceil(np.power(np.log(1 - uniform_rgn_fct(size=size)) / np.log(q), 1.0 / beta)) - 1

def seeded_discrete_weibul_rng_fn(self):
uniform_rng_fct = functools.partial(
getattr(np.random.RandomState, "uniform"), self.get_random_state()
)
return functools.partial(self.discrete_weibul_rng_fn, uniform_rgn_fct=uniform_rng_fct)

pymc_dist = pm.DiscreteWeibull
pymc_dist_params = {"q": 0.25, "beta": 2.0}
expected_rv_op_params = {"q": 0.25, "beta": 2.0}
reference_dist_params = {"q": 0.25, "beta": 2.0}
reference_dist = discrete_weibul_rng_fn
reference_dist = seeded_discrete_weibul_rng_fn
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_rv_size",
"check_pymc_dist_matches_reference",
"check_pymc_draws_match_reference",
]


Expand All @@ -521,7 +527,7 @@ class TestGumbel(BaseTestDistribution):
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_rv_size",
"check_pymc_dist_matches_reference",
"check_pymc_draws_match_reference",
]


Expand All @@ -535,7 +541,7 @@ class TestNormal(BaseTestDistribution):
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_rv_size",
"check_pymc_dist_matches_reference",
"check_pymc_draws_match_reference",
]


Expand Down Expand Up @@ -595,7 +601,7 @@ class TestBeta(BaseTestDistribution):
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_rv_size",
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
]


Expand Down

0 comments on commit 56253dc

Please sign in to comment.