-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix ExGaussian logp #4049
Fix ExGaussian logp #4049
Conversation
Codecov Report
@@ Coverage Diff @@
## master #4049 +/- ##
=======================================
Coverage 86.79% 86.80%
=======================================
Files 88 88
Lines 14143 14147 +4
=======================================
+ Hits 12276 12280 +4
Misses 1867 1867
|
pymc3/distributions/continuous.py
Outdated
lp = tt.switch( | ||
tt.gt(nu, 0.05 * sigma), | ||
-tt.log(nu) | ||
+ (mu - value) / nu | ||
+ 0.5 * (sigma / nu) ** 2 | ||
+ logpow( | ||
tt.switch( | ||
tt.eq(std_cdf((value - mu) / sigma - sigma / nu), 0), | ||
np.finfo(float).eps, | ||
std_cdf((value - mu) / sigma - sigma / nu) | ||
), | ||
1.0 | ||
), | ||
-tt.log(sigma * tt.sqrt(2 * np.pi)) - 0.5 * ((value - mu) / sigma) ** 2, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lp = tt.switch( | |
tt.gt(nu, 0.05 * sigma), | |
-tt.log(nu) | |
+ (mu - value) / nu | |
+ 0.5 * (sigma / nu) ** 2 | |
+ logpow( | |
tt.switch( | |
tt.eq(std_cdf((value - mu) / sigma - sigma / nu), 0), | |
np.finfo(float).eps, | |
std_cdf((value - mu) / sigma - sigma / nu) | |
), | |
1.0 | |
), | |
-tt.log(sigma * tt.sqrt(2 * np.pi)) - 0.5 * ((value - mu) / sigma) ** 2, | |
) | |
standardized_val = (value - mu) / sigma | |
cdf_val = std_cdf(standardized_val - sigma / nu) | |
cdf_val_safe = tt.switch( | |
tt.eq(cdf_val, 0), np.finfo(pm.floatX).eps, cdf_val) | |
lp = tt.switch( | |
tt.gt(nu, 0.05 * sigma), | |
-tt.log(nu) + (mu - value) / nu + 0.5 * (sigma / nu) ** 2 + logpow(cdf_val_safe, 1.0), | |
-tt.log(sigma * tt.sqrt(2 * np.pi)) - 0.5 * standardized_val ** 2, | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one note @junpenglao : when I use np.finfo(pm.floatX).eps
instead of np.finfo(float).eps
, it raises ValueError: data type <class 'numpy.object_'> not inexact
. Is it important to use our float type?
Full traceback:
Traceback (most recent call last):
File "/Users/alex_andorra/opt/anaconda3/envs/pymc-dev/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3319, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-3-cf016555abd8>", line 1, in <module>
pm.ExGaussian.dist(0., .25, 1. / 6).logp(y).eval().round(1)
File "/Users/alex_andorra/tptm_alex/pymc3/pymc3/distributions/continuous.py", line 3344, in logp
cdf_val_safe = tt.switch(tt.eq(cdf_val, 0), np.finfo(floatX).eps, cdf_val)
File "/Users/alex_andorra/opt/anaconda3/envs/pymc-dev/lib/python3.8/site-packages/numpy/core/getlimits.py", line 381, in __new__
raise ValueError("data type %r not inexact" % (dtype))
ValueError: data type <class 'numpy.object_'> not inexact
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think ideally we use the float type that consistent with the rest of the pm.Model - IIUC user set the global dtype and that should reflect in the pm.floatX
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it just converts to theano floatX type:
def floatX(X):
"""
Convert a theano tensor or numpy array to theano.config.floatX type.
"""
try:
return X.astype(theano.config.floatX)
except AttributeError:
# Scalar passed
return np.asarray(X, dtype=theano.config.floatX)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I think this does what you're talking about: np.finfo(theano.config.floatX)
Pushing, and if it's ok for you we can merge 👌
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
huge nitpick LOL
This is actually much better, thanks @junpenglao ! |
Merge branch master to update release notes.
@AlexAndorra Blackify introduced a lot of change - we should do that in a separate PR (including the clean up of the file like import order etc) |
Oops 😨 Do you know if I can revert specific commits but not the commits that came after them? From what you say, I have to revert f9930d8 and da52490 but not the ones in-between... |
I would copy the change and start afresh from master... |
Ah yeah, good point. Let me close this, start a new PR with just the fix, and then another PR to fix the formatting |
As suggested by @junpenglao in #4045, this PR adds a
tt.switch
statement in theExGaussian
logp to replace 0 with epsilon. That way,std_cdf
never returns 0, andlogpow
never returns -inf.I'm not sure what I did is very pythonic/theanoesque, so feel free to comment and suggest improvements! It does seem to work though:
pm.ExGaussian.dist(0., .25, 1./6).logp(y).eval()
doesn't contain -inf anymore, and the model in Discourse doesn't raise aBadInitialEnergy
error.Once these changes are validated, I'll blackify the file for better readibility and update the release notes.
Thanks for the reviews 🖖