Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix chain detection in progress bar #1077

Merged
merged 3 commits into from
Jun 28, 2021

Conversation

MarcoGorelli
Copy link
Contributor

closes #1075

Without ipywidgets installed, I get:

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
<ipython-input-5-0c678769c1ff> in <module>
      8 kernel = NUTS(model)
      9 mcmc = MCMC(kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=4)
---> 10 mcmc.run(rng_key_, marriage=dset.MarriageScaled.values, divorce=dset.DivorceScaled.values)
     11 mcmc.print_summary()
     12 samples_1 = mcmc.get_samples()

~/numpyro-dev/numpyro/infer/mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    565                 states, last_state = _laxmap(partial_map_fn, map_args)
    566             elif self.chain_method == "parallel":
--> 567                 states, last_state = pmap(partial_map_fn)(map_args)
    568             else:
    569                 assert self.chain_method == "vectorized"

    [... skipping hidden 12 frame]

~/numpyro-dev/numpyro/infer/mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
    371             else collection_size // self.thinning
    372         )
--> 373         collect_vals = fori_collect(
    374             lower_idx,
    375             upper_idx,

~/numpyro-dev/numpyro/util.py in fori_collect(lower, upper, body_fun, init_val, transform, progbar, return_last_val, collection_size, thinning, **progbar_opts)
    334         )
    335     elif num_chains > 1:
--> 336         progress_bar_fori_loop = progress_bar_factory(upper, num_chains)
    337         _body_fn_pbar = progress_bar_fori_loop(_body_fn)
    338         last_val, collection, _, _ = fori_loop(

~/numpyro-dev/numpyro/util.py in progress_bar_factory(num_samples, num_chains)
    186     finished_chains = []
    187     for chain in range(num_chains):
--> 188         tqdm_bars[chain] = tqdm_auto(range(num_samples), position=chain)
    189         tqdm_bars[chain].set_description("Compiling.. ", refresh=True)
    190 

~/numpyro-dev/venv/lib/python3.8/site-packages/tqdm/notebook.py in __init__(self, *args, **kwargs)
    237         unit_scale = 1 if self.unit_scale is True else self.unit_scale or 1
    238         total = self.total * unit_scale if self.total else self.total
--> 239         self.container = self.status_printer(self.fp, total, self.desc, self.ncols)
    240         self.container.pbar = self
    241         self.displayed = False

~/numpyro-dev/venv/lib/python3.8/site-packages/tqdm/notebook.py in status_printer(_, total, desc, ncols)
    110         # Prepare IPython progress bar
    111         if IProgress is None:  # #187 #451 #558 #872
--> 112             raise ImportError(
    113                 "IProgress not found. Please update jupyter and ipywidgets."
    114                 " See https://ipywidgets.readthedocs.io/en/stable"

ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html

With ipywidgets installed, it now works fine:

4 chains fixed

Any tips on how to write a test for this? Should device be mocked out somehow?

@fehiepsi
Copy link
Member

Any tips on how to write a test for this?

I think you can skip this because CI runs in the terminal (though I think there would be some workarounds for testing). :) Could you add a warning if ipywidgets not available?

@MarcoGorelli
Copy link
Contributor Author

MarcoGorelli commented Jun 27, 2021

Could you add a warning if ipywidgets not available?

Having looked into this a bit further, the issue is the following:

  • if the kernel in which I installed NumPyro doesn't have ipywidgets installed, then the error
    ImportError: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
    
    will be shown, so there's no need for a warning;
  • if the kernel in which I installed NumPyro does have ipywidgets installed but the environment from which I launched jupyter lab doesn't, then the progress bar will not render correctly (as per the second gif I posted in Progress bars only update with num_chains=1 #1075)

So, to avoid the issue from #1075, I think there'd need to be a way detecting whether ipywidgets is installed in the environment from which jupyter lab was launched (rather than in the running kernel, which is what import ipywidgets would detect) - do you know if that's possible?

numpyro/util.py Outdated
@@ -21,6 +21,7 @@
from jax.tree_util import tree_flatten, tree_map

_DISABLE_CONTROL_FLOW_PRIM = False
CHAIN_RE = re.compile(r"\d+$")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add underscore for this global parameter?

@fehiepsi
Copy link
Member

Makes sense to me. It seems to not possible to me. How about adding a note to MCMC docs mentioning the issue and provides link/solution for readers?

@MarcoGorelli
Copy link
Contributor Author

Sure, done, here's a screenshot of the updated docs:

image

numpyro/util.py Outdated
@@ -21,7 +21,7 @@
from jax.tree_util import tree_flatten, tree_map

_DISABLE_CONTROL_FLOW_PRIM = False
CHAIN_RE = re.compile(r"\d+$")
CHAIN_RE = re.compile(r"(?<=_)\d+$") # e.g. get '3' from 'TFRT_CPU_3'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, sorry, I meant _CHAIN_RE

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, no worries, have updated (I guess it doesn't hurt to keep the underscore in the pattern too?)

The previous CI run had this failure:

=================================== FAILURES ===================================
________________________ test_functional_map[pmap-NUTS] ________________________

algo = 'NUTS', map_fn = <function pmap at 0x7fd3df8802f0>

    @pytest.mark.parametrize("algo", ["HMC", "NUTS"])
    @pytest.mark.parametrize("map_fn", [vmap, pmap])
    @pytest.mark.skipif(
        "XLA_FLAGS" not in os.environ,
        reason="without this mark, we have duplicated tests in Travis",
    )
    def test_functional_map(algo, map_fn):
        if map_fn is pmap and xla_bridge.device_count() == 1:
            pytest.skip("pmap test requires device_count greater than 1.")
    
        true_mean, true_std = 1.0, 2.0
        num_warmup, num_samples = 1000, 8000
    
        def potential_fn(z):
            return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2)
    
        init_kernel, sample_kernel = hmc(potential_fn, algo=algo)
        init_params = jnp.array([0.0, -1.0])
        rng_keys = random.split(random.PRNGKey(0), 2)
    
        init_kernel_map = map_fn(
            lambda init_param, rng_key: init_kernel(
                init_param, trajectory_length=9, num_warmup=num_warmup, rng_key=rng_key
            )
        )
        init_states = init_kernel_map(init_params, rng_keys)
    
        fori_collect_map = map_fn(
            lambda hmc_state: fori_collect(
                0,
                num_samples,
                sample_kernel,
                hmc_state,
                transform=lambda x: x.z,
                progbar=False,
            )
        )
        chain_samples = fori_collect_map(init_states)
    
        assert_allclose(
>           jnp.mean(chain_samples, axis=1), jnp.repeat(true_mean, 2), rtol=0.06
        )
E       AssertionError: 
E       Not equal to tolerance rtol=0.06, atol=0
E       
E       Mismatched elements: 1 / 2 (50%)
E       Max absolute difference: 50.5725
E       Max relative difference: 50.5725
E        x: array([  1.020785, -49.5725  ], dtype=float32)
E        y: array([1., 1.], dtype=float32)

test/infer/test_mcmc.py:752: AssertionError

I presume it's unrelated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It passed this time, so I guess so

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jax program is typically deterministic (unless there are versioning updates). I'll take a closer look when this happens again. Thanks for pointing out!

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix, @MarcoGorelli!

@fehiepsi fehiepsi merged commit ab18282 into pyro-ppl:master Jun 28, 2021
@MarcoGorelli MarcoGorelli deleted the progress-bar branch June 28, 2021 14:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Progress bars only update with num_chains=1
2 participants