Skip to content

Commit

Permalink
switch to simple resource counter
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdipper committed Aug 8, 2024
1 parent 61c0d98 commit b2a2c35
Showing 1 changed file with 37 additions and 39 deletions.
76 changes: 37 additions & 39 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,72 +202,64 @@ def progress_bar_factory(num_samples, num_chains):

remainder = num_samples % print_rate

idx_map = {}
idx_counter = 0
tqdm_bars = {}
finished_chains = []
# lock serializes access to idx_map and finished_chains to avoid races
# lock serializes access to idx_counter
lock = Lock()
for chain in range(num_chains):
tqdm_bars[chain] = tqdm_auto(range(num_samples), position=chain)
tqdm_bars[chain].set_description("Compiling.. ", refresh=True)

# Uses resource counting for each iter_num value. Chains are assigned to progress
# bars based on order of arrival to each iter_num value
# Uses resource counting to assign chain ids
def _calc_chain_idx(iter_num):
nonlocal idx_counter
with lock:
try:
idx = idx_map[iter_num]
except KeyError:
idx = 0
idx_map[iter_num] = 0

if idx + 1 == num_chains:
del idx_map[iter_num]
else:
idx_map[iter_num] += 1
return idx
idx = idx_counter
idx_counter += 1
return idx

def _update_tqdm(iter_num, increment):
def _update_tqdm(iter_num, increment, chain):
iter_num = int(iter_num)
increment = int(increment)
chain = _calc_chain_idx(iter_num)
chain = int(chain)
if iter_num == 0:
chain = _calc_chain_idx(iter_num)
tqdm_bars[chain].set_description(f"Running chain {chain}", refresh=False)
tqdm_bars[chain].update(increment)
return chain

def _close_tqdm(iter_num, increment):
iter_num = int(iter_num)
def _close_tqdm(increment, chain):
increment = int(increment)
chain = _calc_chain_idx(iter_num + 1) # +1 so no collision in idx_map
chain = int(chain)
tqdm_bars[chain].update(increment)
finished_chains.append(chain)
with lock:
if len(finished_chains) == num_chains:
for chain in range(num_chains):
tqdm_bars[chain].close()
tqdm_bars[chain].close()

def _update_progress_bar(iter_num):
def _update_progress_bar(iter_num, chain):
"""Updates tqdm progress bar of a JAX loop only if the iteration number is a multiple of the print_rate
Usage: carry = progress_bar((iter_num, print_rate), carry)
"""

_ = lax.cond(
chain = lax.cond(
iter_num == 1,
lambda _: io_callback(_update_tqdm, None, 0, 0),
lambda _: None,
lambda _: io_callback(_update_tqdm, jnp.array(0), 0, 0, chain),
lambda _: chain,
operand=None,
)
_ = lax.cond(
chain = lax.cond(
iter_num % print_rate == 0,
lambda _: io_callback(_update_tqdm, None, iter_num, print_rate),
lambda _: None,
lambda _: io_callback(
_update_tqdm, jnp.array(0), iter_num, print_rate, chain
),
lambda _: chain,
operand=None,
)
_ = lax.cond(
iter_num == num_samples,
lambda _: io_callback(_close_tqdm, None, iter_num, remainder),
lambda _: io_callback(_close_tqdm, None, remainder, chain),
lambda _: None,
operand=None,
)
return chain

def progress_bar_fori_loop(func):
"""Decorator that adds a progress bar to `body_fun` used in `lax.fori_loop`.
Expand All @@ -276,9 +268,10 @@ def progress_bar_fori_loop(func):
"""

def wrapper_progress_bar(i, vals):
result = func(i, vals)
_update_progress_bar(i + 1)
return result
(subvals, chain) = vals
result = func(i, subvals)
chain = _update_progress_bar(i + 1, chain)
return (result, chain)

return wrapper_progress_bar

Expand Down Expand Up @@ -393,10 +386,15 @@ def loop_fn(collection):

def loop_fn(collection):
return fori_loop(
0, upper, _body_fn_pbar, (init_val, collection, start_idx, thinning)
0,
upper,
_body_fn_pbar,
((init_val, collection, start_idx, thinning), -1), # -1 for chain id
)

last_val, collection, _, _ = maybe_jit(loop_fn, donate_argnums=0)(collection)
(last_val, collection, _, _), _ = maybe_jit(loop_fn, donate_argnums=0)(
collection
)

else:
diagnostics_fn = progbar_opts.pop("diagnostics_fn", None)
Expand Down

0 comments on commit b2a2c35

Please sign in to comment.