-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix chain detection in progress bar #1077
Conversation
I think you can skip this because CI runs in the terminal (though I think there would be some workarounds for testing). :) Could you add a warning if |
Having looked into this a bit further, the issue is the following:
So, to avoid the issue from #1075, I think there'd need to be a way detecting whether |
numpyro/util.py
Outdated
@@ -21,6 +21,7 @@ | |||
from jax.tree_util import tree_flatten, tree_map | |||
|
|||
_DISABLE_CONTROL_FLOW_PRIM = False | |||
CHAIN_RE = re.compile(r"\d+$") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add underscore for this global parameter?
Makes sense to me. It seems to not possible to me. How about adding a note to MCMC docs mentioning the issue and provides link/solution for readers? |
numpyro/util.py
Outdated
@@ -21,7 +21,7 @@ | |||
from jax.tree_util import tree_flatten, tree_map | |||
|
|||
_DISABLE_CONTROL_FLOW_PRIM = False | |||
CHAIN_RE = re.compile(r"\d+$") | |||
CHAIN_RE = re.compile(r"(?<=_)\d+$") # e.g. get '3' from 'TFRT_CPU_3' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, sorry, I meant _CHAIN_RE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, no worries, have updated (I guess it doesn't hurt to keep the underscore in the pattern too?)
The previous CI run had this failure:
=================================== FAILURES ===================================
________________________ test_functional_map[pmap-NUTS] ________________________
algo = 'NUTS', map_fn = <function pmap at 0x7fd3df8802f0>
@pytest.mark.parametrize("algo", ["HMC", "NUTS"])
@pytest.mark.parametrize("map_fn", [vmap, pmap])
@pytest.mark.skipif(
"XLA_FLAGS" not in os.environ,
reason="without this mark, we have duplicated tests in Travis",
)
def test_functional_map(algo, map_fn):
if map_fn is pmap and xla_bridge.device_count() == 1:
pytest.skip("pmap test requires device_count greater than 1.")
true_mean, true_std = 1.0, 2.0
num_warmup, num_samples = 1000, 8000
def potential_fn(z):
return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2)
init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
init_params = jnp.array([0.0, -1.0])
rng_keys = random.split(random.PRNGKey(0), 2)
init_kernel_map = map_fn(
lambda init_param, rng_key: init_kernel(
init_param, trajectory_length=9, num_warmup=num_warmup, rng_key=rng_key
)
)
init_states = init_kernel_map(init_params, rng_keys)
fori_collect_map = map_fn(
lambda hmc_state: fori_collect(
0,
num_samples,
sample_kernel,
hmc_state,
transform=lambda x: x.z,
progbar=False,
)
)
chain_samples = fori_collect_map(init_states)
assert_allclose(
> jnp.mean(chain_samples, axis=1), jnp.repeat(true_mean, 2), rtol=0.06
)
E AssertionError:
E Not equal to tolerance rtol=0.06, atol=0
E
E Mismatched elements: 1 / 2 (50%)
E Max absolute difference: 50.5725
E Max relative difference: 50.5725
E x: array([ 1.020785, -49.5725 ], dtype=float32)
E y: array([1., 1.], dtype=float32)
test/infer/test_mcmc.py:752: AssertionError
I presume it's unrelated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It passed this time, so I guess so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jax program is typically deterministic (unless there are versioning updates). I'll take a closer look when this happens again. Thanks for pointing out!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix, @MarcoGorelli!
closes #1075
Without ipywidgets installed, I get:
With ipywidgets installed, it now works fine:
Any tips on how to write a test for this? Should
device
be mocked out somehow?