Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ExGaussian logp #4049

Closed
wants to merge 8 commits into from
23 changes: 17 additions & 6 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -3268,12 +3268,23 @@ def logp(self, value):
sigma = self.sigma
nu = self.nu

# This condition suggested by exGAUS.R from gamlss
lp = tt.switch(tt.gt(nu, 0.05 * sigma),
- tt.log(nu) + (mu - value) / nu + 0.5 * (sigma / nu)**2
+ logpow(std_cdf((value - mu) / sigma - sigma / nu), 1.),
- tt.log(sigma * tt.sqrt(2 * np.pi))
- 0.5 * ((value - mu) / sigma)**2)
# This condition is suggested by exGAUS.R from gamlss
lp = tt.switch(
tt.gt(nu, 0.05 * sigma),
-tt.log(nu)
+ (mu - value) / nu
+ 0.5 * (sigma / nu) ** 2
+ logpow(
tt.switch(
tt.eq(std_cdf((value - mu) / sigma - sigma / nu), 0),
np.finfo(float).eps,
std_cdf((value - mu) / sigma - sigma / nu)
),
1.0
),
-tt.log(sigma * tt.sqrt(2 * np.pi)) - 0.5 * ((value - mu) / sigma) ** 2,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
lp = tt.switch(
tt.gt(nu, 0.05 * sigma),
-tt.log(nu)
+ (mu - value) / nu
+ 0.5 * (sigma / nu) ** 2
+ logpow(
tt.switch(
tt.eq(std_cdf((value - mu) / sigma - sigma / nu), 0),
np.finfo(float).eps,
std_cdf((value - mu) / sigma - sigma / nu)
),
1.0
),
-tt.log(sigma * tt.sqrt(2 * np.pi)) - 0.5 * ((value - mu) / sigma) ** 2,
)
standardized_val = (value - mu) / sigma
cdf_val = std_cdf(standardized_val - sigma / nu)
cdf_val_safe = tt.switch(
tt.eq(cdf_val, 0), np.finfo(pm.floatX).eps, cdf_val)
lp = tt.switch(
tt.gt(nu, 0.05 * sigma),
-tt.log(nu) + (mu - value) / nu + 0.5 * (sigma / nu) ** 2 + logpow(cdf_val_safe, 1.0),
-tt.log(sigma * tt.sqrt(2 * np.pi)) - 0.5 * standardized_val ** 2,
)

Copy link
Contributor Author

@AlexAndorra AlexAndorra Aug 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one note @junpenglao : when I use np.finfo(pm.floatX).eps instead of np.finfo(float).eps, it raises ValueError: data type <class 'numpy.object_'> not inexact. Is it important to use our float type?
Full traceback:

Traceback (most recent call last):
  File "/Users/alex_andorra/opt/anaconda3/envs/pymc-dev/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3319, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-3-cf016555abd8>", line 1, in <module>
    pm.ExGaussian.dist(0., .25, 1. / 6).logp(y).eval().round(1)
  File "/Users/alex_andorra/tptm_alex/pymc3/pymc3/distributions/continuous.py", line 3344, in logp
    cdf_val_safe = tt.switch(tt.eq(cdf_val, 0), np.finfo(floatX).eps, cdf_val)
  File "/Users/alex_andorra/opt/anaconda3/envs/pymc-dev/lib/python3.8/site-packages/numpy/core/getlimits.py", line 381, in __new__
    raise ValueError("data type %r not inexact" % (dtype))
ValueError: data type <class 'numpy.object_'> not inexact

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ideally we use the float type that consistent with the rest of the pm.Model - IIUC user set the global dtype and that should reflect in the pm.floatX?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it just converts to theano floatX type:

def floatX(X):
    """
    Convert a theano tensor or numpy array to theano.config.floatX type.
    """
    try:
        return X.astype(theano.config.floatX)
    except AttributeError:
        # Scalar passed
        return np.asarray(X, dtype=theano.config.floatX)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think this does what you're talking about: np.finfo(theano.config.floatX)
Pushing, and if it's ok for you we can merge 👌


return bound(lp, sigma > 0., nu > 0.)

def _repr_latex_(self, name=None, dist=None):
Expand Down