Skip to content

Commit

Permalink
V4 update test framework for distributions random method (#4608)
Browse files Browse the repository at this point in the history
* Update tests following distributions refactoring

The distributions refactoring moves the random variable sampling to
aesara. This relies on numpy and scipy random variables implementation.
So, now the only thing we care about testing is that the parametrization
on the PyMC side is sendible given the one on the Aesara side
(effectively the numpy/scipy one)

More details can be found on issue #4554
#4554

* Change tests for more refactored distributions.

More details can be found on issue #4554
#4554

* Change tests for refactored distributions

More details can be found on issue #4554
#4554

* Remove tests for random variable samples shape and size

Most of the random variable logic has been moved to aesara, as well as
most of the relative tests. More details can be found on issue #4554

* Fix test for half cauchy, renmae mv normal tests and add test for
Bernoulli

* Add test checking PyMC samples match the aesara ones

Also mark test_categorical as expected to fail due to bug on aesara
side. The bug is going to be fixed with 2.0.5 release, so we need to
bump the version for categorical and the test to pass.

* Move Aesara to 2.0.5 to include Gumbel distribution

* Enamble exponential and gamma tests following bug-fix

* Enable categorical test following aesara version bump to 2.0.5 and relative bug-fix

* Few small cosmetic changes:
- replace list of tuples with dict
- rename 1 method
- move pymc_dist as first argument in function call
- replace list(params) with params.copy()

* Remove redundant tests

* Further refactoring

The refactoring should make it possible testing both the distribution
parametrization and sampled values according to need, as well as any
other future test. More details on PR #4608

* Add size tests to new rv testing framework

* Add tests for multivariate and for univariate multi-parameters

* remove test already covered in aesara

* fix few names

* Remove "distribution" from test class names

* Add discrete Weibull, improve Beta and some minor refactoring

* Fix typos in checks naming and add sanity check

Co-authored-by: Ricardo <ricardo.vieira1994@gmail.com>
  • Loading branch information
2 people authored and twiecki committed Jun 5, 2021
1 parent 51b22b4 commit 946cf34
Show file tree
Hide file tree
Showing 4 changed files with 396 additions and 327 deletions.
7 changes: 3 additions & 4 deletions pymc3/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,24 +713,23 @@ def NegBinom(a, m, x):

@classmethod
def dist(cls, mu=None, alpha=None, p=None, n=None, *args, **kwargs):
n, p = cls.get_mu_alpha(mu, alpha, p, n)
n, p = cls.get_n_p(mu, alpha, p, n)
n = at.as_tensor_variable(floatX(n))
p = at.as_tensor_variable(floatX(p))
return super().dist([n, p], *args, **kwargs)

@classmethod
def get_mu_alpha(cls, mu=None, alpha=None, p=None, n=None):
def get_n_p(cls, mu=None, alpha=None, p=None, n=None):
if n is None:
if alpha is not None:
n = at.as_tensor_variable(floatX(alpha))
n = alpha
else:
raise ValueError("Incompatible parametrization. Must specify either alpha or n.")
elif alpha is not None:
raise ValueError("Incompatible parametrization. Can't specify both alpha and n.")

if p is None:
if mu is not None:
mu = at.as_tensor_variable(floatX(mu))
p = n / (mu + n)
else:
raise ValueError("Incompatible parametrization. Must specify either mu or p.")
Expand Down
6 changes: 6 additions & 0 deletions pymc3/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

class SeededTest:
random_seed = 20160911
random_state = None

@classmethod
def setup_class(cls):
Expand All @@ -40,6 +41,11 @@ def setup_method(self):
def teardown_method(self):
set_at_rng(self.old_at_rng)

def get_random_state(self, reset=False):
if self.random_state is None or reset:
self.random_state = nr.RandomState(self.random_seed)
return self.random_state


class LoggingHandler(BufferingHandler):
def __init__(self, matcher):
Expand Down
Loading

0 comments on commit 946cf34

Please sign in to comment.