-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable progress bar under pmap #712
Conversation
I don't like how the bars get shuffled. I worry that it would be hard to
follow the progress. Maybe the approaches can be combined somehow
…On Sun, 4 Aug 2024, 06:45 Junpeng Lao, ***@***.***> wrote:
Thanks! I like this better than #655
<#655> because it does not
require explicit num_chain. @zaxtax <https://github.com/zaxtax> thoughts?
—
Reply to this email directly, view it on GitHub
<#712 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUNIJYC34Z7BSLGMM5LZPWWUZAVCNFSM6AAAAABL5V26F6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENRXGMYTMMRQHE>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
The initial random shuffle is a bit wonky. I believe it comes from fastprogress - the bars should always be in sorted order if we used the change from fastprogress to tqdm. That's based on similar tests with numpyro's progressbar and tqdm - but I'll verify it. Would that be better? |
Yes - if that works, how about I merge this PR and then @zaxtax you can update #655? |
I had it wrong - the randomization was a race condition. Added a lock and now the bars are always ordered by most to least complete |
But do the bars move around if one chain overtakes another?
…On Mon, 5 Aug 2024, 20:44 andrewdipper, ***@***.***> wrote:
I had it wrong - the randomization was a race condition. Added a lock and
now the bars are always ordered by most to least complete
—
Reply to this email directly, view it on GitHub
<#712 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCULA6YA2N4AI3UANSQDZP7BYVAVCNFSM6AAAAABL5V26F6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENRZGY4DKOJSGI>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
No, the order will always be in decreasing order of chain progress. As such the chain to progress bar mapping will change if one overtakes another. We'd need the equivalent of a chain/device id to avoid this - but from what I've seen so far the new callbacks don't support identifying the underlying device. |
So IIUC, we cannot have identifier like |
In my PR, I handle this by making a chain id that's carried along with the
sample iteration. It's a bit inelegant but straightforward and doesn't
require special machinery from the io callback code
…On Tue, 6 Aug 2024, 09:03 Junpeng Lao, ***@***.***> wrote:
No, the order will always be in decreasing order of chain progress. As
such the chain to progress bar mapping will change if one overtakes
another. We'd need the equivalent of a chain/device id to avoid this - but
from what I've seen so far the new callbacks don't support identifying the
underlying device.
So IIUC, we cannot have identifier like Chain 0, Chain 1 ...? I am not
too bothered by the sorting of the chains, but would be great if there is
some identifier, so users are not surprised by the behavior.
—
Reply to this email directly, view it on GitHub
<#712 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUIH2POMBI7KYHPLZCDZQBYK5AVCNFSM6AAAAABL5V26F6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENZQGUZTEMJWGU>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Yeah but then it requires an explicit user input of chain_id - I would rather not pushing this handling to user, hence my preference to the current solution here. |
We could make a helper that hides that plumbing?
…On Tue, 6 Aug 2024, 11:22 Junpeng Lao, ***@***.***> wrote:
In my PR, I handle this by making a chain id that's carried along with the
sample iteration. It's a bit inelegant but straightforward and doesn't
require special machinery from the io callback code
Yeah but then it requires an explicit user input of chain_id - I would
rather not pushing this handling to user, hence my preference to the
current solution here.
—
Reply to this email directly, view it on GitHub
<#712 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUJE5YQU5PFB37G4LWDZQCIWDAVCNFSM6AAAAABL5V26F6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENZQHAYDMOJZHA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
Correct. Progress is being used as a surrogate id.
If I'm not mistaken I think this helper has to reside outside the pmap and wouldn't be completely transparent. But #655 is more straightforward from a handling perspective. It might be possible to use the |
The latest commit uses the return value from the However, this changes the carry used in progress_bar_scan which is a problem if it is used externally. That'd have to be a slow change. But it does avoid forcing the user to do the device identification with part of the input. A wrapper over scan itself might clean things up a bit. |
@junpenglao I'm now ok with the PR as is and definitely think we should explore a wrapper over scan in the future. |
Great, thank you @andrewdipper for the contribution! |
Oh should we update any examples to show how the new functionality works @andrewdipper ? Or does this not change the API? |
I didn't see any examples using |
No it looks like the examples in the sampling book don't
https://blackjax-devs.github.io/sampling-book/algorithms/cyclical_sgld.html
…On Thu, 8 Aug 2024, 17:34 andrewdipper, ***@***.***> wrote:
I didn't see any examples using progress_bar_scan and only the two uses
within blackjax so we should be good.
—
Reply to this email directly, view it on GitHub
<#712 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCULMAXONNYETUY42UDTZQOFZVAVCNFSM6AAAAABL5V26F6VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDENZWGEZDEMBXGI>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
* Update README.md (blackjax-devs#638) * Update README.md Update citation. * Update README.md * Indexing the notebook showing how to reproduce the GIF. (blackjax-devs#640) Co-authored-by: Junpeng Lao <junpenglao@gmail.com> * Bump python version (blackjax-devs#645) * Bump python version * update bool inverse * SMC: allow each mutation kernel to have different parameters. (blackjax-devs#649) * vmaping over parameters in base * switch from mcmc_factory to just passing in parameters * pre-commit and typing * CRU and docs improvement * pre-commit * code review updates * pre-commit * rename test * Migrate from deprecated `host_callback` to `io_callback` (blackjax-devs#651) * Migrate from deprecated `host_callback` to `io_callback` Co-Authored-By: George Necula <gnecula@users.noreply.github.com> * Format file * Fix bug * Fix MALA transition energy (blackjax-devs#653) * Fix MALA transition energy * Use a different logic. * Change variable names (blackjax-devs#654) * Replace iterative RNG split and carry with `jax.random.fold_in` (blackjax-devs#656) * Replace iterative RNG split and carry with `jax.random.fold_in` * revert unintended change * file formatting * change `jax.tree_map` to `jax.tree.map` * revert unintended file * fiddle with rng_key * seed again * Removal of Algorithm classes. (blackjax-devs#657) * more * removing export * removal of classes, tests passing * linter * fix on test * linter * removing parametrization on test * code review updates * exporting as_top_level_api in dynamic_hmc * linter * code review update: replace imports * Fix deprecated call to jnp.clip (blackjax-devs#664) * Update jax version requirements (blackjax-devs#666) Fix blackjax-devs#665 * Make tests pass on `aarch64-linux` (blackjax-devs#671) * Enable fitlering of AdaptationInfo (blackjax-devs#674) * enable AdaptationInfo filtering * revert progress_bar * fix pre-commit * fix empty sets * enable adapt info filtering for all adaptation algorithms * fix precommit /progressbar=True * change filter tuple to use tree_map * Update `run_inference_algorithm` to split `initial_position` and `initial_state` (blackjax-devs#672) * UPDATE DOCSTRING * ADD STREAMING VERSION * UPDATE TESTS * ADD DOCSTRING * ADD TEST * REFACTOR RUN_INFERENCE_ALGORITHM * UPDATE DOCSTRING * Precommit * CLEAN TESTS * ADD INITIAL_POSITION * FIX TEST * RENAME O * FIX DOCSTRING * PUT EXPECTATION AFTER TRANSFORM * Preconditioned mclmc (blackjax-devs#673) * TESTS * TESTS * UPDATE DOCSTRING * ADD STREAMING VERSION * ADD PRECONDITIONING TO MCLMC * ADD PRECONDITIONING TO TUNING FOR MCLMC * UPDATE GITIGNORE * UPDATE GITIGNORE * UPDATE TESTS * UPDATE TESTS * ADD DOCSTRING * ADD TEST * STREAMING AVERAGE * ADD TEST * REFACTOR RUN_INFERENCE_ALGORITHM * UPDATE DOCSTRING * Precommit * CLEAN TESTS * GITIGNORE * PRECOMMIT CLEAN UP * ADD INITIAL_POSITION * FIX TEST * ADD TEST * REMOVE BENCHMARKS * BUG FIX * CHANGE PRECISION * CHANGE PRECISION * RENAME O * UPDATE STREAMING AVG * UPDATE PR * RENAME STD_MAT * New integrator, and add some metadata to integrators.py (blackjax-devs#681) * TESTS * TESTS * UPDATE DOCSTRING * ADD STREAMING VERSION * ADD PRECONDITIONING TO MCLMC * ADD PRECONDITIONING TO TUNING FOR MCLMC * UPDATE GITIGNORE * UPDATE GITIGNORE * UPDATE TESTS * UPDATE TESTS * ADD DOCSTRING * ADD TEST * STREAMING AVERAGE * ADD TEST * REFACTOR RUN_INFERENCE_ALGORITHM * UPDATE DOCSTRING * Precommit * CLEAN TESTS * GITIGNORE * PRECOMMIT CLEAN UP * FIX SPELLING, ADD OMELYAN, EXPORT COEFFICIENTS * TEMPORARILY ADD BENCHMARKS * ADD INITIAL_POSITION * FIX TEST * CLEAN UP * REMOVE BENCHMARKS * ADD TEST * REMOVE BENCHMARKS * BUG FIX * CHANGE PRECISION * CHANGE PRECISION * ADD OMELYAN TEST * RENAME O * UPDATE STREAMING AVG * UPDATE PR * RENAME STD_MAT * MERGE MAIN * REMOVE COEFFICIENT EXPORTS * Minor formatting (blackjax-devs#685) * Minor formatting * formatting * fix test * formatting * MAKE WINDOW ADAPTATION TAKE INTEGRATOR AS ARGUMENT (blackjax-devs#687) * FIX KWARG BUG (blackjax-devs#686) * FIX KWARG BUG * FIX KWARG BUG * Change isokinetic_integrator generation API (blackjax-devs#689) * Apply function on pytree directly. (blackjax-devs#692) * Apply function on pytree directly. Avoiding unnecssary unpacking * Fix kwarg * Fix sampling test. (blackjax-devs#693) * Enable shared mcmc parameters with tempered smc (blackjax-devs#694) * add parameter filtering * fix parameter split + docstring * change extend_paramss * convert to bit twiddling (blackjax-devs#696) * Remove nightly release (blackjax-devs#699) * Fix doc mistakes (blackjax-devs#701) * Fix equation formatting * Clarify JAX gradient error * Fix punctuation + capitalization * Fix grammar Should not begin sentence with "i.e." in English. * Fix math formatting error * Fix typo Change parallel _ensample_ chain adaptation to parallel _ensemble_ chain adaptation. * Add SVGD citation to appear in doc Currently the SVGD paper is only cited in the `kernel` function, which is defined _within_ the `build_kernel` function. Because of this nested function format, the SVGD paper is _not_ cited in the documentation. To fix this, I added a citation to the SVGD paper in the `as_top_level_api` docstring. * Fix grammar + clarify doc * Fix typo --------- Co-authored-by: Junpeng Lao <junpenglao@gmail.com> * Update index.md (blackjax-devs#711) The jitted step remained unused, leading to the example running with an uncompiled nuts.step. Changing this reduces the execution time by a factor of 30 on my system and showcases blackjax' speed. * Enable progress bar under pmap (blackjax-devs#712) * enable pmap progbar * fix bar creation * add locking * fix formatting * switch to using chain state * remove labels (blackjax-devs#716) * Simplify `run_inference_algorithm` (blackjax-devs#714) * fix minor type errors * storing only expectation values * fixed memory efficient sampling * clean up * renaming vars * precommit fixes * fixing tests * fixing tests * fixing tests * fixing tests * fixing tests * merge main * burn in and fix tests * burn in and fix tests * minor fixes * minor fixes * minor fixes --------- Co-authored-by: jakob.robnik@gmail.com <jakob.robnik@gmail.com> * Harmonize Quickstart example (blackjax-devs#717) * Update README.md (blackjax-devs#719) --------- Co-authored-by: Junpeng Lao <junpenglao@gmail.com> Co-authored-by: Carlos Iguaran <ciguaran@users.noreply.github.com> Co-authored-by: ksnxr <70186663+ksnxr@users.noreply.github.com> Co-authored-by: Gaétan Lepage <33058747+GaetanLepage@users.noreply.github.com> Co-authored-by: Alberto Cabezas <a.cabezasgonzalez@lancaster.ac.uk> Co-authored-by: andrewdipper <andrewpratt333@gmail.com> Co-authored-by: Reuben <reubenharry@users.noreply.github.com> Co-authored-by: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Co-authored-by: johannahaffner <38662446+johannahaffner@users.noreply.github.com> Co-authored-by: jakob.robnik@gmail.com <jakob.robnik@gmail.com>
Enables proper progress bar behavior of
progress_bar_scan
under pmap. However the progress bars are not tied to physical devices sinceio_callback
won't expose the device id. The bars behave as if they are physically tied but are always sorted by most to least complete (they're actually randomly shuffled but one bar will always be first and another last).I haven't visually tested this under a multi-gpu setup as I don't have an identical pair of gpus - just using
JAX_PLATFORMS='cpu'
#655 - Related: changes progress bar to tqdm and requires the user to label devices - a different approach
A few important guidelines and requirements before we can merge your PR:
main
commit;pre-commit
is installed and configured on your machine, and you ran it before opening the PR;