Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issue where first update of the progress bar is 2 times the print rate #26

Conversation

andrewlesak
Copy link

Hello, thanks for the recent update allowing for multiple progress bars! I've been playing with the code and noticed that the progress bars don't update as expected. The issues are:

  1. iteration number jumps from 0 to 2 * print_rate for first update
  2. after fixing 1, progress bar terminates incorrectly at end of loop/scan

Issue 1 is easily fixed with the addition of (iter_num != 0) in the control flow for _update_progress_bar()

    _ = jax.lax.cond(
        (iter_num != 0) & (iter_num % print_rate == 0) & (iter_num != n - remainder),  # fix issue 1
        lambda _: callback(_update_tqdm, print_rate, bar_id, ordered=True),
        lambda _: None,
        operand=None,
    )

To explain issue 2, consider the following example where we have implemented the fix for issue 1.

from jax_tqdm import PBar, scan_tqdm
import jax
import time

n = 10
print_rate = 3

def _sleep(t):
    time.sleep(float(t)) 

@scan_tqdm(n, print_rate)
def step(carry, _):
    jax.debug.callback(_sleep, 1, ordered=True)
    return carry + 1, carry + 1

def map_func(i):
    # Wrap the initial value and pass the
    # progress bar index
    init = PBar(id=i, carry=i)
    final_value, _all_numbers = jax.lax.scan(
        step, init, jax.numpy.arange(n)
    )
    return (
        final_value.carry,
        _all_numbers,
    )

num_bars = 5
last_numbers, all_numbers = jax.vmap(map_func)(jax.numpy.arange(num_bars))

We choose n = 10 and print_rate = 3 which implies remainder = 1. I would expect the progress bar to update in the following order:

  • start of iter = 0: initialize pbar 0/10
  • start of iter = 3: update pbar to 3/10
  • start of iter = 6: update pbar at 6/10
  • start of iter = 9: update pbar at 9/10
  • at end of loop: update pbar to 10/10 and close

What I actually observe is (with the addition of the fix for issue 1):

  • start of iter = 0: initialize pbar 0/10
  • start of iter = 3: update pbar to 3/10
  • start of iter = 6: update pbar at 6/10
  • start of iter = 9: update pbar at 7/10
  • at end of loop: close pbar at 7/10

This happens because at iter = 9 we have n - remainder = iter which then updates the progress bar by remainder = 1, going from 6/10 --> 7/10 where the progress bar is closed.

To fix this, I removed the iter_num == n - remainder check from _update_progress_bar() and instead deal with the final update at close

    def _update_progress_bar(iter_num, bar_id: int = 0):
        """Updates tqdm from a JAX scan or loop"""
        _ = jax.lax.cond(
            iter_num == 0,
            lambda _: callback(_define_tqdm, None, bar_id, ordered=True),
            lambda _: None,
            operand=None,
        )

        _ = jax.lax.cond(
            (iter_num != 0) & (iter_num % print_rate == 0)  # fix issue 1
            lambda _: callback(_update_tqdm, print_rate, bar_id, ordered=True),
            lambda _: None,
            operand=None,
        )

    def _close_tqdm(iter_num, bar_id: int):
        dif = n - tqdm_bars[int(bar_id)].n  # complete remaining updates on pbar before close
        tqdm_bars[int(bar_id)].update(int(dif))
        tqdm_bars[int(bar_id)].close()

These changes get the progress bars working as expected! I tested both scan and for_i wrappers with and without vmap for varying print rates and confirmed the pbars function properly.

Copy link
Collaborator

@zombie-einstein zombie-einstein left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this looks correct to me, good spot. @jeremiecoullon look good to you? If so I can release later today.

@zombie-einstein
Copy link
Collaborator

