Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix dependence of Uniform logp on bound method #4541

Merged
merged 6 commits into from
Mar 16, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ def logp(self, value):
"""
lower = self.lower
upper = self.upper
return bound(-aet.log(upper - lower), value >= lower, value <= upper)
return bound(
aet.fill(value, -aet.log(upper - lower)),
value >= lower,
value <= upper,
)

def logcdf(self, value):
"""
Expand Down
6 changes: 5 additions & 1 deletion pymc3/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,7 +1278,11 @@ def logp(self, value):
"""
upper = self.upper
lower = self.lower
return bound(-aet.log(upper - lower + 1), lower <= value, value <= upper)
return bound(
aet.fill(value, -aet.log(upper - lower + 1)),
lower <= value,
value <= upper,
)

def logcdf(self, value):
"""
Expand Down
16 changes: 11 additions & 5 deletions pymc3/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,23 @@
def bound(logp, *conditions, **kwargs):
"""
Bounds a log probability density with several conditions.
When conditions are not met, the logp values are replaced by -inf.

Note that bound should not be used to enforce the logic of the logp under the normal
support as it can be disabled by the user via check_bounds = False in pm.Model()

Parameters
----------
logp: float
*conditions: booleans
broadcast_conditions: bool (optional, default=True)
If True, broadcasts logp to match the largest shape of the conditions.
This is used e.g. in DiscreteUniform where logp is a scalar constant and the shape
is specified via the conditions.
If False, will return the same shape as logp.
This is used e.g. in Multinomial where broadcasting can lead to differences in the logp.
If True, conditions are broadcasted and applied element-wise to each value in logp.
If False, conditions are collapsed via aet.all(). As a consequence the entire logp
array is either replaced by -inf or unchanged.

Setting broadcasts_conditions to False is necessary for most (all?) multivariate
distributions where the dimensions of the conditions do not unambigously match
that of the logp.

Returns
-------
Expand Down
33 changes: 17 additions & 16 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,9 @@ def logp(self, value):
# only defined for sum(value) == 1
return bound(
aet.sum(logpow(value, a - 1) - gammaln(a), axis=-1) + gammaln(aet.sum(a, axis=-1)),
aet.all(value >= 0),
aet.all(value <= 1),
aet.all(a > 0),
value >= 0,
value <= 1,
a > 0,
broadcast_conditions=False,
)

Expand Down Expand Up @@ -671,11 +671,11 @@ def logp(self, x):

return bound(
factln(n) + aet.sum(-factln(x) + logpow(p, x), axis=-1, keepdims=True),
aet.all(x >= 0),
aet.all(aet.eq(aet.sum(x, axis=-1, keepdims=True), n)),
aet.all(p <= 1),
aet.all(aet.eq(aet.sum(p, axis=-1), 1)),
aet.all(aet.ge(n, 0)),
x >= 0,
aet.eq(aet.sum(x, axis=-1, keepdims=True), n),
p <= 1,
aet.eq(aet.sum(p, axis=-1), 1),
n >= 0,
broadcast_conditions=False,
)

Expand Down Expand Up @@ -823,10 +823,10 @@ def logp(self, value):
# and that each observation value_i sums to n_i.
return bound(
result,
aet.all(aet.ge(value, 0)),
aet.all(aet.gt(a, 0)),
aet.all(aet.ge(n, 0)),
aet.all(aet.eq(value.sum(axis=-1, keepdims=True), n)),
value >= 0,
a > 0,
n >= 0,
aet.eq(value.sum(axis=-1, keepdims=True), n),
broadcast_conditions=False,
)

Expand Down Expand Up @@ -1575,8 +1575,8 @@ def logp(self, x):
result += (eta - 1.0) * aet.log(det(X))
return bound(
result,
aet.all(X <= 1),
aet.all(X >= -1),
X >= -1,
X <= 1,
matrix_pos_def(X),
eta > 0,
broadcast_conditions=False,
Expand Down Expand Up @@ -2204,9 +2204,10 @@ def logp(self, value):
logquad = (self.tau * delta * tau_dot_delta).sum(axis=-1)
return bound(
0.5 * (logtau + logdet - logquad),
aet.all(self.alpha <= 1),
aet.all(self.alpha >= -1),
self.alpha >= -1,
self.alpha <= 1,
self.tau > 0,
broadcast_conditions=False,
)

def random(self, point=None, size=None):
Expand Down
3 changes: 2 additions & 1 deletion pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,8 @@ class Model(Factor, WithMemoization, metaclass=ContextMeta):
Ensure that input parameters to distributions are in a valid
range. If your model is built in a way where you know your
parameters can only take on valid values you can set this to
False for increased speed.
False for increased speed. This should not be used if your model
contains discrete variables.

Examples
--------
Expand Down
11 changes: 11 additions & 0 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2667,6 +2667,17 @@ def test_issue_3051(self, dims, dist_cls, kwargs):
assert actual_a.shape == (X.shape[0],)
pass

def test_issue_4499(self):
# Test for bug in Uniform and DiscreteUniform logp when setting check_bounds = False
# https://github.com/pymc-devs/pymc3/issues/4499
with pm.Model(check_bounds=False) as m:
x = pm.Uniform("x", 0, 2, shape=10, transform=None)
assert_almost_equal(m.logp_array(np.ones(10)), -np.log(2) * 10)

with pm.Model(check_bounds=False) as m:
x = pm.DiscreteUniform("x", 0, 1, shape=10)
assert_almost_equal(m.logp_array(np.ones(10)), -np.log(2) * 10)


def test_serialize_density_dist():
def func(x):
Expand Down