From 427f25264e8276c58de4457b0cb964d7a23591d5 Mon Sep 17 00:00:00 2001 From: frans Date: Thu, 23 Nov 2023 15:01:47 +0100 Subject: [PATCH 01/10] scan replay handling --- numpyro/contrib/control_flow/scan.py | 25 +++++++++++++++++++++++-- numpyro/handlers.py | 4 ++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index b04c51862..845f55645 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -15,6 +15,22 @@ from numpyro.util import not_jax_tracer +def _replay_wrapper(replay_trace, seeded_fn_trace, i, length): + def get_ith_value(site): + if site["name"] not in seeded_fn_trace.keys(): + return site + shape = jnp.shape(site["value"]) + if shape[0] == length: + site["value"] = site["value"][i] + elif shape[0] != length: + raise RuntimeError( + f"Replay value for site {site['name']} " + "requires length equal to scan length." + f" Expected length == {length}, but got {shape[0]}." + ) + return site + return {k: get_ith_value(v.copy()) for k, v in replay_trace.items()} + def _subs_wrapper(subs_map, i, length, site): if site["type"] != "sample": return @@ -39,7 +55,7 @@ def _subs_wrapper(subs_map, i, length, site): # where we apply init_strategy to each element in the scanned series return value elif value_ndim == fn_ndim + 1: - # this branch happens when we substitute a series of values + # this branch happens when xwe substitute a series of values shape = jnp.shape(value) if shape[0] == length: return value[i] @@ -259,6 +275,7 @@ def scan_wrapper( reverse, rng_key=None, substitute_stack=[], + replay_trace=None, enum=False, history=1, first_available_dim=None, @@ -289,7 +306,6 @@ def body_fn(wrapped_carry, x): fn = handlers.infer_config( f, config_fn=lambda msg: {"_scan_current_index": i} ) - seeded_fn = handlers.seed(fn, subkey) if subkey is not None else fn for subs_type, subs_map in substitute_stack: subs_fn = partial(_subs_wrapper, subs_map, i, length) @@ -298,6 +314,11 @@ def body_fn(wrapped_carry, x): elif subs_type == "substitute": seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) + if replay_trace is not None: + seeded_fn_trace = handlers.trace(seeded_fn).get_trace(carry, x) + replay_trace_i = _replay_wrapper(replay_trace, seeded_fn_trace, i, length) + seeded_fn = handlers.replay(seeded_fn, trace=replay_trace_i) + with handlers.trace() as trace: carry, y = seeded_fn(carry, x) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 5745cea32..ad3a030bd 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -223,6 +223,10 @@ def process_message(self, msg): msg["value"] = guide_msg["value"] msg["infer"] = guide_msg["infer"].copy() + if msg["type"] == "control_flow": + print('''msg.get("_control_flow_done", False)''', msg.get("_control_flow_done", False)) + msg["kwargs"]["replay_trace"] = self.trace + class block(Messenger): """ From 9df26281e2afabc130a3118f6c4918bd1b9062d4 Mon Sep 17 00:00:00 2001 From: frans Date: Thu, 23 Nov 2023 15:12:08 +0100 Subject: [PATCH 02/10] fix off by one error --- numpyro/contrib/control_flow/scan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 845f55645..77b29e4ab 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -55,7 +55,7 @@ def _subs_wrapper(subs_map, i, length, site): # where we apply init_strategy to each element in the scanned series return value elif value_ndim == fn_ndim + 1: - # this branch happens when xwe substitute a series of values + # this branch happens when we substitute a series of values shape = jnp.shape(value) if shape[0] == length: return value[i] @@ -316,7 +316,7 @@ def body_fn(wrapped_carry, x): if replay_trace is not None: seeded_fn_trace = handlers.trace(seeded_fn).get_trace(carry, x) - replay_trace_i = _replay_wrapper(replay_trace, seeded_fn_trace, i, length) + replay_trace_i = _replay_wrapper(replay_trace, seeded_fn_trace, i-1, length) seeded_fn = handlers.replay(seeded_fn, trace=replay_trace_i) with handlers.trace() as trace: From 0d0f2485cd201c5520b503d17621845368bf7a64 Mon Sep 17 00:00:00 2001 From: OlaRonning Date: Fri, 24 Nov 2023 09:00:45 +0100 Subject: [PATCH 03/10] cleaned a bit --- numpyro/contrib/control_flow/scan.py | 29 +++++++++++++++++----------- numpyro/handlers.py | 5 ++++- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 77b29e4ab..81db5e7fc 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -15,22 +15,27 @@ from numpyro.util import not_jax_tracer -def _replay_wrapper(replay_trace, seeded_fn_trace, i, length): +def _replay_wrapper(replay_trace, trace, i, length): def get_ith_value(site): - if site["name"] not in seeded_fn_trace.keys(): + if site["name"] not in trace: return site - shape = jnp.shape(site["value"]) - if shape[0] == length: - site["value"] = site["value"][i] - elif shape[0] != length: + + site_len = jnp.shape(site["value"])[0] + + if site_len != length: raise RuntimeError( f"Replay value for site {site['name']} " "requires length equal to scan length." - f" Expected length == {length}, but got {shape[0]}." + f" Expected length {length}, but got {site_len}." ) + + + site["value"] = site["value"][i] return site + return {k: get_ith_value(v.copy()) for k, v in replay_trace.items()} + def _subs_wrapper(subs_map, i, length, site): if site["type"] != "sample": return @@ -281,10 +286,10 @@ def scan_wrapper( first_available_dim=None, ): if length is None: - length = tree_flatten(xs)[0][0].shape[0] + length = jnp.shape(tree_flatten(xs)[0][0])[0] if enum and history > 0: - return scan_enum( + return scan_enum( # TODO: replay for enum f, init, xs, @@ -315,8 +320,10 @@ def body_fn(wrapped_carry, x): seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) if replay_trace is not None: - seeded_fn_trace = handlers.trace(seeded_fn).get_trace(carry, x) - replay_trace_i = _replay_wrapper(replay_trace, seeded_fn_trace, i-1, length) + trace = handlers.trace(seeded_fn).get_trace(carry, x) + replay_trace_i = _replay_wrapper( + replay_trace, trace, i - 1, length + ) seeded_fn = handlers.replay(seeded_fn, trace=replay_trace_i) with handlers.trace() as trace: diff --git a/numpyro/handlers.py b/numpyro/handlers.py index ad3a030bd..4040e5716 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -224,7 +224,10 @@ def process_message(self, msg): msg["infer"] = guide_msg["infer"].copy() if msg["type"] == "control_flow": - print('''msg.get("_control_flow_done", False)''', msg.get("_control_flow_done", False)) + print( + """msg.get("_control_flow_done", False)""", + msg.get("_control_flow_done", False), + ) msg["kwargs"]["replay_trace"] = self.trace From d3548cf5ad54c785b9baeb79c9ff2f0a20b672e1 Mon Sep 17 00:00:00 2001 From: OlaRonning Date: Fri, 24 Nov 2023 09:07:26 +0100 Subject: [PATCH 04/10] lint --- numpyro/contrib/control_flow/scan.py | 7 ++----- numpyro/handlers.py | 7 +------ 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 81db5e7fc..a511c75ab 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -19,7 +19,7 @@ def _replay_wrapper(replay_trace, trace, i, length): def get_ith_value(site): if site["name"] not in trace: return site - + site_len = jnp.shape(site["value"])[0] if site_len != length: @@ -28,7 +28,6 @@ def get_ith_value(site): "requires length equal to scan length." f" Expected length {length}, but got {site_len}." ) - site["value"] = site["value"][i] return site @@ -321,9 +320,7 @@ def body_fn(wrapped_carry, x): if replay_trace is not None: trace = handlers.trace(seeded_fn).get_trace(carry, x) - replay_trace_i = _replay_wrapper( - replay_trace, trace, i - 1, length - ) + replay_trace_i = _replay_wrapper(replay_trace, trace, i - 1, length) seeded_fn = handlers.replay(seeded_fn, trace=replay_trace_i) with handlers.trace() as trace: diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 4040e5716..830ebf7e1 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -222,13 +222,8 @@ def process_message(self, msg): raise RuntimeError(f"Site {name} must be sampled in trace.") msg["value"] = guide_msg["value"] msg["infer"] = guide_msg["infer"].copy() - if msg["type"] == "control_flow": - print( - """msg.get("_control_flow_done", False)""", - msg.get("_control_flow_done", False), - ) - msg["kwargs"]["replay_trace"] = self.trace + msg["kwargs"]["substitute_stack"].append({"replay": self.trace}) class block(Messenger): From ba2682cf531a87348ca47e3656f3745623b664ce Mon Sep 17 00:00:00 2001 From: OlaRonning Date: Fri, 24 Nov 2023 09:35:50 +0100 Subject: [PATCH 05/10] use substack --- numpyro/contrib/control_flow/scan.py | 9 ++++----- numpyro/handlers.py | 2 +- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index a511c75ab..84d904366 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -279,7 +279,6 @@ def scan_wrapper( reverse, rng_key=None, substitute_stack=[], - replay_trace=None, enum=False, history=1, first_available_dim=None, @@ -318,10 +317,10 @@ def body_fn(wrapped_carry, x): elif subs_type == "substitute": seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) - if replay_trace is not None: - trace = handlers.trace(seeded_fn).get_trace(carry, x) - replay_trace_i = _replay_wrapper(replay_trace, trace, i - 1, length) - seeded_fn = handlers.replay(seeded_fn, trace=replay_trace_i) + elif subs_type == "replay": + trace = handlers.trace(seeded_fn).get_trace(carry, x) + replay_trace_i = _replay_wrapper(subs_map, trace, i - 1, length) + seeded_fn = handlers.replay(seeded_fn, trace=replay_trace_i) with handlers.trace() as trace: carry, y = seeded_fn(carry, x) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 830ebf7e1..3e38e204c 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -223,7 +223,7 @@ def process_message(self, msg): msg["value"] = guide_msg["value"] msg["infer"] = guide_msg["infer"].copy() if msg["type"] == "control_flow": - msg["kwargs"]["substitute_stack"].append({"replay": self.trace}) + msg["kwargs"]["substitute_stack"].append(("replay", self.trace)) class block(Messenger): From ce3be2d1ff674adf9596838b2f86e1d057c731ca Mon Sep 17 00:00:00 2001 From: = Date: Fri, 24 Nov 2023 13:26:42 +0100 Subject: [PATCH 06/10] handle custom guides --- numpyro/contrib/control_flow/scan.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 84d904366..0739dc1d1 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -17,17 +17,10 @@ def _replay_wrapper(replay_trace, trace, i, length): def get_ith_value(site): - if site["name"] not in trace: - return site - site_len = jnp.shape(site["value"])[0] - - if site_len != length: - raise RuntimeError( - f"Replay value for site {site['name']} " - "requires length equal to scan length." - f" Expected length {length}, but got {site_len}." - ) + + if site["name"] not in trace or site_len != length or site["type"] not in ("sample", "deterministic"): + return site site["value"] = site["value"][i] return site @@ -319,7 +312,7 @@ def body_fn(wrapped_carry, x): elif subs_type == "replay": trace = handlers.trace(seeded_fn).get_trace(carry, x) - replay_trace_i = _replay_wrapper(subs_map, trace, i - 1, length) + replay_trace_i = _replay_wrapper(subs_map, trace, i, length) seeded_fn = handlers.replay(seeded_fn, trace=replay_trace_i) with handlers.trace() as trace: From 42bb6f210071d7516d454b1700324510aedccc22 Mon Sep 17 00:00:00 2001 From: frans Date: Wed, 6 Dec 2023 13:06:39 +0100 Subject: [PATCH 07/10] handle empty value shape --- numpyro/contrib/control_flow/scan.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 0739dc1d1..9d6f4da1a 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -17,7 +17,8 @@ def _replay_wrapper(replay_trace, trace, i, length): def get_ith_value(site): - site_len = jnp.shape(site["value"])[0] + value_shape = jnp.shape(site["value"]) + site_len = value_shape[0] if value_shape else 0 if site["name"] not in trace or site_len != length or site["type"] not in ("sample", "deterministic"): return site @@ -309,7 +310,6 @@ def body_fn(wrapped_carry, x): seeded_fn = handlers.condition(seeded_fn, condition_fn=subs_fn) elif subs_type == "substitute": seeded_fn = handlers.substitute(seeded_fn, substitute_fn=subs_fn) - elif subs_type == "replay": trace = handlers.trace(seeded_fn).get_trace(carry, x) replay_trace_i = _replay_wrapper(subs_map, trace, i, length) From 70bc8f3aa7939feebb36aad11f765f9600541468 Mon Sep 17 00:00:00 2001 From: frans Date: Wed, 6 Dec 2023 13:12:29 +0100 Subject: [PATCH 08/10] a --- numpyro/contrib/control_flow/scan.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index 9d6f4da1a..cc3a251f8 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -19,8 +19,11 @@ def _replay_wrapper(replay_trace, trace, i, length): def get_ith_value(site): value_shape = jnp.shape(site["value"]) site_len = value_shape[0] if value_shape else 0 - - if site["name"] not in trace or site_len != length or site["type"] not in ("sample", "deterministic"): + if ( + site["name"] not in trace + or site_len != length + or site["type"] not in ("sample", "deterministic") + ): return site site["value"] = site["value"][i] From a975ab3803f6dfaf9be71590636eb54bdd75b0ab Mon Sep 17 00:00:00 2001 From: frans Date: Wed, 6 Dec 2023 14:41:09 +0100 Subject: [PATCH 09/10] test --- test/contrib/test_control_flow.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/test/contrib/test_control_flow.py b/test/contrib/test_control_flow.py index 67273d02c..f75686daf 100644 --- a/test/contrib/test_control_flow.py +++ b/test/contrib/test_control_flow.py @@ -12,7 +12,9 @@ import numpyro.distributions as dist from numpyro.handlers import mask, seed, substitute, trace from numpyro.infer import MCMC, NUTS, SVI, Predictive, Trace_ELBO +from numpyro.infer.autoguide import AutoNormal from numpyro.infer.util import log_density, potential_energy +from numpyro.optim import Adam def test_scan(): @@ -241,3 +243,32 @@ def transition(carry, y_curr): assert model_density assert model_trace["x"]["fn"].batch_shape == (12, 10) assert model_trace["x"]["fn"].event_shape == (3,) + + +def test_scan_svi(): + T = 3 + N = 5 + + def gaussian_hmm(y=None, T=T, N=N): + def transition(x_prev, y_curr): + with numpyro.plate("data", N): + x_curr = numpyro.sample("x", dist.Normal(x_prev, 1.5)) + y_curr = numpyro.sample("y", dist.Normal(x_curr, 0.1), obs=y_curr) + return x_curr, (x_curr, y_curr) + + with numpyro.plate("data", N): + x0 = numpyro.sample("x_0", dist.Normal(jnp.zeros(N), 5.0)) + _, (x, y) = scan(transition, x0, y, length=T) + return (x, y) + + with numpyro.handlers.seed(rng_seed=0): + x, y = gaussian_hmm() + with numpyro.handlers.seed(rng_seed=0): + tr = numpyro.handlers.trace(gaussian_hmm).get_trace(y=y, T=T, N=N) + + guide = AutoNormal(gaussian_hmm) + svi = SVI(gaussian_hmm, guide, Adam(0.1), Trace_ELBO(), y=y, T=T, N=N) + results = svi.run(random.PRNGKey(0), 10**3) + + xhat = results.params["x_auto_loc"] + assert_allclose(xhat, tr["x"]["value"], rtol=0.1) From 9c18d50dd0a4fbe4539fca6dc5dd863d11573249 Mon Sep 17 00:00:00 2001 From: Frans Zdyb Date: Mon, 18 Dec 2023 13:14:28 +0100 Subject: [PATCH 10/10] move copy inside get_ith_value --- numpyro/contrib/control_flow/scan.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index cc3a251f8..4bd2143a0 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -26,10 +26,11 @@ def get_ith_value(site): ): return site + site = site.copy() site["value"] = site["value"][i] return site - return {k: get_ith_value(v.copy()) for k, v in replay_trace.items()} + return {k: get_ith_value(v) for k, v in replay_trace.items()} def _subs_wrapper(subs_map, i, length, site):