-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix issue where first update of the progress bar is 2 times the print rate #26
fix issue where first update of the progress bar is 2 times the print rate #26
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think this looks correct to me, good spot. @jeremiecoullon look good to you? If so I can release later today.
@andrewlesak looks like you just need to run the linter over this (or I can if you're not set up) |
My apologies but it looks like I haven't entirely solved this problem since the use of sleep() to simulate some computation conveniently masked another issue. When I run the following matrix multiplication at each iter to simulate an expensive computation, I see that the progress bar now jumps from 8 to 10/10 at the last iter so I think i've just pushed the problem to the end of the loop. from jax_tqdm import PBar, scan_tqdm
import jax
import jax.numpy as jnp
import jax.random as jr
n = 10
print_rate = 1
arr_size = 13000
@scan_tqdm(n, print_rate)
def step(carry, stuff):
rand_mat = jr.normal(jr.PRNGKey(carry), (arr_size, arr_size))
mat_mul = rand_mat @ rand_mat.T
return carry + 1, stuff + 1e-8 * jnp.sum(mat_mul)
def map_func(i):
# Wrap the initial value and pass the
# progress bar index
init = PBar(id=i, carry=0)
final_value, _all_numbers = jax.lax.scan(
step, init, jnp.arange(n)
)
return (
final_value.carry,
_all_numbers,
)
num_bars = 5
last_numbers, all_numbers = jax.vmap(map_func)(jnp.arange(num_bars)) I've actually already implemented this in my wrapper for a while loop (will draft a PR soon), but the solution to this problem is to initialize all the progress bars as soon as the wrapper is called so that we don't initialize in the loop tqdm_bars = dict()
# immediately initialize progress bars
def _define_tqdm(bar_id):
bar_id = int(bar_id)
tqdm_bars[bar_id] = pbar(range(n), position=bar_id + position_offset, **kwargs)
tqdm_bars[bar_id].set_description(message)
callback(_define_tqdm, bar_id, ordered=True) Then for the update functions we would just have: def _update_tqdm(arg, bar_id: int):
tqdm_bars[int(bar_id)].update(int(arg))
def _update_progress_bar(iter_num, bar_id: int = 0):
"""Updates tqdm from a JAX scan or loop"""
_ = jax.lax.cond(
# update tqdm every multiple of `print_rate`
(iter_num % print_rate == 0),
lambda _: callback(_update_tqdm, print_rate, bar_id, ordered=True),
lambda _: None,
operand=None,
) The minor problem with this is that we now have to pass from jax_tqdm import PBar, scan_tqdm
import jax
import jax.numpy as jnp
import jax.random as jr
n = 10
print_rate = 1
arr_size = 13000
def run_loop(i):
@scan_tqdm(n, print_rate, bar_id=i)
def step(carry, stuff):
rand_mat = jr.normal(jr.PRNGKey(carry), (arr_size, arr_size))
mat_mul = rand_mat @ rand_mat.T
return carry + 1, stuff + 1e-8 * jnp.sum(mat_mul)
# Wrap the initial value and pass the
# progress bar index
init = PBar(id=i, carry=0)
final_value, _all_numbers = jax.lax.scan(step, init, jnp.arange(n))
return final_value.carry, _all_numbers
num_bars = 5
bar_idxs = jnp.arange(num_bars)
last_numbers, all_numbers = jax.vmap(run_loop)(bar_idxs) I can commit these changes for review sometime later today. I already confirmed this scheme worked when I built the while loop wrapper but I'll double check when I have some time. I'm also new to this whole properly using github thing, so I'm not entirely sure how I would run the linter. I see some test results above so I assume you were able to do it for me? Thanks. |
Not sure what the issue here is now. If I run your example on the current main branch
it seems to run fine and run as expected, i.e. it prints each step in sequence for 5 progress bars. Would using |
I think I've figured out the issue with using |
This PR is now superseded by #27. Can I close this PR now? |
|
Sure this can be closed |
Hello, thanks for the recent update allowing for multiple progress bars! I've been playing with the code and noticed that the progress bars don't update as expected. The issues are:
2 * print_rate
for first updateIssue 1 is easily fixed with the addition of
(iter_num != 0)
in the control flow for_update_progress_bar()
To explain issue 2, consider the following example where we have implemented the fix for issue 1.
We choose
n = 10
andprint_rate = 3
which impliesremainder = 1
. I would expect the progress bar to update in the following order:iter = 0
: initialize pbar 0/10iter = 3
: update pbar to 3/10iter = 6
: update pbar at 6/10iter = 9
: update pbar at 9/10What I actually observe is (with the addition of the fix for issue 1):
iter = 0
: initialize pbar 0/10iter = 3
: update pbar to 3/10iter = 6
: update pbar at 6/10iter = 9
: update pbar at 7/10This happens because at
iter = 9
we haven - remainder = iter
which then updates the progress bar byremainder = 1
, going from 6/10 --> 7/10 where the progress bar is closed.To fix this, I removed the
iter_num == n - remainder
check from_update_progress_bar()
and instead deal with the final update at closeThese changes get the progress bars working as expected! I tested both scan and for_i wrappers with and without vmap for varying print rates and confirmed the pbars function properly.