Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a bug at MCMC.print_summary #821

Merged
merged 5 commits into from
Dec 2, 2020
Merged

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Nov 20, 2020

There, we assume sample_field is z, which might not be true for other MCMC kernels.

Tested on MH example, where I changed sample_field to u.

I also revise scale handler docstring to mention that scale can be a ndarray. cc @xidulu

@fehiepsi
Copy link
Member Author

fehiepsi commented Dec 2, 2020

@neerajprad Could you help me review this PR? This also fixes an issue of scale like the one in this forum thread.

@@ -24,7 +24,7 @@ def _get_codomain(bijector):
return constraints.positive
elif bijector.__class__.__name__ == "GeneralizedPareto":
loc, scale, concentration = bijector.loc, bijector.scale, bijector.concentration
if not_jax_tracer(concentration) and np.all(concentration < 0):
if not_jax_tracer(concentration) and np.all(np.less(concentration, 0)):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently,

import jax
from numpyro.util import not_jax_tracer
import numpy as np
import jax.numpy as jnp

def f(x):
    if not_jax_tracer(y):
        if np.all(y <= 0):
            print("nooo")
    return x + y

y = jnp.arange(10.)
jax.jit(f)(1)

will fail because although concentration is not a tracer, concentration < 0 will be a tracer if concentration is a DeviceArray. For validation checks like this, we need to use all numpy ops.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These nuances are getting very subtle as compared to the earlier behavior, and I am not sure if I fully understand when to use numpy vs jax.numpy, and now there is this additional distinction between numpy and python ops. :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for device array, python op concentration < 0 will be deferred to something like jnp.less(concentration, 0) or so. The rule of thumb now is to always use numpy ops (not python ops). It is quite unfortunate. >___<

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining! Any idea what's the larger motivation for this change? Are we moving towards a world where everything will be JITed by default?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. This behavior has appeared since jax 0.2. I haven't caught the fact that a < 0 will trigger a jax op if a is a DeviceArray until now.

@@ -329,8 +329,7 @@ def signed_stick_breaking_tril(t):
# we omit the step of computing s = z * z_cumprod by using the fact:
# y = sign(r) * s = sign(r) * sqrt(z * z_cumprod) = r * sqrt(z_cumprod)
z = r ** 2
z1m_cumprod = jnp.cumprod(1 - z, axis=-1)
z1m_cumprod_sqrt = jnp.sqrt(z1m_cumprod)
z1m_cumprod_sqrt = jnp.cumprod(jnp.sqrt(1 - z), axis=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1!

# XXX: there might be the case that state.z is not a dictionary but
# its postprocessed value `sites` is a dictionary.
# TODO: in general, when both `sites` and `state.z` are dictionaries,
# they can have different key names, not necessary due to deterministic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain this further - when will sites and state.z have different names?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This depends on post_processing code, e.g. state.z = {'ABC_b': 1} while sites = {'b': 1}. The below code will return an empty dict because it thinks 'b' is an deterministic site.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but I am thinking that we should keep track of deterministic and latent nodes separately so that we don't need to bother with all this downstream. We can refactor this later.

@neerajprad neerajprad merged commit b0fa27b into pyro-ppl:master Dec 2, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants