Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Triangular bounds #4470

Merged
merged 3 commits into from
Feb 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
- `ScalarSharedVariable` can now be used as an input to other RVs directly (see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)).
- `pm.sample` and `pm.find_MAP` no longer change the `start` argument (see [#4458](https://github.com/pymc-devs/pymc3/pull/4458)).
- Fixed `Dirichlet.logp` method to work with unit batch or event shapes (see [#4454](https://github.com/pymc-devs/pymc3/pull/4454)).
- Bugfix in logp and logcdf methods of `Triangular` distribution (see[#4470](https://github.com/pymc-devs/pymc3/pull/4470)).

## PyMC3 3.11.0 (21 January 2021)

Expand Down
40 changes: 21 additions & 19 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from pymc3.distributions import transforms
from pymc3.distributions.dist_math import (
SplineWrapper,
alltrue_elemwise,
betaln,
bound,
clipped_beta_rvs,
Expand Down Expand Up @@ -3649,18 +3648,14 @@ def logp(self, value):
c = self.c
lower = self.lower
upper = self.upper
return tt.switch(
alltrue_elemwise([lower <= value, value < c]),
tt.log(2 * (value - lower) / ((upper - lower) * (c - lower))),
return bound(
tt.switch(
tt.eq(value, c),
tt.log(2 / (upper - lower)),
tt.switch(
alltrue_elemwise([c < value, value <= upper]),
tt.log(2 * (upper - value) / ((upper - lower) * (upper - c))),
np.inf,
),
tt.lt(value, c),
tt.log(2 * (value - lower) / ((upper - lower) * (c - lower))),
tt.log(2 * (upper - value) / ((upper - lower) * (upper - c))),
),
lower <= value,
value <= upper,
)

def logcdf(self, value):
Expand All @@ -3678,17 +3673,24 @@ def logcdf(self, value):
-------
TensorVariable
"""
l = self.lower
u = self.upper
c = self.c
return tt.switch(
tt.le(value, l),
-np.inf,
lower = self.lower
upper = self.upper
return bound(
tt.switch(
tt.le(value, c),
tt.log(((value - l) ** 2) / ((u - l) * (c - l))),
tt.switch(tt.lt(value, u), tt.log1p(-((u - value) ** 2) / ((u - l) * (u - c))), 0),
tt.le(value, lower),
-np.inf,
tt.switch(
tt.le(value, c),
tt.log(((value - lower) ** 2) / ((upper - lower) * (c - lower))),
tt.switch(
tt.lt(value, upper),
tt.log1p(-((upper - value) ** 2) / ((upper - lower) * (upper - c))),
0,
),
),
),
lower <= upper,
)


Expand Down
16 changes: 16 additions & 0 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,10 @@ def test_uniform(self):
lambda value, lower, upper: sp.uniform.logcdf(value, lower, upper - lower),
skip_paramdomain_outside_edge_test=True,
)
# Custom logp / logcdf check for invalid parameters
invalid_dist = Uniform.dist(lower=1, upper=0)
assert invalid_dist.logp(0.5).tag.test_value == -np.inf
assert invalid_dist.logcdf(2).tag.test_value == -np.inf

def test_triangular(self):
self.check_logp(
Expand All @@ -817,6 +821,14 @@ def test_triangular(self):
lambda value, c, lower, upper: sp.triang.logcdf(value, c - lower, lower, upper - lower),
skip_paramdomain_outside_edge_test=True,
)
# Custom logp check for invalid value
valid_dist = Triangular.dist(lower=0, upper=1, c=2.0)
assert np.all(valid_dist.logp(np.array([1.9, 2.0, 2.1])).tag.test_value == -np.inf)

# Custom logp / logcdf check for invalid parameters
invalid_dist = Triangular.dist(lower=1, upper=0, c=2.0)
assert invalid_dist.logp(0.5).tag.test_value == -np.inf
assert invalid_dist.logcdf(2).tag.test_value == -np.inf

def test_bound_normal(self):
PositiveNormal = Bound(Normal, lower=0.0)
Expand Down Expand Up @@ -850,6 +862,10 @@ def test_discrete_unif(self):
Rdunif,
{"lower": -Rplusdunif, "upper": Rplusdunif},
)
# Custom logp / logcdf check for invalid parameters
invalid_dist = DiscreteUniform.dist(lower=1, upper=0)
assert invalid_dist.logp(0.5).tag.test_value == -np.inf
assert invalid_dist.logcdf(2).tag.test_value == -np.inf

def test_flat(self):
self.check_logp(Flat, Runif, {}, lambda value: 0)
Expand Down