diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index fc85b4a86db..b260a82e17d 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -14,6 +14,7 @@ - `ScalarSharedVariable` can now be used as an input to other RVs directly (see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)). - `pm.sample` and `pm.find_MAP` no longer change the `start` argument (see [#4458](https://github.com/pymc-devs/pymc3/pull/4458)). - Fixed `Dirichlet.logp` method to work with unit batch or event shapes (see [#4454](https://github.com/pymc-devs/pymc3/pull/4454)). +- Bugfix in logp and logcdf methods of `Triangular` distribution (see[#4470](https://github.com/pymc-devs/pymc3/pull/4470)). ## PyMC3 3.11.0 (21 January 2021) diff --git a/pymc3/distributions/continuous.py b/pymc3/distributions/continuous.py index 5c990bde3b5..234ed935f2b 100644 --- a/pymc3/distributions/continuous.py +++ b/pymc3/distributions/continuous.py @@ -29,7 +29,6 @@ from pymc3.distributions import transforms from pymc3.distributions.dist_math import ( SplineWrapper, - alltrue_elemwise, betaln, bound, clipped_beta_rvs, @@ -3649,18 +3648,14 @@ def logp(self, value): c = self.c lower = self.lower upper = self.upper - return tt.switch( - alltrue_elemwise([lower <= value, value < c]), - tt.log(2 * (value - lower) / ((upper - lower) * (c - lower))), + return bound( tt.switch( - tt.eq(value, c), - tt.log(2 / (upper - lower)), - tt.switch( - alltrue_elemwise([c < value, value <= upper]), - tt.log(2 * (upper - value) / ((upper - lower) * (upper - c))), - np.inf, - ), + tt.lt(value, c), + tt.log(2 * (value - lower) / ((upper - lower) * (c - lower))), + tt.log(2 * (upper - value) / ((upper - lower) * (upper - c))), ), + lower <= value, + value <= upper, ) def logcdf(self, value): @@ -3678,17 +3673,24 @@ def logcdf(self, value): ------- TensorVariable """ - l = self.lower - u = self.upper c = self.c - return tt.switch( - tt.le(value, l), - -np.inf, + lower = self.lower + upper = self.upper + return bound( tt.switch( - tt.le(value, c), - tt.log(((value - l) ** 2) / ((u - l) * (c - l))), - tt.switch(tt.lt(value, u), tt.log1p(-((u - value) ** 2) / ((u - l) * (u - c))), 0), + tt.le(value, lower), + -np.inf, + tt.switch( + tt.le(value, c), + tt.log(((value - lower) ** 2) / ((upper - lower) * (c - lower))), + tt.switch( + tt.lt(value, upper), + tt.log1p(-((upper - value) ** 2) / ((upper - lower) * (upper - c))), + 0, + ), + ), ), + lower <= upper, ) diff --git a/pymc3/tests/test_distributions.py b/pymc3/tests/test_distributions.py index 2f9c9672aef..06efc90b8d8 100644 --- a/pymc3/tests/test_distributions.py +++ b/pymc3/tests/test_distributions.py @@ -802,6 +802,10 @@ def test_uniform(self): lambda value, lower, upper: sp.uniform.logcdf(value, lower, upper - lower), skip_paramdomain_outside_edge_test=True, ) + # Custom logp / logcdf check for invalid parameters + invalid_dist = Uniform.dist(lower=1, upper=0) + assert invalid_dist.logp(0.5).tag.test_value == -np.inf + assert invalid_dist.logcdf(2).tag.test_value == -np.inf def test_triangular(self): self.check_logp( @@ -817,6 +821,14 @@ def test_triangular(self): lambda value, c, lower, upper: sp.triang.logcdf(value, c - lower, lower, upper - lower), skip_paramdomain_outside_edge_test=True, ) + # Custom logp check for invalid value + valid_dist = Triangular.dist(lower=0, upper=1, c=2.0) + assert np.all(valid_dist.logp(np.array([1.9, 2.0, 2.1])).tag.test_value == -np.inf) + + # Custom logp / logcdf check for invalid parameters + invalid_dist = Triangular.dist(lower=1, upper=0, c=2.0) + assert invalid_dist.logp(0.5).tag.test_value == -np.inf + assert invalid_dist.logcdf(2).tag.test_value == -np.inf def test_bound_normal(self): PositiveNormal = Bound(Normal, lower=0.0) @@ -850,6 +862,10 @@ def test_discrete_unif(self): Rdunif, {"lower": -Rplusdunif, "upper": Rplusdunif}, ) + # Custom logp / logcdf check for invalid parameters + invalid_dist = DiscreteUniform.dist(lower=1, upper=0) + assert invalid_dist.logp(0.5).tag.test_value == -np.inf + assert invalid_dist.logcdf(2).tag.test_value == -np.inf def test_flat(self): self.check_logp(Flat, Runif, {}, lambda value: 0)