@andrewlesak looks like you just need to run the linter over this (or I can if you're not set up)

@andrewlesak
Copy link
Author

andrewlesak commented Oct 10, 2024

My apologies but it looks like I haven't entirely solved this problem since the use of sleep() to simulate some computation conveniently masked another issue. When I run the following matrix multiplication at each iter to simulate an expensive computation, I see that the progress bar now jumps from 8 to 10/10 at the last iter so I think i've just pushed the problem to the end of the loop.

from jax_tqdm import PBar, scan_tqdm
import jax
import jax.numpy as jnp
import jax.random as jr

n = 10
print_rate = 1
arr_size = 13000

@scan_tqdm(n, print_rate)
def step(carry, stuff):
    rand_mat = jr.normal(jr.PRNGKey(carry), (arr_size, arr_size)) 
    mat_mul = rand_mat @ rand_mat.T
    return carry + 1, stuff + 1e-8 * jnp.sum(mat_mul)

def map_func(i):
    # Wrap the initial value and pass the
    # progress bar index
    init = PBar(id=i, carry=0)
    final_value, _all_numbers = jax.lax.scan(
        step, init, jnp.arange(n)
    )
    return (
        final_value.carry,
        _all_numbers,
    )

num_bars = 5
last_numbers, all_numbers = jax.vmap(map_func)(jnp.arange(num_bars))

I've actually already implemented this in my wrapper for a while loop (will draft a PR soon), but the solution to this problem is to initialize all the progress bars as soon as the wrapper is called so that we don't initialize in the loop _update_progress_bar when iter_num == 0. Code would look like this:

tqdm_bars = dict()

# immediately initialize progress bars
def _define_tqdm(bar_id):
    bar_id = int(bar_id)
    tqdm_bars[bar_id] = pbar(range(n), position=bar_id + position_offset, **kwargs)
    tqdm_bars[bar_id].set_description(message)
    
callback(_define_tqdm, bar_id, ordered=True)

Then for the update functions we would just have:

def _update_tqdm(arg, bar_id: int):
    tqdm_bars[int(bar_id)].update(int(arg))

def _update_progress_bar(iter_num, bar_id: int = 0):
    """Updates tqdm from a JAX scan or loop"""    
    _ = jax.lax.cond(
        # update tqdm every multiple of `print_rate`
        (iter_num % print_rate == 0),
        lambda _: callback(_update_tqdm, print_rate, bar_id, ordered=True),
        lambda _: None,
        operand=None,
    )

The minor problem with this is that we now have to pass bar_id as an input to the build_tqdm function. We can have it default to bar_id = 0 when we only want a single progress bar, but if we want to vmap then we'll have to put the entire loop in a function like so:

from jax_tqdm import PBar, scan_tqdm
import jax
import jax.numpy as jnp
import jax.random as jr

n = 10
print_rate = 1
arr_size = 13000

def run_loop(i):

    @scan_tqdm(n, print_rate, bar_id=i)
    def step(carry, stuff):
        rand_mat = jr.normal(jr.PRNGKey(carry), (arr_size, arr_size)) 
        mat_mul = rand_mat @ rand_mat.T
        return carry + 1, stuff + 1e-8 * jnp.sum(mat_mul)

    # Wrap the initial value and pass the
    # progress bar index
    init = PBar(id=i, carry=0)
    final_value, _all_numbers = jax.lax.scan(step, init, jnp.arange(n))
    return final_value.carry, _all_numbers

num_bars = 5
bar_idxs = jnp.arange(num_bars)
last_numbers, all_numbers = jax.vmap(run_loop)(bar_idxs)

I can commit these changes for review sometime later today. I already confirmed this scheme worked when I built the while loop wrapper but I'll double check when I have some time.

I'm also new to this whole properly using github thing, so I'm not entirely sure how I would run the linter. I see some test results above so I assume you were able to do it for me? Thanks.

@zombie-einstein
Copy link
Collaborator

zombie-einstein commented Oct 10, 2024

Not sure what the issue here is now. If I run your example on the current main branch

import jax
import jax.numpy as jnp
import jax.random as jr

n = 10
print_rate = 1
arr_size = 13000

@scan_tqdm(n, print_rate)
def step(carry, stuff):
    rand_mat = jr.normal(jr.PRNGKey(carry), (arr_size, arr_size)) 
    mat_mul = rand_mat @ rand_mat.T
    return carry + 1, stuff + 1e-8 * jnp.sum(mat_mul)

def map_func(i):
    # Wrap the initial value and pass the
    # progress bar index
    init = PBar(id=i, carry=0)
    final_value, _all_numbers = jax.lax.scan(
        step, init, jnp.arange(n)
    )
    return (
        final_value.carry,
        _all_numbers,
    )

num_bars = 5
last_numbers, all_numbers = jax.vmap(map_func)(jnp.arange(num_bars))

it seems to run fine and run as expected, i.e. it prints each step in sequence for 5 progress bars.

Would using sleep() to test this actually work, I would have thought this would have just been called when JAX is inspecting the Python code, not when actually running the compiled function? I'm curious where the issue arises with the current implementation.

@andrewlesak
Copy link
Author

andrewlesak commented Oct 10, 2024

Just so we're on the same page, this is what I'm seeing when I run the example code where I sleep for 1 second each iter on my branch. The progress bars appear to work as expected, updating each iter at the same rate.
test_5_iters_sleep_curr_branch

And this is what I see when I run the example code that does a matrix multiplication each iter and doesn't sleep. We now jump past the last iter!
test_5iters_matmul_curr_branch

Im not exactly sure how sleep would be processed in jax, but it seems like its masking the issue that becomes evident when performing an actual computation.

I think that calling define_tqdm when iter_num == 0 is somehow preventing proper updates for the first iter. If I instead change the code to allow for updates at iter_num == 0 (which naively one might think would fix this issue)

def _update_progress_bar(iter_num, bar_id: int = 0):
    """Updates tqdm from a JAX scan or loop"""
    _ = jax.lax.cond(
        iter_num == 0,
        lambda _: callback(_define_tqdm, None, bar_id, ordered=True),
        lambda _: None,
        operand=None,
    )
    
    _ = jax.lax.cond(
        # update tqdm every multiple of `print_rate`
        (iter_num % print_rate == 0), # allow updates for `iter_num == 0`
        lambda _: callback(_update_tqdm, print_rate, bar_id, ordered=True),
        lambda _: None,
        operand=None,
    )

Then this is what I see when running the matrix multiplication example.
test_5_iters_matmul_update_iter0_curr_branch
This results in the original issue I tried to fix where we jump from 0 to 2 * print_rate at the first update of the progress bar. Thats why I think its necessary to initialize all the progress bars outside of the loop.

I can open a new PR soon with my suggested changes and reference this PR.

@andrewlesak
Copy link
Author

I think I've figured out the issue with using sleep() to test. I believe the progress bar gets updated as soon as sleep() is called instead of after. So note to self, I probably shouldn't use sleep() to test jax loops in the future.

@andrewlesak
Copy link
Author

This PR is now superseded by #27. Can I close this PR now?

@zombie-einstein
Copy link
Collaborator

I think I've figured out the issue with using sleep() to test. I believe the progress bar gets updated as soon as sleep() is called instead of after. So note to self, I probably shouldn't use sleep() to test jax loops in the future.

sleep() would be called at compile time, then ignored by JAX at runtime unless wrapped in a callback

@zombie-einstein
Copy link
Collaborator

This PR is now superseded by #27. Can I close this PR now?

Sure this can be closed

@andrewlesak andrewlesak deleted the fix_iter_updates_and_termination branch October 17, 2024 19:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants