Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bounded while loop #30

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jax_tqdm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from jax_tqdm.pbar import PBar, loop_tqdm, scan_tqdm
from jax_tqdm.pbar import PBar, bounded_while_tqdm, loop_tqdm, scan_tqdm
64 changes: 48 additions & 16 deletions jax_tqdm/pbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

@chex.dataclass
class PBar:
id: int
carry: typing.Any
id: int = 0
iter: int = 0


def scan_tqdm(
Expand Down Expand Up @@ -51,7 +52,7 @@ def _scan_tqdm(func):
"""

def wrapper_progress_bar(carry, x):
if type(x) is tuple:
if isinstance(x, tuple):
iter_num, *_ = x
else:
iter_num = x
Expand All @@ -62,11 +63,11 @@ def wrapper_progress_bar(carry, x):
carry, x = update_progress_bar((carry, x), iter_num, bar_id=bar_id)
result = func(carry, x)
result = (PBar(id=bar_id, carry=result[0]), result[1])
return close_tqdm(result, iter_num, bar_id=bar_id)
return close_tqdm(result, n, iter_num, bar_id=bar_id)
else:
carry, x = update_progress_bar((carry, x), iter_num)
result = func(carry, x)
return close_tqdm(result, iter_num)
return close_tqdm(result, n, iter_num)

return wrapper_progress_bar

Expand Down Expand Up @@ -115,17 +116,49 @@ def wrapper_progress_bar(i, val):
i, val = update_progress_bar((i, val), i, bar_id=bar_id)
result = func(i, val)
result = PBar(id=bar_id, carry=result)
return close_tqdm(result, i, bar_id=bar_id)
return close_tqdm(result, n, i, bar_id=bar_id)
else:
i, val = update_progress_bar((i, val), i)
result = func(i, val)
return close_tqdm(result, i)
return close_tqdm(result, n, i)

return wrapper_progress_bar

return _loop_tqdm


def bounded_while_tqdm(
cond_fun: typing.Callable,
body_fun: typing.Callable,
n: int,
print_rate: typing.Optional[int] = None,
tqdm_type: str = "auto",
**kwargs,
) -> typing.Tuple[typing.Callable, typing.Callable]:

update_progress_bar, close_tqdm = build_tqdm(n, print_rate, tqdm_type, **kwargs)

def cond_fun_wrapper(val: PBar) -> bool:
return cond_fun(val.carry)

def close_bar(val, iter_num, bar_id):
return close_tqdm(val, iter_num, iter_num - 1, bar_id=bar_id)

def cont(val, _iter_num, _bar_id):
return val

def body_fun_wrapper(val: PBar) -> PBar:
iter_num = val.iter
bar_id = val.id
val = val.carry
val = update_progress_bar(val, iter_num, bar_id=bar_id)
val = body_fun(val)
val = jax.lax.cond(cond_fun(val), close_bar, cont, val, iter_num, bar_id)
return PBar(carry=val, id=bar_id, iter=iter_num + 1)

return cond_fun_wrapper, body_fun_wrapper


def build_tqdm(
n: int,
print_rate: typing.Optional[int],
Expand Down Expand Up @@ -178,9 +211,6 @@ def build_tqdm(
f"number of steps {n}, got {print_rate}"
)

remainder = n % print_rate
remainder = remainder if remainder > 0 else print_rate

def _define_tqdm(bar_id: int):
bar_id = int(bar_id)
tqdm_bars[bar_id] = pbar(
Expand All @@ -193,9 +223,9 @@ def _define_tqdm(bar_id: int):
def _update_tqdm(bar_id: int):
tqdm_bars[int(bar_id)].update(print_rate)

def _close_tqdm(bar_id: int):
def _close_tqdm(bar_id: int, final_value: int):
_pbar = tqdm_bars.pop(int(bar_id))
_pbar.update(remainder)
_pbar.update(int(final_value) - _pbar.n)
_pbar.clear()
_pbar.close()

Expand All @@ -214,22 +244,24 @@ def _inner_update(i, _carry):
)
return _carry

cond = iter_num > 0

carry = jax.lax.cond(
iter_num == 0,
_inner_init,
cond,
_inner_update,
_inner_init,
iter_num,
carry,
)

return carry

def close_tqdm(result: typing.Any, iter_num: int, bar_id: int = 0):
def close_tqdm(result: typing.Any, target: int, iter_num: int, bar_id: int = 0):
def _inner_close(_result):
callback(_close_tqdm, bar_id, ordered=True)
callback(_close_tqdm, bar_id, target, ordered=True)
return _result

result = jax.lax.cond(iter_num + 1 == n, _inner_close, lambda r: r, result)
result = jax.lax.cond(iter_num + 1 == target, _inner_close, lambda r: r, result)
return result

return update_progress_bar, close_tqdm
19 changes: 18 additions & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax.numpy as jnp
import pytest

from jax_tqdm import PBar, loop_tqdm, scan_tqdm
from jax_tqdm import PBar, bounded_while_tqdm, loop_tqdm, scan_tqdm


@pytest.mark.parametrize("print_rate", [None, 1, 10])
Expand Down Expand Up @@ -61,6 +61,23 @@ def inner(i):
assert jnp.array_equal(all_numbers, jnp.tile(1 + jnp.arange(n), (n_maps, 1)))


def test_bounded_while_loop():
n_total = 10_000
n_stop = 5_000

def cond_fun(x):
return x < n_stop

def body_fun(x):
return x + 1

cond_fun, body_fun = bounded_while_tqdm(cond_fun, body_fun, n=n_total)
init_val = PBar(carry=0)
result = jax.lax.while_loop(cond_fun, body_fun, init_val)

assert result == 5_000


@pytest.mark.parametrize("print_rate", [None, 1, 10])
def test_vmap_w_loop(print_rate):
n = 10_000
Expand Down