Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support scan markov with history > 1 #848

Merged
merged 11 commits into from
Dec 23, 2020
21 changes: 15 additions & 6 deletions docs/source/svi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,28 @@ ELBO
:show-inheritance:
:member-order: bysource

RenyiELBO
---------
Trace_ELBO
----------

.. autoclass:: numpyro.infer.elbo.RenyiELBO
.. autoclass:: numpyro.infer.elbo.Trace_ELBO
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

Trace_ELBO
----------
TraceMeanField_ELBO
-------------------

.. autoclass:: numpyro.infer.elbo.Trace_ELBO
.. autoclass:: numpyro.infer.elbo.TraceMeanField_ELBO
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

RenyiELBO
---------

.. autoclass:: numpyro.infer.elbo.RenyiELBO
:members:
:undoc-members:
:show-inheritance:
Expand Down
68 changes: 64 additions & 4 deletions examples/hmm_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
3. *Modeling Temporal Dependencies in High-Dimensional Sequences:
Application to Polyphonic Music Generation and Transcription*,
Boulanger-Lewandowski, N., Bengio, Y. and Vincent, P.
4. *Tensor Variable Elimination for Plated Factor Graphs*,
Fritz Obermeyer, Eli Bingham, Martin Jankowiak, Justin Chiu,
Neeraj Pradhan, Alexander Rush, Noah Goodman (https://arxiv.org/abs/1902.03210)
"""

import argparse
Expand All @@ -65,8 +68,9 @@
logger.setLevel(logging.INFO)


# %%
# Let's start with a simple Hidden Markov Model.
#

# x[t-1] --> x[t] --> x[t+1]
# | | |
# V V V
Expand Down Expand Up @@ -100,8 +104,9 @@ def transition_fn(carry, y):
scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))


# %%
# Next let's add a dependency of y[t] on y[t-1].
#

# x[t-1] --> x[t] --> x[t+1]
# | | |
# V V V
Expand Down Expand Up @@ -137,8 +142,9 @@ def transition_fn(carry, y):
scan(transition_fn, (x_init, y_init, 0), jnp.swapaxes(sequences, 0, 1))


# %%
# Next consider a Factorial HMM with two hidden states.
#

# w[t-1] ----> w[t] ---> w[t+1]
# \ x[t-1] --\-> x[t] --\-> x[t+1]
# \ / \ / \ /
Expand Down Expand Up @@ -185,9 +191,10 @@ def transition_fn(carry, y):
scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1))


# %%
# By adding a dependency of x on w, we generalize to a
# Dynamic Bayesian Network.
#

# w[t-1] ----> w[t] ---> w[t+1]
# | \ | \ | \
# | x[t-1] ----> x[t] ----> x[t+1]
Expand Down Expand Up @@ -230,6 +237,59 @@ def transition_fn(carry, y):
scan(transition_fn, (w_init, x_init, 0), jnp.swapaxes(sequences, 0, 1))


# %%
# Next let's consider a second-order HMM model
# in which x[t+1] depends on both x[t] and x[t-1].

# _______>______
# _____>_____/______ \
# / / \ \
# x[t-1] --> x[t] --> x[t+1] --> x[t+2]
# | | | |
# V V V V
# y[t-1] y[t] y[t+1] y[t+2]
#
# Note that in this model (in contrast to the previous model) we treat
# the transition and emission probabilities as parameters (so they have no prior).
#
# Note that this is the "2HMM" model in reference [4].
def model_6(sequences, lengths, args, include_prior=False):
num_sequences, max_length, data_dim = sequences.shape

with mask(mask=include_prior):
# Explicitly parameterize the full tensor of transition probabilities, which
# has hidden_dim cubed entries.
probs_x = numpyro.sample("probs_x",
dist.Dirichlet(0.9 * jnp.eye(args.hidden_dim) + 0.1)
.expand([args.hidden_dim, args.hidden_dim])
.to_event(2))

probs_y = numpyro.sample("probs_y",
dist.Beta(0.1, 0.9)
.expand([args.hidden_dim, data_dim])
.to_event(2))

def transition_fn(carry, y):
x_prev, x_curr, t = carry
with numpyro.plate("sequences", num_sequences, dim=-2):
with mask(mask=(t < lengths)[..., None]):
probs_x_t = Vindex(probs_x)[x_prev, x_curr]
x_prev, x_curr = x_curr, numpyro.sample("x", dist.Categorical(probs_x_t))
with numpyro.plate("tones", data_dim, dim=-1):
probs_y_t = probs_y[x_curr.squeeze(-1)]
numpyro.sample("y",
dist.Bernoulli(probs_y_t),
obs=y)
return (x_prev, x_curr, t + 1), None

x_prev = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
x_curr = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
scan(transition_fn, (x_prev, x_curr, 0), jnp.swapaxes(sequences, 0, 1), history=2)


# %%
# Do inference

models = {name[len('model_'):]: model
for name, model in globals().items()
if name.startswith('model_')}
Expand Down
97 changes: 58 additions & 39 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import OrderedDict
from functools import partial

from jax import lax, random, tree_flatten, tree_map, tree_multimap, tree_unflatten
from jax import device_put, lax, random, tree_flatten, tree_map, tree_multimap, tree_unflatten
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class

Expand Down Expand Up @@ -105,27 +105,23 @@ def process_message(self, msg):
msg["fn"] = tree_map(lambda x: jnp.reshape(x, prepend_shapes + jnp.shape(x)), fn)


def scan_enum(f, init, xs, length, reverse, rng_key=None, substitute_stack=None):
def scan_enum(f, init, xs, length, reverse, rng_key=None, substitute_stack=None, history=1):
from numpyro.contrib.funsor import config_enumerate, enum, markov
from numpyro.contrib.funsor import trace as packed_trace

# XXX: This implementation only works for history size=1 but can be
# extended to history size > 1 by running `f` `history_size` times
# for initialization. However, `sequential_sum_product` does not
# support history size > 1, so we skip supporting it here.
# Note that `funsor.sum_product.sarkka_bilmes_product` does support history > 1.
history = min(history, length)
if reverse:
x0 = tree_map(lambda x: x[-1], xs)
xs_ = tree_map(lambda x: x[:-1], xs)
x0 = tree_map(lambda x: x[-history:][::-1], xs)
xs_ = tree_map(lambda x: x[:-history], xs)
else:
x0 = tree_map(lambda x: x[0], xs)
xs_ = tree_map(lambda x: x[1:], xs)
x0 = tree_map(lambda x: x[:history], xs)
xs_ = tree_map(lambda x: x[history:], xs)

carry_shape_at_t1 = None
carry_shapes = []

def body_fn(wrapped_carry, x, prefix=None):
i, rng_key, carry = wrapped_carry
init = True if (not_jax_tracer(i) and i == 0) else False
init = True if (not_jax_tracer(i) and i in range(history)) else False
rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None)

# we need to tell unconstrained messenger in potential energy computation
Expand All @@ -141,7 +137,8 @@ def body_fn(wrapped_carry, x, prefix=None):
seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)

if init:
with handlers.scope(prefix="_init"):
# handler the name to match the pattern of sakkar_bilmes product
with handlers.scope(prefix='P' * (history - i), divider='_'):
new_carry, y = seeded_fn(carry, x)
trace = {}
else:
Expand All @@ -152,24 +149,33 @@ def body_fn(wrapped_carry, x, prefix=None):
# at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`,
# and value's batch_shape is (3,), then we promote shape of
# value so that its batch shape is (1, 3)).
new_carry, y = config_enumerate(seeded_fn)(carry, x)
with handlers.scope(divider='_'):
new_carry, y = config_enumerate(seeded_fn)(carry, x)

# store shape of new_carry at a global variable
nonlocal carry_shape_at_t1
carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]]
if len(carry_shapes) < (history + 1):
carry_shapes.append([jnp.shape(x) for x in tree_flatten(new_carry)[0]])
# make new_carry have the same shape as carry
# FIXME: is this rigorous?
new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)),
new_carry, carry)
return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y)
return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)

with markov():
with markov(history=history):
wrapped_carry = (0, rng_key, init)
wrapped_carry, (_, y0) = body_fn(wrapped_carry, x0)
if length == 1:
ys = tree_map(lambda x: jnp.expand_dims(x, 0), y0)
return wrapped_carry, (PytreeTrace({}), ys)
wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_, length - 1, reverse)
y0s = []
for i in range(history):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarification, after changing this to for i in range(2 * history) what other changes would be necessary to get the correct final carry_shape?

wrapped_carry, (_, y0) = body_fn(wrapped_carry, tree_map(lambda z: z[i], x0))
if i > 0:
# reshape y1, y2,... to have the same shape as y0
y0 = tree_multimap(lambda z0, z: jnp.reshape(z, jnp.shape(z0)), y0s[0], y0)
y0s.append(y0)
carry_shapes.append([jnp.shape(x) for x in tree_flatten(wrapped_carry[-1])[0]])
y0s = tree_multimap(lambda *z: jnp.stack(z, axis=0), *y0s)
if length == history:
return wrapped_carry, (PytreeTrace({}), y0s)
wrapped_carry = device_put(wrapped_carry)
wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_, length - history, reverse)

first_var = None
for name, site in pytree_trace.trace.items():
Expand All @@ -187,24 +193,34 @@ def body_fn(wrapped_carry, x, prefix=None):
site['infer']['dim_to_name'][time_dim] = '_time_{}'.format(first_var)

# similar to carry, we need to reshape due to shape alternating in markov
ys = tree_multimap(lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)), y0, ys)
ys = tree_multimap(lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)[1:]), y0s, ys)
# then join with y0s
ys = tree_multimap(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys)
# we also need to reshape `carry` to match sequential behavior
if length % 2 == 0:
i = (length - 1) % (history + 1)
# XXX: unless we unroll more steps, we only know the correct shapes starting from
# the `history`-th iteration; that means we only know carry shapes when
# i == history - 1 or i == history, which corresponds to input carry or output carry
# at the `history`-th iteration.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi I think I understand this issue, thanks for explaining. How much compilation overhead does unrolling 2 * history steps add in the HMM examples? Unrolling history steps in exchange for incorrect output carry shapes is a very subtle optimization that only produces constant-time savings in program length, and may cause maintenance difficulties down the line, so it's important to be sure that it's worth it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it again, I guess it won't take much compiling time because history is usually small (1 or 2). Let me address the issue in this PR.


# NB: no need to reshape if i == history - 1
if i == history:
t, rng_key, carry = wrapped_carry
carry_shape = carry_shapes[i]
flatten_carry, treedef = tree_flatten(carry)
flatten_carry = [jnp.reshape(x, t1_shape)
for x, t1_shape in zip(flatten_carry, carry_shape_at_t1)]
for x, t1_shape in zip(flatten_carry, carry_shape)]
carry = tree_unflatten(treedef, flatten_carry)
wrapped_carry = (t, rng_key, carry)
return wrapped_carry, (pytree_trace, ys)


def scan_wrapper(f, init, xs, length, reverse, rng_key=None, substitute_stack=[], enum=False):
def scan_wrapper(f, init, xs, length, reverse, rng_key=None, substitute_stack=[], enum=False, history=1):
if length is None:
length = tree_flatten(xs)[0][0].shape[0]

if enum:
return scan_enum(f, init, xs, length, reverse, rng_key, substitute_stack)
if enum and history > 0:
return scan_enum(f, init, xs, length, reverse, rng_key, substitute_stack, history)

def body_fn(wrapped_carry, x):
i, rng_key, carry = wrapped_carry
Expand All @@ -229,10 +245,11 @@ def body_fn(wrapped_carry, x):

return (i + 1, rng_key, carry), (PytreeTrace(trace), y)

return lax.scan(body_fn, (jnp.array(0), rng_key, init), xs, length=length, reverse=reverse)
wrapped_carry = device_put((0, rng_key, init))
return lax.scan(body_fn, wrapped_carry, xs, length=length, reverse=reverse)


def scan(f, init, xs, length=None, reverse=False):
def scan(f, init, xs, length=None, reverse=False, history=1):
"""
This primitive scans a function over the leading array axes of
`xs` while carrying along state. See :func:`jax.lax.scan` for more
Expand Down Expand Up @@ -297,13 +314,11 @@ def g(*args, **kwargs):
evaluated using parallel-scan (reference [1]) over time dimension, which
reduces parallel complexity to `O(log(length))`.

Currently, only the equivalence to
:class:`~numpyro.contrib.funsor.enum_messenger.markov(history_size=1)`
is supported. A :class:`~numpyro.handlers.trace` of `scan` with discrete latent
A :class:`~numpyro.handlers.trace` of `scan` with discrete latent
variables will contain the following sites:

+ init sites: those sites belong to the first trace of `f`. Each of
them will have name prefixed with `_init/`.
+ init sites: those sites belong to the first `history` traces of `f`.
Sites at the `i`-th trace will have name prefixed with `P` * (history - i).
+ scanned sites: those sites collect the values of the remaining scan
loop over `f`. An addition time dimension `_time_foo` will be
added to those sites, where `foo` is the name of the first site
Expand All @@ -312,7 +327,8 @@ def g(*args, **kwargs):
Not all transition functions `f` are supported. All of the restrictions from
Pyro's enumeration tutorial [2] still apply here. In addition, there should
not have any site outside of `scan` depend on the first output of `scan`
(the last carry value).
(the last carry value). In addition, when `history > 1`, under enumeration,
the last carry value might not have correct shapes.

** References **

Expand All @@ -331,6 +347,8 @@ def g(*args, **kwargs):
but can be used when `xs` is an empty pytree (e.g. None)
:param bool reverse: optional boolean specifying whether to run the scan iteration
forward (the default) or in reverse
:param int history: The number of previous contexts visible from the current context.
Defaults to 1. If zero, this is similar to :class:`numpyro.plate`.
:return: output of scan, quoted from :func:`jax.lax.scan` docs:
"pair of type (c, [b]) where the first element represents the final loop
carry value and the second element represents the stacked outputs of the
Expand All @@ -347,7 +365,8 @@ def g(*args, **kwargs):
'fn': scan_wrapper,
'args': (f, init, xs, length, reverse),
'kwargs': {'rng_key': None,
'substitute_stack': []},
'substitute_stack': [],
'history': history},
'value': None,
}

Expand Down
2 changes: 1 addition & 1 deletion numpyro/contrib/funsor/enum_messenger.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def postprocess_message(self, msg):
if msg["type"] == "sample":
total_batch_shape = lax.broadcast_shapes(
tuple(msg["fn"].batch_shape),
msg["value"].shape[:len(msg["value"].shape) - msg["fn"].event_dim]
jnp.shape(msg["value"])[:jnp.ndim(msg["value"]) - msg["fn"].event_dim]
)
msg["infer"]["dim_to_name"] = NamedMessenger._get_dim_to_name(total_batch_shape)
if msg["type"] in ("sample", "param"):
Expand Down
Loading