Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bounded while loop #30

Closed
wants to merge 3 commits into from
Closed

Bounded while loop #30

wants to merge 3 commits into from

Conversation

zombie-einstein
Copy link
Collaborator

This is currently a prototype, may still have some bugs!

Prototype of the API for bounded while loop mentioned in #24

Since the progress-bar needs to be closed if the end condition is met, I think the simplest way to implement this is to decorate the condition function (as this is also called every step), and then the user needs to ensure the iteration number is part of the carried value.

This would look like

n_total = 10_000
n_stop = 5_000

@bounded_while_tqdm(n_total)
def cond_fun(x):
    return x < n_stop

def body_fun(x):
    return x + 1

result = jax.lax.while_loop(cond_fun, body_fun, 0)

prints

50%|████████████████████▌                    | 5000/10000 [00:00<00:00, 3593474.98it/s]

At each iteration it performs the same update checks as the scan and look, but it checks if the condition has been met, and terminates the progress bar manually if so.

Any thoughts at @andrewlesak?

@andrewlesak
Copy link

andrewlesak commented Oct 19, 2024

I've also tried wrapping the cond_fun but there's issues associated with doing so:

  1. Passing the iteration number makes it a traced value when mapped which makes lax.cond structures execute both branches, leading to errors. You can address this by pushing all the conditionals into one long if-else sequence in a single callback.
  2. Wrapping cond_fun combined with vmap has a weird effect where all updates still happen even if some conditions were met. i.e. we map over init_vals = jnp.array([4,3,2]) and subtract 1 from each val while > 0. You would expect each instance to take 4, 3, and 2 iters returning final_vals = ([0,0,0]). However, i've found mapping forces updates on all values until the last value meets the condition, so it would return final_vals = ([0,-1,-2]) since it forced 4 iters on all vals. Wrapping the body_fun does not have this effect strangely.
  3. Using PBar class is a bit awkard when wrapping cond_fun since we cannot return an object, just a bool. So you have to explicitly deal with returning a PBar object in the body_fun if you want to work with it.

Number 2 is really the killer here. No idea how to address this currently.

I've also taken a look at the source code for while_loop and found this function _while_lowering that allows you to pass extra arguments into the cond_fun (so possibly a bar_id?). I'm not really familiar with constructing custom jaxprs and I'm trying to figure out how to even call the function correctly. If this is something you're familiar with please take a shot at figuring out how to call this. Here's a mess of code to look at.

from jax import make_jaxpr

n = 10

def cond(token, val):
    pred = val < n 
    return token, pred

def body(token, x):
    x += 1
    return token, x

cond_jaxpr = make_jaxpr(cond)
body_jaxpr = make_jaxpr(body)

# ------------ #

# have to create context to call while_lowering?
from jax._src.core import JaxprEqnContext

# Example values for the arguments
compute_type = "example_type"  # Replace with a relevant compute type
threefry_partitionable = True  # Set as needed

# Initialize the JaxprEqnContext with the required arguments
context = JaxprEqnContext(compute_type=compute_type, 
                           threefry_partitionable=threefry_partitionable)

# Print the context to verify
print(context)

# ------------ #

# try to use _while_lowering()
# this DOES NOT work yet!
from jax._src.lax.control_flow.loops import _while_lowering
_while_lowering(
    context, 
    cond_jaxpr=cond_jaxpr(True, 3.0),
    body_jaxpr=body_jaxpr(True, 3.0),
    cond_nconsts=0, 
    body_nconsts=0
)

Of course, figuring out how to pass an id explicitly to the cond_fun using this method probably isn't necessary especially if we cant figure out point 2 and we're forced to wrap the body anyway.

@andrewlesak
Copy link

andrewlesak commented Oct 19, 2024

For context heres my attempt

@chex.dataclass
class PBar:
    id: int
    carry: typing.Any

    def state(self):
        return self.id, self.carry

def bounded_while_tqdm_v2(
    n: int,
    print_rate: typing.Optional[int] = None,
    tqdm_type: str = "auto",
    **kwargs,
) -> typing.Callable:
    """
    tqdm progress bar for a JAX while_loop

    Parameters
    ----------
    n: int
        Maximum number of iterations.
    print_rate: int
        Optional integer rate at which the progress bar will be updated,
        by default the print rate will 1/20th of the total number of steps.
    tqdm_type: str
        Type of progress-bar, should be one of "auto", "std", or "notebook".
    **kwargs
        Extra keyword arguments to pass to tqdm.

    Returns
    -------
    typing.Callable:
        Progress bar wrapping function.
    """

    update_progress_bar, close_tqdm = build_tqdm_v2(n, print_rate, tqdm_type, **kwargs)

    def _bounded_while_tqdm(cond_fun) -> typing.Callable:
        """
        Decorator that adds a tqdm progress bar to `cond_fun` used in
        `jax.lax.while_loop`. The conditional should be bounded (i.e
        have a maximum number of updates). Note that the iteration
        number `i` must be the last element in the `carry` tuple.
        """
        def cond_fun_wrapper(carry) -> bool:

            if isinstance(carry, PBar):
                bar_id = carry.id
                carry = carry.carry

                if isinstance(carry, tuple):
                    *_, i = carry
                else:
                    i = carry

                cond = cond_fun(carry)
                carry = update_progress_bar(carry, i, update=cond, bar_id=bar_id)
                return close_tqdm(cond, i - 1, update=cond, bar_id=bar_id)
            else:
                if isinstance(carry, tuple):
                    *_, i = carry
                else:
                    i = carry        
                
                cond = cond_fun(carry)
                carry = update_progress_bar(carry, i, update=cond)
                return close_tqdm(cond, i - 1, update=cond)


        return cond_fun_wrapper

    return _bounded_while_tqdm


