Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple Progress Bars #25

Merged
merged 6 commits into from
Oct 7, 2024
Merged

Multiple Progress Bars #25

merged 6 commits into from
Oct 7, 2024

Conversation

zombie-einstein
Copy link
Collaborator

@zombie-einstein zombie-einstein commented Oct 5, 2024

Issue #24 (thanks also to @andrewlesak for the prototype) and some of my own projects have highlighted that it would be useful to be able to print multiple progress bars when the scan or loop is inside a vmap. Currently this will push all the updates to a single progress bar (as the scanned function is compiled once for the vmap, and so only one progress bar is initialised), showing false progress updates when using vmap.

My approach here has been to allow the user to inject an id for the progress bar, which is then carried through the loop or scan, which then allows updates to be routed to individual progress bars. For example with a scan

from jax_tqdm import BarId, scan_tqdm

@scan_tqdm(n, print_rate=10)
def step(carry, _):
    return carry + 1, carry + 1

@jax.jit
def inner(i):
    init = BarId(i=i, carry=0)
    _final, _all_numbers = jax.lax.scan(step, init, jnp.arange(n))
    return _final.carry, _all_numbers

last_numbers, all_numbers = jax.vmap(inner)(jax.numpy.arange(5))

Here BarId is a simple class that contain the id of the progress bar and the carried value. The user needs to wrap the init of the scan in a BarId and provide an integer id, and conversely unpack the results afterwards. Using a class like this is an easy way of detecting when a user has provided a bar-id.

This approach means the original scan function can still be used without modification (you can just provide the carry without wrapping), and overall it has a fairly low impact on the overall API. I guess it adds some complication (mainly arising from needing to initialise the bars at runtime rather than compile time) requiring users to bundle up the scan arguments.

Welcome any thoughts on this API change!

*Note that some terminals don't support updating multiple bars (for example the PyCharm terminal) this is a known bug with TQDM

@zombie-einstein zombie-einstein added the enhancement New feature or request label Oct 5, 2024
Copy link
Owner

@jeremiecoullon jeremiecoullon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! 🚀

@zombie-einstein
Copy link
Collaborator Author

LGTM! 🚀

Cheers, I just need to add instructions to the README then I can release.

@zombie-einstein zombie-einstein marked this pull request as ready for review October 7, 2024 19:34
@zombie-einstein zombie-einstein changed the title Multiple Bars Multiple Progress Bars Oct 7, 2024
@zombie-einstein zombie-einstein merged commit f216dcb into main Oct 7, 2024
3 checks passed
@zombie-einstein zombie-einstein deleted the multiple_bars branch October 7, 2024 19:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants