-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix a bug at MCMC.print_summary #821
Conversation
@neerajprad Could you help me review this PR? This also fixes an issue of |
@@ -24,7 +24,7 @@ def _get_codomain(bijector): | |||
return constraints.positive | |||
elif bijector.__class__.__name__ == "GeneralizedPareto": | |||
loc, scale, concentration = bijector.loc, bijector.scale, bijector.concentration | |||
if not_jax_tracer(concentration) and np.all(concentration < 0): | |||
if not_jax_tracer(concentration) and np.all(np.less(concentration, 0)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently,
import jax
from numpyro.util import not_jax_tracer
import numpy as np
import jax.numpy as jnp
def f(x):
if not_jax_tracer(y):
if np.all(y <= 0):
print("nooo")
return x + y
y = jnp.arange(10.)
jax.jit(f)(1)
will fail because although concentration
is not a tracer, concentration < 0
will be a tracer if concentration is a DeviceArray. For validation checks like this, we need to use all numpy ops.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These nuances are getting very subtle as compared to the earlier behavior, and I am not sure if I fully understand when to use numpy
vs jax.numpy
, and now there is this additional distinction between numpy and python ops. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, for device array, python op concentration < 0
will be deferred to something like jnp.less(concentration, 0)
or so. The rule of thumb now is to always use numpy
ops (not python ops). It is quite unfortunate. >___<
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining! Any idea what's the larger motivation for this change? Are we moving towards a world where everything will be JITed by default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure. This behavior has appeared since jax 0.2. I haven't caught the fact that a < 0
will trigger a jax op if a is a DeviceArray until now.
@@ -329,8 +329,7 @@ def signed_stick_breaking_tril(t): | |||
# we omit the step of computing s = z * z_cumprod by using the fact: | |||
# y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod) | |||
z = r ** 2 | |||
z1m_cumprod = jnp.cumprod(1 - z, axis=-1) | |||
z1m_cumprod_sqrt = jnp.sqrt(z1m_cumprod) | |||
z1m_cumprod_sqrt = jnp.cumprod(jnp.sqrt(1 - z), axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1!
# XXX: there might be the case that state.z is not a dictionary but | ||
# its postprocessed value `sites` is a dictionary. | ||
# TODO: in general, when both `sites` and `state.z` are dictionaries, | ||
# they can have different key names, not necessary due to deterministic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you explain this further - when will sites
and state.z
have different names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This depends on post_processing
code, e.g. state.z = {'ABC_b': 1}
while sites = {'b': 1}
. The below code will return an empty dict because it thinks 'b'
is an deterministic site.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but I am thinking that we should keep track of deterministic and latent nodes separately so that we don't need to bother with all this downstream. We can refactor this later.
There, we assume
sample_field
isz
, which might not be true for other MCMC kernels.Tested on MH example, where I changed
sample_field
tou
.I also revise
scale
handler docstring to mention thatscale
can be a ndarray. cc @xidulu