def build_tqdm_v2(
    n: int,
    print_rate: typing.Optional[int],
    tqdm_type: str,
    **kwargs,
) -> typing.Tuple[typing.Callable, typing.Callable]:
    """
    Build the tqdm progress bar on the host

    Parameters
    ----------
    n: int
        Number of updates
    print_rate: int
        Optional integer rate at which the progress bar will be updated,
        If ``None`` the print rate will 1/20th of the total number of steps.
    tqdm_type: str
        Type of progress-bar, should be one of "auto", "std", or "notebook".
    **kwargs
        Extra keyword arguments to pass to tqdm.
    """

    if tqdm_type not in ("auto", "std", "notebook"):
        raise ValueError(
            'tqdm_type should be one of "auto", "std", or "notebook" '
            f'but got "{tqdm_type}"'
        )
    pbar = getattr(tqdm, tqdm_type).tqdm

    desc = kwargs.pop("desc", f"Running for {n:,} iterations")
    message = kwargs.pop("message", desc)
    position_offset = kwargs.pop("position", 0)

    for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
        kwargs.pop(kwarg, None)

    tqdm_bars = dict()

    if print_rate is None:
        if n > 20:
            print_rate = int(n / 20)
        else:
            print_rate = 1
    else:
        if print_rate < 1:
            raise ValueError(f"Print rate should be > 0 got {print_rate}")
        elif print_rate > n:
            raise ValueError(
                "Print rate should be less than the "
                f"number of steps {n}, got {print_rate}"
            )

    def _update_tqdm(iter_num: int, update: bool, bar_id: int):
        """Update progress bars"""
        bar_id = int(bar_id)

        if iter_num == 0:
            tqdm_bars[bar_id] = pbar(
                total=n,
                position=bar_id + position_offset,
                desc=message,
                **kwargs,
            )
        elif (iter_num % print_rate == 0) & update:
            tqdm_bars[bar_id].update(print_rate)

    def update_progress_bar(
        carry: typing.Any, iter_num: int, update: bool = True, bar_id: int = 0
    ):
        """Updates tqdm from a JAX scan, fori or while loop"""

        def inner_update(_carry):
            callback(_update_tqdm, iter_num, update, bar_id, ordered=True)
            return _carry

        carry = inner_update(carry)
        return carry

    def _close_tqdm(iter_num: int, update: bool, bar_id: int):
        if (iter_num + 1 == n) | jnp.logical_not(update):
            # pbar = tqdm_bars.pop(int(bar_id))
            pbar = tqdm_bars[int(bar_id)]
            dif = iter_num - pbar.n
            pbar.update(int(dif))
            pbar.clear()
            pbar.close()

    def close_tqdm(
        result: typing.Any, iter_num: int, update: bool = True, bar_id: int = 0
    ):
        def inner_close(_result):
            callback(_close_tqdm, iter_num + 1, update, bar_id, ordered=True)
            return _result

        result = inner_close(result)
        return result

    return update_progress_bar, close_tqdm

And i've tested with this code:

from jax_tqdm import PBar, bounded_while_tqdm_v2
import jax
from jax import lax
import jax.numpy as jnp
import jax.random as jr

n = 10
print_rate = 1
arr_size = 9000

@bounded_while_tqdm_v2(n, print_rate)
def cond_fun(carry):
    val, iter_num = carry
    return (val < 7) & (iter_num < n)

def body_fun(carry):
    bar_id, carry = carry.state()
    # jax.debug.print("carry = {}",carry)
    val, iter_num = carry
    rand_mat = jr.normal(jr.PRNGKey(iter_num), (arr_size, arr_size)) 
    mat_mul = rand_mat @ rand_mat.T
    val += 1e-8 * jnp.sum(mat_mul)
    carry = (val, iter_num + 1)
    return PBar(id=bar_id, carry=carry)

def map_func(i, val):
    init = PBar(id=i, carry=(val, 0))
    final_carry = lax.while_loop(cond_fun, body_fun, init)
    return final_carry.carry

bar_idxs = jnp.arange(4)
init_vals = jnp.array([2,3,4,4.5],dtype=float)
final_carry = jax.vmap(map_func, in_axes=(0,0))(bar_idxs, init_vals)

@zombie-einstein
Copy link
Collaborator Author

I think the issue here is more fundamental to JAX from this issue jax-ml/jax#8409 the conditional inside the vmap evaluates both branches. Not sure, at least for the current mechanics, this would mean the progress bar with vmap could be implemented. One option could be to push conditional checks into the Python callback. But then you lose the performance benefit of checking in JAX before calling the slower Python.

@zombie-einstein
Copy link
Collaborator Author

zombie-einstein commented Oct 20, 2024

BTW my thinking was to have a more functional pattern like

cond, body = bounded_while_tqdm(cond, body, n=n)

meaning you can wrap both functions, allowing structures like Pbar to be unpacked in both. But either way it seems as if a conditional statement is used to check when to create/update/close the progress bar then we're a bit stuck.

@zombie-einstein
Copy link
Collaborator Author

Pushed some more prototype changes for this, but closing as it seems it's not actually currently feasible given how JAX currently functions

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants