-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bounded while loop #30
Conversation
I've also tried wrapping the
Number 2 is really the killer here. No idea how to address this currently. I've also taken a look at the source code for from jax import make_jaxpr
n = 10
def cond(token, val):
pred = val < n
return token, pred
def body(token, x):
x += 1
return token, x
cond_jaxpr = make_jaxpr(cond)
body_jaxpr = make_jaxpr(body)
# ------------ #
# have to create context to call while_lowering?
from jax._src.core import JaxprEqnContext
# Example values for the arguments
compute_type = "example_type" # Replace with a relevant compute type
threefry_partitionable = True # Set as needed
# Initialize the JaxprEqnContext with the required arguments
context = JaxprEqnContext(compute_type=compute_type,
threefry_partitionable=threefry_partitionable)
# Print the context to verify
print(context)
# ------------ #
# try to use _while_lowering()
# this DOES NOT work yet!
from jax._src.lax.control_flow.loops import _while_lowering
_while_lowering(
context,
cond_jaxpr=cond_jaxpr(True, 3.0),
body_jaxpr=body_jaxpr(True, 3.0),
cond_nconsts=0,
body_nconsts=0
) Of course, figuring out how to pass an id explicitly to the |
For context heres my attempt @chex.dataclass
class PBar:
id: int
carry: typing.Any
def state(self):
return self.id, self.carry
def bounded_while_tqdm_v2(
n: int,
print_rate: typing.Optional[int] = None,
tqdm_type: str = "auto",
**kwargs,
) -> typing.Callable:
"""
tqdm progress bar for a JAX while_loop
Parameters
----------
n: int
Maximum number of iterations.
print_rate: int
Optional integer rate at which the progress bar will be updated,
by default the print rate will 1/20th of the total number of steps.
tqdm_type: str
Type of progress-bar, should be one of "auto", "std", or "notebook".
**kwargs
Extra keyword arguments to pass to tqdm.
Returns
-------
typing.Callable:
Progress bar wrapping function.
"""
update_progress_bar, close_tqdm = build_tqdm_v2(n, print_rate, tqdm_type, **kwargs)
def _bounded_while_tqdm(cond_fun) -> typing.Callable:
"""
Decorator that adds a tqdm progress bar to `cond_fun` used in
`jax.lax.while_loop`. The conditional should be bounded (i.e
have a maximum number of updates). Note that the iteration
number `i` must be the last element in the `carry` tuple.
"""
def cond_fun_wrapper(carry) -> bool:
if isinstance(carry, PBar):
bar_id = carry.id
carry = carry.carry
if isinstance(carry, tuple):
*_, i = carry
else:
i = carry
cond = cond_fun(carry)
carry = update_progress_bar(carry, i, update=cond, bar_id=bar_id)
return close_tqdm(cond, i - 1, update=cond, bar_id=bar_id)
else:
if isinstance(carry, tuple):
*_, i = carry
else:
i = carry
cond = cond_fun(carry)
carry = update_progress_bar(carry, i, update=cond)
return close_tqdm(cond, i - 1, update=cond)
return cond_fun_wrapper
return _bounded_while_tqdm
def build_tqdm_v2(
n: int,
print_rate: typing.Optional[int],
tqdm_type: str,
**kwargs,
) -> typing.Tuple[typing.Callable, typing.Callable]:
"""
Build the tqdm progress bar on the host
Parameters
----------
n: int
Number of updates
print_rate: int
Optional integer rate at which the progress bar will be updated,
If ``None`` the print rate will 1/20th of the total number of steps.
tqdm_type: str
Type of progress-bar, should be one of "auto", "std", or "notebook".
**kwargs
Extra keyword arguments to pass to tqdm.
"""
if tqdm_type not in ("auto", "std", "notebook"):
raise ValueError(
'tqdm_type should be one of "auto", "std", or "notebook" '
f'but got "{tqdm_type}"'
)
pbar = getattr(tqdm, tqdm_type).tqdm
desc = kwargs.pop("desc", f"Running for {n:,} iterations")
message = kwargs.pop("message", desc)
position_offset = kwargs.pop("position", 0)
for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
kwargs.pop(kwarg, None)
tqdm_bars = dict()
if print_rate is None:
if n > 20:
print_rate = int(n / 20)
else:
print_rate = 1
else:
if print_rate < 1:
raise ValueError(f"Print rate should be > 0 got {print_rate}")
elif print_rate > n:
raise ValueError(
"Print rate should be less than the "
f"number of steps {n}, got {print_rate}"
)
def _update_tqdm(iter_num: int, update: bool, bar_id: int):
"""Update progress bars"""
bar_id = int(bar_id)
if iter_num == 0:
tqdm_bars[bar_id] = pbar(
total=n,
position=bar_id + position_offset,
desc=message,
**kwargs,
)
elif (iter_num % print_rate == 0) & update:
tqdm_bars[bar_id].update(print_rate)
def update_progress_bar(
carry: typing.Any, iter_num: int, update: bool = True, bar_id: int = 0
):
"""Updates tqdm from a JAX scan, fori or while loop"""
def inner_update(_carry):
callback(_update_tqdm, iter_num, update, bar_id, ordered=True)
return _carry
carry = inner_update(carry)
return carry
def _close_tqdm(iter_num: int, update: bool, bar_id: int):
if (iter_num + 1 == n) | jnp.logical_not(update):
# pbar = tqdm_bars.pop(int(bar_id))
pbar = tqdm_bars[int(bar_id)]
dif = iter_num - pbar.n
pbar.update(int(dif))
pbar.clear()
pbar.close()
def close_tqdm(
result: typing.Any, iter_num: int, update: bool = True, bar_id: int = 0
):
def inner_close(_result):
callback(_close_tqdm, iter_num + 1, update, bar_id, ordered=True)
return _result
result = inner_close(result)
return result
return update_progress_bar, close_tqdm And i've tested with this code: from jax_tqdm import PBar, bounded_while_tqdm_v2
import jax
from jax import lax
import jax.numpy as jnp
import jax.random as jr
n = 10
print_rate = 1
arr_size = 9000
@bounded_while_tqdm_v2(n, print_rate)
def cond_fun(carry):
val, iter_num = carry
return (val < 7) & (iter_num < n)
def body_fun(carry):
bar_id, carry = carry.state()
# jax.debug.print("carry = {}",carry)
val, iter_num = carry
rand_mat = jr.normal(jr.PRNGKey(iter_num), (arr_size, arr_size))
mat_mul = rand_mat @ rand_mat.T
val += 1e-8 * jnp.sum(mat_mul)
carry = (val, iter_num + 1)
return PBar(id=bar_id, carry=carry)
def map_func(i, val):
init = PBar(id=i, carry=(val, 0))
final_carry = lax.while_loop(cond_fun, body_fun, init)
return final_carry.carry
bar_idxs = jnp.arange(4)
init_vals = jnp.array([2,3,4,4.5],dtype=float)
final_carry = jax.vmap(map_func, in_axes=(0,0))(bar_idxs, init_vals) |
I think the issue here is more fundamental to JAX from this issue jax-ml/jax#8409 the conditional inside the vmap evaluates both branches. Not sure, at least for the current mechanics, this would mean the progress bar with vmap could be implemented. One option could be to push conditional checks into the Python callback. But then you lose the performance benefit of checking in JAX before calling the slower Python. |
BTW my thinking was to have a more functional pattern like
meaning you can wrap both functions, allowing structures like |
Pushed some more prototype changes for this, but closing as it seems it's not actually currently feasible given how JAX currently functions |
This is currently a prototype, may still have some bugs!
Prototype of the API for bounded while loop mentioned in #24
Since the progress-bar needs to be closed if the end condition is met, I think the simplest way to implement this is to decorate the condition function (as this is also called every step), and then the user needs to ensure the iteration number is part of the carried value.
This would look like
prints
At each iteration it performs the same update checks as the scan and look, but it checks if the condition has been met, and terminates the progress bar manually if so.
Any thoughts at @andrewlesak?