From 4646d06d87ebb6eb9bee8c2dbf3a2281eb02c151 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Mon, 1 Apr 2024 14:32:24 +0200 Subject: [PATCH 1/2] Migrate progress bar from fastprogress to tqdm, and multiple chain support --- blackjax/progress_bar.py | 66 ++++++++++++++++++++++++---------------- pyproject.toml | 2 +- 2 files changed, 41 insertions(+), 27 deletions(-) diff --git a/blackjax/progress_bar.py b/blackjax/progress_bar.py index ac509b9b6..3b50f0c61 100644 --- a/blackjax/progress_bar.py +++ b/blackjax/progress_bar.py @@ -14,14 +14,16 @@ """Progress bar decorators for use with step functions. Adapted from Jeremie Coullon's blog post :cite:p:`progress_bar`. """ -from fastprogress.fastprogress import progress_bar +import tqdm +from tqdm.auto import tqdm as tqdm_auto +import jax from jax import lax -from jax.experimental import io_callback -def progress_bar_scan(num_samples, print_rate=None): - "Progress bar for a JAX scan" - progress_bars = {} +def progress_bar_scan(num_samples, num_chains=1, print_rate=None): + """Factory that builds a progress bar decorator along + with the `set_tqdm_description` and `close_tqdm` functions + """ if print_rate is None: if num_samples > 20: @@ -29,38 +31,43 @@ def progress_bar_scan(num_samples, print_rate=None): else: print_rate = 1 # if you run the sampler for less than 20 iterations - def _define_bar(arg): - del arg - progress_bars[0] = progress_bar(range(num_samples)) - progress_bars[0].update(0) + remainder = num_samples % print_rate - def _update_bar(arg): - progress_bars[0].update_bar(arg + 1) + tqdm_bars = {} + for chain in range(num_chains): + tqdm_bars[chain] = tqdm_auto(range(num_samples), position=chain) + tqdm_bars[chain].set_description("Compiling.. ", refresh=True) - def _close_bar(arg): - del arg - progress_bars[0].on_iter_end() + def _update_tqdm(arg, chain): + chain = int(chain) + tqdm_bars[chain].set_description(f"Running chain {chain}", refresh=False) + tqdm_bars[chain].update(arg) + + def _close_tqdm(arg, chain): + chain = int(chain) + tqdm_bars[chain].update(arg) + tqdm_bars[chain].close() + + def _update_progress_bar(iter_num, chain): + """Updates tqdm progress bar of a JAX loop only if the iteration number is a multiple of the print_rate + Usage: carry = progress_bar((iter_num, print_rate), carry) + """ - def _update_progress_bar(iter_num): - "Updates progress bar of a JAX scan or loop" _ = lax.cond( iter_num == 0, - lambda _: io_callback(_define_bar, None, iter_num), + lambda _: jax.debug.callback(_update_tqdm, iter_num, chain), lambda _: None, operand=None, ) - _ = lax.cond( - # update every multiple of `print_rate` except at the end - (iter_num % print_rate == 0) | (iter_num == (num_samples - 1)), - lambda _: io_callback(_update_bar, None, iter_num), + (iter_num % print_rate) == 0, + lambda _: jax.debug.callback(_update_tqdm, print_rate, chain), lambda _: None, operand=None, ) - _ = lax.cond( iter_num == num_samples - 1, - lambda _: io_callback(_close_bar, None, None), + lambda _: jax.debug.callback(_close_tqdm, remainder, chain), lambda _: None, operand=None, ) @@ -68,16 +75,23 @@ def _update_progress_bar(iter_num): def _progress_bar_scan(func): """Decorator that adds a progress bar to `body_fun` used in `lax.scan`. Note that `body_fun` must either be looping over `np.arange(num_samples)`, - or be looping over a tuple who's first element is `np.arange(num_samples)` + looping over a tuple whose elements are `np.arange(num_samples), and a + chain id defined as `chain * np.ones(num_samples)`, or be looping over a + tuple who's first element and second elements include iter_num and chain. This means that `iter_num` is the current iteration number """ def wrapper_progress_bar(carry, x): if type(x) is tuple: - iter_num, *_ = x + if num_chains > 1: + iter_num, chain, *_ = x + else: + iter_num, *_ = x + chain = 0 else: iter_num = x - _update_progress_bar(iter_num) + chain = 0 + _update_progress_bar(iter_num, chain) return func(carry, x) return wrapper_progress_bar diff --git a/pyproject.toml b/pyproject.toml index 0739361e2..b33b24aa4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,11 +31,11 @@ classifiers = [ "Topic :: Scientific/Engineering :: Mathematics", ] dependencies = [ - "fastprogress>=1.0.0", "jax>=0.4.16", "jaxlib>=0.4.16", "jaxopt>=0.8", "optax>=0.1.7", + "tqdm", "typing-extensions>=4.4.0", ] dynamic = ["version"] From b2d22731caaf089e9c5e0b351e78c9b737d9725f Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Mon, 1 Apr 2024 14:55:33 +0200 Subject: [PATCH 2/2] Lint fix --- blackjax/progress_bar.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/blackjax/progress_bar.py b/blackjax/progress_bar.py index 3b50f0c61..fae968d7b 100644 --- a/blackjax/progress_bar.py +++ b/blackjax/progress_bar.py @@ -14,10 +14,9 @@ """Progress bar decorators for use with step functions. Adapted from Jeremie Coullon's blog post :cite:p:`progress_bar`. """ -import tqdm -from tqdm.auto import tqdm as tqdm_auto import jax from jax import lax +from tqdm.auto import tqdm as tqdm_auto def progress_bar_scan(num_samples, num_chains=1, print_rate=None):