From fa63f9ad7d5daba8214afe623429d8fa01373ee2 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 27 Feb 2024 17:43:39 -0500 Subject: [PATCH 01/42] added zerosumnormal and tests --- numpyro/distributions/__init__.py | 2 + numpyro/distributions/continuous.py | 110 ++++++++++++++++++++++++++++ test/test_distributions.py | 15 +++- 3 files changed, 126 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index 49554097e..d05376573 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -47,6 +47,7 @@ StudentT, Uniform, Weibull, + ZeroSumNormal, ) from numpyro.distributions.copula import GaussianCopula, GaussianCopulaBeta from numpyro.distributions.directional import ( @@ -196,4 +197,5 @@ "ZeroInflatedDistribution", "ZeroInflatedPoisson", "ZeroInflatedNegativeBinomial2", + "ZeroSumNormal", ] diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 220c23b92..973544ff7 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -25,6 +25,8 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. +from typing import Optional + import numpy as np from jax import lax, vmap @@ -2444,3 +2446,111 @@ def cdf(self, value): def icdf(self, value): return self._ald.icdf(value) + + +class ZeroSumNormal(Distribution): + arg_constraints = {"scale": constraints.positive} + support = constraints.real + reparametrized_params = ["scale"] + pytree_aux_fields = ("n_zerosum_axes","support_shape",) + + def __init__(self, scale=1.0, n_zerosum_axes=None, support_shape=(1,), *, validate_args=None): + if not all(tuple(i == 1 for i in jnp.shape( scale ))): + raise ValueError("scale must have length one across the zero-sum axes") + + self.n_zerosum_axes = self.check_zerosum_axes(n_zerosum_axes) + # batch_shape = lax.broadcast_shapes(jnp.shape(scale)) + if jnp.ndim(scale) == 0: + (scale,) = promote_shapes(scale, shape=(1,)) + # temporary append a new axis to scale + scale = scale[..., jnp.newaxis] + cov_placeholder = jnp.eye(len(support_shape)) + scale, cov_placeholder = promote_shapes(scale, cov_placeholder) + batch_shape = lax.broadcast_shapes( + jnp.shape(scale)[:-2], jnp.shape(cov_placeholder)[:-2] + ) + self.scale = scale[..., 0] + super(ZeroSumNormal, self).__init__( + batch_shape=batch_shape, + event_shape=support_shape, + validate_args=validate_args + ) + + def sample(self, key, sample_shape=()): + assert is_prng_key(key) + zerosum_rv_ = random.normal( + key, shape=sample_shape + self.batch_shape + self.event_shape + ) * self.scale + + if not zerosum_rv_.shape: + return jnp.zeros(zerosum_rv_.shape) + + for axis in range(self.n_zerosum_axes): + zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True) + return zerosum_rv_ + + @validate_sample + def log_prob(self, value): + shape = jnp.array(value.shape) + _deg_free_support_shape = shape.at[-self.n_zerosum_axes:].set( shape[-self.n_zerosum_axes:] - 1 ) + _full_size = jnp.prod(shape).astype(float) + _degrees_of_freedom = jnp.prod(_deg_free_support_shape).astype(float) + + if not value.shape or self.batch_shape: + value = jnp.expand_dims(value, -1) + + log_pdf = jnp.sum( + -0.5 * jnp.pow(value / self.scale, 2) + - (jnp.log(jnp.sqrt(2.0 * jnp.pi)) + jnp.log(self.scale)) * _degrees_of_freedom / _full_size, + axis=tuple(np.arange(-self.n_zerosum_axes, 0)), + ) + return log_pdf + + # def cdf(self, value): + # scaled = (value - 0) / self.scale + # return ndtr(scaled) + + # def log_cdf(self, value): + # return jax_norm.logcdf(value, loc=0, scale=self.scale) + + # def icdf(self, q): + # return 0 + self.scale * ndtri(q) + + @property + def mean(self): + return jnp.broadcast_to(0, self.batch_shape) + + @property + def variance(self): + theoretical_var = self.scale.astype(float)**2 + for axis in range(1,self.n_zerosum_axes+1): + theoretical_var *= (1 - 1 / self.event_shape[-axis]) + + return theoretical_var + + def check_zerosum_axes(self, n_zerosum_axes: Optional[int]) -> int: + if n_zerosum_axes is None: + n_zerosum_axes = 1 + + is_integer = isinstance(n_zerosum_axes, int) + is_jax_int_array = isinstance(n_zerosum_axes, jnp.ndarray) and jnp.issubdtype(n_zerosum_axes.dtype, jnp.integer) + if not (is_integer or is_jax_int_array): + raise TypeError("n_zerosum_axes has to be an integer") + if not n_zerosum_axes > 0: + raise ValueError("n_zerosum_axes has to be > 0") + return n_zerosum_axes + + @staticmethod + def infer_shapes(scale=1.0, n_zerosum_axes=None, support_shape=(1,)): + '''Numpyro assumes that the event and batch shape can be entirely + determined by the shapes of the distribution inputs. This distribution + doesn't follow those conventions, so the `infer_shapes` method cant be implemented. + ''' + raise NotImplementedError() + + def _validate_sample(self, value): + mask = super(ZeroSumNormal, self)._validate_sample(value) + batch_dim = jnp.ndim(value) - len(self.event_shape) + if batch_dim < jnp.ndim(mask): + mask = jnp.all(jnp.reshape(mask, jnp.shape(mask)[:batch_dim] + (-1,)), -1) + return mask diff --git a/test/test_distributions.py b/test/test_distributions.py index f80f6bcf6..a862f67d4 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -773,6 +773,10 @@ def get_sp_dist(jax_dist): T(dist.Weibull, 0.2, 1.1), T(dist.Weibull, 2.8, np.array([2.0, 2.0])), T(dist.Weibull, 1.8, np.array([[1.0, 1.0], [2.0, 2.0]])), + T(dist.ZeroSumNormal, 1.0, None, (1,)), + T(dist.ZeroSumNormal, 1.0, 1, (1,)), + T(dist.ZeroSumNormal, np.array([2.0]), None, (1,)), + T(dist.ZeroSumNormal, 1.0, 2, (4,5)), T( _GaussianMixture, np.ones(3) / 3.0, @@ -1296,6 +1300,7 @@ def test_jit_log_likelihood(jax_dist, sp_dist, params): "LKJ", "LKJCholesky", "_SparseCAR", + "ZeroSumNormal", ): pytest.xfail(reason="non-jittable params") @@ -1454,6 +1459,9 @@ def test_gof(jax_dist, sp_dist, params): if jax_dist is dist.ProjectedNormal: dim = samples.shape[-1] - 1 + if jax_dist is dist.ZeroSumNormal: + pytest.skip("skip gof test for ZeroSumNormal") + # Test each batch independently. probs = probs.reshape(num_samples, -1) samples = samples.reshape(probs.shape + d.event_shape) @@ -1671,6 +1679,9 @@ def fn(*args): if jax_dist is _SparseCAR and i == 3: # skip taking grad w.r.t. adj_matrix continue + if jax_dist is dist.ZeroSumNormal and i != 0: + # skip taking grad w.r.t. n_zerosum_axes and support_shape + continue if isinstance( params[i], dist.Distribution ): # skip taking grad w.r.t. base_dist @@ -1857,7 +1868,7 @@ def get_min_shape(ix, batch_shape): if isinstance(d_jax, dist.Gompertz): pytest.skip("Gompertz distribution does not have `variance` implemented.") if jnp.all(jnp.isfinite(d_jax.variance)): - assert_allclose( + jnp.allclose( jnp.std(samples, 0), jnp.sqrt(d_jax.variance), rtol=0.05, atol=1e-2 ) @@ -1898,6 +1909,8 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): continue if jax_dist is dist.GaussianRandomWalk and dist_args[i] == "num_steps": continue + if jax_dist is dist.ZeroSumNormal and dist_args[i] in ("n_zerosum_axes", "support_shape"): + continue if ( jax_dist is dist.SineBivariateVonMises and dist_args[i] == "weighted_correlation" From c28fd0cdadd4d800400c6f3b321a86d3d7ec23bb Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Wed, 28 Feb 2024 20:04:33 -0500 Subject: [PATCH 02/42] added edge case handling for support shape --- numpyro/distributions/continuous.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 973544ff7..c9a85398e 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -25,8 +25,6 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. -from typing import Optional - import numpy as np from jax import lax, vmap @@ -2459,9 +2457,9 @@ def __init__(self, scale=1.0, n_zerosum_axes=None, support_shape=(1,), *, valida raise ValueError("scale must have length one across the zero-sum axes") self.n_zerosum_axes = self.check_zerosum_axes(n_zerosum_axes) - # batch_shape = lax.broadcast_shapes(jnp.shape(scale)) if jnp.ndim(scale) == 0: (scale,) = promote_shapes(scale, shape=(1,)) + # temporary append a new axis to scale scale = scale[..., jnp.newaxis] cov_placeholder = jnp.eye(len(support_shape)) @@ -2472,7 +2470,7 @@ def __init__(self, scale=1.0, n_zerosum_axes=None, support_shape=(1,), *, valida self.scale = scale[..., 0] super(ZeroSumNormal, self).__init__( batch_shape=batch_shape, - event_shape=support_shape, + event_shape=self.check_support_shape(support_shape, self.n_zerosum_axes), validate_args=validate_args ) @@ -2528,7 +2526,7 @@ def variance(self): return theoretical_var - def check_zerosum_axes(self, n_zerosum_axes: Optional[int]) -> int: + def check_zerosum_axes(self, n_zerosum_axes): if n_zerosum_axes is None: n_zerosum_axes = 1 @@ -2540,6 +2538,12 @@ def check_zerosum_axes(self, n_zerosum_axes: Optional[int]) -> int: raise ValueError("n_zerosum_axes has to be > 0") return n_zerosum_axes + def check_support_shape(self, support_shape, n_zerosum_axes): + assert n_zerosum_axes <= len(support_shape), "support_shape has to be as long as n_zerosum_axes" + assert all(shape > 0 for shape in support_shape), "support_shape must be a valid shape" + assert len(support_shape) > 0, "support_shape must be a valid shape" + return support_shape + @staticmethod def infer_shapes(scale=1.0, n_zerosum_axes=None, support_shape=(1,)): '''Numpyro assumes that the event and batch shape can be entirely From 93bcf0f6c772c41eb55987e5753c0874ac77f42d Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Wed, 28 Feb 2024 20:19:24 -0500 Subject: [PATCH 03/42] removed commented out functions --- numpyro/distributions/continuous.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index c9a85398e..7886bb64c 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2504,16 +2504,6 @@ def log_prob(self, value): ) return log_pdf - # def cdf(self, value): - # scaled = (value - 0) / self.scale - # return ndtr(scaled) - - # def log_cdf(self, value): - # return jax_norm.logcdf(value, loc=0, scale=self.scale) - - # def icdf(self, q): - # return 0 + self.scale * ndtri(q) - @property def mean(self): return jnp.broadcast_to(0, self.batch_shape) From d9f2b4e50d8cd8b7e9d9c28d718b901f86356835 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Wed, 28 Feb 2024 20:25:25 -0500 Subject: [PATCH 04/42] added zerosumnormal to docs --- docs/source/distributions.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 2614e7f7e..c19bd3e6f 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -380,6 +380,13 @@ Weibull :show-inheritance: :member-order: bysource +ZeroSumNormal +^^^^^^^ +.. autoclass:: numpyro.distributions.continuous.ZeroSumNormal + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource Discrete Distributions ---------------------- From 0abb60b9d6487b793ae1083b7f7121f5d51b6d6f Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Wed, 28 Feb 2024 22:27:24 -0500 Subject: [PATCH 05/42] fixed zerosumnormal support shape default --- numpyro/distributions/continuous.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 7886bb64c..51ea5aea9 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2457,12 +2457,13 @@ def __init__(self, scale=1.0, n_zerosum_axes=None, support_shape=(1,), *, valida raise ValueError("scale must have length one across the zero-sum axes") self.n_zerosum_axes = self.check_zerosum_axes(n_zerosum_axes) + support_shape = self.check_support_shape(support_shape, self.n_zerosum_axes) if jnp.ndim(scale) == 0: (scale,) = promote_shapes(scale, shape=(1,)) # temporary append a new axis to scale scale = scale[..., jnp.newaxis] - cov_placeholder = jnp.eye(len(support_shape)) + cov_placeholder = jnp.eye(len(scale)) scale, cov_placeholder = promote_shapes(scale, cov_placeholder) batch_shape = lax.broadcast_shapes( jnp.shape(scale)[:-2], jnp.shape(cov_placeholder)[:-2] @@ -2470,7 +2471,7 @@ def __init__(self, scale=1.0, n_zerosum_axes=None, support_shape=(1,), *, valida self.scale = scale[..., 0] super(ZeroSumNormal, self).__init__( batch_shape=batch_shape, - event_shape=self.check_support_shape(support_shape, self.n_zerosum_axes), + event_shape=support_shape, validate_args=validate_args ) @@ -2529,6 +2530,8 @@ def check_zerosum_axes(self, n_zerosum_axes): return n_zerosum_axes def check_support_shape(self, support_shape, n_zerosum_axes): + if support_shape is None: + return () assert n_zerosum_axes <= len(support_shape), "support_shape has to be as long as n_zerosum_axes" assert all(shape > 0 for shape in support_shape), "support_shape must be a valid shape" assert len(support_shape) > 0, "support_shape must be a valid shape" From b28f38c0b7b313e2591daf2ed240fb06894bfc49 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Thu, 29 Feb 2024 09:08:02 -0500 Subject: [PATCH 06/42] Added v1 of docstrings for zerosumnormal --- numpyro/distributions/continuous.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 51ea5aea9..052341a6f 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2447,12 +2447,25 @@ def icdf(self, value): class ZeroSumNormal(Distribution): + """ + Zero Sum Normal distribution adapted from PyMC as described in [1]. This is a Normal distribution where one or + more axes are constrained to sum to zero (the last axis by default). + + :param array_like scale: Standard deviation of the underlying normal distribution before the zerosum constraint is + enforced. + :param int n_zerosum_axes: The number of axes to enforce a zerosum constraint. + :param tuple support_shape: The event shape of the distribution. + + **References** + + [1] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html + """ arg_constraints = {"scale": constraints.positive} support = constraints.real reparametrized_params = ["scale"] pytree_aux_fields = ("n_zerosum_axes","support_shape",) - def __init__(self, scale=1.0, n_zerosum_axes=None, support_shape=(1,), *, validate_args=None): + def __init__(self, scale=1.0, n_zerosum_axes=None, support_shape=None, *, validate_args=None): if not all(tuple(i == 1 for i in jnp.shape( scale ))): raise ValueError("scale must have length one across the zero-sum axes") From 4e1dd16c305f524381d4d526dbd269fb02d40f4b Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Thu, 29 Feb 2024 09:38:34 -0500 Subject: [PATCH 07/42] updated zsn docstring --- numpyro/distributions/continuous.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 052341a6f..f12296668 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2448,7 +2448,7 @@ def icdf(self, value): class ZeroSumNormal(Distribution): """ - Zero Sum Normal distribution adapted from PyMC as described in [1]. This is a Normal distribution where one or + Zero Sum Normal distribution adapted from PyMC [1] as described in [2]. This is a Normal distribution where one or more axes are constrained to sum to zero (the last axis by default). :param array_like scale: Standard deviation of the underlying normal distribution before the zerosum constraint is @@ -2457,8 +2457,8 @@ class ZeroSumNormal(Distribution): :param tuple support_shape: The event shape of the distribution. **References** - - [1] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html + [1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637 + [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html """ arg_constraints = {"scale": constraints.positive} support = constraints.real From 8cd792ceccb44aab99f43e0a4eb3499e75c34b88 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Thu, 29 Feb 2024 10:21:04 -0500 Subject: [PATCH 08/42] improved init shape handling for zerosumnormal --- numpyro/distributions/continuous.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index f12296668..d1398f747 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2474,14 +2474,9 @@ def __init__(self, scale=1.0, n_zerosum_axes=None, support_shape=None, *, valida if jnp.ndim(scale) == 0: (scale,) = promote_shapes(scale, shape=(1,)) - # temporary append a new axis to scale - scale = scale[..., jnp.newaxis] - cov_placeholder = jnp.eye(len(scale)) - scale, cov_placeholder = promote_shapes(scale, cov_placeholder) - batch_shape = lax.broadcast_shapes( - jnp.shape(scale)[:-2], jnp.shape(cov_placeholder)[:-2] - ) - self.scale = scale[..., 0] + batch_shape = jnp.shape(scale)[:-1] + self.scale = scale + super(ZeroSumNormal, self).__init__( batch_shape=batch_shape, event_shape=support_shape, From dcbdd8545fbd11d710b7741c1a1a25ec8ee3cb60 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Thu, 29 Feb 2024 12:33:15 -0500 Subject: [PATCH 09/42] improved docstrings --- numpyro/distributions/continuous.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index d1398f747..7c04aa94a 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2447,7 +2447,7 @@ def icdf(self, value): class ZeroSumNormal(Distribution): - """ + r""" Zero Sum Normal distribution adapted from PyMC [1] as described in [2]. This is a Normal distribution where one or more axes are constrained to sum to zero (the last axis by default). @@ -2456,6 +2456,13 @@ class ZeroSumNormal(Distribution): :param int n_zerosum_axes: The number of axes to enforce a zerosum constraint. :param tuple support_shape: The event shape of the distribution. + .. math:: + \begin{align*} + ZSN(\sigma) = N(0, \sigma^2 (I - \tfrac{1}{n}J)) \\ + \text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\ + n = \text{number of zero-sum axes} + \end{align*} + **References** [1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637 [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html From 13fff40d61728432ee4a742333dddd3bb66c3cd2 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 5 Mar 2024 16:26:21 -0500 Subject: [PATCH 10/42] added ZeroSumTransform --- numpyro/distributions/transforms.py | 53 +++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index b7c7605c4..e56da31d4 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -6,6 +6,7 @@ import weakref import numpy as np +from numpy.core.numeric import normalize_axis_tuple from jax import lax, vmap from jax.flatten_util import ravel_pytree @@ -48,6 +49,7 @@ "StickBreakingTransform", "Transform", "UnpackTransform", + "ZeroSumTransform", ] @@ -1142,6 +1144,57 @@ def __eq__(self, other): return isinstance(other, UnpackTransform) and self.unpack_fn is other.unpack_fn +class ZeroSumTransform(ParameterFreeTransform): + """A transform that constrains an array to sum to zero + """ + # domain = constraints.real_vector + # codomain = constraints.simplex + + def __call__(self, value, zerosum_axes): + for axis in zerosum_axes: + value = self.extend_axis_rev(value, axis=axis) + return value + + def _inverse(self, value, zerosum_axes): + for axis in zerosum_axes: + value = self.extend_axis(value, axis=axis) + return value + + def extend_axis_rev(self, array, axis): + normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] + + n = array.shape[normalized_axis] + last = jnp.take(array, [-1], axis=normalized_axis) + + sum_vals = -last * jnp.sqrt(n) + norm = sum_vals / (jnp.sqrt(n) + n) + slice_before = (slice(None, None),) * normalized_axis + return array[(*slice_before, slice(None, -1))] + norm + + def extend_axis(self, array, axis): + n = (array.shape[axis] + 1) + + sum_vals = array.sum(axis, keepdims=True) + norm = sum_vals / (jnp.sqrt(n) + n) + fill_val = norm - sum_vals / jnp.sqrt(n) + + out = jnp.concatenate([array, fill_val], axis=axis) + return out - norm + + def log_abs_det_jacobian(self, x, y, intermediates=None): + return jnp.array(0.0) + + def forward_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape[:-1] + (shape[-1] + 1,) + + def inverse_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape[:-1] + (shape[-1] - 1,) + + def _get_target_shape(shape, forward_shape, inverse_shape): batch_ndims = len(shape) - len(inverse_shape) return shape[:batch_ndims] + forward_shape From 514000c79bbf5c8e32ecc38cbb8c262653bd18a3 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 5 Mar 2024 17:02:32 -0500 Subject: [PATCH 11/42] made n_zerosum_axes an attribute for the zerosumtransform --- numpyro/distributions/transforms.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index e56da31d4..7729ddf7a 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1149,16 +1149,18 @@ class ZeroSumTransform(ParameterFreeTransform): """ # domain = constraints.real_vector # codomain = constraints.simplex + def __init__(self, zerosum_axes): + self.zerosum_axes = zerosum_axes - def __call__(self, value, zerosum_axes): - for axis in zerosum_axes: - value = self.extend_axis_rev(value, axis=axis) - return value + def __call__(self, x): + for axis in self.zerosum_axes: + x = self.extend_axis_rev(x, axis=axis) + return x - def _inverse(self, value, zerosum_axes): - for axis in zerosum_axes: - value = self.extend_axis(value, axis=axis) - return value + def _inverse(self, y): + for axis in self.zerosum_axes: + y = self.extend_axis(y, axis=axis) + return y def extend_axis_rev(self, array, axis): normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] @@ -1184,16 +1186,6 @@ def extend_axis(self, array, axis): def log_abs_det_jacobian(self, x, y, intermediates=None): return jnp.array(0.0) - def forward_shape(self, shape): - if len(shape) < 1: - raise ValueError("Too few dimensions on input") - return shape[:-1] + (shape[-1] + 1,) - - def inverse_shape(self, shape): - if len(shape) < 1: - raise ValueError("Too few dimensions on input") - return shape[:-1] + (shape[-1] - 1,) - def _get_target_shape(shape, forward_shape, inverse_shape): batch_ndims = len(shape) - len(inverse_shape) From d6315c36876cef417b3cbe50c74d299c0c4aed5c Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 5 Mar 2024 17:03:22 -0500 Subject: [PATCH 12/42] removed commented out lines --- numpyro/distributions/transforms.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 7729ddf7a..be5895d03 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1147,8 +1147,6 @@ def __eq__(self, other): class ZeroSumTransform(ParameterFreeTransform): """A transform that constrains an array to sum to zero """ - # domain = constraints.real_vector - # codomain = constraints.simplex def __init__(self, zerosum_axes): self.zerosum_axes = zerosum_axes From 907cd2ee0de88ae656a14c3e2e4a292e0b03760e Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Thu, 7 Mar 2024 18:47:32 -0500 Subject: [PATCH 13/42] added zerosumtransform class --- numpyro/distributions/continuous.py | 5 +++-- numpyro/distributions/transforms.py | 9 +++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 7c04aa94a..7aca0cb18 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2448,7 +2448,7 @@ def icdf(self, value): class ZeroSumNormal(Distribution): r""" - Zero Sum Normal distribution adapted from PyMC [1] as described in [2]. This is a Normal distribution where one or + Zero Sum Normal distribution adapted from PyMC [1] as described in [2,3]. This is a Normal distribution where one or more axes are constrained to sum to zero (the last axis by default). :param array_like scale: Standard deviation of the underlying normal distribution before the zerosum constraint is @@ -2466,9 +2466,10 @@ class ZeroSumNormal(Distribution): **References** [1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637 [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html + [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/ """ arg_constraints = {"scale": constraints.positive} - support = constraints.real + support = constraints.real # FIXME reparametrized_params = ["scale"] pytree_aux_fields = ("n_zerosum_axes","support_shape",) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index be5895d03..c5a707231 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1145,7 +1145,12 @@ def __eq__(self, other): class ZeroSumTransform(ParameterFreeTransform): - """A transform that constrains an array to sum to zero + """A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3] + + **References** + [1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637 + [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html + [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/ """ def __init__(self, zerosum_axes): self.zerosum_axes = zerosum_axes @@ -1164,7 +1169,7 @@ def extend_axis_rev(self, array, axis): normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] n = array.shape[normalized_axis] - last = jnp.take(array, [-1], axis=normalized_axis) + last = jnp.take(array, jnp.array([-1]), axis=normalized_axis) sum_vals = -last * jnp.sqrt(n) norm = sum_vals / (jnp.sqrt(n) + n) From fc3f0534c572f3603a60604f39aa162738fb6b41 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Fri, 8 Mar 2024 15:59:08 -0500 Subject: [PATCH 14/42] switched zsn from ParameterFreeTransform to Transform --- numpyro/distributions/transforms.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index c5a707231..1a35f0346 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1144,7 +1144,7 @@ def __eq__(self, other): return isinstance(other, UnpackTransform) and self.unpack_fn is other.unpack_fn -class ZeroSumTransform(ParameterFreeTransform): +class ZeroSumTransform(Transform): """A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3] **References** @@ -1189,6 +1189,9 @@ def extend_axis(self, array, axis): def log_abs_det_jacobian(self, x, y, intermediates=None): return jnp.array(0.0) + def tree_flatten(self): + return (self.zerosum_axes,), (("zerosum_axes",), dict()) + def _get_target_shape(shape, forward_shape, inverse_shape): batch_ndims = len(shape) - len(inverse_shape) From 8187421b316c5c828d8cae8b8a8214c4afcfc81c Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Mon, 25 Mar 2024 17:23:23 -0400 Subject: [PATCH 15/42] changed ZeroSumNormal to transformed distibutrion --- numpyro/distributions/constraints.py | 25 +++++- numpyro/distributions/continuous.py | 128 +++------------------------ numpyro/distributions/transforms.py | 14 ++- 3 files changed, 49 insertions(+), 118 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index d442a3150..e8566b0a5 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -54,6 +54,7 @@ "softplus_lower_cholesky", "softplus_positive", "unit_interval", + "zero_sum", "Constraint", ] @@ -509,7 +510,6 @@ class _UnitInterval(_SingletonConstraint, _Interval): def __init__(self): super().__init__(0.0, 1.0) - class _OpenInterval(_Interval): def __call__(self, x): return (x > self.lower_bound) & (x < self.upper_bound) @@ -540,6 +540,28 @@ def feasible_like(self, prototype): ) +class _ZeroSum(Constraint): + def __init__(self, event_dim = 1): + self.event_dim = event_dim + super().__init__() + + def __call__(self, x): + jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + zerosum_true = [] + for dim in range(-self.event_dim, 0): + zerosum_true.append(jnp.allclose(x.sum(-1), 0, rtol=0.05, atol=1e-2)) + return all(zerosum_true) + + def feasible_like(self, prototype): + return jax.numpy.broadcast_to(0, prototype.shape) + + def tree_flatten(self): + return (self.event_dim), ( + ("event_dim"), + dict(), + ) + + class _Multinomial(Constraint): is_discrete = True event_dim = 1 @@ -720,3 +742,4 @@ def feasible_like(self, prototype): sphere = _Sphere() unit_interval = _UnitInterval() open_interval = _OpenInterval +zero_sum = _ZeroSum diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 7aca0cb18..cab3ebea9 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -58,6 +58,7 @@ ExpTransform, PowerTransform, SigmoidTransform, + ZeroSumTransform, ) from numpyro.distributions.util import ( betainc, @@ -2446,124 +2447,19 @@ def icdf(self, value): return self._ald.icdf(value) -class ZeroSumNormal(Distribution): - r""" - Zero Sum Normal distribution adapted from PyMC [1] as described in [2,3]. This is a Normal distribution where one or - more axes are constrained to sum to zero (the last axis by default). - - :param array_like scale: Standard deviation of the underlying normal distribution before the zerosum constraint is - enforced. - :param int n_zerosum_axes: The number of axes to enforce a zerosum constraint. - :param tuple support_shape: The event shape of the distribution. - - .. math:: - \begin{align*} - ZSN(\sigma) = N(0, \sigma^2 (I - \tfrac{1}{n}J)) \\ - \text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\ - n = \text{number of zero-sum axes} - \end{align*} - - **References** - [1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637 - [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html - [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/ - """ +class ZeroSumNormal(TransformedDistribution): arg_constraints = {"scale": constraints.positive} - support = constraints.real # FIXME reparametrized_params = ["scale"] - pytree_aux_fields = ("n_zerosum_axes","support_shape",) - - def __init__(self, scale=1.0, n_zerosum_axes=None, support_shape=None, *, validate_args=None): - if not all(tuple(i == 1 for i in jnp.shape( scale ))): - raise ValueError("scale must have length one across the zero-sum axes") - - self.n_zerosum_axes = self.check_zerosum_axes(n_zerosum_axes) - support_shape = self.check_support_shape(support_shape, self.n_zerosum_axes) - if jnp.ndim(scale) == 0: - (scale,) = promote_shapes(scale, shape=(1,)) - - batch_shape = jnp.shape(scale)[:-1] - self.scale = scale - - super(ZeroSumNormal, self).__init__( - batch_shape=batch_shape, - event_shape=support_shape, - validate_args=validate_args - ) - - def sample(self, key, sample_shape=()): - assert is_prng_key(key) - zerosum_rv_ = random.normal( - key, shape=sample_shape + self.batch_shape + self.event_shape - ) * self.scale - - if not zerosum_rv_.shape: - return jnp.zeros(zerosum_rv_.shape) - - for axis in range(self.n_zerosum_axes): - zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True) - return zerosum_rv_ - @validate_sample - def log_prob(self, value): - shape = jnp.array(value.shape) - _deg_free_support_shape = shape.at[-self.n_zerosum_axes:].set( shape[-self.n_zerosum_axes:] - 1 ) - _full_size = jnp.prod(shape).astype(float) - _degrees_of_freedom = jnp.prod(_deg_free_support_shape).astype(float) - - if not value.shape or self.batch_shape: - value = jnp.expand_dims(value, -1) - - log_pdf = jnp.sum( - -0.5 * jnp.pow(value / self.scale, 2) - - (jnp.log(jnp.sqrt(2.0 * jnp.pi)) + jnp.log(self.scale)) * _degrees_of_freedom / _full_size, - axis=tuple(np.arange(-self.n_zerosum_axes, 0)), + def __init__(self, scale, event_shape): + event_ndim = len(event_shape) + transformed_shape = tuple(size - 1 for size in event_shape) + zero_sum_axes = tuple(-(i + 1) for i in range(event_ndim)) + super().__init__( + Normal(0, scale).expand(transformed_shape).to_event(event_ndim), + ZeroSumTransform(zero_sum_axes).inv, ) - return log_pdf - @property - def mean(self): - return jnp.broadcast_to(0, self.batch_shape) - - @property - def variance(self): - theoretical_var = self.scale.astype(float)**2 - for axis in range(1,self.n_zerosum_axes+1): - theoretical_var *= (1 - 1 / self.event_shape[-axis]) - - return theoretical_var - - def check_zerosum_axes(self, n_zerosum_axes): - if n_zerosum_axes is None: - n_zerosum_axes = 1 - - is_integer = isinstance(n_zerosum_axes, int) - is_jax_int_array = isinstance(n_zerosum_axes, jnp.ndarray) and jnp.issubdtype(n_zerosum_axes.dtype, jnp.integer) - if not (is_integer or is_jax_int_array): - raise TypeError("n_zerosum_axes has to be an integer") - if not n_zerosum_axes > 0: - raise ValueError("n_zerosum_axes has to be > 0") - return n_zerosum_axes - - def check_support_shape(self, support_shape, n_zerosum_axes): - if support_shape is None: - return () - assert n_zerosum_axes <= len(support_shape), "support_shape has to be as long as n_zerosum_axes" - assert all(shape > 0 for shape in support_shape), "support_shape must be a valid shape" - assert len(support_shape) > 0, "support_shape must be a valid shape" - return support_shape - - @staticmethod - def infer_shapes(scale=1.0, n_zerosum_axes=None, support_shape=(1,)): - '''Numpyro assumes that the event and batch shape can be entirely - determined by the shapes of the distribution inputs. This distribution - doesn't follow those conventions, so the `infer_shapes` method cant be implemented. - ''' - raise NotImplementedError() - - def _validate_sample(self, value): - mask = super(ZeroSumNormal, self)._validate_sample(value) - batch_dim = jnp.ndim(value) - len(self.event_shape) - if batch_dim < jnp.ndim(mask): - mask = jnp.all(jnp.reshape(mask, jnp.shape(mask)[:batch_dim] + (-1,)), -1) - return mask + @constraints.dependent_property(is_discrete=False) + def support(self): + return constraints.zero_sum(len(self.event_shape)) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 1a35f0346..99842b04e 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1148,13 +1148,21 @@ class ZeroSumTransform(Transform): """A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3] **References** - [1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637 + [1] https://github.com/pymc-devs/pymc/blob/244fb97b01ad0f3dadf5c3837b65839e2a59a0e8/pymc/distributions/transforms.py#L266 [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/ """ def __init__(self, zerosum_axes): self.zerosum_axes = zerosum_axes + @property + def domain(self): + return constraints.independent(constraints.real, len(self.zerosum_axes)) + + @property + def codomain(self): + return constraints.zero_sum(len(self.zerosum_axes)) + def __call__(self, x): for axis in self.zerosum_axes: x = self.extend_axis_rev(x, axis=axis) @@ -1396,3 +1404,7 @@ def _transform_to_softplus_lower_cholesky(constraint): @biject_to.register(constraints.simplex) def _transform_to_simplex(constraint): return StickBreakingTransform() + +@biject_to.register(constraints.zero_sum) +def _transform_to_zero_sum(constraint): + return ZeroSumTransform(jnp.array([-1])).inv From 005134264c1c38614e05b89dadc9304c7f8b446a Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Mon, 25 Mar 2024 17:46:17 -0400 Subject: [PATCH 16/42] changed input to tuple for _transform_to_zero_sum --- numpyro/distributions/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 99842b04e..fe22c57f9 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1407,4 +1407,4 @@ def _transform_to_simplex(constraint): @biject_to.register(constraints.zero_sum) def _transform_to_zero_sum(constraint): - return ZeroSumTransform(jnp.array([-1])).inv + return ZeroSumTransform(tuple([-1])).inv From 1820a742e55c37e24ed08b7f99bf679230e42602 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 09:57:11 -0400 Subject: [PATCH 17/42] added forward and inverse shape to transform, fixed zero_sum constraint handling --- numpyro/distributions/constraints.py | 2 +- numpyro/distributions/continuous.py | 2 +- numpyro/distributions/transforms.py | 9 ++++++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index e8566b0a5..0c5742dac 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -549,7 +549,7 @@ def __call__(self, x): jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy zerosum_true = [] for dim in range(-self.event_dim, 0): - zerosum_true.append(jnp.allclose(x.sum(-1), 0, rtol=0.05, atol=1e-2)) + zerosum_true.append(jnp.allclose(x.sum(dim), 0, rtol=0.05, atol=1e-2)) return all(zerosum_true) def feasible_like(self, prototype): diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index cab3ebea9..d4aa0c4e1 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2454,7 +2454,7 @@ class ZeroSumNormal(TransformedDistribution): def __init__(self, scale, event_shape): event_ndim = len(event_shape) transformed_shape = tuple(size - 1 for size in event_shape) - zero_sum_axes = tuple(-(i + 1) for i in range(event_ndim)) + zero_sum_axes = tuple(i for i in range(-event_ndim,0)) super().__init__( Normal(0, scale).expand(transformed_shape).to_event(event_ndim), ZeroSumTransform(zero_sum_axes).inv, diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index fe22c57f9..9ed2cbd58 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1197,6 +1197,12 @@ def extend_axis(self, array, axis): def log_abs_det_jacobian(self, x, y, intermediates=None): return jnp.array(0.0) + def forward_shape(self, shape): + return tuple(s - 1 for s in shape) + + def inverse_shape(self, shape): + return tuple(s + 1 for s in shape) + def tree_flatten(self): return (self.zerosum_axes,), (("zerosum_axes",), dict()) @@ -1407,4 +1413,5 @@ def _transform_to_simplex(constraint): @biject_to.register(constraints.zero_sum) def _transform_to_zero_sum(constraint): - return ZeroSumTransform(tuple([-1])).inv + zero_sum_axes = tuple(i for i in range(-constraint.event_dim,0)) + return ZeroSumTransform(zero_sum_axes).inv From ee227bf62b73819932db1467284c82b98ce60f68 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 11:48:02 -0400 Subject: [PATCH 18/42] fixed failing zsn tests --- numpyro/distributions/constraints.py | 52 +++++++----- numpyro/distributions/continuous.py | 20 ++++- numpyro/distributions/transforms.py | 115 +++++++++++++-------------- test/test_distributions.py | 21 +++-- 4 files changed, 120 insertions(+), 88 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 0c5742dac..d626f7cea 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -510,6 +510,7 @@ class _UnitInterval(_SingletonConstraint, _Interval): def __init__(self): super().__init__(0.0, 1.0) + class _OpenInterval(_Interval): def __call__(self, x): return (x > self.lower_bound) & (x < self.upper_bound) @@ -540,28 +541,6 @@ def feasible_like(self, prototype): ) -class _ZeroSum(Constraint): - def __init__(self, event_dim = 1): - self.event_dim = event_dim - super().__init__() - - def __call__(self, x): - jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy - zerosum_true = [] - for dim in range(-self.event_dim, 0): - zerosum_true.append(jnp.allclose(x.sum(dim), 0, rtol=0.05, atol=1e-2)) - return all(zerosum_true) - - def feasible_like(self, prototype): - return jax.numpy.broadcast_to(0, prototype.shape) - - def tree_flatten(self): - return (self.event_dim), ( - ("event_dim"), - dict(), - ) - - class _Multinomial(Constraint): is_discrete = True event_dim = 1 @@ -709,6 +688,35 @@ def feasible_like(self, prototype): return jax.numpy.full_like(prototype, prototype.shape[-1] ** (-0.5)) +class _ZeroSum(Constraint): + def __init__(self, event_dim = 1): + self.event_dim = event_dim + super().__init__() + + def __call__(self, x): + jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + zerosum_true = [] + for dim in range(-self.event_dim, 0): + zerosum_true.append(jnp.allclose(x.sum(dim), 0, rtol=0.05, atol=1e-2)) + return all(zerosum_true) + + def __eq__(self, other): + return ( + type(self) is type(other) + and self.event_dim == other.event_dim + ) + + def feasible_like(self, prototype): + return jax.numpy.broadcast_to(0, prototype.shape) + + def tree_flatten(self): + return (self.event_dim), ( + ("event_dim"), + dict(), + ) + + + # TODO: Make types consistent # See https://github.com/pytorch/pytorch/issues/50616 diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index d4aa0c4e1..ee0965652 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2451,15 +2451,33 @@ class ZeroSumNormal(TransformedDistribution): arg_constraints = {"scale": constraints.positive} reparametrized_params = ["scale"] - def __init__(self, scale, event_shape): + def __init__(self, scale, event_shape, *, validate_args=None): event_ndim = len(event_shape) + if jnp.ndim(scale) == 0: + (scale,) = promote_shapes(scale, shape=(1,)) transformed_shape = tuple(size - 1 for size in event_shape) zero_sum_axes = tuple(i for i in range(-event_ndim,0)) + self.scale = scale super().__init__( Normal(0, scale).expand(transformed_shape).to_event(event_ndim), ZeroSumTransform(zero_sum_axes).inv, + validate_args=validate_args ) @constraints.dependent_property(is_discrete=False) def support(self): return constraints.zero_sum(len(self.event_shape)) + + @property + def mean(self): + return jnp.broadcast_to(0, self.batch_shape) + + @property + def variance(self): + event_ndim = len(self.event_shape) + zero_sum_axes = tuple(i for i in range(-event_ndim,0)) + theoretical_var = self.scale.astype(float)**2 + for axis in zero_sum_axes: + theoretical_var *= (1 - 1 / self.event_shape[axis]) + + return theoretical_var diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 9ed2cbd58..68ea706cd 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1144,6 +1144,63 @@ def __eq__(self, other): return isinstance(other, UnpackTransform) and self.unpack_fn is other.unpack_fn +def _get_target_shape(shape, forward_shape, inverse_shape): + batch_ndims = len(shape) - len(inverse_shape) + return shape[:batch_ndims] + forward_shape + + +class ReshapeTransform(Transform): + """ + Reshape a sample, leaving batch dimensions unchanged. + + :param forward_shape: Shape to transform the sample to. + :param inverse_shape: Shape of the sample for the inverse transform. + """ + + domain = constraints.real + codomain = constraints.real + + def __init__(self, forward_shape, inverse_shape) -> None: + forward_size = math.prod(forward_shape) + inverse_size = math.prod(inverse_shape) + if forward_size != inverse_size: + raise ValueError( + f"forward shape {forward_shape} (size {forward_size}) and inverse " + f"shape {inverse_shape} (size {inverse_size}) are not compatible" + ) + self._forward_shape = forward_shape + self._inverse_shape = inverse_shape + + def forward_shape(self, shape): + return _get_target_shape(shape, self._forward_shape, self._inverse_shape) + + def inverse_shape(self, shape): + return _get_target_shape(shape, self._inverse_shape, self._forward_shape) + + def __call__(self, x): + return jnp.reshape(x, self.forward_shape(jnp.shape(x))) + + def _inverse(self, y): + return jnp.reshape(y, self.inverse_shape(jnp.shape(y))) + + def log_abs_det_jacobian(self, x, y, intermediates=None): + return 0.0 + + def tree_flatten(self): + aux_data = { + "_forward_shape": self._forward_shape, + "_inverse_shape": self._inverse_shape, + } + return (), ((), aux_data) + + def __eq__(self, other): + return ( + isinstance(other, ReshapeTransform) + and self._forward_shape == other._forward_shape + and self._inverse_shape == other._inverse_shape + ) + + class ZeroSumTransform(Transform): """A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3] @@ -1207,64 +1264,6 @@ def tree_flatten(self): return (self.zerosum_axes,), (("zerosum_axes",), dict()) -def _get_target_shape(shape, forward_shape, inverse_shape): - batch_ndims = len(shape) - len(inverse_shape) - return shape[:batch_ndims] + forward_shape - - -class ReshapeTransform(Transform): - """ - Reshape a sample, leaving batch dimensions unchanged. - - :param forward_shape: Shape to transform the sample to. - :param inverse_shape: Shape of the sample for the inverse transform. - """ - - domain = constraints.real - codomain = constraints.real - - def __init__(self, forward_shape, inverse_shape) -> None: - forward_size = math.prod(forward_shape) - inverse_size = math.prod(inverse_shape) - if forward_size != inverse_size: - raise ValueError( - f"forward shape {forward_shape} (size {forward_size}) and inverse " - f"shape {inverse_shape} (size {inverse_size}) are not compatible" - ) - self._forward_shape = forward_shape - self._inverse_shape = inverse_shape - - def forward_shape(self, shape): - return _get_target_shape(shape, self._forward_shape, self._inverse_shape) - - def inverse_shape(self, shape): - return _get_target_shape(shape, self._inverse_shape, self._forward_shape) - - def __call__(self, x): - return jnp.reshape(x, self.forward_shape(jnp.shape(x))) - - def _inverse(self, y): - return jnp.reshape(y, self.inverse_shape(jnp.shape(y))) - - def log_abs_det_jacobian(self, x, y, intermediates=None): - return 0.0 - - def tree_flatten(self): - aux_data = { - "_forward_shape": self._forward_shape, - "_inverse_shape": self._inverse_shape, - } - return (), ((), aux_data) - - def __eq__(self, other): - return ( - isinstance(other, ReshapeTransform) - and self._forward_shape == other._forward_shape - and self._inverse_shape == other._inverse_shape - ) - - - ########################################################## # CONSTRAINT_REGISTRY ########################################################## diff --git a/test/test_distributions.py b/test/test_distributions.py index a862f67d4..cb571a439 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -6,7 +6,6 @@ import inspect from itertools import product import math -import os import numpy as np from numpy.testing import assert_allclose, assert_array_equal @@ -773,10 +772,9 @@ def get_sp_dist(jax_dist): T(dist.Weibull, 0.2, 1.1), T(dist.Weibull, 2.8, np.array([2.0, 2.0])), T(dist.Weibull, 1.8, np.array([[1.0, 1.0], [2.0, 2.0]])), - T(dist.ZeroSumNormal, 1.0, None, (1,)), - T(dist.ZeroSumNormal, 1.0, 1, (1,)), - T(dist.ZeroSumNormal, np.array([2.0]), None, (1,)), - T(dist.ZeroSumNormal, 1.0, 2, (4,5)), + T(dist.ZeroSumNormal, 1.0, (1,)), + T(dist.ZeroSumNormal, np.array([2.0]), (1,)), + T(dist.ZeroSumNormal, 1.0, (4,5)), T( _GaussianMixture, np.ones(3) / 3.0, @@ -1021,6 +1019,12 @@ def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)): sign = random.bernoulli(key1) bounds = [0, (-1) ** sign * 0.5] return random.uniform(key, size, float, *sorted(bounds)) + elif isinstance(constraint, constraints.zero_sum): + x = random.normal(key, size) + zero_sum_axes = tuple(i for i in range(-constraint.event_dim,0)) + for axis in zero_sum_axes: + x -= x.mean(axis) + return x else: raise NotImplementedError("{} not implemented.".format(constraint)) @@ -1088,6 +1092,9 @@ def gen_values_outside_bounds(constraint, size, key=random.PRNGKey(11)): sign = random.bernoulli(key1) bounds = [(-1) ** sign * 1.1, (-1) ** sign * 2] return random.uniform(key, size, float, *sorted(bounds)) + elif isinstance(constraint, constraints.zero_sum): + x = random.normal(key, size) + return x else: raise NotImplementedError("{} not implemented.".format(constraint)) @@ -1680,7 +1687,7 @@ def fn(*args): # skip taking grad w.r.t. adj_matrix continue if jax_dist is dist.ZeroSumNormal and i != 0: - # skip taking grad w.r.t. n_zerosum_axes and support_shape + # skip taking grad w.r.t. event_shape continue if isinstance( params[i], dist.Distribution @@ -1909,7 +1916,7 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): continue if jax_dist is dist.GaussianRandomWalk and dist_args[i] == "num_steps": continue - if jax_dist is dist.ZeroSumNormal and dist_args[i] in ("n_zerosum_axes", "support_shape"): + if jax_dist is dist.ZeroSumNormal and dist_args[i] == "event_shape": continue if ( jax_dist is dist.SineBivariateVonMises From bb4880caa6383e34e7c7b5663ec699a41196fcd4 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 11:57:27 -0400 Subject: [PATCH 19/42] added docstring, removed whitespace, fixed missing import --- numpyro/distributions/constraints.py | 1 - numpyro/distributions/continuous.py | 20 ++++++++++++++++++++ test/test_distributions.py | 1 + 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index d626f7cea..115677fc1 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -716,7 +716,6 @@ def tree_flatten(self): ) - # TODO: Make types consistent # See https://github.com/pytorch/pytorch/issues/50616 diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index ee0965652..badd892d9 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2448,6 +2448,26 @@ def icdf(self, value): class ZeroSumNormal(TransformedDistribution): + r""" + Zero Sum Normal distribution adapted from PyMC [1] as described in [2,3]. This is a Normal distribution where one or + more axes are constrained to sum to zero (the last axis by default). + + :param array_like scale: Standard deviation of the underlying normal distribution before the zerosum constraint is + enforced. + :param tuple event_shape: The event shape of the distribution, the axes of which get constrained to sum to zero. + + .. math:: + \begin{align*} + ZSN(\sigma) = N(0, \sigma^2 (I - \tfrac{1}{n}J)) \\ + \text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\ + n = \text{number of zero-sum axes} + \end{align*} + + **References** + [1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637 + [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html + [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/ + """ arg_constraints = {"scale": constraints.positive} reparametrized_params = ["scale"] diff --git a/test/test_distributions.py b/test/test_distributions.py index cb571a439..3f4b5f778 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -6,6 +6,7 @@ import inspect from itertools import product import math +import os import numpy as np from numpy.testing import assert_allclose, assert_array_equal From 38b8f566518f64148d205fd1325eea951e26bc51 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 12:03:49 -0400 Subject: [PATCH 20/42] fixed allclose to be assert allclose --- test/test_distributions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index 3f4b5f778..325f96be1 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1876,7 +1876,7 @@ def get_min_shape(ix, batch_shape): if isinstance(d_jax, dist.Gompertz): pytest.skip("Gompertz distribution does not have `variance` implemented.") if jnp.all(jnp.isfinite(d_jax.variance)): - jnp.allclose( + assert jnp.allclose( jnp.std(samples, 0), jnp.sqrt(d_jax.variance), rtol=0.05, atol=1e-2 ) From c8af390b8fe5a3b444b8460d0308226a31da06ea Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 12:49:41 -0400 Subject: [PATCH 21/42] linted and formatted --- numpyro/distributions/constraints.py | 7 ++----- numpyro/distributions/continuous.py | 13 +++++++------ numpyro/distributions/transforms.py | 6 ++++-- test/test_distributions.py | 4 ++-- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 882fbf3a7..34608da64 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -699,7 +699,7 @@ def feasible_like(self, prototype): class _ZeroSum(Constraint): - def __init__(self, event_dim = 1): + def __init__(self, event_dim=1): self.event_dim = event_dim super().__init__() @@ -711,10 +711,7 @@ def __call__(self, x): return all(zerosum_true) def __eq__(self, other): - return ( - type(self) is type(other) - and self.event_dim == other.event_dim - ) + return type(self) is type(other) and self.event_dim == other.event_dim def feasible_like(self, prototype): return jax.numpy.broadcast_to(0, prototype.shape) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index b40d77178..0b5c3b2cf 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2462,20 +2462,21 @@ class ZeroSumNormal(TransformedDistribution): [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/ """ + arg_constraints = {"scale": constraints.positive} reparametrized_params = ["scale"] - def __init__(self, scale, event_shape, *, validate_args=None): + def __init__(self, scale, event_shape, *, validate_args=None): event_ndim = len(event_shape) if jnp.ndim(scale) == 0: (scale,) = promote_shapes(scale, shape=(1,)) transformed_shape = tuple(size - 1 for size in event_shape) - zero_sum_axes = tuple(i for i in range(-event_ndim,0)) + zero_sum_axes = tuple(i for i in range(-event_ndim, 0)) self.scale = scale super().__init__( Normal(0, scale).expand(transformed_shape).to_event(event_ndim), ZeroSumTransform(zero_sum_axes).inv, - validate_args=validate_args + validate_args=validate_args, ) @constraints.dependent_property(is_discrete=False) @@ -2489,9 +2490,9 @@ def mean(self): @property def variance(self): event_ndim = len(self.event_shape) - zero_sum_axes = tuple(i for i in range(-event_ndim,0)) - theoretical_var = self.scale.astype(float)**2 + zero_sum_axes = tuple(i for i in range(-event_ndim, 0)) + theoretical_var = self.scale.astype(float) ** 2 for axis in zero_sum_axes: - theoretical_var *= (1 - 1 / self.event_shape[axis]) + theoretical_var *= 1 - 1 / self.event_shape[axis] return theoretical_var diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 00e3ae59c..e201cea0d 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1390,6 +1390,7 @@ class ZeroSumTransform(Transform): [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/ """ + def __init__(self, zerosum_axes): self.zerosum_axes = zerosum_axes @@ -1423,7 +1424,7 @@ def extend_axis_rev(self, array, axis): return array[(*slice_before, slice(None, -1))] + norm def extend_axis(self, array, axis): - n = (array.shape[axis] + 1) + n = array.shape[axis] + 1 sum_vals = array.sum(axis, keepdims=True) norm = sum_vals / (jnp.sqrt(n) + n) @@ -1596,7 +1597,8 @@ def _transform_to_softplus_lower_cholesky(constraint): def _transform_to_simplex(constraint): return StickBreakingTransform() + @biject_to.register(constraints.zero_sum) def _transform_to_zero_sum(constraint): - zero_sum_axes = tuple(i for i in range(-constraint.event_dim,0)) + zero_sum_axes = tuple(i for i in range(-constraint.event_dim, 0)) return ZeroSumTransform(zero_sum_axes).inv diff --git a/test/test_distributions.py b/test/test_distributions.py index cb276c209..bca28ce15 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -776,7 +776,7 @@ def get_sp_dist(jax_dist): T(dist.Weibull, 1.8, np.array([[1.0, 1.0], [2.0, 2.0]])), T(dist.ZeroSumNormal, 1.0, (1,)), T(dist.ZeroSumNormal, np.array([2.0]), (1,)), - T(dist.ZeroSumNormal, 1.0, (4,5)), + T(dist.ZeroSumNormal, 1.0, (4, 5)), T( _GaussianMixture, np.ones(3) / 3.0, @@ -1023,7 +1023,7 @@ def gen_values_within_bounds(constraint, size, key=random.PRNGKey(11)): return random.uniform(key, size, float, *sorted(bounds)) elif isinstance(constraint, constraints.zero_sum): x = random.normal(key, size) - zero_sum_axes = tuple(i for i in range(-constraint.event_dim,0)) + zero_sum_axes = tuple(i for i in range(-constraint.event_dim, 0)) for axis in zero_sum_axes: x -= x.mean(axis) return x From 3034f4ab85328678c41e9760f440ecd9cd33f83b Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 13:20:13 -0400 Subject: [PATCH 22/42] added sample code to docstring for zsn --- numpyro/distributions/continuous.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 0b5c3b2cf..6d810761b 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2446,6 +2446,18 @@ class ZeroSumNormal(TransformedDistribution): Zero Sum Normal distribution adapted from PyMC [1] as described in [2,3]. This is a Normal distribution where one or more axes are constrained to sum to zero (the last axis by default). + Sample code for using ZeroSumNormal in the context of a single axis to zero-contrain:: + + def model(category_ind, y): # X is an indexed categorical variable with 20 categories + N = len(category_ind) + alpha = numpyro.sample("alpha", dist.Normal(0, 2.5)) + beta = numpyro.sample("beta", dist.ZeroSumNormal(1, event_shape=(20,))) + sigma = numpyro.sample("sigma", dist.Exponential(1)) + with numpyro.plate("observations", N): + mu = alpha + beta[category_ind] + obs = numpyro.sample("obs", dist.Normal(mu, sigma), obs=y) + return obs + :param array_like scale: Standard deviation of the underlying normal distribution before the zerosum constraint is enforced. :param tuple event_shape: The event shape of the distribution, the axes of which get constrained to sum to zero. From ebdd309bfc4d985087291550a9f1df69e13219ae Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 13:21:04 -0400 Subject: [PATCH 23/42] updated docstring --- numpyro/distributions/continuous.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 6d810761b..2895bfaf9 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2448,7 +2448,7 @@ class ZeroSumNormal(TransformedDistribution): Sample code for using ZeroSumNormal in the context of a single axis to zero-contrain:: - def model(category_ind, y): # X is an indexed categorical variable with 20 categories + def model(category_ind, y): # category_ind is an indexed categorical variable with 20 categories N = len(category_ind) alpha = numpyro.sample("alpha", dist.Normal(0, 2.5)) beta = numpyro.sample("beta", dist.ZeroSumNormal(1, event_shape=(20,))) From 8cb7a5fd0dbcd69062628905056a18f23cfec9a1 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 15:04:30 -0400 Subject: [PATCH 24/42] removed list from ZeroSum constraint call --- numpyro/distributions/constraints.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 34608da64..15ea92e84 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -705,9 +705,9 @@ def __init__(self, event_dim=1): def __call__(self, x): jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy - zerosum_true = [] + zerosum_true = True for dim in range(-self.event_dim, 0): - zerosum_true.append(jnp.allclose(x.sum(dim), 0, rtol=0.05, atol=1e-2)) + zerosum_true = zerosum_true & jnp.allclose(x.sum(dim), 0, rtol=0.05, atol=1e-2) return all(zerosum_true) def __eq__(self, other): From ae1586fa3d0d488cc47cf94ae1d094bd9aa0d6dd Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 15:05:30 -0400 Subject: [PATCH 25/42] removed unneeded iteration, updated docstring --- numpyro/distributions/continuous.py | 14 +++++++------- numpyro/distributions/transforms.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 2895bfaf9..44b9a1c6c 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2456,11 +2456,7 @@ def model(category_ind, y): # category_ind is an indexed categorical variable wi with numpyro.plate("observations", N): mu = alpha + beta[category_ind] obs = numpyro.sample("obs", dist.Normal(mu, sigma), obs=y) - return obs - - :param array_like scale: Standard deviation of the underlying normal distribution before the zerosum constraint is - enforced. - :param tuple event_shape: The event shape of the distribution, the axes of which get constrained to sum to zero. + return obs .. math:: \begin{align*} @@ -2469,6 +2465,10 @@ def model(category_ind, y): # category_ind is an indexed categorical variable wi n = \text{number of zero-sum axes} \end{align*} + :param array_like scale: Standard deviation of the underlying normal distribution before the zerosum constraint is + enforced. + :param tuple event_shape: The event shape of the distribution, the axes of which get constrained to sum to zero. + **References** [1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637 [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html @@ -2483,7 +2483,7 @@ def __init__(self, scale, event_shape, *, validate_args=None): if jnp.ndim(scale) == 0: (scale,) = promote_shapes(scale, shape=(1,)) transformed_shape = tuple(size - 1 for size in event_shape) - zero_sum_axes = tuple(i for i in range(-event_ndim, 0)) + zero_sum_axes = tuple(range(-event_ndim, 0)) self.scale = scale super().__init__( Normal(0, scale).expand(transformed_shape).to_event(event_ndim), @@ -2502,7 +2502,7 @@ def mean(self): @property def variance(self): event_ndim = len(self.event_shape) - zero_sum_axes = tuple(i for i in range(-event_ndim, 0)) + zero_sum_axes = tuple(range(-event_ndim, 0)) theoretical_var = self.scale.astype(float) ** 2 for axis in zero_sum_axes: theoretical_var *= 1 - 1 / self.event_shape[axis] diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index e201cea0d..f51b1af90 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1600,5 +1600,5 @@ def _transform_to_simplex(constraint): @biject_to.register(constraints.zero_sum) def _transform_to_zero_sum(constraint): - zero_sum_axes = tuple(i for i in range(-constraint.event_dim, 0)) + zero_sum_axes = tuple(range(-constraint.event_dim, 0)) return ZeroSumTransform(zero_sum_axes).inv From ab582166542cb28a40a007bfd42443cfc49faf3c Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 15:15:34 -0400 Subject: [PATCH 26/42] updated constraint code --- numpyro/distributions/constraints.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 15ea92e84..e09a21fa7 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -707,8 +707,10 @@ def __call__(self, x): jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy zerosum_true = True for dim in range(-self.event_dim, 0): - zerosum_true = zerosum_true & jnp.allclose(x.sum(dim), 0, rtol=0.05, atol=1e-2) - return all(zerosum_true) + zerosum_true = zerosum_true & jnp.allclose( + x.sum(dim), 0, rtol=0.05, atol=1e-2 + ) + return zerosum_true def __eq__(self, other): return type(self) is type(other) and self.event_dim == other.event_dim From ad4e7c2593c8d12e745baf1f26bfe5e6886e4197 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 15:21:15 -0400 Subject: [PATCH 27/42] added ZeroSumTransform to docs --- docs/source/distributions.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 98348b9ec..a9c0ec71b 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -1021,6 +1021,15 @@ StickBreakingTransform :show-inheritance: :member-order: bysource +ZeroSumTransform +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: numpyro.distributions.transforms.ZeroSumTransform + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + Flows ----- From 54547f28d7bd158831c1ab9285e8c104b3550dc2 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 17:02:43 -0400 Subject: [PATCH 28/42] fixed transform shapes --- numpyro/distributions/continuous.py | 3 +- numpyro/distributions/transforms.py | 45 ++++++++++++++++++++--------- test/test_transforms.py | 3 ++ 3 files changed, 35 insertions(+), 16 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 44b9a1c6c..d798e77be 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2483,11 +2483,10 @@ def __init__(self, scale, event_shape, *, validate_args=None): if jnp.ndim(scale) == 0: (scale,) = promote_shapes(scale, shape=(1,)) transformed_shape = tuple(size - 1 for size in event_shape) - zero_sum_axes = tuple(range(-event_ndim, 0)) self.scale = scale super().__init__( Normal(0, scale).expand(transformed_shape).to_event(event_ndim), - ZeroSumTransform(zero_sum_axes).inv, + ZeroSumTransform(event_ndim).inv, validate_args=validate_args, ) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index f51b1af90..d7f798ad2 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1391,24 +1391,26 @@ class ZeroSumTransform(Transform): [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/ """ - def __init__(self, zerosum_axes): - self.zerosum_axes = zerosum_axes + def __init__(self, transform_ndims=1): + self.transform_ndims = transform_ndims @property - def domain(self): - return constraints.independent(constraints.real, len(self.zerosum_axes)) + def domain(self) -> constraints.Constraint: + return constraints.independent(constraints.real, self.transform_ndims) @property - def codomain(self): - return constraints.zero_sum(len(self.zerosum_axes)) + def codomain(self) -> constraints.Constraint: + return constraints.zero_sum(self.transform_ndims) def __call__(self, x): - for axis in self.zerosum_axes: + zero_sum_axes = tuple(range(-self.transform_ndims, 0)) + for axis in zero_sum_axes: x = self.extend_axis_rev(x, axis=axis) return x def _inverse(self, y): - for axis in self.zerosum_axes: + zero_sum_axes = tuple(range(-self.transform_ndims, 0)) + for axis in zero_sum_axes: y = self.extend_axis(y, axis=axis) return y @@ -1434,16 +1436,32 @@ def extend_axis(self, array, axis): return out - norm def log_abs_det_jacobian(self, x, y, intermediates=None): - return jnp.array(0.0) + shape = jnp.broadcast_shapes( + x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] + ) + return jnp.zeros_like(x, shape=shape) def forward_shape(self, shape): - return tuple(s - 1 for s in shape) + return shape[: -self.transform_ndims] + tuple( + s - 1 for s in shape[-self.transform_ndims :] + ) def inverse_shape(self, shape): - return tuple(s + 1 for s in shape) + return shape[: -self.transform_ndims] + tuple( + s + 1 for s in shape[-self.transform_ndims :] + ) def tree_flatten(self): - return (self.zerosum_axes,), (("zerosum_axes",), dict()) + aux_data = { + "transform_ndims": self.transform_ndims, + } + return (), ((), aux_data) + + def __eq__(self, other): + return ( + isinstance(other, ZeroSumTransform) + and self.transform_ndims == other.transform_ndims + ) ########################################################## @@ -1600,5 +1618,4 @@ def _transform_to_simplex(constraint): @biject_to.register(constraints.zero_sum) def _transform_to_zero_sum(constraint): - zero_sum_axes = tuple(range(-constraint.event_dim, 0)) - return ZeroSumTransform(zero_sum_axes).inv + return ZeroSumTransform(constraint.event_dim).inv diff --git a/test/test_transforms.py b/test/test_transforms.py index 1a706bbc6..7771f3fed 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -40,6 +40,7 @@ SoftplusTransform, StickBreakingTransform, UnpackTransform, + ZeroSumTransform, biject_to, ) @@ -134,6 +135,7 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])): "reshape": T( ReshapeTransform, (), {"forward_shape": (3, 4), "inverse_shape": (4, 3)} ), + "zero_sum": T(ZeroSumTransform, (), dict(transform_ndims=1)), } @@ -296,6 +298,7 @@ def test_real_fast_fourier_transform(input_shape, shape, ndims): (SoftplusLowerCholeskyTransform(), (10,)), (SoftplusTransform(), ()), (StickBreakingTransform(), (11,)), + (ZeroSumTransform(1).inv, (5,)), ], ) def test_bijective_transforms(transform, shape): From bdc6480da7b6fb07ab8aa81aa40afb14e453b0c2 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 17:24:42 -0400 Subject: [PATCH 29/42] added doctest example for zsn --- numpyro/distributions/continuous.py | 52 ++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 12 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index d798e77be..763ee876e 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2446,18 +2446,6 @@ class ZeroSumNormal(TransformedDistribution): Zero Sum Normal distribution adapted from PyMC [1] as described in [2,3]. This is a Normal distribution where one or more axes are constrained to sum to zero (the last axis by default). - Sample code for using ZeroSumNormal in the context of a single axis to zero-contrain:: - - def model(category_ind, y): # category_ind is an indexed categorical variable with 20 categories - N = len(category_ind) - alpha = numpyro.sample("alpha", dist.Normal(0, 2.5)) - beta = numpyro.sample("beta", dist.ZeroSumNormal(1, event_shape=(20,))) - sigma = numpyro.sample("sigma", dist.Exponential(1)) - with numpyro.plate("observations", N): - mu = alpha + beta[category_ind] - obs = numpyro.sample("obs", dist.Normal(mu, sigma), obs=y) - return obs - .. math:: \begin{align*} ZSN(\sigma) = N(0, \sigma^2 (I - \tfrac{1}{n}J)) \\ @@ -2469,6 +2457,46 @@ def model(category_ind, y): # category_ind is an indexed categorical variable wi enforced. :param tuple event_shape: The event shape of the distribution, the axes of which get constrained to sum to zero. + **Example:** + + .. doctest:: + + >>> from numpy.testing import assert_allclose + >>> from jax import random + >>> import jax.numpy as jnp + >>> import numpyro + >>> import numpyro.distributions as dist + >>> from numpyro.infer import MCMC, NUTS + + >>> N = 1000 + >>> n_categories = 20 + >>> rng_key = random.PRNGKey(0) + >>> key1, key2, key3 = random.split(rng_key, 3) + >>> category_ind = random.choice(key1, jnp.arange(n_categories), shape=(N,)) + >>> beta = random.normal(key2, shape=(n_categories,)) + >>> beta -= beta.mean(-1) + >>> y = 5 + beta[category_ind] + random.normal(key3, shape=(N,)) + + >>> def model(category_ind, y): # category_ind is an indexed categorical variable with 20 categories + ... N = len(category_ind) + ... alpha = numpyro.sample("alpha", dist.Normal(0, 2.5)) + ... beta = numpyro.sample("beta", dist.ZeroSumNormal(1, event_shape=(n_categories,))) + ... sigma = numpyro.sample("sigma", dist.Exponential(1)) + ... with numpyro.plate("observations", N): + ... mu = alpha + beta[category_ind] + ... obs = numpyro.sample("obs", dist.Normal(mu, sigma), obs=y) + ... return obs + + >>> nuts_kernel = NUTS(model=model, target_accept_prob=0.9) + >>> mcmc = MCMC( + >>> sampler=nuts_kernel, + >>> num_samples=1_000, num_warmup=1_000, num_chains=4 + >>> ) + >>> mcmc.run(random.PRNGKey(0), category_ind=category_ind, y=y) + >>> posterior_samples = mcmc.get_samples() + >>> # Confirm everything along last axis sums to zero + >>> assert_allclose(posterior_samples['beta'].sum(-1), 0, atol=1e-3) + **References** [1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637 [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html From 0b5070b912d148d60d8a7b60379a457f5b561ed9 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 18:52:23 -0400 Subject: [PATCH 30/42] added constraint test --- numpyro/distributions/constraints.py | 5 +---- test/test_constraints.py | 1 + 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index e09a21fa7..b077b75f2 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -719,10 +719,7 @@ def feasible_like(self, prototype): return jax.numpy.broadcast_to(0, prototype.shape) def tree_flatten(self): - return (self.event_dim), ( - ("event_dim"), - dict(), - ) + return (self.event_dim,), (("event_dim",), dict()) # TODO: Make types consistent diff --git a/test/test_constraints.py b/test/test_constraints.py index 735969fa6..fb34bf6b8 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -62,6 +62,7 @@ class T(namedtuple("TestCase", ["constraint_cls", "params", "kwargs"])): dict(), ), "open_interval": T(constraints.open_interval, (_a(-1.0), _a(1.0)), dict()), + "zero_sum": T(constraints.zero_sum, (), dict(event_dim=1)), } # TODO: BijectorConstraint From b1129bf7d6635386165df62cc39161d5a9e7de1c Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 18:59:39 -0400 Subject: [PATCH 31/42] added zero_sum constraint to docs --- docs/source/distributions.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index a9c0ec71b..cde5eb7b6 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -827,6 +827,9 @@ unit_interval ^^^^^^^^^^^^^ .. autodata:: numpyro.distributions.constraints.unit_interval +zero_sum +^^^^^^^^^^^^^ +.. autodata:: numpyro.distributions.constraints.zero_sum Transforms ---------- From 5fcaf68eb52455201571439614e60fd6a7719b2f Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 19:11:16 -0400 Subject: [PATCH 32/42] added type hinting to transforms file --- numpyro/distributions/transforms.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index d7f798ad2..7c45ffd29 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1385,13 +1385,15 @@ def __eq__(self, other): class ZeroSumTransform(Transform): """A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3] + :param transform_ndims: Number of trailing dimensions to transform. + **References** [1] https://github.com/pymc-devs/pymc/blob/244fb97b01ad0f3dadf5c3837b65839e2a59a0e8/pymc/distributions/transforms.py#L266 [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/ """ - def __init__(self, transform_ndims=1): + def __init__(self, transform_ndims: int = 1) -> None: self.transform_ndims = transform_ndims @property @@ -1402,19 +1404,19 @@ def domain(self) -> constraints.Constraint: def codomain(self) -> constraints.Constraint: return constraints.zero_sum(self.transform_ndims) - def __call__(self, x): + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: zero_sum_axes = tuple(range(-self.transform_ndims, 0)) for axis in zero_sum_axes: x = self.extend_axis_rev(x, axis=axis) return x - def _inverse(self, y): + def _inverse(self, y: jnp.ndarray) -> jnp.ndarray: zero_sum_axes = tuple(range(-self.transform_ndims, 0)) for axis in zero_sum_axes: y = self.extend_axis(y, axis=axis) return y - def extend_axis_rev(self, array, axis): + def extend_axis_rev(self, array: jnp.ndarray, axis: int) -> jnp.ndarray: normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] n = array.shape[normalized_axis] @@ -1425,7 +1427,7 @@ def extend_axis_rev(self, array, axis): slice_before = (slice(None, None),) * normalized_axis return array[(*slice_before, slice(None, -1))] + norm - def extend_axis(self, array, axis): + def extend_axis(self, array: jnp.ndarray, axis: int) -> jnp.ndarray: n = array.shape[axis] + 1 sum_vals = array.sum(axis, keepdims=True) @@ -1435,18 +1437,20 @@ def extend_axis(self, array, axis): out = jnp.concatenate([array, fill_val], axis=axis) return out - norm - def log_abs_det_jacobian(self, x, y, intermediates=None): + def log_abs_det_jacobian( + self, x: jnp.ndarray, y: jnp.ndarray, intermediates: None = None + ) -> jnp.ndarray: shape = jnp.broadcast_shapes( x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] ) return jnp.zeros_like(x, shape=shape) - def forward_shape(self, shape): + def forward_shape(self, shape: tuple) -> tuple: return shape[: -self.transform_ndims] + tuple( s - 1 for s in shape[-self.transform_ndims :] ) - def inverse_shape(self, shape): + def inverse_shape(self, shape: tuple) -> tuple: return shape[: -self.transform_ndims] + tuple( s + 1 for s in shape[-self.transform_ndims :] ) From 619f90bdcad26ad55eec85a6bd013693b852139f Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 20:33:40 -0400 Subject: [PATCH 33/42] fixed docs formatting --- docs/source/distributions.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index cde5eb7b6..06b51c929 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -381,7 +381,7 @@ Weibull :member-order: bysource ZeroSumNormal -^^^^^^^ +^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.ZeroSumNormal :members: :undoc-members: @@ -828,7 +828,7 @@ unit_interval .. autodata:: numpyro.distributions.constraints.unit_interval zero_sum -^^^^^^^^^^^^^ +^^^^^^^^ .. autodata:: numpyro.distributions.constraints.zero_sum Transforms @@ -1025,7 +1025,7 @@ StickBreakingTransform :member-order: bysource ZeroSumTransform -^^^^^^^^^^^^^^^^^^^^^^^^ +^^^^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.transforms.ZeroSumTransform :members: From 2e796777b514541942139db3243e01f2abadc8ff Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 20:47:59 -0400 Subject: [PATCH 34/42] moved skip zsn from test_gof earlier --- test/test_distributions.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index bca28ce15..48ddf1e82 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1455,7 +1455,9 @@ def test_gof(jax_dist, sp_dist, params): d = jax_dist(*params) if d.event_dim > 1: pytest.skip("EulerMaruyama skip test when event shape is non-trivial.") - + if jax_dist is dist.ZeroSumNormal: + pytest.skip("skip gof test for ZeroSumNormal") + num_samples = 10000 if "BetaProportion" in jax_dist.__name__: num_samples = 20000 @@ -1468,9 +1470,6 @@ def test_gof(jax_dist, sp_dist, params): if jax_dist is dist.ProjectedNormal: dim = samples.shape[-1] - 1 - if jax_dist is dist.ZeroSumNormal: - pytest.skip("skip gof test for ZeroSumNormal") - # Test each batch independently. probs = probs.reshape(num_samples, -1) samples = samples.reshape(probs.shape + d.event_shape) From da382f58b6ac48d4bf4399429fed093d3ee86037 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Tue, 26 Mar 2024 21:08:24 -0400 Subject: [PATCH 35/42] reversed zerosumtransform --- numpyro/distributions/continuous.py | 2 +- numpyro/distributions/transforms.py | 8 ++++---- test/test_distributions.py | 2 +- test/test_transforms.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 763ee876e..8ede764a4 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2514,7 +2514,7 @@ def __init__(self, scale, event_shape, *, validate_args=None): self.scale = scale super().__init__( Normal(0, scale).expand(transformed_shape).to_event(event_ndim), - ZeroSumTransform(event_ndim).inv, + ZeroSumTransform(event_ndim), validate_args=validate_args, ) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 7c45ffd29..0a9d6ad23 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1407,13 +1407,13 @@ def codomain(self) -> constraints.Constraint: def __call__(self, x: jnp.ndarray) -> jnp.ndarray: zero_sum_axes = tuple(range(-self.transform_ndims, 0)) for axis in zero_sum_axes: - x = self.extend_axis_rev(x, axis=axis) + x = self.extend_axis(x, axis=axis) return x def _inverse(self, y: jnp.ndarray) -> jnp.ndarray: zero_sum_axes = tuple(range(-self.transform_ndims, 0)) for axis in zero_sum_axes: - y = self.extend_axis(y, axis=axis) + y = self.extend_axis_rev(y, axis=axis) return y def extend_axis_rev(self, array: jnp.ndarray, axis: int) -> jnp.ndarray: @@ -1447,12 +1447,12 @@ def log_abs_det_jacobian( def forward_shape(self, shape: tuple) -> tuple: return shape[: -self.transform_ndims] + tuple( - s - 1 for s in shape[-self.transform_ndims :] + s + 1 for s in shape[-self.transform_ndims :] ) def inverse_shape(self, shape: tuple) -> tuple: return shape[: -self.transform_ndims] + tuple( - s + 1 for s in shape[-self.transform_ndims :] + s - 1 for s in shape[-self.transform_ndims :] ) def tree_flatten(self): diff --git a/test/test_distributions.py b/test/test_distributions.py index 48ddf1e82..a2ef1bb0d 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1457,7 +1457,7 @@ def test_gof(jax_dist, sp_dist, params): pytest.skip("EulerMaruyama skip test when event shape is non-trivial.") if jax_dist is dist.ZeroSumNormal: pytest.skip("skip gof test for ZeroSumNormal") - + num_samples = 10000 if "BetaProportion" in jax_dist.__name__: num_samples = 20000 diff --git a/test/test_transforms.py b/test/test_transforms.py index 7771f3fed..261818429 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -298,7 +298,7 @@ def test_real_fast_fourier_transform(input_shape, shape, ndims): (SoftplusLowerCholeskyTransform(), (10,)), (SoftplusTransform(), ()), (StickBreakingTransform(), (11,)), - (ZeroSumTransform(1).inv, (5,)), + (ZeroSumTransform(1), (5,)), ], ) def test_bijective_transforms(transform, shape): From 5aa5aeb40d58f9dcfaabf0c42a8c6072e18cd83b Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Wed, 27 Mar 2024 08:58:45 -0400 Subject: [PATCH 36/42] broadcasted mean and var of zsn --- numpyro/distributions/continuous.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 8ede764a4..4522d5cc1 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2524,7 +2524,7 @@ def support(self): @property def mean(self): - return jnp.broadcast_to(0, self.batch_shape) + return jnp.broadcast_to(0, self.batch_shape + self.event_shape) @property def variance(self): @@ -2534,4 +2534,4 @@ def variance(self): for axis in zero_sum_axes: theoretical_var *= 1 - 1 / self.event_shape[axis] - return theoretical_var + return jnp.broadcast_to(theoretical_var, self.batch_shape + self.event_shape) From f7992d1d72437be8db5874599d4a07e540a615be Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Wed, 27 Mar 2024 20:44:00 -0400 Subject: [PATCH 37/42] added stricter zero_sum constraint tol, improved mean and var functions --- numpyro/distributions/constraints.py | 5 ++--- numpyro/distributions/continuous.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index b077b75f2..3913e683f 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -705,11 +705,10 @@ def __init__(self, event_dim=1): def __call__(self, x): jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy + tol = jnp.finfo(x.dtype).eps * x.shape[-1] * 10 zerosum_true = True for dim in range(-self.event_dim, 0): - zerosum_true = zerosum_true & jnp.allclose( - x.sum(dim), 0, rtol=0.05, atol=1e-2 - ) + zerosum_true = zerosum_true & jnp.allclose(x.sum(dim), 0, atol=tol) return zerosum_true def __eq__(self, other): diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 4522d5cc1..80680c9c3 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2524,13 +2524,13 @@ def support(self): @property def mean(self): - return jnp.broadcast_to(0, self.batch_shape + self.event_shape) + return jnp.zeros(self.batch_shape + self.event_shape) @property def variance(self): event_ndim = len(self.event_shape) zero_sum_axes = tuple(range(-event_ndim, 0)) - theoretical_var = self.scale.astype(float) ** 2 + theoretical_var = jnp.square(self.scale) for axis in zero_sum_axes: theoretical_var *= 1 - 1 / self.event_shape[axis] From 1e77815c9cda803c12714a0f154eed61f88fa912 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Thu, 28 Mar 2024 14:42:14 -0400 Subject: [PATCH 38/42] fixed _transform_to_zero_sum --- numpyro/distributions/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/transforms.py b/numpyro/distributions/transforms.py index 0a9d6ad23..a057d86d2 100644 --- a/numpyro/distributions/transforms.py +++ b/numpyro/distributions/transforms.py @@ -1622,4 +1622,4 @@ def _transform_to_simplex(constraint): @biject_to.register(constraints.zero_sum) def _transform_to_zero_sum(constraint): - return ZeroSumTransform(constraint.event_dim).inv + return ZeroSumTransform(constraint.event_dim) From 98f32f9560f15365e322148b93e3fdaacb62d3dd Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Thu, 28 Mar 2024 14:53:50 -0400 Subject: [PATCH 39/42] removed shape promote from zsn, changed broadcast to zeros_like --- numpyro/distributions/constraints.py | 2 +- numpyro/distributions/continuous.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 3913e683f..80eac5547 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -715,7 +715,7 @@ def __eq__(self, other): return type(self) is type(other) and self.event_dim == other.event_dim def feasible_like(self, prototype): - return jax.numpy.broadcast_to(0, prototype.shape) + return jax.numpy.zeros_like(0, prototype.shape) def tree_flatten(self): return (self.event_dim,), (("event_dim",), dict()) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 80680c9c3..5fd00090e 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2508,8 +2508,6 @@ class ZeroSumNormal(TransformedDistribution): def __init__(self, scale, event_shape, *, validate_args=None): event_ndim = len(event_shape) - if jnp.ndim(scale) == 0: - (scale,) = promote_shapes(scale, shape=(1,)) transformed_shape = tuple(size - 1 for size in event_shape) self.scale = scale super().__init__( From c639e706e1e0c655ce5f2fe19284943baeea4489 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Thu, 28 Mar 2024 15:28:17 -0400 Subject: [PATCH 40/42] chose better zsn test cases --- test/test_distributions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index a2ef1bb0d..6ebb990f2 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -774,8 +774,8 @@ def get_sp_dist(jax_dist): T(dist.Weibull, 0.2, 1.1), T(dist.Weibull, 2.8, np.array([2.0, 2.0])), T(dist.Weibull, 1.8, np.array([[1.0, 1.0], [2.0, 2.0]])), - T(dist.ZeroSumNormal, 1.0, (1,)), - T(dist.ZeroSumNormal, np.array([2.0]), (1,)), + T(dist.ZeroSumNormal, 1.0, (5,)), + T(dist.ZeroSumNormal, np.array([2.0]), (5,)), T(dist.ZeroSumNormal, 1.0, (4, 5)), T( _GaussianMixture, From 8a7a9052cd53251e2ed7504bbc777c4c23f5b0e2 Mon Sep 17 00:00:00 2001 From: kylejcaron <44980552+kylejcaron@users.noreply.github.com> Date: Thu, 28 Mar 2024 16:32:54 -0400 Subject: [PATCH 41/42] Update zero_sum constraint feasible_like Co-authored-by: Till Hoffmann --- numpyro/distributions/constraints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 80eac5547..21dac6669 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -715,7 +715,7 @@ def __eq__(self, other): return type(self) is type(other) and self.event_dim == other.event_dim def feasible_like(self, prototype): - return jax.numpy.zeros_like(0, prototype.shape) + return jax.numpy.zeros_like(prototype) def tree_flatten(self): return (self.event_dim,), (("event_dim",), dict()) From d7f05ff49469582cd509382b176b64abbec61965 Mon Sep 17 00:00:00 2001 From: Kyle Caron Date: Fri, 29 Mar 2024 10:58:56 -0400 Subject: [PATCH 42/42] fixed docstring for doctests --- numpyro/distributions/continuous.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 5fd00090e..0e895c827 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2489,9 +2489,9 @@ class ZeroSumNormal(TransformedDistribution): >>> nuts_kernel = NUTS(model=model, target_accept_prob=0.9) >>> mcmc = MCMC( - >>> sampler=nuts_kernel, - >>> num_samples=1_000, num_warmup=1_000, num_chains=4 - >>> ) + ... sampler=nuts_kernel, + ... num_samples=1_000, num_warmup=1_000, num_chains=4 + ... ) >>> mcmc.run(random.PRNGKey(0), category_ind=category_ind, y=y) >>> posterior_samples = mcmc.get_samples() >>> # Confirm everything along last axis sums to zero