From 91402020579f4fb5ee5b92e3ee8c33d3d49dcd93 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 19 Feb 2023 18:29:50 +0000 Subject: [PATCH 1/5] subsample scale --- numpyro/infer/elbo.py | 57 ++++++++++++++++++++++++---------- test/contrib/test_enum_elbo.py | 42 ++++++++++++------------- 2 files changed, 61 insertions(+), 38 deletions(-) diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index b324f67b6..f76eca1b4 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict, defaultdict -from functools import partial +from functools import partial, reduce from operator import itemgetter import warnings @@ -992,6 +992,19 @@ def single_particle_elbo(rng_key): *(frozenset(f.inputs) & group_plates for f in group_factors) ) elim_plates = group_plates - outermost_plates + plate_to_scale = {} + for name in group_names: + for plate, value in ( + model_trace[name].get("plate_to_scale", {}).items() + ): + if plate in plate_to_scale: + if value != plate_to_scale[plate]: + raise ValueError( + "Expected all enumerated sample sites to share a common scale factor, " + f"but found different scales at plate('{plate}')." + ) + else: + plate_to_scale[plate] = value with funsor.interpretations.normalize: cost = funsor.sum_product.sum_product( funsor.ops.logaddexp, @@ -999,26 +1012,36 @@ def single_particle_elbo(rng_key): group_factors, plates=group_plates, eliminate=group_sum_vars | elim_plates, + plate_to_scale=plate_to_scale, ) # TODO: add memoization cost = funsor.optimizer.apply_optimizer(cost) # incorporate the effects of subsampling and handlers.scale through a common scale factor - scales_set = set() - for name in group_names | group_sum_vars: - site_scale = model_trace[name]["scale"] - if site_scale is None: - site_scale = 1.0 - if isinstance(site_scale, jnp.ndarray): - raise ValueError( - "Enumeration only supports scalar handlers.scale" - ) - scales_set.add(float(site_scale)) - if len(scales_set) != 1: - raise ValueError( - "Expected all enumerated sample sites to share a common scale, " - f"but found {len(scales_set)} different scales." - ) - scale = next(iter(scales_set)) + scale = reduce( + funsor.ops.mul, + [ + value + for plate, value in plate_to_scale.items() + if plate not in elim_plates + ], + 1.0, + ) + # scales_set = set() + # for name in group_names | group_sum_vars: + # site_scale = model_trace[name]["scale"] + # if site_scale is None: + # site_scale = 1.0 + # if isinstance(site_scale, jnp.ndarray): + # raise ValueError( + # "Enumeration only supports scalar handlers.scale" + # ) + # scales_set.add(float(site_scale)) + # if len(scales_set) != 1: + # raise ValueError( + # "Expected all enumerated sample sites to share a common scale, " + # f"but found {len(scales_set)} different scales." + # ) + # scale = next(iter(scales_set)) # combine deps deps = frozenset().union( *[model_deps[name] for name in group_names] diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index d464fb33a..d3a49db50 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -2314,14 +2314,14 @@ def actual_loss_fn(params_raw): } return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params) - with pytest.raises( - ValueError, match="Expected all enumerated sample sites to share a common scale" - ): - # This never gets run because we don't support this yet. - actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) + # with pytest.raises( + # ValueError, match="Expected all enumerated sample sites to share a common scale" + # ): + # This never gets run because we don't support this yet. + actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) - assert_equal(actual_loss, expected_loss, prec=1e-5) - assert_equal(actual_grads, expected_grads, prec=1e-5) + assert_equal(actual_loss, expected_loss, prec=1e-5) + assert_equal(actual_grads, expected_grads, prec=1e-5) @pytest.mark.parametrize("scale", [1, 10]) @@ -2389,14 +2389,14 @@ def actual_loss_fn(params_raw): } return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params) - with pytest.raises( - ValueError, match="Expected all enumerated sample sites to share a common scale" - ): - # This never gets run because we don't support this yet. - actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) + # with pytest.raises( + # ValueError, match="Expected all enumerated sample sites to share a common scale" + # ): + # This never gets run because we don't support this yet. + actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) - assert_equal(actual_loss, expected_loss, prec=1e-5) - assert_equal(actual_grads, expected_grads, prec=1e-5) + assert_equal(actual_loss, expected_loss, prec=1e-5) + assert_equal(actual_grads, expected_grads, prec=1e-5) @pytest.mark.parametrize("scale", [1, 10]) @@ -2464,14 +2464,14 @@ def actual_loss_fn(params_raw): } return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params) - with pytest.raises( - ValueError, match="Expected all enumerated sample sites to share a common scale" - ): - # This never gets run because we don't support this yet. - actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) + # with pytest.raises( + # ValueError, match="Expected all enumerated sample sites to share a common scale" + # ): + # This never gets run because we don't support this yet. + actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) - assert_equal(actual_loss, expected_loss, prec=1e-3) - assert_equal(actual_grads, expected_grads, prec=1e-5) + assert_equal(actual_loss, expected_loss, prec=1e-3) + assert_equal(actual_grads, expected_grads, prec=1e-5) def test_guide_plate_contraction(): From 0c612d95367232c0b809c34388b0cb2e4a9be7e4 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Sun, 9 Apr 2023 01:50:30 +0000 Subject: [PATCH 2/5] remove comments --- numpyro/infer/elbo.py | 16 ---------------- test/contrib/test_enum_elbo.py | 12 ------------ 2 files changed, 28 deletions(-) diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index e03c43d30..b29cbd45e 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -993,22 +993,6 @@ def single_particle_elbo(rng_key): ], 1.0, ) - # scales_set = set() - # for name in group_names | group_sum_vars: - # site_scale = model_trace[name]["scale"] - # if site_scale is None: - # site_scale = 1.0 - # if isinstance(site_scale, jnp.ndarray): - # raise ValueError( - # "Enumeration only supports scalar handlers.scale" - # ) - # scales_set.add(float(site_scale)) - # if len(scales_set) != 1: - # raise ValueError( - # "Expected all enumerated sample sites to share a common scale, " - # f"but found {len(scales_set)} different scales." - # ) - # scale = next(iter(scales_set)) # combine deps deps = frozenset().union( *[model_deps[name] for name in group_names] diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index d3a49db50..1a6b0ef66 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -2314,10 +2314,6 @@ def actual_loss_fn(params_raw): } return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params) - # with pytest.raises( - # ValueError, match="Expected all enumerated sample sites to share a common scale" - # ): - # This never gets run because we don't support this yet. actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) assert_equal(actual_loss, expected_loss, prec=1e-5) @@ -2389,10 +2385,6 @@ def actual_loss_fn(params_raw): } return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params) - # with pytest.raises( - # ValueError, match="Expected all enumerated sample sites to share a common scale" - # ): - # This never gets run because we don't support this yet. actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) assert_equal(actual_loss, expected_loss, prec=1e-5) @@ -2464,10 +2456,6 @@ def actual_loss_fn(params_raw): } return elbo.loss(random.PRNGKey(0), {}, model_subsample, guide, params) - # with pytest.raises( - # ValueError, match="Expected all enumerated sample sites to share a common scale" - # ): - # This never gets run because we don't support this yet. actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) assert_equal(actual_loss, expected_loss, prec=1e-3) From feecd2f03923d31e0c39f83d60e9bf3337fc66ce Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 31 Aug 2023 17:39:34 +0000 Subject: [PATCH 3/5] fix comment --- test/contrib/test_enum_elbo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index 1a6b0ef66..2d166cce7 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -2394,7 +2394,7 @@ def actual_loss_fn(params_raw): @pytest.mark.parametrize("scale", [1, 10]) def test_model_enum_subsample_3(scale): # Enumerate: a - # Subsample: a, b, c + # Subsample: b, c # [ a - [----> b ] # [ \ [ ] # [ - [- [-> c ] ] From 2bb86f763600503a79b9c12eed935b5763fd8252 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 31 Aug 2023 17:49:28 +0000 Subject: [PATCH 4/5] pin to funsor master branch --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index bba87d335..1202e6589 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,8 @@ "dev": [ "dm-haiku", "flax", - "funsor>=0.4.1", + "funsor @ git+https://github.com/pyro-ppl/funsor.git", + # "funsor>=0.4.1", "graphviz", "jaxns>=2.0.1", "matplotlib", From 1d9813e9fd95491d8baa7ce5c394097fd7d04646 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev Date: Thu, 31 Aug 2023 18:46:05 +0000 Subject: [PATCH 5/5] pin funsor to release version --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 1202e6589..334905d17 100644 --- a/setup.py +++ b/setup.py @@ -61,8 +61,7 @@ "dev": [ "dm-haiku", "flax", - "funsor @ git+https://github.com/pyro-ppl/funsor.git", - # "funsor>=0.4.1", + "funsor>=0.4.6", "graphviz", "jaxns>=2.0.1", "matplotlib",