Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in which TruncatedNormal returns -inf for all values if any value is out of bounds #6128

Merged
merged 8 commits into from
Sep 17, 2022

Conversation

adrn
Copy link
Contributor

@adrn adrn commented Sep 14, 2022

What is this PR about?

With pymc v4.1.7, I found that evaluating TruncatedNormal's logp with an array of values was returning all -inf values -- for example:

import numpy as np
import pymc as pm

p = {'mu': 1, 'sigma': 0.1}
with pm.Model() as model:    
    dist1 = pm.TruncatedNormal.dist(**p, lower=0.)
    dist2 = pm.Normal.dist(**p)
    
    x_grid = np.linspace(-5, 10, 1024)
    dist1_logp = pm.Deterministic('dist1', pm.logp(dist1, x_grid))
    dist2_logp = pm.Deterministic('dist2', pm.logp(dist2, x_grid))

func1 = model.compile_fn(dist1_logp, inputs=[])
func2 = model.compile_fn(dist2_logp, inputs=[])

print(func1({}))
print(func2({}))

Output:

[-inf -inf -inf ... -inf -inf -inf]
[-1798.61635344 -1789.8294493  -1781.06404481 ... -4022.26639085
 -4035.43062232 -4048.61635344]

In the output above, the first line should not be -inf everywhere, as the grid we evaluate on includes values in the allowed range of values.

With @ricardoV94's help, we tracked this down to the way that TruncatedNormal.logp was enforcing the value bounds:
https://github.com/pymc-devs/pymc/blob/main/pymc/distributions/continuous.py#L779

I noticed this comment in the check_parameters() docstring: "Note that check_parameter should not be used to enforce the logic of the logp expression under the normal parameter support as it can be disabled by the user via check_bounds = False in pm.Model()" and indeed the above example works as expected with check_bounds=False.

This PR instead follows the implementation in other truncated distributions, for example, HalfStudentT to use a switch statement instead. I also added a regression test for the example case above.

See https://discourse.pymc.io/t/truncatednormal-logp-returning-all-inf/10398 for more context.

Checklist

Major / Breaking Changes

  • n/a

Bugfixes / New features

  • Fixed a bug in which TruncatedNormal would return -inf for all logp values if any input value was outside of the bounds.

Docs / Maintenance

  • n/a

pymc/distributions/continuous.py Outdated Show resolved Hide resolved
pymc/tests/distributions/test_continuous.py Outdated Show resolved Hide resolved
adrn and others added 2 commits September 14, 2022 12:37
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
@codecov
Copy link

codecov bot commented Sep 14, 2022

Codecov Report

Merging #6128 (8c9dbd2) into main (ec27b5c) will increase coverage by 1.14%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6128      +/-   ##
==========================================
+ Coverage   90.90%   92.05%   +1.14%     
==========================================
  Files          99      102       +3     
  Lines       20543    21299     +756     
==========================================
+ Hits        18675    19607     +932     
+ Misses       1868     1692     -176     
Impacted Files Coverage Δ
pymc/distributions/continuous.py 97.50% <100.00%> (-0.01%) ⬇️
pymc/tests/distributions/test_continuous.py 99.76% <100.00%> (+<0.01%) ⬆️
pymc/tests/distributions/test_shape_utils.py 99.73% <0.00%> (-0.01%) ⬇️
pymc/__init__.py 100.00% <0.00%> (ø)
pymc/exceptions.py 100.00% <0.00%> (ø)
pymc/distributions/bound.py 100.00% <0.00%> (ø)
pymc/distributions/__init__.py 100.00% <0.00%> (ø)
pymc/tests/distributions/test_bound.py 100.00% <0.00%> (ø)
pymc/tests/distributions/test_distribution.py 97.83% <0.00%> (ø)
pymc/distributions/truncated.py 99.30% <0.00%> (ø)
... and 19 more

@@ -777,11 +777,13 @@ def logp(
norm = 0.0

logp = _logprob(normal, (value,), None, None, None, mu, sigma) - norm
logp = at.switch(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually to be equivalent to what we had before we should do something like this:

if is_lower_bounded:
logp = at.switch(value < lower, -np.inf, logp)
if is_upper_bounded:
logp = at.switch(value <= upper, logp, -np.inf)

Copy link
Contributor Author

@adrn adrn Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that equivalent to what is implemented here because of the default values for lower and upper?

lower = at.as_tensor_variable(floatX(lower)) if lower is not None else at.constant(-np.inf)
upper = at.as_tensor_variable(floatX(upper)) if upper is not None else at.constant(np.inf)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We retrieve the None case here:

unbounded_lower = isinstance(lower, TensorConstant) and np.all(lower.value == -np.inf)
unbounded_upper = isinstance(upper, TensorConstant) and np.all(upper.value == np.inf)

So in those cases we avoid introducing the useless switch. It's a small optimization but I don't see any reason yo modify it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. OK

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last commit makes the implementation here more analogous to the general truncated case you linked above - thanks for that pointer!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ricardoV94 for taking a look and helping out already! Let me know if the implementation in the latest few commits looks ok. Also, it looks like this PR is waiting for approval to run the full test suite.

@ricardoV94
Copy link
Member

Pre-commit is complaining, otherwise looks good

@adrn adrn force-pushed the truncatednormal-bug branch from 05c5981 to 8c9dbd2 Compare September 16, 2022 13:08
@adrn
Copy link
Contributor Author

adrn commented Sep 16, 2022

Cool - I forgot to pre-commit install so just ran it manually and pushed up the changes. I think the workflows need to be approved to run again?

@ricardoV94 ricardoV94 merged commit 5236d3e into pymc-devs:main Sep 17, 2022
@ricardoV94
Copy link
Member

Thanks @adrn !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants