Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory reduction fixes for MCMC sampler #1802

Merged
merged 17 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion numpyro/contrib/einstein/steinvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import operator

from jax import grad, jacfwd, numpy as jnp, random, vmap
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map

from numpyro import handlers
Expand All @@ -25,7 +26,7 @@
from numpyro.infer.autoguide import AutoGuide
from numpyro.infer.util import _guess_max_plate_nesting, transform_fn
from numpyro.optim import _NumPyroOptim
from numpyro.util import fori_collect, ravel_pytree
from numpyro.util import fori_collect
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved

SteinVIState = namedtuple("SteinVIState", ["optim_state", "rng_key"])
SteinVIRunResult = namedtuple("SteinRunResult", ["params", "state", "losses"])
Expand Down
23 changes: 14 additions & 9 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,15 @@ def _compile(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
except TypeError:
pass

def _get_states_flat(self):
if self._states_flat is None:
self._states_flat = tree_map(
# need to calculate first dimension manually; see issue #1328
lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],) + x.shape[2:]),
self._states,
)
return self._states_flat

@property
def post_warmup_state(self):
"""
Expand Down Expand Up @@ -675,14 +684,10 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
states, last_state = partial_map_fn(map_args)
# swap num_samples x num_chains to num_chains x num_samples
states = tree_map(lambda x: jnp.swapaxes(x, 0, 1), states)
states_flat = tree_map(
# need to calculate first dimension manually; see issue #1328
lambda x: jnp.reshape(x, (x.shape[0] * x.shape[1],) + x.shape[2:]),
states,
)

self._last_state = last_state
self._states = states
self._states_flat = states_flat
self._states_flat = None
self._set_collection_params()

def get_samples(self, group_by_chain=False):
Expand All @@ -708,7 +713,7 @@ def get_samples(self, group_by_chain=False):
return (
self._states[self._sample_field]
if group_by_chain
else self._states_flat[self._sample_field]
else self._get_states_flat()[self._sample_field]
)

def get_extra_fields(self, group_by_chain=False):
Expand All @@ -720,7 +725,7 @@ def get_extra_fields(self, group_by_chain=False):
:return: Extra fields keyed by field names which are specified in the
`extra_fields` keyword of :meth:`run`.
"""
states = self._states if group_by_chain else self._states_flat
states = self._states if group_by_chain else self._get_states_flat()
return {k: v for k, v in states.items() if k != self._sample_field}

