diff --git a/docs/source/svi.rst b/docs/source/svi.rst index 6f176d2c7..7281ef62e 100644 --- a/docs/source/svi.rst +++ b/docs/source/svi.rst @@ -16,19 +16,28 @@ ELBO :show-inheritance: :member-order: bysource -RenyiELBO ---------- +Trace_ELBO +---------- -.. autoclass:: numpyro.infer.elbo.RenyiELBO +.. autoclass:: numpyro.infer.elbo.Trace_ELBO :members: :undoc-members: :show-inheritance: :member-order: bysource -Trace_ELBO ----------- +TraceMeanField_ELBO +------------------- -.. autoclass:: numpyro.infer.elbo.Trace_ELBO +.. autoclass:: numpyro.infer.elbo.TraceMeanField_ELBO + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + +RenyiELBO +--------- + +.. autoclass:: numpyro.infer.elbo.RenyiELBO :members: :undoc-members: :show-inheritance: diff --git a/examples/hmm_enum.py b/examples/hmm_enum.py index bd6a114d7..5a52f192a 100644 --- a/examples/hmm_enum.py +++ b/examples/hmm_enum.py @@ -43,6 +43,9 @@ 3. *Modeling Temporal Dependencies in High-Dimensional Sequences: Application to Polyphonic Music Generation and Transcription*, Boulanger-Lewandowski, N., Bengio, Y. and Vincent, P. + 4. *Tensor Variable Elimination for Plated Factor Graphs*, + Fritz Obermeyer, Eli Bingham, Martin Jankowiak, Justin Chiu, + Neeraj Pradhan, Alexander Rush, Noah Goodman (https://arxiv.org/abs/1902.03210) """ import argparse @@ -65,8 +68,9 @@ logger.setLevel(logging.INFO) +# %% # Let's start with a simple Hidden Markov Model. -# + # x[t-1] --> x[t] --> x[t+1] # | | | # V V V @@ -100,8 +104,9 @@ def transition_fn(carry, y): scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1)) +# %% # Next let's add a dependency of y[t] on y[t-1]. -# + # x[t-1] --> x[t] --> x[t+1] # | | | # V V V @@ -137,8 +142,9 @@ def transition_fn(carry, y): scan(transition_fn, (x_init, y_init, 0), jnp.swapaxes(sequences, 0, 1)) +# %% # Next consider a Factorial HMM with two hidden states. -# + # w[t-1] ----> w[t] ---> w[t+1] # \ x[t-1] --\-> x[t] --\-> x[t+1] # \ / \ / \ / @@ -185,9 +191,10 @@ def transition_fn(carry, y): scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1)) +# %% # By adding a dependency of x on w, we generalize to a # Dynamic Bayesian Network. -# + # w[t-1] ----> w[t] ---> w[t+1] # | \ | \ | \ # | x[t-1] ----> x[t] ----> x[t+1] @@ -230,6 +237,59 @@ def transition_fn(carry, y): scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1)) +# %% +# Next let's consider a second-order HMM model +# in which x[t+1] depends on both x[t] and x[t-1]. + +# _______>______ +# _____>_____/______ \ +# / / \ \ +# x[t-1] --> x[t] --> x[t+1] --> x[t+2] +# | | | | +# V V V V +# y[t-1] y[t] y[t+1] y[t+2] +# +# Note that in this model (in contrast to the previous model) we treat +# the transition and emission probabilities as parameters (so they have no prior). +# +# Note that this is the "2HMM" model in reference [4]. +def model_6(sequences, lengths, args, include_prior=False): + num_sequences, max_length, data_dim = sequences.shape + + with mask(mask=include_prior): + # Explicitly parameterize the full tensor of transition probabilities, which + # has hidden_dim cubed entries. + probs_x = numpyro.sample("probs_x", + dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1) + .expand([args.hidden_dim, args.hidden_dim]) + .to_event(2)) + + probs_y = numpyro.sample("probs_y", + dist.Beta(0.1, 0.9) + .expand([args.hidden_dim, data_dim]) + .to_event(2)) + + def transition_fn(carry, y): + x_prev, x_curr, t = carry + with numpyro.plate("sequences", num_sequences, dim=-2): + with mask(mask=(t < lengths)[..., None]): + probs_x_t = Vindex(probs_x)[x_prev, x_curr] + x_prev, x_curr = x_curr, numpyro.sample("x", dist.Categorical(probs_x_t)) + with numpyro.plate("tones", data_dim, dim=-1): + probs_y_t = probs_y[x_curr.squeeze(-1)] + numpyro.sample("y", + dist.Bernoulli(probs_y_t), + obs=y) + return (x_prev, x_curr, t + 1), None + + x_prev = jnp.zeros((num_sequences, 1), dtype=jnp.int32) + x_curr = jnp.zeros((num_sequences, 1), dtype=jnp.int32) + scan(transition_fn, (x_prev, x_curr, 0), jnp.swapaxes(sequences, 0, 1), history=2) + + +# %% +# Do inference + models = {name[len('model_'):]: model for name, model in globals().items() if name.startswith('model_')} diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 7bf4b6b9b..a81083421 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -4,7 +4,7 @@ from collections import OrderedDict from functools import partial -from jax import lax, random, tree_flatten, tree_map, tree_multimap, tree_unflatten +from jax import device_put, lax, random, tree_flatten, tree_map, tree_multimap, tree_unflatten import jax.numpy as jnp from jax.tree_util import register_pytree_node_class @@ -87,45 +87,53 @@ def _subs_wrapper(subs_map, i, length, site): " which is currently not supported. Please report the issue to us!") -class promote_shapes(Messenger): - # a helper messenger to promote shapes of `fn` and `value` - # + msg: fn.batch_shape = (2, 3), value.shape = (3,) + fn.event_shape - # process_message(msg): promote value so that value.shape = (1, 3) + fn.event_shape +class _promote_fn_shapes(Messenger): + # a helper messenger to promote shapes of `fn` # + msg: fn.batch_shape = (3,), value.shape = (2, 3) + fn.event_shape # process_message(msg): promote fn so that fn.batch_shape = (1, 3). - def process_message(self, msg): + def postprocess_message(self, msg): if msg["type"] == "sample" and msg["value"] is not None: fn, value = msg["fn"], msg["value"] value_batch_ndims = jnp.ndim(value) - fn.event_dim fn_batch_ndim = len(fn.batch_shape) - prepend_shapes = (1,) * abs(fn_batch_ndim - value_batch_ndims) - if fn_batch_ndim > value_batch_ndims: - msg["value"] = jnp.reshape(value, prepend_shapes + jnp.shape(value)) - elif fn_batch_ndim < value_batch_ndims: + if fn_batch_ndim < value_batch_ndims: + prepend_shapes = (1,) * (value_batch_ndims - fn_batch_ndim) msg["fn"] = tree_map(lambda x: jnp.reshape(x, prepend_shapes + jnp.shape(x)), fn) -def scan_enum(f, init, xs, length, reverse, rng_key=None, substitute_stack=None): +def _promote_scanned_value_shapes(value, fn): + # a helper function to promote shapes of `value` + # + msg: fn.batch_shape = (T, 2, 3), value.shape = (T, 3,) + fn.event_shape + # process_message(msg): promote value so that value.shape = (T, 1, 3) + fn.event_shape + value_batch_ndims = jnp.ndim(value) - fn.event_dim + fn_batch_ndim = len(fn.batch_shape) + if fn_batch_ndim > value_batch_ndims: + prepend_shapes = (1,) * (fn_batch_ndim - value_batch_ndims) + return jnp.reshape(value, jnp.shape(value)[:1] + prepend_shapes + jnp.shape(value)[1:]) + else: + return value + + +def scan_enum(f, init, xs, length, reverse, rng_key=None, substitute_stack=None, history=1, + first_available_dim=None): from numpyro.contrib.funsor import config_enumerate, enum, markov from numpyro.contrib.funsor import trace as packed_trace - # XXX: This implementation only works for history size=1 but can be - # extended to history size > 1 by running `f` `history_size` times - # for initialization. However, `sequential_sum_product` does not - # support history size > 1, so we skip supporting it here. - # Note that `funsor.sum_product.sarkka_bilmes_product` does support history > 1. + # amount number of steps to unroll + history = min(history, length) + unroll_steps = min(2 * history - 1, length) if reverse: - x0 = tree_map(lambda x: x[-1], xs) - xs_ = tree_map(lambda x: x[:-1], xs) + x0 = tree_map(lambda x: x[-unroll_steps:][::-1], xs) + xs_ = tree_map(lambda x: x[:-unroll_steps], xs) else: - x0 = tree_map(lambda x: x[0], xs) - xs_ = tree_map(lambda x: x[1:], xs) + x0 = tree_map(lambda x: x[:unroll_steps], xs) + xs_ = tree_map(lambda x: x[unroll_steps:], xs) - carry_shape_at_t1 = None + carry_shapes = [] def body_fn(wrapped_carry, x, prefix=None): i, rng_key, carry = wrapped_carry - init = True if (not_jax_tracer(i) and i == 0) else False + init = True if (not_jax_tracer(i) and i in range(unroll_steps)) else False rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None) # we need to tell unconstrained messenger in potential energy computation @@ -141,35 +149,57 @@ def body_fn(wrapped_carry, x, prefix=None): seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if init: - with handlers.scope(prefix="_init"): - new_carry, y = seeded_fn(carry, x) + # handler the name to match the pattern of sakkar_bilmes product + with handlers.scope(prefix='P' * (unroll_steps - i), divider='_'): + new_carry, y = config_enumerate(seeded_fn)(carry, x) trace = {} else: - with handlers.block(), packed_trace() as trace, promote_shapes(), enum(), markov(): - # Like scan_wrapper, we collect the trace of scan's transition function - # `seeded_fn` here. To put time dimension to the correct position, we need to - # promote shapes to make `fn` and `value` - # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, - # and value's batch_shape is (3,), then we promote shape of - # value so that its batch shape is (1, 3)). + # Like scan_wrapper, we collect the trace of scan's transition function + # `seeded_fn` here. To put time dimension to the correct position, we need to + # promote shapes to make `fn` and `value` + # at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, + # and value's batch_shape is (3,), then we promote shape of + # value so that its batch shape is (1, 3)). + # Here we will promote `fn` shape first. `value` shape will be promoted after scanned. + # We don't promote `value` shape here because we need to store carry shape + # at this step. If we reshape the `value` here, output carry might get wrong shape. + with _promote_fn_shapes(), packed_trace() as trace, handlers.scope(divider='_'): new_carry, y = config_enumerate(seeded_fn)(carry, x) # store shape of new_carry at a global variable - nonlocal carry_shape_at_t1 - carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]] + if len(carry_shapes) < (history + 1): + carry_shapes.append([jnp.shape(x) for x in tree_flatten(new_carry)[0]]) # make new_carry have the same shape as carry # FIXME: is this rigorous? new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)), new_carry, carry) - return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y) + return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y) - with markov(): + with handlers.block(hide_fn=lambda site: site["name"].startswith("_")), \ + enum(first_available_dim=first_available_dim): wrapped_carry = (0, rng_key, init) - wrapped_carry, (_, y0) = body_fn(wrapped_carry, x0) - if length == 1: - ys = tree_map(lambda x: jnp.expand_dims(x, 0), y0) - return wrapped_carry, (PytreeTrace({}), ys) - wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_, length - 1, reverse) + y0s = [] + # We run unroll_steps + 1 where the last step is used for rolling with `lax.scan` + for i in markov(range(unroll_steps + 1), history=history): + if i < unroll_steps: + wrapped_carry, (_, y0) = body_fn(wrapped_carry, tree_map(lambda z: z[i], x0)) + if i > 0: + # reshape y1, y2,... to have the same shape as y0 + y0 = tree_multimap(lambda z0, z: jnp.reshape(z, jnp.shape(z0)), y0s[0], y0) + y0s.append(y0) + # shapes of the first `history - 1` steps are not useful to interpret the last carry + # shape so we don't need to record them here + if (i >= history - 1) and (len(carry_shapes) < history + 1): + carry_shapes.append(jnp.shape(x) for x in tree_flatten(wrapped_carry[-1])[0]) + else: + # this is the last rolling step + y0s = tree_multimap(lambda *z: jnp.stack(z, axis=0), *y0s) + # return early if length = unroll_steps + if length == unroll_steps: + return wrapped_carry, (PytreeTrace({}), y0s) + wrapped_carry = device_put(wrapped_carry) + wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_, + length - unroll_steps, reverse) first_var = None for name, site in pytree_trace.trace.items(): @@ -181,30 +211,38 @@ def body_fn(wrapped_carry, x, prefix=None): if first_var is None: first_var = name + # we haven't promote shapes of values yet during `lax.scan`, so we do it here + site["value"] = _promote_scanned_value_shapes(site["value"], site["fn"]) + # XXX: site['infer']['dim_to_name'] is not enough to determine leftmost dimension because # we don't record 1-size dimensions in this field time_dim = -min(len(site['fn'].batch_shape), jnp.ndim(site['value']) - site['fn'].event_dim) site['infer']['dim_to_name'][time_dim] = '_time_{}'.format(first_var) # similar to carry, we need to reshape due to shape alternating in markov - ys = tree_multimap(lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)), y0, ys) + ys = tree_multimap(lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)[1:]), y0s, ys) + # then join with y0s + ys = tree_multimap(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys) # we also need to reshape `carry` to match sequential behavior - if length % 2 == 0: - t, rng_key, carry = wrapped_carry - flatten_carry, treedef = tree_flatten(carry) - flatten_carry = [jnp.reshape(x, t1_shape) - for x, t1_shape in zip(flatten_carry, carry_shape_at_t1)] - carry = tree_unflatten(treedef, flatten_carry) - wrapped_carry = (t, rng_key, carry) + i = (length + 1) % (history + 1) + t, rng_key, carry = wrapped_carry + carry_shape = carry_shapes[i] + flatten_carry, treedef = tree_flatten(carry) + flatten_carry = [jnp.reshape(x, t1_shape) + for x, t1_shape in zip(flatten_carry, carry_shape)] + carry = tree_unflatten(treedef, flatten_carry) + wrapped_carry = (t, rng_key, carry) return wrapped_carry, (pytree_trace, ys) -def scan_wrapper(f, init, xs, length, reverse, rng_key=None, substitute_stack=[], enum=False): +def scan_wrapper(f, init, xs, length, reverse, rng_key=None, substitute_stack=[], enum=False, + history=1, first_available_dim=None): if length is None: length = tree_flatten(xs)[0][0].shape[0] - if enum: - return scan_enum(f, init, xs, length, reverse, rng_key, substitute_stack) + if enum and history > 0: + return scan_enum(f, init, xs, length, reverse, rng_key, substitute_stack, history, + first_available_dim) def body_fn(wrapped_carry, x): i, rng_key, carry = wrapped_carry @@ -229,10 +267,11 @@ def body_fn(wrapped_carry, x): return (i + 1, rng_key, carry), (PytreeTrace(trace), y) - return lax.scan(body_fn, (jnp.array(0), rng_key, init), xs, length=length, reverse=reverse) + wrapped_carry = device_put((0, rng_key, init)) + return lax.scan(body_fn, wrapped_carry, xs, length=length, reverse=reverse) -def scan(f, init, xs, length=None, reverse=False): +def scan(f, init, xs, length=None, reverse=False, history=1): """ This primitive scans a function over the leading array axes of `xs` while carrying along state. See :func:`jax.lax.scan` for more @@ -297,13 +336,11 @@ def g(*args, **kwargs): evaluated using parallel-scan (reference [1]) over time dimension, which reduces parallel complexity to `O(log(length))`. - Currently, only the equivalence to - :class:`~numpyro.contrib.funsor.enum_messenger.markov(history_size=1)` - is supported. A :class:`~numpyro.handlers.trace` of `scan` with discrete latent + A :class:`~numpyro.handlers.trace` of `scan` with discrete latent variables will contain the following sites: - + init sites: those sites belong to the first trace of `f`. Each of - them will have name prefixed with `_init/`. + + init sites: those sites belong to the first `history` traces of `f`. + Sites at the `i`-th trace will have name prefixed with `P` * (2 * history - 1 - i). + scanned sites: those sites collect the values of the remaining scan loop over `f`. An addition time dimension `_time_foo` will be added to those sites, where `foo` is the name of the first site @@ -331,6 +368,8 @@ def g(*args, **kwargs): but can be used when `xs` is an empty pytree (e.g. None) :param bool reverse: optional boolean specifying whether to run the scan iteration forward (the default) or in reverse + :param int history: The number of previous contexts visible from the current context. + Defaults to 1. If zero, this is similar to :class:`numpyro.plate`. :return: output of scan, quoted from :func:`jax.lax.scan` docs: "pair of type (c, [b]) where the first element represents the final loop carry value and the second element represents the stacked outputs of the @@ -347,7 +386,8 @@ def g(*args, **kwargs): 'fn': scan_wrapper, 'args': (f, init, xs, length, reverse), 'kwargs': {'rng_key': None, - 'substitute_stack': []}, + 'substitute_stack': [], + 'history': history}, 'value': None, } diff --git a/numpyro/contrib/funsor/enum_messenger.py b/numpyro/contrib/funsor/enum_messenger.py index 6f63948b3..dd241c897 100644 --- a/numpyro/contrib/funsor/enum_messenger.py +++ b/numpyro/contrib/funsor/enum_messenger.py @@ -494,6 +494,7 @@ def process_message(self, msg): msg["infer"].get("enumerate") != "parallel" or (not msg["fn"].has_enumerate_support): if msg["type"] == "control_flow": msg["kwargs"]["enum"] = True + msg["kwargs"]["first_available_dim"] = self.first_available_dim return super().process_message(msg) if msg["infer"].get("num_samples", None) is not None: @@ -526,7 +527,7 @@ def postprocess_message(self, msg): if msg["type"] == "sample": total_batch_shape = lax.broadcast_shapes( tuple(msg["fn"].batch_shape), - msg["value"].shape[:len(msg["value"].shape) - msg["fn"].event_dim] + jnp.shape(msg["value"])[:jnp.ndim(msg["value"]) - msg["fn"].event_dim] ) msg["infer"]["dim_to_name"] = NamedMessenger._get_dim_to_name(total_batch_shape) if msg["type"] in ("sample", "param"): diff --git a/numpyro/contrib/funsor/infer_util.py b/numpyro/contrib/funsor/infer_util.py index 7ba7dfe56..433a2d2b0 100644 --- a/numpyro/contrib/funsor/infer_util.py +++ b/numpyro/contrib/funsor/infer_util.py @@ -71,7 +71,8 @@ def config_fn(site): return infer_config(fn, config_fn) -def compute_markov_factors(time_to_factors, time_to_init_vars, time_to_markov_dims, sum_vars, prod_vars): +def compute_markov_factors(time_to_factors, time_to_init_vars, time_to_markov_dims, + sum_vars, prod_vars, history): """ :param dict time_to_factors: a map from time variable to the log prob factors. :param dict time_to_init_vars: a map from time variable to init discrete sites. @@ -79,13 +80,13 @@ def compute_markov_factors(time_to_factors, time_to_init_vars, time_to_markov_di (discrete sites that depend on previous steps). :param frozenset sum_vars: all plate and enum dimensions in the trace. :param frozenset prod_vars: all plate dimensions in the trace. + :param int history: The number of previous contexts visible from the current context. :returns: a list of factors after eliminate time dimensions """ markov_factors = [] for time_var, log_factors in time_to_factors.items(): prev_vars = time_to_init_vars[time_var] - # remove `_init/` prefix to convert prev to curr - prev_to_curr = {k: "/".join(k.split("/")[1:]) for k in prev_vars} + # we eliminate all plate and enum dimensions not available at markov sites. eliminate_vars = (sum_vars | prod_vars) - time_to_markov_dims[time_var] with funsor.interpreter.interpretation(funsor.terms.lazy): @@ -93,8 +94,19 @@ def compute_markov_factors(time_to_factors, time_to_init_vars, time_to_markov_di funsor.ops.logaddexp, funsor.ops.add, log_factors, eliminate=eliminate_vars, plates=prod_vars) trans = funsor.optimizer.apply_optimizer(lazy_result) - markov_factors.append(funsor.sum_product.sequential_sum_product( - funsor.ops.logaddexp, funsor.ops.add, trans, time_var, prev_to_curr)) + + if history > 1: + global_vars = frozenset(set(trans.inputs) - {time_var.name} - prev_vars + - {k.lstrip("P") for k in prev_vars}) + markov_factors.append(funsor.sum_product.sarkka_bilmes_product( + funsor.ops.logaddexp, funsor.ops.add, trans, time_var, global_vars + )) + else: + # remove `P` prefix to convert prev to curr + prev_to_curr = {k: k.lstrip("P") for k in prev_vars} + markov_factors.append(funsor.sum_product.sequential_sum_product( + funsor.ops.logaddexp, funsor.ops.add, trans, time_var, prev_to_curr + )) return markov_factors @@ -124,9 +136,10 @@ def model(*args, **kwargs): model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs) log_factors = [] time_to_factors = defaultdict(list) # log prob factors - time_to_init_vars = defaultdict(frozenset) # _init/... variables + time_to_init_vars = defaultdict(frozenset) # PP... variables time_to_markov_dims = defaultdict(frozenset) # dimensions at markov sites sum_vars, prod_vars = frozenset(), frozenset() + history = 1 for site in model_trace.values(): if site['type'] == 'sample': value = site['value'] @@ -148,8 +161,9 @@ def model(*args, **kwargs): if name.startswith("_time"): time_dim = funsor.Variable(name, funsor.domains.bint(log_prob.shape[dim])) time_to_factors[time_dim].append(log_prob_factor) + history = max(history, max((len(s) - len(s.lstrip("P"))) for s in dim_to_name.values())) time_to_init_vars[time_dim] |= frozenset( - s for s in dim_to_name.values() if s.startswith("_init")) + s for s in dim_to_name.values() if s.startswith("P")) break if time_dim is None: log_factors.append(log_prob_factor) @@ -160,14 +174,14 @@ def model(*args, **kwargs): for time_dim, init_vars in time_to_init_vars.items(): for var in init_vars: - curr_var = "/".join(var.split("/")[1:]) + curr_var = var.lstrip("P") dim_to_name = model_trace[curr_var]["infer"]["dim_to_name"] - if var in dim_to_name.values(): # i.e. _init (i.e. prev) in dim_to_name + if var in dim_to_name.values(): # i.e. P* (i.e. prev) in dim_to_name time_to_markov_dims[time_dim] |= frozenset(name for name in dim_to_name.values()) if len(time_to_factors) > 0: markov_factors = compute_markov_factors(time_to_factors, time_to_init_vars, - time_to_markov_dims, sum_vars, prod_vars) + time_to_markov_dims, sum_vars, prod_vars, history) log_factors = log_factors + markov_factors with funsor.interpreter.interpretation(funsor.terms.lazy): diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 0fa21767c..b03b6c4f5 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -391,7 +391,7 @@ def __init__(self, fn=None, config_fn=None): self.config_fn = config_fn def process_message(self, msg): - if msg["type"] in ("sample", "param"): + if msg["type"] in ("sample",): msg["infer"].update(self.config_fn(msg)) @@ -565,7 +565,7 @@ def process_message(self, msg): class scope(Messenger): """ - This handler prepend a prefix followed by a ``/`` to the name of sample sites. + This handler prepend a prefix followed by a divider to the name of sample sites. Example:: @@ -577,21 +577,23 @@ class scope(Messenger): >>> >>> def model(): ... with scope(prefix="a"): - ... with scope(prefix="b"): + ... with scope(prefix="b", divider="."): ... return numpyro.sample("x", dist.Bernoulli(0.5)) ... - >>> assert "a/b/x" in trace(seed(model, 0)).get_trace() + >>> assert "a/b.x" in trace(seed(model, 0)).get_trace() :param fn: Python callable with NumPyro primitives. :param str prefix: a string to prepend to sample names + :param str divider: a string to join the prefix and sample name; default to `'/'` """ - def __init__(self, fn=None, prefix=''): + def __init__(self, fn=None, prefix='', divider='/'): self.prefix = prefix + self.divider = divider super().__init__(fn) def process_message(self, msg): if msg.get('name'): - msg['name'] = f"{self.prefix}/{msg['name']}" + msg['name'] = f"{self.prefix}{self.divider}{msg['name']}" class seed(Messenger): diff --git a/test/contrib/test_funsor.py b/test/contrib/test_funsor.py index 4ae1a133d..ea7d54a2a 100644 --- a/test/contrib/test_funsor.py +++ b/test/contrib/test_funsor.py @@ -226,14 +226,14 @@ def transition_fn(x, y): probs = init_probs if x is None else transition_probs[x] x = numpyro.sample("x", dist.Categorical(probs)) numpyro.sample("y", dist.Normal(locs[x], 1), obs=y) - return x, None + return x, 1 x, collections = scan(transition_fn, None, data) - assert collections is None + assert collections.shape == data.shape[:1] return x - actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model)), (data,), {}, {})[0] + actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint) actual_last_x = enum(config_enumerate(fun_model))(data) @@ -267,8 +267,8 @@ def transition_fn(x, y): scan(transition_fn, None, data) - actual_log_joint = log_density(enum(config_enumerate(fun_model), -2), (data,), {}, {})[0] expected_log_joint = log_density(enum(config_enumerate(model), -2), (data,), {}, {})[0] + actual_log_joint = log_density(enum(config_enumerate(fun_model), -2), (data,), {}, {})[0] assert_allclose(actual_log_joint, expected_log_joint) @@ -433,6 +433,46 @@ def transition_fn(name, probs, locs, x, y): assert_allclose(actual_log_joint, expected_log_joint) +@pytest.mark.parametrize('history', [2, 3]) +@pytest.mark.parametrize('T', [1, 2, 3, 4, 10, 11, 12, 13]) +def test_scan_history(history, T): + def model(): + p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2))) + q = numpyro.param("q", 0.25 * jnp.ones(2)) + z = numpyro.sample("z", dist.Bernoulli(0.5)) + x_prev = 0 + x_curr = 0 + for t in markov(range(T), history=history): + probs = p[x_prev, x_curr, z] + x_prev, x_curr = x_curr, numpyro.sample("x_{}".format(t), dist.Bernoulli(probs)) + numpyro.sample("y_{}".format(t), dist.Bernoulli(q[x_curr]), obs=0) + return x_prev, x_curr + + def fun_model(): + p = numpyro.param("p", 0.25 * jnp.ones((2, 2, 2))) + q = numpyro.param("q", 0.25 * jnp.ones(2)) + z = numpyro.sample("z", dist.Bernoulli(0.5)) + + def transition_fn(carry, y): + x_prev, x_curr = carry + probs = p[x_prev, x_curr, z] + x_prev, x_curr = x_curr, numpyro.sample("x", dist.Bernoulli(probs)) + numpyro.sample("y", dist.Bernoulli(q[x_curr]), obs=y) + return (x_prev, x_curr), None + + (x_prev, x_curr), _ = scan(transition_fn, (0, 0), jnp.zeros(T), history=history) + return x_prev, x_curr + + expected_log_joint = log_density(enum(config_enumerate(model)), (), {}, {})[0] + actual_log_joint = log_density(enum(config_enumerate(fun_model)), (), {}, {})[0] + assert_allclose(actual_log_joint, expected_log_joint) + + expected_x_prev, expected_x_curr = enum(config_enumerate(model))() + actual_x_prev, actual_x_curr = enum(config_enumerate(fun_model))() + assert_allclose(actual_x_prev, expected_x_prev) + assert_allclose(actual_x_curr, expected_x_curr) + + def test_missing_plate(monkeypatch): K, N = 3, 1000 diff --git a/test/test_examples.py b/test/test_examples.py index de34c86ae..3d0029df5 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -30,6 +30,7 @@ 'hmm_enum.py -m 3 -t 3 -d 3 --num-warmup 1 -n 4', 'hmm_enum.py -m 3 -t 3 -d 4 --num-warmup 1 -n 4', 'hmm_enum.py -m 4 -t 3 -d 4 --num-warmup 1 -n 4', + 'hmm_enum.py -m 6 -t 4 -d 3 --num-warmup 1 -n 4', 'minipyro.py', 'neutra.py --num-samples 100 --num-warmup 100', 'ode.py --num-samples 100 --num-warmup 100 --num-chains 1',