-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support scan markov with history > 1 #848
Changes from 6 commits
bfb1bf6
51aaf62
2474ac8
9a44196
ab6a591
60c9312
5becd4f
a1e55b7
a17571a
629335d
74fe244
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
from collections import OrderedDict | ||
from functools import partial | ||
|
||
from jax import lax, random, tree_flatten, tree_map, tree_multimap, tree_unflatten | ||
from jax import device_put, lax, random, tree_flatten, tree_map, tree_multimap, tree_unflatten | ||
import jax.numpy as jnp | ||
from jax.tree_util import register_pytree_node_class | ||
|
||
|
@@ -105,27 +105,23 @@ def process_message(self, msg): | |
msg["fn"] = tree_map(lambda x: jnp.reshape(x, prepend_shapes + jnp.shape(x)), fn) | ||
|
||
|
||
def scan_enum(f, init, xs, length, reverse, rng_key=None, substitute_stack=None): | ||
def scan_enum(f, init, xs, length, reverse, rng_key=None, substitute_stack=None, history=1): | ||
from numpyro.contrib.funsor import config_enumerate, enum, markov | ||
from numpyro.contrib.funsor import trace as packed_trace | ||
|
||
# XXX: This implementation only works for history size=1 but can be | ||
# extended to history size > 1 by running `f` `history_size` times | ||
# for initialization. However, `sequential_sum_product` does not | ||
# support history size > 1, so we skip supporting it here. | ||
# Note that `funsor.sum_product.sarkka_bilmes_product` does support history > 1. | ||
history = min(history, length) | ||
if reverse: | ||
x0 = tree_map(lambda x: x[-1], xs) | ||
xs_ = tree_map(lambda x: x[:-1], xs) | ||
x0 = tree_map(lambda x: x[-history:][::-1], xs) | ||
xs_ = tree_map(lambda x: x[:-history], xs) | ||
else: | ||
x0 = tree_map(lambda x: x[0], xs) | ||
xs_ = tree_map(lambda x: x[1:], xs) | ||
x0 = tree_map(lambda x: x[:history], xs) | ||
xs_ = tree_map(lambda x: x[history:], xs) | ||
|
||
carry_shape_at_t1 = None | ||
carry_shapes = [] | ||
|
||
def body_fn(wrapped_carry, x, prefix=None): | ||
i, rng_key, carry = wrapped_carry | ||
init = True if (not_jax_tracer(i) and i == 0) else False | ||
init = True if (not_jax_tracer(i) and i in range(history)) else False | ||
rng_key, subkey = random.split(rng_key) if rng_key is not None else (None, None) | ||
|
||
# we need to tell unconstrained messenger in potential energy computation | ||
|
@@ -141,7 +137,8 @@ def body_fn(wrapped_carry, x, prefix=None): | |
seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) | ||
|
||
if init: | ||
with handlers.scope(prefix="_init"): | ||
# handler the name to match the pattern of sakkar_bilmes product | ||
with handlers.scope(prefix='P' * (history - i), divider='_'): | ||
new_carry, y = seeded_fn(carry, x) | ||
trace = {} | ||
else: | ||
|
@@ -152,24 +149,33 @@ def body_fn(wrapped_carry, x, prefix=None): | |
# at each site have the same batch dims (e.g. if `fn.batch_shape = (2, 3)`, | ||
# and value's batch_shape is (3,), then we promote shape of | ||
# value so that its batch shape is (1, 3)). | ||
new_carry, y = config_enumerate(seeded_fn)(carry, x) | ||
with handlers.scope(divider='_'): | ||
new_carry, y = config_enumerate(seeded_fn)(carry, x) | ||
|
||
# store shape of new_carry at a global variable | ||
nonlocal carry_shape_at_t1 | ||
carry_shape_at_t1 = [jnp.shape(x) for x in tree_flatten(new_carry)[0]] | ||
if len(carry_shapes) < (history + 1): | ||
carry_shapes.append([jnp.shape(x) for x in tree_flatten(new_carry)[0]]) | ||
# make new_carry have the same shape as carry | ||
# FIXME: is this rigorous? | ||
new_carry = tree_multimap(lambda a, b: jnp.reshape(a, jnp.shape(b)), | ||
new_carry, carry) | ||
return (i + jnp.array(1), rng_key, new_carry), (PytreeTrace(trace), y) | ||
return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y) | ||
|
||
with markov(): | ||
with markov(history=history): | ||
wrapped_carry = (0, rng_key, init) | ||
wrapped_carry, (_, y0) = body_fn(wrapped_carry, x0) | ||
if length == 1: | ||
ys = tree_map(lambda x: jnp.expand_dims(x, 0), y0) | ||
return wrapped_carry, (PytreeTrace({}), ys) | ||
wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_, length - 1, reverse) | ||
y0s = [] | ||
for i in range(history): | ||
wrapped_carry, (_, y0) = body_fn(wrapped_carry, tree_map(lambda z: z[i], x0)) | ||
if i > 0: | ||
# reshape y1, y2,... to have the same shape as y0 | ||
y0 = tree_multimap(lambda z0, z: jnp.reshape(z, jnp.shape(z0)), y0s[0], y0) | ||
y0s.append(y0) | ||
carry_shapes.append([jnp.shape(x) for x in tree_flatten(wrapped_carry[-1])[0]]) | ||
y0s = tree_multimap(lambda *z: jnp.stack(z, axis=0), *y0s) | ||
if length == history: | ||
return wrapped_carry, (PytreeTrace({}), y0s) | ||
wrapped_carry = device_put(wrapped_carry) | ||
wrapped_carry, (pytree_trace, ys) = lax.scan(body_fn, wrapped_carry, xs_, length - history, reverse) | ||
|
||
first_var = None | ||
for name, site in pytree_trace.trace.items(): | ||
|
@@ -187,24 +193,34 @@ def body_fn(wrapped_carry, x, prefix=None): | |
site['infer']['dim_to_name'][time_dim] = '_time_{}'.format(first_var) | ||
|
||
# similar to carry, we need to reshape due to shape alternating in markov | ||
ys = tree_multimap(lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)), y0, ys) | ||
ys = tree_multimap(lambda z0, z: jnp.reshape(z, z.shape[:1] + jnp.shape(z0)[1:]), y0s, ys) | ||
# then join with y0s | ||
ys = tree_multimap(lambda z0, z: jnp.concatenate([z0, z], axis=0), y0s, ys) | ||
# we also need to reshape `carry` to match sequential behavior | ||
if length % 2 == 0: | ||
i = (length - 1) % (history + 1) | ||
# XXX: unless we unroll more steps, we only know the correct shapes starting from | ||
# the `history`-th iteration; that means we only know carry shapes when | ||
# i == history - 1 or i == history, which corresponds to input carry or output carry | ||
# at the `history`-th iteration. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fehiepsi I think I understand this issue, thanks for explaining. How much compilation overhead does unrolling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thinking about it again, I guess it won't take much compiling time because history is usually small (1 or 2). Let me address the issue in this PR. |
||
|
||
# NB: no need to reshape if i == history - 1 | ||
if i == history: | ||
t, rng_key, carry = wrapped_carry | ||
carry_shape = carry_shapes[i] | ||
flatten_carry, treedef = tree_flatten(carry) | ||
flatten_carry = [jnp.reshape(x, t1_shape) | ||
for x, t1_shape in zip(flatten_carry, carry_shape_at_t1)] | ||
for x, t1_shape in zip(flatten_carry, carry_shape)] | ||
carry = tree_unflatten(treedef, flatten_carry) | ||
wrapped_carry = (t, rng_key, carry) | ||
return wrapped_carry, (pytree_trace, ys) | ||
|
||
|
||
def scan_wrapper(f, init, xs, length, reverse, rng_key=None, substitute_stack=[], enum=False): | ||
def scan_wrapper(f, init, xs, length, reverse, rng_key=None, substitute_stack=[], enum=False, history=1): | ||
if length is None: | ||
length = tree_flatten(xs)[0][0].shape[0] | ||
|
||
if enum: | ||
return scan_enum(f, init, xs, length, reverse, rng_key, substitute_stack) | ||
if enum and history > 0: | ||
return scan_enum(f, init, xs, length, reverse, rng_key, substitute_stack, history) | ||
|
||
def body_fn(wrapped_carry, x): | ||
i, rng_key, carry = wrapped_carry | ||
|
@@ -229,10 +245,11 @@ def body_fn(wrapped_carry, x): | |
|
||
return (i + 1, rng_key, carry), (PytreeTrace(trace), y) | ||
|
||
return lax.scan(body_fn, (jnp.array(0), rng_key, init), xs, length=length, reverse=reverse) | ||
wrapped_carry = device_put((0, rng_key, init)) | ||
return lax.scan(body_fn, wrapped_carry, xs, length=length, reverse=reverse) | ||
|
||
|
||
def scan(f, init, xs, length=None, reverse=False): | ||
def scan(f, init, xs, length=None, reverse=False, history=1): | ||
""" | ||
This primitive scans a function over the leading array axes of | ||
`xs` while carrying along state. See :func:`jax.lax.scan` for more | ||
|
@@ -297,13 +314,11 @@ def g(*args, **kwargs): | |
evaluated using parallel-scan (reference [1]) over time dimension, which | ||
reduces parallel complexity to `O(log(length))`. | ||
|
||
Currently, only the equivalence to | ||
:class:`~numpyro.contrib.funsor.enum_messenger.markov(history_size=1)` | ||
is supported. A :class:`~numpyro.handlers.trace` of `scan` with discrete latent | ||
A :class:`~numpyro.handlers.trace` of `scan` with discrete latent | ||
variables will contain the following sites: | ||
|
||
+ init sites: those sites belong to the first trace of `f`. Each of | ||
them will have name prefixed with `_init/`. | ||
+ init sites: those sites belong to the first `history` traces of `f`. | ||
Sites at the `i`-th trace will have name prefixed with `P` * (history - i). | ||
+ scanned sites: those sites collect the values of the remaining scan | ||
loop over `f`. An addition time dimension `_time_foo` will be | ||
added to those sites, where `foo` is the name of the first site | ||
|
@@ -312,7 +327,8 @@ def g(*args, **kwargs): | |
Not all transition functions `f` are supported. All of the restrictions from | ||
Pyro's enumeration tutorial [2] still apply here. In addition, there should | ||
not have any site outside of `scan` depend on the first output of `scan` | ||
(the last carry value). | ||
(the last carry value). In addition, when `history > 1`, under enumeration, | ||
the last carry value might not have correct shapes. | ||
|
||
** References ** | ||
|
||
|
@@ -331,6 +347,8 @@ def g(*args, **kwargs): | |
but can be used when `xs` is an empty pytree (e.g. None) | ||
:param bool reverse: optional boolean specifying whether to run the scan iteration | ||
forward (the default) or in reverse | ||
:param int history: The number of previous contexts visible from the current context. | ||
Defaults to 1. If zero, this is similar to :class:`numpyro.plate`. | ||
:return: output of scan, quoted from :func:`jax.lax.scan` docs: | ||
"pair of type (c, [b]) where the first element represents the final loop | ||
carry value and the second element represents the stacked outputs of the | ||
|
@@ -347,7 +365,8 @@ def g(*args, **kwargs): | |
'fn': scan_wrapper, | ||
'args': (f, init, xs, length, reverse), | ||
'kwargs': {'rng_key': None, | ||
'substitute_stack': []}, | ||
'substitute_stack': [], | ||
'history': history}, | ||
'value': None, | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For clarification, after changing this to
for i in range(2 * history)
what other changes would be necessary to get the correct finalcarry_shape
?