Skip to content

Commit

Permalink
Fix Triangular bounds (#4470)
Browse files Browse the repository at this point in the history
* Small fix Triangular logp and logcdf methods
* Add tests for invalid parameters Uniform, Triangular, DiscreteUniform
  • Loading branch information
ricardoV94 authored Feb 12, 2021
1 parent e46f490 commit 1d37d31
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 19 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
- `ScalarSharedVariable` can now be used as an input to other RVs directly (see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)).
- `pm.sample` and `pm.find_MAP` no longer change the `start` argument (see [#4458](https://github.com/pymc-devs/pymc3/pull/4458)).
- Fixed `Dirichlet.logp` method to work with unit batch or event shapes (see [#4454](https://github.com/pymc-devs/pymc3/pull/4454)).
- Bugfix in logp and logcdf methods of `Triangular` distribution (see[#4470](https://github.com/pymc-devs/pymc3/pull/4470)).

## PyMC3 3.11.0 (21 January 2021)

Expand Down
40 changes: 21 additions & 19 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from pymc3.distributions import transforms
from pymc3.distributions.dist_math import (
SplineWrapper,
alltrue_elemwise,
betaln,
bound,
clipped_beta_rvs,
Expand Down Expand Up @@ -3649,18 +3648,14 @@ def logp(self, value):
c = self.c
lower = self.lower
upper = self.upper
return tt.switch(
alltrue_elemwise([lower <= value, value < c]),
tt.log(2 * (value - lower) / ((upper - lower) * (c - lower))),
return bound(
tt.switch(
tt.eq(value, c),
tt.log(2 / (upper - lower)),
tt.switch(
alltrue_elemwise([c < value, value <= upper]),
tt.log(2 * (upper - value) / ((upper - lower) * (upper - c))),
np.inf,
),
tt.lt(value, c),
tt.log(2 * (value - lower) / ((upper - lower) * (c - lower))),
tt.log(2 * (upper - value) / ((upper - lower) * (upper - c))),
),
lower <= value,
value <= upper,
)

def logcdf(self, value):
Expand All @@ -3678,17 +3673,24 @@ def logcdf(self, value):
-------
TensorVariable
"""
l = self.lower
u = self.upper
c = self.c
return tt.switch(
tt.le(value, l),
-np.inf,
lower = self.lower
upper = self.upper
return bound(
tt.switch(
tt.le(value, c),
tt.log(((value - l) ** 2) / ((u - l) * (c - l))),
tt.switch(tt.lt(value, u), tt.log1p(-((u - value) ** 2) / ((u - l) * (u - c))), 0),
tt.le(value, lower),
-np.inf,
tt.switch(
tt.le(value, c),
tt.log(((value - lower) ** 2) / ((upper - lower) * (c - lower))),
tt.switch(
tt.lt(value, upper),
tt.log1p(-((upper - value) ** 2) / ((upper - lower) * (upper - c))),
0,
),
),
),
lower <= upper,
)


Expand Down
16 changes: 16 additions & 0 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,10 @@ def test_uniform(self):
lambda value, lower, upper: sp.uniform.logcdf(value, lower, upper - lower),
skip_paramdomain_outside_edge_test=True,
)
# Custom logp / logcdf check for invalid parameters
invalid_dist = Uniform.dist(lower=1, upper=0)
assert invalid_dist.logp(0.5).tag.test_value == -np.inf
assert invalid_dist.logcdf(2).tag.test_value == -np.inf

def test_triangular(self):
self.check_logp(
Expand All @@ -817,6 +821,14 @@ def test_triangular(self):
lambda value, c, lower, upper: sp.triang.logcdf(value, c - lower, lower, upper - lower),
skip_paramdomain_outside_edge_test=True,
)
# Custom logp check for invalid value
valid_dist = Triangular.dist(lower=0, upper=1, c=2.0)
assert np.all(valid_dist.logp(np.array([1.9, 2.0, 2.1])).tag.test_value == -np.inf)

# Custom logp / logcdf check for invalid parameters
invalid_dist = Triangular.dist(lower=1, upper=0, c=2.0)
assert invalid_dist.logp(0.5).tag.test_value == -np.inf
assert invalid_dist.logcdf(2).tag.test_value == -np.inf

def test_bound_normal(self):
PositiveNormal = Bound(Normal, lower=0.0)
Expand Down Expand Up @@ -850,6 +862,10 @@ def test_discrete_unif(self):
Rdunif,
{"lower": -Rplusdunif, "upper": Rplusdunif},
)
# Custom logp / logcdf check for invalid parameters
invalid_dist = DiscreteUniform.dist(lower=1, upper=0)
assert invalid_dist.logp(0.5).tag.test_value == -np.inf
assert invalid_dist.logcdf(2).tag.test_value == -np.inf

def test_flat(self):
self.check_logp(Flat, Runif, {}, lambda value: 0)
Expand Down

0 comments on commit 1d37d31

Please sign in to comment.