def print_summary(self, prob=0.9, exclude_deterministic=True):
Expand Down Expand Up @@ -758,7 +763,7 @@ def transfer_states_to_host(self):
Reduce the memory footprint of collected samples by transfering them to the host device.
"""
self._states = device_get(self._states)
self._states_flat = device_get(self._states_flat)
self._states_flat = device_get(self._get_states_flat())

def __getstate__(self):
state = self.__dict__.copy()
Expand Down
94 changes: 64 additions & 30 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from collections import OrderedDict
from contextlib import contextmanager
from functools import partial
import inspect
from itertools import zip_longest
import os
Expand All @@ -18,7 +19,6 @@
from jax import device_put, jit, lax, vmap
from jax.core import Tracer
from jax.experimental import host_callback
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_map

Expand Down Expand Up @@ -111,6 +111,13 @@ def control_flow_prims_disabled():
_DISABLE_CONTROL_FLOW_PRIM = stored_flag


def maybe_jit(fn, *args, **kwargs):
if _DISABLE_CONTROL_FLOW_PRIM:
return fn
else:
return jit(fn, *args, **kwargs)


def cond(pred, true_operand, true_fun, false_operand, false_fun):
if _DISABLE_CONTROL_FLOW_PRIM:
if pred:
Expand Down Expand Up @@ -168,13 +175,14 @@ def cached_by(outer_fn, *keys):

def _wrapped(fn):
fn_cache = outer_fn._cache
if keys in fn_cache:
fn = fn_cache[keys]
hashkeys = (*keys, fn.__name__)
if hashkeys in fn_cache:
fn = fn_cache[hashkeys]
# update position
del fn_cache[keys]
fn_cache[keys] = fn
del fn_cache[hashkeys]
fn_cache[hashkeys] = fn
else:
fn_cache[keys] = fn
fn_cache[hashkeys] = fn
if len(fn_cache) > max_size:
fn_cache.popitem(last=False)
return fn
Expand Down Expand Up @@ -314,7 +322,7 @@ def fori_collect(
(upper - lower) // thinning if collection_size is None else collection_size
)
assert collection_size >= (upper - lower) // thinning
init_val_flat, unravel_fn = ravel_pytree(transform(init_val))
init_val_transformed = transform(init_val)
start_idx = lower + (upper - lower) % thinning
num_chains = progbar_opts.pop("num_chains", 1)
# host_callback does not work yet with multi-GPU platforms
Expand All @@ -326,53 +334,79 @@ def fori_collect(
)
progbar = False

@partial(maybe_jit, donate_argnums=2)
@cached_by(fori_collect, body_fun, transform)
def _body_fn(i, vals):
val, collection, start_idx, thinning = vals
def _body_fn(i, val, collection, start_idx, thinning):
val = body_fun(val)
idx = (i - start_idx) // thinning
collection = cond(
idx >= 0,
collection,
lambda x: x.at[idx].set(ravel_pytree(transform(val))[0]),
collection,
identity,
)

def update_fn(collect_array, new_val):
return cond(
idx >= 0,
collect_array,
lambda x: x.at[idx].set(new_val),
collect_array,
identity,
)

def update_collection(collection, val):
return jax.tree.map(update_fn, collection, transform(val))

collection = update_collection(collection, val)
return val, collection, start_idx, thinning

collection = jnp.zeros(
(collection_size,) + init_val_flat.shape, dtype=init_val_flat.dtype
)
def map_fn(x):
nx = jnp.asarray(x)
return jnp.zeros((collection_size, *nx.shape), dtype=nx.dtype) * nx[None, ...]
fehiepsi marked this conversation as resolved.
Show resolved Hide resolved

collection = jax.tree.map(map_fn, init_val_transformed)

if not progbar:
last_val, collection, _, _ = fori_loop(
0, upper, _body_fn, (init_val, collection, start_idx, thinning)
)

def loop_fn(collection):
return fori_loop(
0,
upper,
lambda i, vals: _body_fn(i, *vals),
(init_val, collection, start_idx, thinning),
)

last_val, collection, _, _ = maybe_jit(loop_fn, donate_argnums=0)(collection)

elif num_chains > 1:
progress_bar_fori_loop = progress_bar_factory(upper, num_chains)
_body_fn_pbar = progress_bar_fori_loop(_body_fn)
last_val, collection, _, _ = fori_loop(
0, upper, _body_fn_pbar, (init_val, collection, start_idx, thinning)
)
_body_fn_pbar = progress_bar_fori_loop(lambda i, vals: _body_fn(i, *vals))

def loop_fn(collection):
return fori_loop(
0, upper, _body_fn_pbar, (init_val, collection, start_idx, thinning)
)

last_val, collection, _, _ = maybe_jit(loop_fn, donate_argnums=0)(collection)

else:
diagnostics_fn = progbar_opts.pop("diagnostics_fn", None)
progbar_desc = progbar_opts.pop("progbar_desc", lambda x: "")

vals = (init_val, collection, device_put(start_idx), device_put(thinning))

if upper == 0:
# special case, only compiling
jit(_body_fn)(0, vals)
val, collection, start_idx, thinning = vals
_, collection, _, _ = _body_fn(-1, val, collection, start_idx, thinning)
vals = (val, collection, start_idx, thinning)
else:
with tqdm.trange(upper) as t:
for i in t:
vals = jit(_body_fn)(i, vals)
vals = _body_fn(i, *vals)

t.set_description(progbar_desc(i), refresh=False)
if diagnostics_fn:
t.set_postfix_str(diagnostics_fn(vals[0]), refresh=False)

last_val, collection, _, _ = vals

unravel_collection = vmap(unravel_fn)(collection)
return (unravel_collection, last_val) if return_last_val else unravel_collection
return (collection, last_val) if return_last_val else collection


def soft_vmap(fn, xs, batch_ndims=1, chunk_size=None):
Expand Down
Loading