Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support scan for Trace_ELBO #1693

Merged
merged 10 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,24 @@
from numpyro.util import not_jax_tracer


def _replay_wrapper(replay_trace, trace, i, length):
def get_ith_value(site):
value_shape = jnp.shape(site["value"])
site_len = value_shape[0] if value_shape else 0
if (
site["name"] not in trace
or site_len != length
or site["type"] not in ("sample", "deterministic")
):
return site

site = site.copy()
site["value"] = site["value"][i]
return site

return {k: get_ith_value(v) for k, v in replay_trace.items()}


def _subs_wrapper(subs_map, i, length, site):
if site["type"] != "sample":
return
Expand Down Expand Up @@ -264,10 +282,10 @@ def scan_wrapper(
first_available_dim=None,
):
if length is None:
length = tree_flatten(xs)[0][0].shape[0]
length = jnp.shape(tree_flatten(xs)[0][0])[0]

if enum and history > 0:
return scan_enum(
return scan_enum( # TODO: replay for enum
f,
init,
xs,
Expand All @@ -289,14 +307,17 @@ def body_fn(wrapped_carry, x):
fn = handlers.infer_config(
f, config_fn=lambda msg: {"_scan_current_index": i}
)

seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn
for subs_type, subs_map in substitute_stack:
subs_fn = partial(_subs_wrapper, subs_map, i, length)
if subs_type == "condition":
seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn)
elif subs_type == "substitute":
seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn)
elif subs_type == "replay":
trace = handlers.trace(seeded_fn).get_trace(carry, x)
replay_trace_i = _replay_wrapper(subs_map, trace, i, length)
seeded_fn = handlers.replay(seeded_fn, trace=replay_trace_i)

with handlers.trace() as trace:
carry, y = seeded_fn(carry, x)
Expand Down
2 changes: 2 additions & 0 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def process_message(self, msg):
raise RuntimeError(f"Site {name} must be sampled in trace.")
msg["value"] = guide_msg["value"]
msg["infer"] = guide_msg["infer"].copy()
if msg["type"] == "control_flow":
msg["kwargs"]["substitute_stack"].append(("replay", self.trace))


class block(Messenger):
Expand Down
31 changes: 31 additions & 0 deletions test/contrib/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import numpyro.distributions as dist
from numpyro.handlers import mask, seed, substitute, trace
from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.util import log_density, potential_energy
from numpyro.optim import Adam


def test_scan():
Expand Down Expand Up @@ -241,3 +243,32 @@ def transition(carry, y_curr):
assert model_density
assert model_trace["x"]["fn"].batch_shape == (12, 10)
assert model_trace["x"]["fn"].event_shape == (3,)


def test_scan_svi():
T = 3
N = 5

def gaussian_hmm(y=None, T=T, N=N):
def transition(x_prev, y_curr):
with numpyro.plate("data", N):
x_curr = numpyro.sample("x", dist.Normal(x_prev, 1.5))
y_curr = numpyro.sample("y", dist.Normal(x_curr, 0.1), obs=y_curr)
return x_curr, (x_curr, y_curr)

with numpyro.plate("data", N):
x0 = numpyro.sample("x_0", dist.Normal(jnp.zeros(N), 5.0))
_, (x, y) = scan(transition, x0, y, length=T)
return (x, y)

with numpyro.handlers.seed(rng_seed=0):
x, y = gaussian_hmm()
with numpyro.handlers.seed(rng_seed=0):
tr = numpyro.handlers.trace(gaussian_hmm).get_trace(y=y, T=T, N=N)

guide = AutoNormal(gaussian_hmm)
svi = SVI(gaussian_hmm, guide, Adam(0.1), Trace_ELBO(), y=y, T=T, N=N)
results = svi.run(random.PRNGKey(0), 10**3)

xhat = results.params["x_auto_loc"]
assert_allclose(xhat, tr["x"]["value"], rtol=0.1)
Loading