Skip to content

Commit

Permalink
Move Wald distribution to use the numpy rvs implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
matteo-pallini authored and ricardoV94 committed Jun 30, 2021
1 parent 73e932f commit 74aeafc
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 65 deletions.
20 changes: 2 additions & 18 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,24 +905,8 @@ class WaldRV(RandomVariable):
_print_name = ("Wald", "\\operatorname{Wald}")

@classmethod
def rng_fn(
cls,
rng: np.random.RandomState,
mu: Union[np.ndarray, float],
lam: Union[np.ndarray, float],
alpha: Union[np.ndarray, float],
size: Optional[Union[List[int], int]],
) -> np.ndarray:
v = rng.normal(size=size) ** 2
z = rng.uniform(size=size)
value = (
mu
+ (mu ** 2) * v / (2.0 * lam)
- mu / (2.0 * lam) * np.sqrt(4.0 * mu * lam * v + (mu * v) ** 2)
)
i = np.floor(z - mu / (mu + value)) * 2 + 1
value = (value ** -i) * (mu ** (i + 1))
return value + alpha
def rng_fn(cls, rng, mu, lam, alpha, size):
return getattr(np.random.RandomState, cls.name)(rng, mu, lam, size) + alpha


wald = WaldRV()
Expand Down
82 changes: 35 additions & 47 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,76 +621,64 @@ class TestTruncatedNormalUpperTau(BaseTestDistribution):
]


class BaseTestWald(BaseTestDistribution):
def wald_rng_fn(self, size, mu, lam, alpha, uniform_rng_fct, normal_rng_fct):
v = normal_rng_fct(size=size) ** 2
z = uniform_rng_fct(size=size)
value = (
mu
+ (mu ** 2) * v / (2.0 * lam)
- mu / (2.0 * lam) * np.sqrt(4.0 * mu * lam * v + (mu * v) ** 2)
)
i = np.floor(z - mu / (mu + value)) * 2 + 1
value = (value ** -i) * (mu ** (i + 1))
return value + alpha

def seeded_wald_rng_fn(self):
uniform_rng_fct = functools.partial(
getattr(np.random.RandomState, "uniform"), self.get_random_state()
)
normal_rng_fct = functools.partial(
getattr(np.random.RandomState, "normal"), self.get_random_state()
)
return functools.partial(
self.wald_rng_fn, uniform_rng_fct=uniform_rng_fct, normal_rng_fct=normal_rng_fct
)

class TestWald(BaseTestDistribution):
pymc_dist = pm.Wald
reference_dist = seeded_wald_rng_fn
mu, lam, alpha = 1.0, 1.0, 0.0
mu_rv, lam_rv, phi_rv = pm.Wald.get_mu_lam_phi(mu=mu, lam=lam, phi=None)
pymc_dist_params = {"mu": mu, "lam": lam, "alpha": alpha}
expected_rv_op_params = {"mu": mu_rv, "lam": lam_rv, "alpha": alpha}
reference_dist_params = [mu, lam_rv]
reference_dist = seeded_numpy_distribution_builder("wald")
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_rv_size",
]

def test_distribution(self):
self.validate_tests_list()
self._instantiate_pymc_rv()
if self.reference_dist is not None:
self.reference_dist_draws = self.reference_dist()(
*self.reference_dist_params, self.size
)
for check_name in self.tests_to_run:
getattr(self, check_name)()

def check_pymc_draws_match_reference(self):
assert_array_almost_equal(
self.pymc_rv.eval(), self.reference_dist_draws + self.alpha, decimal=self.decimal
)

class TestWaldPureScipy(BaseTestDistribution):

class TestWaldAlpha(TestWald):
pymc_dist = pm.Wald
mu, lam, alpha = 1.0, 1.0, 2.0
mu_rv, lam_rv, phi_rv = pm.Wald.get_mu_lam_phi(mu=mu, lam=lam, phi=None)
pymc_dist_params = {"mu": mu, "lam": lam, "alpha": alpha}
expected_rv_op_params = {"mu": mu_rv, "lam": lam_rv, "alpha": alpha}
reference_dist_params = {"loc": alpha}
reference_dist = seeded_scipy_distribution_builder("wald")
reference_dist_params = [mu, lam_rv]
reference_dist = seeded_numpy_distribution_builder("wald")
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_rv_size",
]


class TestWaldMuLam(BaseTestWald):
mu, lam, alpha = 1.0, 3.0, 0.0
mu_rv, lam_rv, phi_rv = pm.Wald.get_mu_lam_phi(mu=mu, lam=lam, phi=None)
pymc_dist_params = {"mu": mu, "lam": lam}
expected_rv_op_params = {"mu": mu_rv, "lam": lam_rv, "alpha": 0.0}
reference_dist_params = {"mu": mu_rv, "lam": lam_rv, "alpha": 0.0}


class TestWaldMuLamShifted(BaseTestWald):
mu, lam, alpha = 1.0, 3.0, 2.0
mu_rv, lam_rv, phi_rv = pm.Wald.get_mu_lam_phi(mu=mu, lam=lam, phi=None)
pymc_dist_params = {"mu": mu, "lam": lam, "alpha": alpha}
expected_rv_op_params = {"mu": mu_rv, "lam": lam_rv, "alpha": 2.0}
reference_dist_params = {"mu": mu_rv, "lam": lam_rv, "alpha": 2.0}


class TestWaldMuPhi(BaseTestWald):
class TestWaldMuPhi(TestWald):
pymc_dist = pm.Wald
mu, phi, alpha = 1.0, 3.0, 0.0
mu_rv, lam_rv, phi_rv = pm.Wald.get_mu_lam_phi(mu=mu, lam=None, phi=phi)
pymc_dist_params = {"mu": mu, "phi": phi, "alpha": alpha}
expected_rv_op_params = {"mu": mu_rv, "lam": phi_rv, "alpha": 0.0}
reference_dist_params = {"mu": mu_rv, "lam": lam_rv, "alpha": 0.0}
expected_rv_op_params = {"mu": mu_rv, "lam": lam_rv, "alpha": alpha}
reference_dist_params = [mu, lam_rv]
reference_dist = seeded_numpy_distribution_builder("wald")
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_rv_size",
]


class TestSkewNormal(BaseTestDistribution):
Expand Down

0 comments on commit 74aeafc

Please sign in to comment.