Skip to content

Commit

Permalink
Multiple Progress Bars (#25)
Browse files Browse the repository at this point in the history
* Update docstrings

* Prototype multiple progress bar API

* Rename class and fill out tests

* Add multi-bar example to README

* Fix case where position keyword passed

* Bump minor version
  • Loading branch information
zombie-einstein authored Oct 7, 2024
1 parent b08c9f6 commit f216dcb
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ jobs:
installer-parallel: true
- run: poetry install
- run: pip install --upgrade "jax[cpu]"
- run: pytest -vv
- run: pytest -vvs
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,41 @@ def step(i, val):
last_number = lax.fori_loop(0, n, step, 0)
```

### Scans & Loops Inside VMAP

For scans and loops inside a map, jax-tqdm can print stacked progress bars
showing the individual progress of each process. To do this you can wrap
the initial value of the loop or scan inside a `PBar` class, along with the
index of the progress bar. For example

```python
from jax_tqdm import PBar, scan_tqdm
import jax

n = 10_000

@scan_tqdm(n)
def step(carry, _):
return carry + 1, carry + 1

def map_func(i):
# Wrap the initial value and pass the
# progress bar index
init = PBar(id=i, carry=0)
final_value, _all_numbers = jax.lax.scan(
step, init, jax.numpy.arange(n)
)
return (
final_value.carry,
_all_numbers,
)

last_numbers, all_numbers = jax.vmap(map_func)(jax.numpy.arange(10))
```

The indices of the progress bars should be contiguous integers starting
from 0.

### Print Rate

By default, the progress bar is updated 20 times over the course of the scan/loop
Expand Down
2 changes: 1 addition & 1 deletion jax_tqdm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from jax_tqdm.pbar import loop_tqdm, scan_tqdm
from jax_tqdm.pbar import PBar, loop_tqdm, scan_tqdm
73 changes: 52 additions & 21 deletions jax_tqdm/pbar.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import typing

import chex
import jax
import tqdm.auto
import tqdm.notebook
import tqdm.std
from jax.debug import callback


@chex.dataclass
class PBar:
id: int
carry: typing.Any


def scan_tqdm(
n: int,
print_rate: typing.Optional[int] = None,
Expand All @@ -23,6 +30,8 @@ def scan_tqdm(
print_rate : int
Optional integer rate at which the progress bar will be updated,
by default the print rate will 1/20th of the total number of steps.
tqdm_type: str
Type of progress-bar, should be one of "auto", "std", or "notebook".
**kwargs
Extra keyword arguments to pass to tqdm.
Expand All @@ -46,9 +55,18 @@ def wrapper_progress_bar(carry, x):
iter_num, *_ = x
else:
iter_num = x
_update_progress_bar(iter_num)
result = func(carry, x)
return close_tqdm(result, iter_num)

if isinstance(carry, PBar):
bar_id = carry.id
carry = carry.carry
_update_progress_bar(iter_num, bar_id=bar_id)
result = func(carry, x)
result = (PBar(id=bar_id, carry=result[0]), result[1])
return close_tqdm(result, iter_num, bar_id=bar_id)
else:
_update_progress_bar(iter_num)
result = func(carry, x)
return close_tqdm(result, iter_num)

return wrapper_progress_bar

Expand All @@ -71,6 +89,8 @@ def loop_tqdm(
print_rate: int
Optional integer rate at which the progress bar will be updated,
by default the print rate will 1/20th of the total number of steps.
tqdm_type: str
Type of progress-bar, should be one of "auto", "std", or "notebook".
**kwargs
Extra keyword arguments to pass to tqdm.
Expand All @@ -89,9 +109,17 @@ def _loop_tqdm(func):
"""

def wrapper_progress_bar(i, val):
_update_progress_bar(i)
result = func(i, val)
return close_tqdm(result, i)
if isinstance(val, PBar):
bar_id = val.id
val = val.carry
_update_progress_bar(i, bar_id=bar_id)
result = func(i, val)
result = PBar(id=bar_id, carry=result)
return close_tqdm(result, i, bar_id=bar_id)
else:
_update_progress_bar(i)
result = func(i, val)
return close_tqdm(result, i)

return wrapper_progress_bar

Expand All @@ -117,10 +145,12 @@ def build_tqdm(

desc = kwargs.pop("desc", f"Running for {n:,} iterations")
message = kwargs.pop("message", desc)
position_offset = kwargs.pop("position", 0)

for kwarg in ("total", "mininterval", "maxinterval", "miniters"):
kwargs.pop(kwarg, None)

tqdm_bars = {}
tqdm_bars = dict()

if print_rate is None:
if n > 20:
Expand All @@ -138,45 +168,46 @@ def build_tqdm(

remainder = n % print_rate

def _define_tqdm(arg, transform):
tqdm_bars[0] = pbar(range(n), **kwargs)
tqdm_bars[0].set_description(message, refresh=False)
def _define_tqdm(_arg, bar_id: int):
bar_id = int(bar_id)
tqdm_bars[bar_id] = pbar(range(n), position=bar_id + position_offset, **kwargs)
tqdm_bars[bar_id].set_description(message, refresh=False)

def _update_tqdm(arg, transform):
tqdm_bars[0].update(int(arg))
def _update_tqdm(arg, bar_id: int):
tqdm_bars[int(bar_id)].update(int(arg))

def _update_progress_bar(iter_num):
"Updates tqdm from a JAX scan or loop"
def _update_progress_bar(iter_num, bar_id: int = 0):
"""Updates tqdm from a JAX scan or loop"""
_ = jax.lax.cond(
iter_num == 0,
lambda _: callback(_define_tqdm, None, None, ordered=True),
lambda _: callback(_define_tqdm, None, bar_id, ordered=True),
lambda _: None,
operand=None,
)

_ = jax.lax.cond(
# update tqdm every multiple of `print_rate` except at the end
(iter_num % print_rate == 0) & (iter_num != n - remainder),
lambda _: callback(_update_tqdm, print_rate, None, ordered=True),
lambda _: callback(_update_tqdm, print_rate, bar_id, ordered=True),
lambda _: None,
operand=None,
)

_ = jax.lax.cond(
# update tqdm by `remainder`
iter_num == n - remainder,
lambda _: callback(_update_tqdm, remainder, None, ordered=True),
lambda _: callback(_update_tqdm, remainder, bar_id, ordered=True),
lambda _: None,
operand=None,
)

def _close_tqdm(arg, transform):
tqdm_bars[0].close()
def _close_tqdm(_arg, bar_id: int):
tqdm_bars[int(bar_id)].close()

def close_tqdm(result, iter_num):
def close_tqdm(result, iter_num, bar_id: int = 0):
_ = jax.lax.cond(
iter_num == n - 1,
lambda _: callback(_close_tqdm, None, None, ordered=True),
lambda _: callback(_close_tqdm, None, bar_id, ordered=True),
lambda _: None,
operand=None,
)
Expand Down
75 changes: 74 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "jax-tqdm"
version = "0.2.2"
version = "0.3.0"
description = "Tqdm progress bar for JAX scans and loops"
authors = [
"Jeremie Coullon <jeremie.coullon@gmail.com>",
Expand All @@ -17,6 +17,7 @@ license = "MIT"
python = ">=3.9,<4.0"
tqdm = "^4.64.1"
jax = ">=0.4.12"
chex = "^0.1.87"


[tool.poetry.group.dev.dependencies]
Expand Down
Loading

0 comments on commit f216dcb

Please sign in to comment.