Skip to content

Commit

Permalink
Refactor Wald distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
matteo-pallini committed Jun 13, 2021
1 parent a88dec7 commit fe86dcd
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 81 deletions.
135 changes: 77 additions & 58 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
nodes in PyMC.
"""

from typing import Union
from typing import List, Optional, Tuple, Union

import aesara.tensor as at
import numpy as np
Expand Down Expand Up @@ -896,6 +896,37 @@ def _distr_parameters_for_repr(self):
return ["sigma"]


class WaldRV(RandomVariable):
name = "wald"
ndim_supp = 0
ndims_params = [0, 0, 0]
dtype = "floatX"
_print_name = ("Wald", "\\operatorname{Wald}")

@classmethod
def rng_fn(
cls,
rng: np.random.RandomState,
mu: Union[np.ndarray, float],
lam: Union[np.ndarray, float],
alpha: Union[np.ndarray, float],
size: Optional[Union[List[int], int]],
) -> np.ndarray:
v = rng.normal(size=size) ** 2
z = rng.uniform(size=size)
value = (
mu
+ (mu ** 2) * v / (2.0 * lam)
- mu / (2.0 * lam) * np.sqrt(4.0 * mu * lam * v + (mu * v) ** 2)
)
i = np.floor(z - mu / (mu + value)) * 2 + 1
value = (value ** -i) * (mu ** (i + 1))
return value + alpha


wald = WaldRV()


class Wald(PositiveContinuous):
r"""
Wald log-likelihood.
Expand Down Expand Up @@ -974,27 +1005,33 @@ class Wald(PositiveContinuous):
.. [Giner2016] Göknur Giner, Gordon K. Smyth (2016)
statmod: Probability Calculations for the Inverse Gaussian Distribution
"""
rv_op = wald

def __init__(self, mu=None, lam=None, phi=None, alpha=0.0, *args, **kwargs):
super().__init__(*args, **kwargs)
mu, lam, phi = self.get_mu_lam_phi(mu, lam, phi)
self.alpha = alpha = at.as_tensor_variable(floatX(alpha))
self.mu = mu = at.as_tensor_variable(floatX(mu))
self.lam = lam = at.as_tensor_variable(floatX(lam))
self.phi = phi = at.as_tensor_variable(floatX(phi))

self.mean = self.mu + self.alpha
self.mode = (
self.mu * (at.sqrt(1.0 + (1.5 * self.mu / self.lam) ** 2) - 1.5 * self.mu / self.lam)
+ self.alpha
)
self.variance = (self.mu ** 3) / self.lam
@classmethod
def dist(
cls,
mu: Optional[Union[float, np.ndarray]] = None,
lam: Optional[Union[float, np.ndarray]] = None,
phi: Optional[Union[float, np.ndarray]] = None,
alpha: Union[float, np.ndarray] = 0.0,
*args,
**kwargs,
) -> RandomVariable:
mu, lam, phi = cls.get_mu_lam_phi(mu, lam, phi)
alpha = at.as_tensor_variable(floatX(alpha))
mu = at.as_tensor_variable(floatX(mu))
lam = at.as_tensor_variable(floatX(lam))

assert_negative_support(phi, "phi", "Wald")
assert_negative_support(mu, "mu", "Wald")
assert_negative_support(lam, "lam", "Wald")

def get_mu_lam_phi(self, mu, lam, phi):
return super().dist([mu, lam, alpha], **kwargs)

@staticmethod
def get_mu_lam_phi(
mu: Optional[float], lam: Optional[float], phi: Optional[float]
) -> Tuple[Union[float, np.ndarray], Union[float, np.ndarray], Union[float, np.ndarray]]:
if mu is None:
if lam is not None and phi is not None:
return lam / phi, lam, phi
Expand All @@ -1013,39 +1050,12 @@ def get_mu_lam_phi(self, mu, lam, phi):
"mu and lam, mu and phi, or lam and phi."
)

def _random(self, mu, lam, alpha, size=None):
v = np.random.normal(size=size) ** 2
value = (
mu
+ (mu ** 2) * v / (2.0 * lam)
- mu / (2.0 * lam) * np.sqrt(4.0 * mu * lam * v + (mu * v) ** 2)
)
z = np.random.uniform(size=size)
i = np.floor(z - mu / (mu + value)) * 2 + 1
value = (value ** -i) * (mu ** (i + 1))
return value + alpha

def random(self, point=None, size=None):
"""
Draw random values from Wald distribution.
Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
size: int, optional
Desired size of random sample (returns one sample if not
specified).
Returns
-------
array
"""
# mu, lam, alpha = draw_values([self.mu, self.lam, self.alpha], point=point, size=size)
# return generate_samples(self._random, mu, lam, alpha, dist_shape=self.shape, size=size)

def logp(self, value):
def logp(
value,
mu: Union[float, np.ndarray, TensorVariable],
lam: Union[float, np.ndarray, TensorVariable],
alpha: Union[float, np.ndarray, TensorVariable],
) -> RandomVariable:
"""
Calculate log-probability of Wald distribution at specified value.
Expand All @@ -1054,14 +1064,17 @@ def logp(self, value):
value: numeric
Value(s) for which log-probability is calculated. If the log probabilities for multiple
values are desired the values must be provided in a numpy array or Aesara tensor
mu: float or TensorVariable
Mean of the distribution (mu > 0).
lam: float or TensorVariable
Relative precision (lam > 0).
alpha: float or TensorVariable
Shift/location parameter (alpha >= 0).
Returns
-------
TensorVariable
"""
mu = self.mu
lam = self.lam
alpha = self.alpha
centered_value = value - alpha
# value *must* be iid. Otherwise this is wrong.
return bound(
Expand All @@ -1077,7 +1090,12 @@ def logp(self, value):
def _distr_parameters_for_repr(self):
return ["mu", "lam", "alpha"]

def logcdf(self, value):
def logcdf(
value,
mu: Union[float, np.ndarray, TensorVariable],
lam: Union[float, np.ndarray, TensorVariable],
alpha: Union[float, np.ndarray, TensorVariable],
) -> RandomVariable:
"""
Compute the log of the cumulative distribution function for Wald distribution
at the specified value.
Expand All @@ -1087,16 +1105,17 @@ def logcdf(self, value):
value: numeric or np.ndarray or aesara.tensor
Value(s) for which log CDF is calculated. If the log CDF for multiple
values are desired the values must be provided in a numpy array or Aesara tensor.
mu: float or TensorVariable
Mean of the distribution (mu > 0).
lam: float or TensorVariable
Relative precision (lam > 0).
alpha: float or TensorVariable
Shift/location parameter (alpha >= 0).
Returns
-------
TensorVariable
"""
# Distribution parameters
mu = self.mu
lam = self.lam
alpha = self.alpha

value -= alpha
q = value / mu
l = lam * mu
Expand Down
4 changes: 0 additions & 4 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,6 @@ def test_chisquared_logcdf(self):
lambda value, nu: sp.chi2.logcdf(value, df=nu),
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_wald_logp(self):
self.check_logp(
Wald,
Expand All @@ -1077,7 +1076,6 @@ def test_wald_logp(self):
decimal=select_by_precision(float64=6, float32=1),
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
@pytest.mark.xfail(
condition=(aesara.config.floatX == "float32"),
reason="Poor CDF in SciPy. See scipy/scipy#869 for details.",
Expand All @@ -1090,7 +1088,6 @@ def test_wald_logcdf(self):
lambda value, mu, alpha: sp.invgauss.logcdf(value, mu=mu, loc=alpha),
)

@pytest.mark.xfail(reason="Distribution not refactored yet")
@pytest.mark.parametrize(
"value,mu,lam,phi,alpha,logp",
[
Expand All @@ -1110,7 +1107,6 @@ def test_wald_logcdf(self):
(50.0, 15.0, None, 0.666666, 10.0, -5.6481874),
],
)
@pytest.mark.xfail(reason="Distribution not refactored yet")
def test_wald_logp_custom_points(self, value, mu, lam, phi, alpha, logp):
# Log probabilities calculated using the dIG function from the R package gamlss.
# See e.g., doi: 10.1111/j.1467-9876.2005.00510.x, or
Expand Down
91 changes: 72 additions & 19 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,6 @@ class TestGaussianRandomWalk(BaseTestCases.BaseTestCase):
default_shape = (1,)


@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
class TestWald(BaseTestCases.BaseTestCase):
distribution = pm.Wald
params = {"mu": 1.0, "lam": 1.0, "alpha": 0.0}


@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
class TestAsymmetricLaplace(BaseTestCases.BaseTestCase):
distribution = pm.AsymmetricLaplace
Expand Down Expand Up @@ -575,6 +569,78 @@ class TestTruncatedNormalUpperTau(BaseTestDistribution):
]


class BaseTestWald(BaseTestDistribution):
def wald_rng_fn(self, size, mu, lam, alpha, uniform_rng_fct, normal_rng_fct):
v = normal_rng_fct(size=size) ** 2
z = uniform_rng_fct(size=size)
value = (
mu
+ (mu ** 2) * v / (2.0 * lam)
- mu / (2.0 * lam) * np.sqrt(4.0 * mu * lam * v + (mu * v) ** 2)
)
i = np.floor(z - mu / (mu + value)) * 2 + 1
value = (value ** -i) * (mu ** (i + 1))
return value + alpha

def seeded_wald_rng_fn(self):
uniform_rng_fct = functools.partial(
getattr(np.random.RandomState, "uniform"), self.get_random_state()
)
normal_rng_fct = functools.partial(
getattr(np.random.RandomState, "normal"), self.get_random_state()
)
return functools.partial(
self.wald_rng_fn, uniform_rng_fct=uniform_rng_fct, normal_rng_fct=normal_rng_fct
)

pymc_dist = pm.Wald
reference_dist = seeded_wald_rng_fn
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_rv_size",
]


class TestWaldAlpha(BaseTestDistribution):
pymc_dist = pm.Wald
mu, lam, alpha = 1.0, 1.0, 2.0
mu_rv, lam_rv, phi_rv = pm.Wald.get_mu_lam_phi(mu=mu, lam=lam, phi=None)
pymc_dist_params = {"mu": mu, "lam": lam, "alpha": alpha}
expected_rv_op_params = {"mu": mu_rv, "lam": lam_rv, "alpha": alpha}
reference_dist_params = {"loc": alpha}
reference_dist = seeded_scipy_distribution_builder("wald")
tests_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
"check_rv_size",
]


class TestWaldMuLam(BaseTestWald):
mu, lam, alpha = 1.0, 3.0, 0.0
mu_rv, lam_rv, phi_rv = pm.Wald.get_mu_lam_phi(mu=mu, lam=lam, phi=None)
pymc_dist_params = {"mu": mu, "lam": lam}
expected_rv_op_params = {"mu": mu_rv, "lam": lam_rv, "alpha": 0.0}
reference_dist_params = {"mu": mu_rv, "lam": lam_rv, "alpha": 0.0}


class TestWaldMuLamShifted(BaseTestWald):
mu, lam, alpha = 1.0, 3.0, 2.0
mu_rv, lam_rv, phi_rv = pm.Wald.get_mu_lam_phi(mu=mu, lam=lam, phi=None)
pymc_dist_params = {"mu": mu, "lam": lam, "alpha": alpha}
expected_rv_op_params = {"mu": mu_rv, "lam": lam_rv, "alpha": 2.0}
reference_dist_params = {"mu": mu_rv, "lam": lam_rv, "alpha": 2.0}


class TestWaldMuPhi(BaseTestWald):
mu, phi, alpha = 1.0, 3.0, 0.0
mu_rv, lam_rv, phi_rv = pm.Wald.get_mu_lam_phi(mu=mu, lam=None, phi=phi)
pymc_dist_params = {"mu": mu, "phi": phi, "alpha": alpha}
expected_rv_op_params = {"mu": mu_rv, "lam": phi_rv, "alpha": 0.0}
reference_dist_params = {"mu": mu_rv, "lam": lam_rv, "alpha": 0.0}


class TestSkewNormal(BaseTestDistribution):
pymc_dist = pm.SkewNormal
pymc_dist_params = {"mu": 0.0, "sigma": 1.0, "alpha": 5.0}
Expand Down Expand Up @@ -1263,19 +1329,6 @@ def ref_rand(size, alpha, mu, sigma):

pymc3_random(pm.SkewNormal, {"mu": R, "sigma": Rplus, "alpha": R}, ref_rand=ref_rand)

@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
def test_wald(self):
# Cannot do anything too exciting as scipy wald is a
# location-scale model of the *standard* wald with mu=1 and lam=1
def ref_rand(size, mu, lam, alpha):
return st.wald.rvs(size=size, loc=alpha)

pymc3_random(
pm.Wald,
{"mu": Domain([1.0, 1.0, 1.0]), "lam": Domain([1.0, 1.0, 1.0]), "alpha": Rplus},
ref_rand=ref_rand,
)

@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
def test_laplace_asymmetric(self):
def ref_rand(size, kappa, b, mu):
Expand Down

0 comments on commit fe86dcd

Please sign in to comment.