Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate progress bar from fastprogress to tqdm #655

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 39 additions & 26 deletions blackjax/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,70 +14,83 @@
"""Progress bar decorators for use with step functions.
Adapted from Jeremie Coullon's blog post :cite:p:`progress_bar`.
"""
from fastprogress.fastprogress import progress_bar
import jax
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from jax.debug import callback?

from jax import lax
from jax.experimental import io_callback
from tqdm.auto import tqdm as tqdm_auto


def progress_bar_scan(num_samples, print_rate=None):
"Progress bar for a JAX scan"
progress_bars = {}
def progress_bar_scan(num_samples, num_chains=1, print_rate=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC in the usage we need to specify the num_chains for pmap to work properly. Could you explain a bit more how you are planning to change the API for downstream application so that part works?

one_step = progress_bar_scan(num_steps)(_one_step)

one_step_ = jax.jit(progress_bar_scan(num_steps)(one_step))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not fully committed to this API, but I was thinking something where along with passing an array of iteration numbers, you also pass in the chain you are currently in. I think this is better than the numpyro design where you are using regexes on device objects to guess what chain to put the computation on.

def inference_loop(rng_key, kernel, initial_state, chain, num_samples, num_chains):

    def _one_step(state, xs):
        _, _, rng_key = xs
        state, _ = kernel(rng_key, state)
        return state, state
    one_step = jax.jit(progress_bar_factory(num_samples, num_chains)(_one_step))

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(
        one_step,
        initial_state,
        (np.arange(num_samples), chain * np.ones(num_samples), keys),
    )

    return states
    
inference_loop_multiple_chains = jax.pmap(
    inference_loop,
    in_axes=(0, None, 0, 0, None, None),
    static_broadcasted_argnums=(1, 4, 5),
    devices=jax.devices(),
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For downstream applications that don't use multiple chains, I have included logic to maintain backward compatibility. Though I'm not sure how actual code is implementing progress bars for multiple chains today.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, could you share a small jupyter notebook how it looks like?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that's very helpful. Let me think about it a bit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfectly happy to rework the API. This is was an attempt to make something simple and backwards compatible.

"""Factory that builds a progress bar decorator along
with the `set_tqdm_description` and `close_tqdm` functions
"""

if print_rate is None:
if num_samples > 20:
print_rate = int(num_samples / 20)
else:
print_rate = 1 # if you run the sampler for less than 20 iterations

def _define_bar(arg):
del arg
progress_bars[0] = progress_bar(range(num_samples))
progress_bars[0].update(0)
remainder = num_samples % print_rate

def _update_bar(arg):
progress_bars[0].update_bar(arg + 1)
tqdm_bars = {}
for chain in range(num_chains):
tqdm_bars[chain] = tqdm_auto(range(num_samples), position=chain)
tqdm_bars[chain].set_description("Compiling.. ", refresh=True)

def _close_bar(arg):
del arg
progress_bars[0].on_iter_end()
def _update_tqdm(arg, chain):
chain = int(chain)
tqdm_bars[chain].set_description(f"Running chain {chain}", refresh=False)
tqdm_bars[chain].update(arg)

def _close_tqdm(arg, chain):
chain = int(chain)
tqdm_bars[chain].update(arg)
tqdm_bars[chain].close()

def _update_progress_bar(iter_num, chain):
"""Updates tqdm progress bar of a JAX loop only if the iteration number is a multiple of the print_rate
Usage: carry = progress_bar((iter_num, print_rate), carry)
"""

def _update_progress_bar(iter_num):
"Updates progress bar of a JAX scan or loop"
_ = lax.cond(
iter_num == 0,
lambda _: io_callback(_define_bar, None, iter_num),
lambda _: jax.debug.callback(_update_tqdm, iter_num, chain),
lambda _: None,
operand=None,
)

_ = lax.cond(
# update every multiple of `print_rate` except at the end
(iter_num % print_rate == 0) | (iter_num == (num_samples - 1)),
lambda _: io_callback(_update_bar, None, iter_num),
(iter_num % print_rate) == 0,
lambda _: jax.debug.callback(_update_tqdm, print_rate, chain),
lambda _: None,
operand=None,
)

_ = lax.cond(
iter_num == num_samples - 1,
lambda _: io_callback(_close_bar, None, None),
lambda _: jax.debug.callback(_close_tqdm, remainder, chain),
lambda _: None,
operand=None,
)

def _progress_bar_scan(func):
"""Decorator that adds a progress bar to `body_fun` used in `lax.scan`.
Note that `body_fun` must either be looping over `np.arange(num_samples)`,
or be looping over a tuple who's first element is `np.arange(num_samples)`
looping over a tuple whose elements are `np.arange(num_samples), and a
chain id defined as `chain * np.ones(num_samples)`, or be looping over a
tuple who's first element and second elements include iter_num and chain.
This means that `iter_num` is the current iteration number
"""

def wrapper_progress_bar(carry, x):
if type(x) is tuple:
iter_num, *_ = x
if num_chains > 1:
iter_num, chain, *_ = x

Check warning on line 86 in blackjax/progress_bar.py

View check run for this annotation

Codecov / codecov/patch

blackjax/progress_bar.py#L86

Added line #L86 was not covered by tests
else:
iter_num, *_ = x
chain = 0
else:
iter_num = x
_update_progress_bar(iter_num)
chain = 0

Check warning on line 92 in blackjax/progress_bar.py

View check run for this annotation

Codecov / codecov/patch

blackjax/progress_bar.py#L92

Added line #L92 was not covered by tests
_update_progress_bar(iter_num, chain)
return func(carry, x)

return wrapper_progress_bar
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
]
dependencies = [
"fastprogress>=1.0.0",
"jax>=0.4.16",
"jaxlib>=0.4.16",
"jaxopt>=0.8",
"optax>=0.1.7",
"tqdm",
"typing-extensions>=4.4.0",
]
dynamic = ["version"]
Expand Down
Loading