Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ExGaussian logp #4049

Closed
wants to merge 8 commits into from

Conversation

AlexAndorra
Copy link
Contributor

As suggested by @junpenglao in #4045, this PR adds a tt.switch statement in the ExGaussian logp to replace 0 with epsilon. That way, std_cdf never returns 0, and logpow never returns -inf.

I'm not sure what I did is very pythonic/theanoesque, so feel free to comment and suggest improvements! It does seem to work though: pm.ExGaussian.dist(0., .25, 1./6).logp(y).eval() doesn't contain -inf anymore, and the model in Discourse doesn't raise a BadInitialEnergy error.

Once these changes are validated, I'll blackify the file for better readibility and update the release notes.
Thanks for the reviews 🖖

@AlexAndorra AlexAndorra linked an issue Aug 11, 2020 that may be closed by this pull request
@codecov
Copy link

codecov bot commented Aug 11, 2020

Codecov Report

Merging #4049 into master will increase coverage by 0.00%.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master    #4049   +/-   ##
=======================================
  Coverage   86.79%   86.80%           
=======================================
  Files          88       88           
  Lines       14143    14147    +4     
=======================================
+ Hits        12276    12280    +4     
  Misses       1867     1867           
Impacted Files Coverage Δ
pymc3/distributions/continuous.py 80.09% <ø> (+0.07%) ⬆️

Comment on lines 3272 to 3286
lp = tt.switch(
tt.gt(nu, 0.05 * sigma),
-tt.log(nu)
+ (mu - value) / nu
+ 0.5 * (sigma / nu) ** 2
+ logpow(
tt.switch(
tt.eq(std_cdf((value - mu) / sigma - sigma / nu), 0),
np.finfo(float).eps,
std_cdf((value - mu) / sigma - sigma / nu)
),
1.0
),
-tt.log(sigma * tt.sqrt(2 * np.pi)) - 0.5 * ((value - mu) / sigma) ** 2,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lp = tt.switch(
tt.gt(nu, 0.05 * sigma),
-tt.log(nu)
+ (mu - value) / nu
+ 0.5 * (sigma / nu) ** 2
+ logpow(
tt.switch(
tt.eq(std_cdf((value - mu) / sigma - sigma / nu), 0),
np.finfo(float).eps,
std_cdf((value - mu) / sigma - sigma / nu)
),
1.0
),
-tt.log(sigma * tt.sqrt(2 * np.pi)) - 0.5 * ((value - mu) / sigma) ** 2,
)
standardized_val = (value - mu) / sigma
cdf_val = std_cdf(standardized_val - sigma / nu)
cdf_val_safe = tt.switch(
tt.eq(cdf_val, 0), np.finfo(pm.floatX).eps, cdf_val)
lp = tt.switch(
tt.gt(nu, 0.05 * sigma),
-tt.log(nu) + (mu - value) / nu + 0.5 * (sigma / nu) ** 2 + logpow(cdf_val_safe, 1.0),
-tt.log(sigma * tt.sqrt(2 * np.pi)) - 0.5 * standardized_val ** 2,
)

Copy link
Contributor Author

@AlexAndorra AlexAndorra Aug 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one note @junpenglao : when I use np.finfo(pm.floatX).eps instead of np.finfo(float).eps, it raises ValueError: data type <class 'numpy.object_'> not inexact. Is it important to use our float type?
Full traceback:

Traceback (most recent call last):
  File "/Users/alex_andorra/opt/anaconda3/envs/pymc-dev/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3319, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-3-cf016555abd8>", line 1, in <module>
    pm.ExGaussian.dist(0., .25, 1. / 6).logp(y).eval().round(1)
  File "/Users/alex_andorra/tptm_alex/pymc3/pymc3/distributions/continuous.py", line 3344, in logp
    cdf_val_safe = tt.switch(tt.eq(cdf_val, 0), np.finfo(floatX).eps, cdf_val)
  File "/Users/alex_andorra/opt/anaconda3/envs/pymc-dev/lib/python3.8/site-packages/numpy/core/getlimits.py", line 381, in __new__
    raise ValueError("data type %r not inexact" % (dtype))
ValueError: data type <class 'numpy.object_'> not inexact

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ideally we use the float type that consistent with the rest of the pm.Model - IIUC user set the global dtype and that should reflect in the pm.floatX?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it just converts to theano floatX type:

def floatX(X):
    """
    Convert a theano tensor or numpy array to theano.config.floatX type.
    """
    try:
        return X.astype(theano.config.floatX)
    except AttributeError:
        # Scalar passed
        return np.asarray(X, dtype=theano.config.floatX)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think this does what you're talking about: np.finfo(theano.config.floatX)
Pushing, and if it's ok for you we can merge 👌

Copy link
Member

@junpenglao junpenglao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huge nitpick LOL

@AlexAndorra
Copy link
Contributor Author

This is actually much better, thanks @junpenglao !
Just made the changes. I'll blackify and update release notes once #4048 is merged to avoid merge conflicts. Will ping you when it's pushed 😉

@junpenglao
Copy link
Member

@AlexAndorra Blackify introduced a lot of change - we should do that in a separate PR (including the clean up of the file like import order etc)

@AlexAndorra
Copy link
Contributor Author

Oops 😨 Do you know if I can revert specific commits but not the commits that came after them? From what you say, I have to revert f9930d8 and da52490 but not the ones in-between...
That being said, while reviewing you can specify which commit(s) you want GitHub to display: would that be comfortable enough for you to review? Note that I didn't change anything more than what we said in our discussion above -- the Black and import changes have no impact on functionality.

@junpenglao
Copy link
Member

I would copy the change and start afresh from master...
Looking at the black formatting, it seems it surface quite a few place that we have these kind of repetitive pattern in the code - it does not impact performance as theano will merge these nodes, but for readability I think we should do a clean up.

@AlexAndorra
Copy link
Contributor Author

Ah yeah, good point. Let me close this, start a new PR with just the fix, and then another PR to fix the formatting

This was referenced Aug 14, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ExGaussian logp is numerical unstable
2 participants