diff --git a/pymc3/distributions/continuous.py b/pymc3/distributions/continuous.py index fad5b003b2..86f51e75d4 100644 --- a/pymc3/distributions/continuous.py +++ b/pymc3/distributions/continuous.py @@ -18,7 +18,7 @@ nodes in PyMC. """ -from typing import Union +from typing import List, Optional, Tuple, Union import aesara.tensor as at import numpy as np @@ -46,7 +46,7 @@ vonmises, ) from aesara.tensor.random.op import RandomVariable -from aesara.tensor.var import TensorVariable +from aesara.tensor.var import TensorConstant, TensorVariable from scipy import stats from scipy.interpolate import InterpolatedUnivariateSpline from scipy.special import expit @@ -162,8 +162,17 @@ def transform_params(rv_var): if cls.bound_args_indices[1] is not None: upper = args[cls.bound_args_indices[1]] - lower = at.as_tensor_variable(lower) if lower is not None else None - upper = at.as_tensor_variable(upper) if upper is not None else None + if lower is not None: + if isinstance(lower, TensorConstant) and np.all(lower.value == -np.inf): + lower = None + else: + lower = at.as_tensor_variable(lower) + + if upper is not None: + if isinstance(upper, TensorConstant) and np.all(upper.value == np.inf): + upper = None + else: + upper = at.as_tensor_variable(upper) return lower, upper @@ -559,6 +568,36 @@ def logcdf(value, mu, sigma): ) +class TruncatedNormalRV(RandomVariable): + name = "truncated_normal" + ndim_supp = 0 + ndims_params = [0, 0, 0, 0] + dtype = "floatX" + _print_name = ("TruncatedNormal", "\\operatorname{TruncatedNormal}") + + @classmethod + def rng_fn( + cls, + rng: np.random.RandomState, + mu: Union[np.ndarray, float], + sigma: Union[np.ndarray, float], + lower: Union[np.ndarray, float], + upper: Union[np.ndarray, float], + size: Optional[Union[List[int], int]], + ) -> np.ndarray: + return stats.truncnorm.rvs( + a=(lower - mu) / sigma, + b=(upper - mu) / sigma, + loc=mu, + scale=sigma, + size=size, + random_state=rng, + ) + + +truncated_normal = TruncatedNormalRV() + + class TruncatedNormal(BoundedContinuous): r""" Univariate truncated normal log-likelihood. @@ -632,99 +671,50 @@ class TruncatedNormal(BoundedContinuous): """ - def __init__( - self, - mu=0, - sigma=None, - tau=None, - lower=None, - upper=None, - transform="auto", - sd=None, + rv_op = truncated_normal + bound_args_indices = (2, 3) # indexes for lower and upper args + + @classmethod + def dist( + cls, + mu: Optional[Union[float, np.ndarray]] = None, + sigma: Optional[Union[float, np.ndarray]] = None, + tau: Optional[Union[float, np.ndarray]] = None, + sd: Optional[Union[float, np.ndarray]] = None, + lower: Optional[Union[float, np.ndarray]] = None, + upper: Optional[Union[float, np.ndarray]] = None, + transform: str = "auto", *args, **kwargs, - ): - if sd is not None: - sigma = sd + ) -> RandomVariable: + sigma = sd if sd is not None else sigma tau, sigma = get_tau_sigma(tau=tau, sigma=sigma) - self.sigma = self.sd = at.as_tensor_variable(sigma) - self.tau = at.as_tensor_variable(tau) - self.lower_check = at.as_tensor_variable(floatX(lower)) if lower is not None else lower - self.upper_check = at.as_tensor_variable(floatX(upper)) if upper is not None else upper - self.lower = ( - at.as_tensor_variable(floatX(lower)) - if lower is not None - else at.as_tensor_variable(-np.inf) - ) - self.upper = ( - at.as_tensor_variable(floatX(upper)) - if upper is not None - else at.as_tensor_variable(np.inf) - ) - self.mu = at.as_tensor_variable(floatX(mu)) - - if self.lower_check is None and self.upper_check is None: - self._defaultval = mu - elif self.lower_check is None and self.upper_check is not None: - self._defaultval = self.upper - 1.0 - elif self.lower_check is not None and self.upper_check is None: - self._defaultval = self.lower + 1.0 - else: - self._defaultval = (self.lower + self.upper) / 2 - + sigma = at.as_tensor_variable(sigma) + tau = at.as_tensor_variable(tau) + mu = at.as_tensor_variable(floatX(mu)) assert_negative_support(sigma, "sigma", "TruncatedNormal") assert_negative_support(tau, "tau", "TruncatedNormal") - super().__init__( - defaults=("_defaultval",), - transform=transform, - lower=lower, - upper=upper, - *args, - **kwargs, - ) - - def random(self, point=None, size=None): - """ - Draw random values from TruncatedNormal distribution. - - Parameters - ---------- - point: dict, optional - Dict of variable values on which random values are to be - conditioned (uses default point if not specified). - size: int, optional - Desired size of random sample (returns one sample if not - specified). - - Returns - ------- - array - """ - # mu, sigma, lower, upper = draw_values( - # [self.mu, self.sigma, self.lower, self.upper], point=point, size=size - # ) - # return generate_samples( - # self._random, - # mu=mu, - # sigma=sigma, - # lower=lower, - # upper=upper, - # dist_shape=self.shape, - # size=size, - # ) + # if lower is None and upper is None: + # initval = mu + # elif lower is None and upper is not None: + # initval = upper - 1.0 + # elif lower is not None and upper is None: + # initval = lower + 1.0 + # else: + # initval = (lower + upper) / 2 - def _random(self, mu, sigma, lower, upper, size): - """Wrapper around stats.truncnorm.rvs that converts TruncatedNormal's - parametrization to scipy.truncnorm. All parameter arrays should have - been broadcasted properly by generate_samples at this point and size is - the scipy.rvs representation. - """ - return stats.truncnorm.rvs( - a=(lower - mu) / sigma, b=(upper - mu) / sigma, loc=mu, scale=sigma, size=size - ) + lower = at.as_tensor_variable(floatX(lower)) if lower is not None else at.constant(-np.inf) + upper = at.as_tensor_variable(floatX(upper)) if upper is not None else at.constant(np.inf) + return super().dist([mu, sigma, lower, upper], **kwargs) - def logp(self, value): + def logp( + value, + mu: Union[float, np.ndarray, TensorVariable], + sigma: Union[float, np.ndarray, TensorVariable], + lower: Union[float, np.ndarray, TensorVariable], + upper: Union[float, np.ndarray, TensorVariable], + ) -> RandomVariable: """ Calculate log-probability of TruncatedNormal distribution at specified value. @@ -738,40 +728,31 @@ def logp(self, value): ------- TensorVariable """ - mu = self.mu - sigma = self.sigma - - norm = self._normalization() - logp = Normal.dist(mu=mu, sigma=sigma).logp(value) - norm - - bounds = [sigma > 0] - if self.lower_check is not None: - bounds.append(value >= self.lower) - if self.upper_check is not None: - bounds.append(value <= self.upper) - return bound(logp, *bounds) - - def _normalization(self): - mu, sigma = self.mu, self.sigma - - if self.lower_check is None and self.upper_check is None: - return 0.0 - - if self.lower_check is not None and self.upper_check is not None: - lcdf_a = normal_lcdf(mu, sigma, self.lower) - lcdf_b = normal_lcdf(mu, sigma, self.upper) - lsf_a = normal_lccdf(mu, sigma, self.lower) - lsf_b = normal_lccdf(mu, sigma, self.upper) - - return at.switch(self.lower > 0, logdiffexp(lsf_a, lsf_b), logdiffexp(lcdf_b, lcdf_a)) - - if self.lower_check is not None: - return normal_lccdf(mu, sigma, self.lower) + unbounded_lower = isinstance(lower, TensorConstant) and np.all(lower.value == -np.inf) + unbounded_upper = isinstance(upper, TensorConstant) and np.all(upper.value == np.inf) + + if not unbounded_lower and not unbounded_upper: + lcdf_a = normal_lcdf(mu, sigma, lower) + lcdf_b = normal_lcdf(mu, sigma, upper) + lsf_a = normal_lccdf(mu, sigma, lower) + lsf_b = normal_lccdf(mu, sigma, upper) + norm = at.switch(lower > 0, logdiffexp(lsf_a, lsf_b), logdiffexp(lcdf_b, lcdf_a)) + elif not unbounded_lower: + norm = normal_lccdf(mu, sigma, lower) + elif not unbounded_upper: + norm = normal_lcdf(mu, sigma, upper) else: - return normal_lcdf(mu, sigma, self.upper) - - def _distr_parameters_for_repr(self): - return ["mu", "sigma", "lower", "upper"] + norm = 0.0 + + logp = Normal.logp(value, mu=mu, sigma=sigma) - norm + bounds = [] + if not unbounded_lower: + bounds.append(value >= lower) + if not unbounded_upper: + bounds.append(value <= upper) + if not unbounded_lower and not unbounded_upper: + bounds.append(lower <= upper) + return bound(logp, *bounds) class HalfNormal(PositiveContinuous): @@ -900,6 +881,21 @@ def _distr_parameters_for_repr(self): return ["sigma"] +class WaldRV(RandomVariable): + name = "wald" + ndim_supp = 0 + ndims_params = [0, 0, 0] + dtype = "floatX" + _print_name = ("Wald", "\\operatorname{Wald}") + + @classmethod + def rng_fn(cls, rng, mu, lam, alpha, size): + return rng.wald(mu, lam, size=size) + alpha + + +wald = WaldRV() + + class Wald(PositiveContinuous): r""" Wald log-likelihood. @@ -978,27 +974,33 @@ class Wald(PositiveContinuous): .. [Giner2016] Göknur Giner, Gordon K. Smyth (2016) statmod: Probability Calculations for the Inverse Gaussian Distribution """ + rv_op = wald - def __init__(self, mu=None, lam=None, phi=None, alpha=0.0, *args, **kwargs): - super().__init__(*args, **kwargs) - mu, lam, phi = self.get_mu_lam_phi(mu, lam, phi) - self.alpha = alpha = at.as_tensor_variable(floatX(alpha)) - self.mu = mu = at.as_tensor_variable(floatX(mu)) - self.lam = lam = at.as_tensor_variable(floatX(lam)) - self.phi = phi = at.as_tensor_variable(floatX(phi)) - - self.mean = self.mu + self.alpha - self.mode = ( - self.mu * (at.sqrt(1.0 + (1.5 * self.mu / self.lam) ** 2) - 1.5 * self.mu / self.lam) - + self.alpha - ) - self.variance = (self.mu ** 3) / self.lam + @classmethod + def dist( + cls, + mu: Optional[Union[float, np.ndarray]] = None, + lam: Optional[Union[float, np.ndarray]] = None, + phi: Optional[Union[float, np.ndarray]] = None, + alpha: Union[float, np.ndarray] = 0.0, + *args, + **kwargs, + ) -> RandomVariable: + mu, lam, phi = cls.get_mu_lam_phi(mu, lam, phi) + alpha = at.as_tensor_variable(floatX(alpha)) + mu = at.as_tensor_variable(floatX(mu)) + lam = at.as_tensor_variable(floatX(lam)) assert_negative_support(phi, "phi", "Wald") assert_negative_support(mu, "mu", "Wald") assert_negative_support(lam, "lam", "Wald") - def get_mu_lam_phi(self, mu, lam, phi): + return super().dist([mu, lam, alpha], **kwargs) + + @staticmethod + def get_mu_lam_phi( + mu: Optional[float], lam: Optional[float], phi: Optional[float] + ) -> Tuple[Union[float, np.ndarray], Union[float, np.ndarray], Union[float, np.ndarray]]: if mu is None: if lam is not None and phi is not None: return lam / phi, lam, phi @@ -1017,39 +1019,12 @@ def get_mu_lam_phi(self, mu, lam, phi): "mu and lam, mu and phi, or lam and phi." ) - def _random(self, mu, lam, alpha, size=None): - v = np.random.normal(size=size) ** 2 - value = ( - mu - + (mu ** 2) * v / (2.0 * lam) - - mu / (2.0 * lam) * np.sqrt(4.0 * mu * lam * v + (mu * v) ** 2) - ) - z = np.random.uniform(size=size) - i = np.floor(z - mu / (mu + value)) * 2 + 1 - value = (value ** -i) * (mu ** (i + 1)) - return value + alpha - - def random(self, point=None, size=None): - """ - Draw random values from Wald distribution. - - Parameters - ---------- - point: dict, optional - Dict of variable values on which random values are to be - conditioned (uses default point if not specified). - size: int, optional - Desired size of random sample (returns one sample if not - specified). - - Returns - ------- - array - """ - # mu, lam, alpha = draw_values([self.mu, self.lam, self.alpha], point=point, size=size) - # return generate_samples(self._random, mu, lam, alpha, dist_shape=self.shape, size=size) - - def logp(self, value): + def logp( + value, + mu: Union[float, np.ndarray, TensorVariable], + lam: Union[float, np.ndarray, TensorVariable], + alpha: Union[float, np.ndarray, TensorVariable], + ) -> RandomVariable: """ Calculate log-probability of Wald distribution at specified value. @@ -1058,14 +1033,17 @@ def logp(self, value): value: numeric Value(s) for which log-probability is calculated. If the log probabilities for multiple values are desired the values must be provided in a numpy array or Aesara tensor + mu: float or TensorVariable + Mean of the distribution (mu > 0). + lam: float or TensorVariable + Relative precision (lam > 0). + alpha: float or TensorVariable + Shift/location parameter (alpha >= 0). Returns ------- TensorVariable """ - mu = self.mu - lam = self.lam - alpha = self.alpha centered_value = value - alpha # value *must* be iid. Otherwise this is wrong. return bound( @@ -1081,7 +1059,12 @@ def logp(self, value): def _distr_parameters_for_repr(self): return ["mu", "lam", "alpha"] - def logcdf(self, value): + def logcdf( + value, + mu: Union[float, np.ndarray, TensorVariable], + lam: Union[float, np.ndarray, TensorVariable], + alpha: Union[float, np.ndarray, TensorVariable], + ) -> RandomVariable: """ Compute the log of the cumulative distribution function for Wald distribution at the specified value. @@ -1091,16 +1074,17 @@ def logcdf(self, value): value: numeric or np.ndarray or aesara.tensor Value(s) for which log CDF is calculated. If the log CDF for multiple values are desired the values must be provided in a numpy array or Aesara tensor. + mu: float or TensorVariable + Mean of the distribution (mu > 0). + lam: float or TensorVariable + Relative precision (lam > 0). + alpha: float or TensorVariable + Shift/location parameter (alpha >= 0). Returns ------- TensorVariable """ - # Distribution parameters - mu = self.mu - lam = self.lam - alpha = self.alpha - value -= alpha q = value / mu l = lam * mu diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index a91d96e265..32a7ca319b 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import functools import itertools import sys @@ -1019,7 +1019,6 @@ def test_normal(self): decimal=select_by_precision(float64=6, float32=1), ) - @pytest.mark.xfail(reason="Distribution not refactored yet") def test_truncated_normal(self): def scipy_logp(value, mu, sigma, lower, upper): return sp.truncnorm.logpdf( @@ -1034,6 +1033,22 @@ def scipy_logp(value, mu, sigma, lower, upper): decimal=select_by_precision(float64=6, float32=1), ) + self.check_logp( + TruncatedNormal, + R, + {"mu": R, "sigma": Rplusbig, "upper": Rplusbig}, + functools.partial(scipy_logp, lower=-np.inf), + decimal=select_by_precision(float64=6, float32=1), + ) + + self.check_logp( + TruncatedNormal, + R, + {"mu": R, "sigma": Rplusbig, "lower": -Rplusbig}, + functools.partial(scipy_logp, upper=np.inf), + decimal=select_by_precision(float64=6, float32=1), + ) + def test_half_normal(self): self.check_logp( HalfNormal, @@ -1069,7 +1084,6 @@ def test_chisquared_logcdf(self): lambda value, nu: sp.chi2.logcdf(value, df=nu), ) - @pytest.mark.xfail(reason="Distribution not refactored yet") def test_wald_logp(self): self.check_logp( Wald, @@ -1079,7 +1093,6 @@ def test_wald_logp(self): decimal=select_by_precision(float64=6, float32=1), ) - @pytest.mark.xfail(reason="Distribution not refactored yet") @pytest.mark.xfail( condition=(aesara.config.floatX == "float32"), reason="Poor CDF in SciPy. See scipy/scipy#869 for details.", @@ -1092,7 +1105,6 @@ def test_wald_logcdf(self): lambda value, mu, alpha: sp.invgauss.logcdf(value, mu=mu, loc=alpha), ) - @pytest.mark.xfail(reason="Distribution not refactored yet") @pytest.mark.parametrize( "value,mu,lam,phi,alpha,logp", [ @@ -1112,7 +1124,6 @@ def test_wald_logcdf(self): (50.0, 15.0, None, 0.666666, 10.0, -5.6481874), ], ) - @pytest.mark.xfail(reason="Distribution not refactored yet") def test_wald_logp_custom_points(self, value, mu, lam, phi, alpha, logp): # Log probabilities calculated using the dIG function from the R package gamlss. # See e.g., doi: 10.1111/j.1467-9876.2005.00510.x, or @@ -2718,6 +2729,89 @@ def test_bound(): BoundPoissonPositionalArgs = Bound(Poisson, upper=6)("x", 2.0) +class TestBoundedContinuous: + def get_dist_params_and_interval_bounds(self, model, rv_name): + interval_rv = model.named_vars[f"{rv_name}_interval__"] + rv = model.named_vars[rv_name] + dist_params = rv.owner.inputs[3:] + lower_interval, upper_interval = interval_rv.tag.transform.param_extract_fn(rv) + return ( + dist_params, + lower_interval, + upper_interval, + ) + + def test_upper_bounded(self): + bounded_rv_name = "lower_bounded" + with Model() as model: + TruncatedNormal(bounded_rv_name, mu=1, sigma=2, lower=None, upper=3) + ( + (_, _, lower, upper), + lower_interval, + upper_interval, + ) = self.get_dist_params_and_interval_bounds(model, bounded_rv_name) + assert lower.value == -np.inf + assert upper.value == 3 + assert lower_interval is None + assert upper_interval.value == 3 + + def test_lower_bounded(self): + bounded_rv_name = "upper_bounded" + with Model() as model: + TruncatedNormal(bounded_rv_name, mu=1, sigma=2, lower=-2, upper=None) + ( + (_, _, lower, upper), + lower_interval, + upper_interval, + ) = self.get_dist_params_and_interval_bounds(model, bounded_rv_name) + assert lower.value == -2 + assert upper.value == np.inf + assert lower_interval.value == -2 + assert upper_interval is None + + def test_lower_bounded_vector(self): + bounded_rv_name = "upper_bounded" + with Model() as model: + TruncatedNormal( + bounded_rv_name, + mu=np.array([1, 1]), + sigma=np.array([2, 3]), + lower=np.array([-1.0, 0]), + upper=None, + ) + ( + (_, _, lower, upper), + lower_interval, + upper_interval, + ) = self.get_dist_params_and_interval_bounds(model, bounded_rv_name) + + assert np.array_equal(lower.value, [-1, 0]) + assert upper.value == np.inf + assert np.array_equal(lower_interval.value, [-1, 0]) + assert upper_interval is None + + def test_lower_bounded_broadcasted(self): + bounded_rv_name = "upper_bounded" + with Model() as model: + TruncatedNormal( + bounded_rv_name, + mu=np.array([1, 1]), + sigma=np.array([2, 3]), + lower=-1, + upper=np.array([np.inf, np.inf]), + ) + ( + (_, _, lower, upper), + lower_interval, + upper_interval, + ) = self.get_dist_params_and_interval_bounds(model, bounded_rv_name) + + assert lower.value == -1 + assert np.array_equal(upper.value, [np.inf, np.inf]) + assert lower_interval.value == -1 + assert upper_interval is None + + @pytest.mark.xfail(reason="LaTeX repr and str no longer applicable") class TestStrAndLatexRepr: def setup_class(self): diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index da1e894add..0122f97945 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -44,7 +44,6 @@ R, RandomPdMatrix, Rplus, - Rplusbig, Simplex, Vector, build_model, @@ -243,30 +242,6 @@ class TestGaussianRandomWalk(BaseTestCases.BaseTestCase): default_shape = (1,) -@pytest.mark.xfail(reason="This distribution has not been refactored for v4") -class TestTruncatedNormal(BaseTestCases.BaseTestCase): - distribution = pm.TruncatedNormal - params = {"mu": 0.0, "tau": 1.0, "lower": -0.5, "upper": 0.5} - - -@pytest.mark.xfail(reason="This distribution has not been refactored for v4") -class TestTruncatedNormalLower(BaseTestCases.BaseTestCase): - distribution = pm.TruncatedNormal - params = {"mu": 0.0, "tau": 1.0, "lower": -0.5} - - -@pytest.mark.xfail(reason="This distribution has not been refactored for v4") -class TestTruncatedNormalUpper(BaseTestCases.BaseTestCase): - distribution = pm.TruncatedNormal - params = {"mu": 0.0, "tau": 1.0, "upper": 0.5} - - -@pytest.mark.xfail(reason="This distribution has not been refactored for v4") -class TestWald(BaseTestCases.BaseTestCase): - distribution = pm.Wald - params = {"mu": 1.0, "lam": 1.0, "alpha": 0.0} - - @pytest.mark.xfail(reason="This distribution has not been refactored for v4") class TestZeroInflatedNegativeBinomial(BaseTestCases.BaseTestCase): distribution = pm.ZeroInflatedNegativeBinomial @@ -626,6 +601,121 @@ def seeded_kumaraswamy_rng_fn(self): ] +class TestTruncatedNormal(BaseTestDistribution): + pymc_dist = pm.TruncatedNormal + lower, upper, mu, sigma = -2.0, 2.0, 0, 1.0 + pymc_dist_params = {"mu": mu, "sigma": sigma, "lower": lower, "upper": upper} + expected_rv_op_params = {"mu": mu, "sigma": sigma, "lower": lower, "upper": upper} + reference_dist_params = { + "loc": mu, + "scale": sigma, + "a": (lower - mu) / sigma, + "b": (upper - mu) / sigma, + } + reference_dist = seeded_scipy_distribution_builder("truncnorm") + tests_to_run = [ + "check_pymc_params_match_rv_op", + "check_pymc_draws_match_reference", + "check_rv_size", + ] + + +class TestTruncatedNormalTau(BaseTestDistribution): + pymc_dist = pm.TruncatedNormal + lower, upper, mu, tau = -2.0, 2.0, 0, 1.0 + tau, sigma = get_tau_sigma(tau=tau, sigma=None) + pymc_dist_params = {"mu": mu, "tau": tau, "lower": lower, "upper": upper} + expected_rv_op_params = {"mu": mu, "sigma": sigma, "lower": lower, "upper": upper} + tests_to_run = [ + "check_pymc_params_match_rv_op", + ] + + +class TestTruncatedNormalLowerTau(BaseTestDistribution): + pymc_dist = pm.TruncatedNormal + lower, upper, mu, tau = -2.0, np.inf, 0, 1.0 + tau, sigma = get_tau_sigma(tau=tau, sigma=None) + pymc_dist_params = {"mu": mu, "tau": tau, "lower": lower} + expected_rv_op_params = {"mu": mu, "sigma": sigma, "lower": lower, "upper": upper} + tests_to_run = [ + "check_pymc_params_match_rv_op", + ] + + +class TestTruncatedNormalUpperTau(BaseTestDistribution): + pymc_dist = pm.TruncatedNormal + lower, upper, mu, tau = -np.inf, 2.0, 0, 1.0 + tau, sigma = get_tau_sigma(tau=tau, sigma=None) + pymc_dist_params = {"mu": mu, "tau": tau, "upper": upper} + expected_rv_op_params = {"mu": mu, "sigma": sigma, "lower": lower, "upper": upper} + tests_to_run = [ + "check_pymc_params_match_rv_op", + ] + + +class TestTruncatedNormalUpperArray(BaseTestDistribution): + pymc_dist = pm.TruncatedNormal + lower, upper, mu, tau = ( + np.array([-np.inf, -np.inf]), + np.array([3, 2]), + np.array([0, 0]), + np.array( + [ + 1, + 1, + ] + ), + ) + size = (15, 2) + tau, sigma = get_tau_sigma(tau=tau, sigma=None) + pymc_dist_params = {"mu": mu, "tau": tau, "upper": upper} + expected_rv_op_params = {"mu": mu, "sigma": sigma, "lower": lower, "upper": upper} + tests_to_run = [ + "check_pymc_params_match_rv_op", + ] + + +class TestWald(BaseTestDistribution): + pymc_dist = pm.Wald + mu, lam, alpha = 1.0, 1.0, 0.0 + mu_rv, lam_rv, phi_rv = pm.Wald.get_mu_lam_phi(mu=mu, lam=lam, phi=None) + pymc_dist_params = {"mu": mu, "lam": lam, "alpha": alpha} + expected_rv_op_params = {"mu": mu_rv, "lam": lam_rv, "alpha": alpha} + reference_dist_params = [mu, lam_rv] + reference_dist = seeded_numpy_distribution_builder("wald") + tests_to_run = [ + "check_pymc_params_match_rv_op", + "check_pymc_draws_match_reference", + "check_rv_size", + ] + + def test_distribution(self): + self.validate_tests_list() + self._instantiate_pymc_rv() + if self.reference_dist is not None: + self.reference_dist_draws = self.reference_dist()( + *self.reference_dist_params, self.size + ) + for check_name in self.tests_to_run: + getattr(self, check_name)() + + def check_pymc_draws_match_reference(self): + assert_array_almost_equal( + self.pymc_rv.eval(), self.reference_dist_draws + self.alpha, decimal=self.decimal + ) + + +class TestWaldMuPhi(BaseTestDistribution): + pymc_dist = pm.Wald + mu, phi, alpha = 1.0, 3.0, 0.0 + mu_rv, lam_rv, phi_rv = pm.Wald.get_mu_lam_phi(mu=mu, lam=None, phi=phi) + pymc_dist_params = {"mu": mu, "phi": phi, "alpha": alpha} + expected_rv_op_params = {"mu": mu_rv, "lam": lam_rv, "alpha": alpha} + tests_to_run = [ + "check_pymc_params_match_rv_op", + ] + + class TestSkewNormal(BaseTestDistribution): pymc_dist = pm.SkewNormal pymc_dist_params = {"mu": 0.0, "sigma": 1.0, "alpha": 5.0} @@ -1519,56 +1609,12 @@ def ref_rand(size, tau): pymc3_random(BoundedNormal, {"tau": Rplus}, ref_rand=ref_rand) - @pytest.mark.xfail(reason="This distribution has not been refactored for v4") - def test_truncated_normal(self): - def ref_rand(size, mu, sigma, lower, upper): - return st.truncnorm.rvs( - (lower - mu) / sigma, (upper - mu) / sigma, size=size, loc=mu, scale=sigma - ) - - pymc3_random( - pm.TruncatedNormal, - {"mu": R, "sigma": Rplusbig, "lower": -Rplusbig, "upper": Rplusbig}, - ref_rand=ref_rand, - ) - - @pytest.mark.xfail(reason="This distribution has not been refactored for v4") - def test_truncated_normal_lower(self): - def ref_rand(size, mu, sigma, lower): - return st.truncnorm.rvs((lower - mu) / sigma, np.inf, size=size, loc=mu, scale=sigma) - - pymc3_random( - pm.TruncatedNormal, {"mu": R, "sigma": Rplusbig, "lower": -Rplusbig}, ref_rand=ref_rand - ) - - @pytest.mark.xfail(reason="This distribution has not been refactored for v4") - def test_truncated_normal_upper(self): - def ref_rand(size, mu, sigma, upper): - return st.truncnorm.rvs(-np.inf, (upper - mu) / sigma, size=size, loc=mu, scale=sigma) - - pymc3_random( - pm.TruncatedNormal, {"mu": R, "sigma": Rplusbig, "upper": Rplusbig}, ref_rand=ref_rand - ) - def test_skew_normal(self): def ref_rand(size, alpha, mu, sigma): return st.skewnorm.rvs(size=size, a=alpha, loc=mu, scale=sigma) pymc3_random(pm.SkewNormal, {"mu": R, "sigma": Rplus, "alpha": R}, ref_rand=ref_rand) - @pytest.mark.xfail(reason="This distribution has not been refactored for v4") - def test_wald(self): - # Cannot do anything too exciting as scipy wald is a - # location-scale model of the *standard* wald with mu=1 and lam=1 - def ref_rand(size, mu, lam, alpha): - return st.wald.rvs(size=size, loc=alpha) - - pymc3_random( - pm.Wald, - {"mu": Domain([1.0, 1.0, 1.0]), "lam": Domain([1.0, 1.0, 1.0]), "alpha": Rplus}, - ref_rand=ref_rand, - ) - @pytest.mark.xfail(reason="This distribution has not been refactored for v4") def test_dirichlet_multinomial(self): def ref_rand(size, a, n): @@ -2099,7 +2145,6 @@ def test_Rice( ], ids=str, ) - @pytest.mark.xfail(reason="TruncatedNormal not yet refactored for v4") def test_TruncatedNormal( self, prior_samples, diff --git a/pymc3/tests/test_model.py b/pymc3/tests/test_model.py index f178e5400a..7968ff80fa 100644 --- a/pymc3/tests/test_model.py +++ b/pymc3/tests/test_model.py @@ -331,7 +331,6 @@ def test_aesara_switch_broadcast_edge_cases_1(self): np.log(0.5) * 10, ) - @pytest.mark.xfail(reason="TruncatedNormal not refactored for v4") def test_aesara_switch_broadcast_edge_cases_2(self): # Known issue 2: https://github.com/pymc-devs/pymc3/issues/4417 # fmt: off @@ -344,7 +343,7 @@ def test_aesara_switch_broadcast_edge_cases_2(self): mu = pm.Normal("mu", 0, 5) obs = pm.TruncatedNormal("obs", mu=mu, sigma=1, lower=-1, upper=2, observed=data) - npt.assert_allclose(m.dlogp([mu])({"mu": 0}), 2.499424682024436, rtol=1e-5) + npt.assert_allclose(m.dlogp([m.rvs_to_values[mu]])({"mu": 0}), 2.499424682024436, rtol=1e-5) @pytest.mark.xfail(reason="DensityDist not refactored for v4")