Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ZeroSumNormal distribution #1751

Merged
merged 43 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
fa63f9a
added zerosumnormal and tests
kylejcaron Feb 27, 2024
c28fd0c
added edge case handling for support shape
kylejcaron Feb 29, 2024
93bcf0f
removed commented out functions
kylejcaron Feb 29, 2024
d9f2b4e
added zerosumnormal to docs
kylejcaron Feb 29, 2024
0abb60b
fixed zerosumnormal support shape default
kylejcaron Feb 29, 2024
b28f38c
Added v1 of docstrings for zerosumnormal
kylejcaron Feb 29, 2024
4e1dd16
updated zsn docstring
kylejcaron Feb 29, 2024
8cd792c
improved init shape handling for zerosumnormal
kylejcaron Feb 29, 2024
dcbdd85
improved docstrings
kylejcaron Feb 29, 2024
13fff40
added ZeroSumTransform
kylejcaron Mar 5, 2024
514000c
made n_zerosum_axes an attribute for the zerosumtransform
kylejcaron Mar 5, 2024
d6315c3
removed commented out lines
kylejcaron Mar 5, 2024
907cd2e
added zerosumtransform class
kylejcaron Mar 7, 2024
fc3f053
switched zsn from ParameterFreeTransform to Transform
kylejcaron Mar 8, 2024
8187421
changed ZeroSumNormal to transformed distibutrion
kylejcaron Mar 25, 2024
0051342
changed input to tuple for _transform_to_zero_sum
kylejcaron Mar 25, 2024
1820a74
added forward and inverse shape to transform, fixed zero_sum constrai…
kylejcaron Mar 26, 2024
ee227bf
fixed failing zsn tests
kylejcaron Mar 26, 2024
bb4880c
added docstring, removed whitespace, fixed missing import
kylejcaron Mar 26, 2024
38b8f56
fixed allclose to be assert allclose
kylejcaron Mar 26, 2024
54533ff
Merge branch 'master' into zsn-dist
kylejcaron Mar 26, 2024
c8af390
linted and formatted
kylejcaron Mar 26, 2024
3034f4a
added sample code to docstring for zsn
kylejcaron Mar 26, 2024
ebdd309
updated docstring
kylejcaron Mar 26, 2024
8cb7a5f
removed list from ZeroSum constraint call
kylejcaron Mar 26, 2024
ae1586f
removed unneeded iteration, updated docstring
kylejcaron Mar 26, 2024
ab58216
updated constraint code
kylejcaron Mar 26, 2024
ad4e7c2
added ZeroSumTransform to docs
kylejcaron Mar 26, 2024
54547f2
fixed transform shapes
kylejcaron Mar 26, 2024
bdc6480
added doctest example for zsn
kylejcaron Mar 26, 2024
0b5070b
added constraint test
kylejcaron Mar 26, 2024
b1129bf
added zero_sum constraint to docs
kylejcaron Mar 26, 2024
5fcaf68
added type hinting to transforms file
kylejcaron Mar 26, 2024
619f90b
fixed docs formatting
kylejcaron Mar 27, 2024
2e79677
moved skip zsn from test_gof earlier
kylejcaron Mar 27, 2024
da382f5
reversed zerosumtransform
kylejcaron Mar 27, 2024
5aa5aeb
broadcasted mean and var of zsn
kylejcaron Mar 27, 2024
f7992d1
added stricter zero_sum constraint tol, improved mean and var functions
kylejcaron Mar 28, 2024
1e77815
fixed _transform_to_zero_sum
kylejcaron Mar 28, 2024
98f32f9
removed shape promote from zsn, changed broadcast to zeros_like
kylejcaron Mar 28, 2024
c639e70
chose better zsn test cases
kylejcaron Mar 28, 2024
8a7a905
Update zero_sum constraint feasible_like
kylejcaron Mar 28, 2024
d7f05ff
fixed docstring for doctests
kylejcaron Mar 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,13 @@ Weibull
:show-inheritance:
:member-order: bysource

ZeroSumNormal
^^^^^^^^^^^^^
.. autoclass:: numpyro.distributions.continuous.ZeroSumNormal
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

Discrete Distributions
----------------------
Expand Down Expand Up @@ -820,6 +827,9 @@ unit_interval
^^^^^^^^^^^^^
.. autodata:: numpyro.distributions.constraints.unit_interval

zero_sum
^^^^^^^^
.. autodata:: numpyro.distributions.constraints.zero_sum

Transforms
----------
Expand Down Expand Up @@ -1014,6 +1024,15 @@ StickBreakingTransform
:show-inheritance:
:member-order: bysource

ZeroSumTransform
^^^^^^^^^^^^^^^^

.. autoclass:: numpyro.distributions.transforms.ZeroSumTransform
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource


Flows
-----
Expand Down
2 changes: 2 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
StudentT,
Uniform,
Weibull,
ZeroSumNormal,
)
from numpyro.distributions.copula import GaussianCopula, GaussianCopulaBeta
from numpyro.distributions.directional import (
Expand Down Expand Up @@ -196,4 +197,5 @@
"ZeroInflatedDistribution",
"ZeroInflatedPoisson",
"ZeroInflatedNegativeBinomial2",
"ZeroSumNormal",
]
25 changes: 25 additions & 0 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"softplus_lower_cholesky",
"softplus_positive",
"unit_interval",
"zero_sum",
"Constraint",
]

Expand Down Expand Up @@ -697,6 +698,29 @@ def feasible_like(self, prototype):
return jax.numpy.full_like(prototype, prototype.shape[-1] ** (-0.5))


class _ZeroSum(Constraint):
def __init__(self, event_dim=1):
self.event_dim = event_dim
super().__init__()

def __call__(self, x):
jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy
tol = jnp.finfo(x.dtype).eps * x.shape[-1] * 10
zerosum_true = True
for dim in range(-self.event_dim, 0):
zerosum_true = zerosum_true & jnp.allclose(x.sum(dim), 0, atol=tol)
return zerosum_true

def __eq__(self, other):
return type(self) is type(other) and self.event_dim == other.event_dim

def feasible_like(self, prototype):
return jax.numpy.broadcast_to(0, prototype.shape)
kylejcaron marked this conversation as resolved.
Show resolved Hide resolved

def tree_flatten(self):
return (self.event_dim,), (("event_dim",), dict())


# TODO: Make types consistent
# See https://github.com/pytorch/pytorch/issues/50616

Expand Down Expand Up @@ -731,3 +755,4 @@ def feasible_like(self, prototype):
sphere = _Sphere()
unit_interval = _UnitInterval()
open_interval = _OpenInterval
zero_sum = _ZeroSum
97 changes: 97 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
ExpTransform,
PowerTransform,
SigmoidTransform,
ZeroSumTransform,
)
from numpyro.distributions.util import (
add_diag,
Expand Down Expand Up @@ -2438,3 +2439,99 @@ def cdf(self, value):

def icdf(self, value):
return self._ald.icdf(value)


class ZeroSumNormal(TransformedDistribution):
r"""
Zero Sum Normal distribution adapted from PyMC [1] as described in [2,3]. This is a Normal distribution where one or
kylejcaron marked this conversation as resolved.
Show resolved Hide resolved
more axes are constrained to sum to zero (the last axis by default).

.. math::
\begin{align*}
ZSN(\sigma) = N(0, \sigma^2 (I - \tfrac{1}{n}J)) \\
\text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\
n = \text{number of zero-sum axes}
\end{align*}

:param array_like scale: Standard deviation of the underlying normal distribution before the zerosum constraint is
enforced.
:param tuple event_shape: The event shape of the distribution, the axes of which get constrained to sum to zero.

**Example:**

.. doctest::

>>> from numpy.testing import assert_allclose
>>> from jax import random
>>> import jax.numpy as jnp
>>> import numpyro
>>> import numpyro.distributions as dist
>>> from numpyro.infer import MCMC, NUTS

>>> N = 1000
>>> n_categories = 20
>>> rng_key = random.PRNGKey(0)
>>> key1, key2, key3 = random.split(rng_key, 3)
>>> category_ind = random.choice(key1, jnp.arange(n_categories), shape=(N,))
>>> beta = random.normal(key2, shape=(n_categories,))
>>> beta -= beta.mean(-1)
>>> y = 5 + beta[category_ind] + random.normal(key3, shape=(N,))

>>> def model(category_ind, y): # category_ind is an indexed categorical variable with 20 categories
... N = len(category_ind)
... alpha = numpyro.sample("alpha", dist.Normal(0, 2.5))
... beta = numpyro.sample("beta", dist.ZeroSumNormal(1, event_shape=(n_categories,)))
... sigma = numpyro.sample("sigma", dist.Exponential(1))
... with numpyro.plate("observations", N):
... mu = alpha + beta[category_ind]
... obs = numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)
... return obs

>>> nuts_kernel = NUTS(model=model, target_accept_prob=0.9)
>>> mcmc = MCMC(
>>> sampler=nuts_kernel,
>>> num_samples=1_000, num_warmup=1_000, num_chains=4
>>> )
>>> mcmc.run(random.PRNGKey(0), category_ind=category_ind, y=y)
>>> posterior_samples = mcmc.get_samples()
>>> # Confirm everything along last axis sums to zero
>>> assert_allclose(posterior_samples['beta'].sum(-1), 0, atol=1e-3)

**References**
[1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637
[2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html
[3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/
"""

arg_constraints = {"scale": constraints.positive}
reparametrized_params = ["scale"]

def __init__(self, scale, event_shape, *, validate_args=None):
event_ndim = len(event_shape)
if jnp.ndim(scale) == 0:
kylejcaron marked this conversation as resolved.
Show resolved Hide resolved
(scale,) = promote_shapes(scale, shape=(1,))
transformed_shape = tuple(size - 1 for size in event_shape)
self.scale = scale
kylejcaron marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(
Normal(0, scale).expand(transformed_shape).to_event(event_ndim),
ZeroSumTransform(event_ndim),
validate_args=validate_args,
)

@constraints.dependent_property(is_discrete=False)
def support(self):
return constraints.zero_sum(len(self.event_shape))

@property
def mean(self):
return jnp.zeros(self.batch_shape + self.event_shape)

@property
def variance(self):
event_ndim = len(self.event_shape)
zero_sum_axes = tuple(range(-event_ndim, 0))
theoretical_var = jnp.square(self.scale)
for axis in zero_sum_axes:
theoretical_var *= 1 - 1 / self.event_shape[axis]

return jnp.broadcast_to(theoretical_var, self.batch_shape + self.event_shape)
93 changes: 93 additions & 0 deletions numpyro/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import weakref

import numpy as np
from numpy.core.numeric import normalize_axis_tuple

from jax import lax, vmap
from jax.flatten_util import ravel_pytree
Expand Down Expand Up @@ -50,6 +51,7 @@
"StickBreakingTransform",
"Transform",
"UnpackTransform",
"ZeroSumTransform",
]


Expand Down Expand Up @@ -1380,6 +1382,92 @@ def __eq__(self, other):
return jnp.array_equal(self.transition_matrix, other.transition_matrix)


class ZeroSumTransform(Transform):
"""A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3]
Copy link
Contributor Author

@kylejcaron kylejcaron Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexAndorra @aseyboldt @ricardoV94 same as I said above, this PR is nearing ready to go - let me know if there's more I can add to properly credit all of you and pymc


:param transform_ndims: Number of trailing dimensions to transform.

**References**
[1] https://github.com/pymc-devs/pymc/blob/244fb97b01ad0f3dadf5c3837b65839e2a59a0e8/pymc/distributions/transforms.py#L266
[2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html
[3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/
"""

def __init__(self, transform_ndims: int = 1) -> None:
self.transform_ndims = transform_ndims

@property
def domain(self) -> constraints.Constraint:
return constraints.independent(constraints.real, self.transform_ndims)

@property
def codomain(self) -> constraints.Constraint:
return constraints.zero_sum(self.transform_ndims)

def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
zero_sum_axes = tuple(range(-self.transform_ndims, 0))
for axis in zero_sum_axes:
x = self.extend_axis(x, axis=axis)
return x

def _inverse(self, y: jnp.ndarray) -> jnp.ndarray:
zero_sum_axes = tuple(range(-self.transform_ndims, 0))
for axis in zero_sum_axes:
y = self.extend_axis_rev(y, axis=axis)
return y

def extend_axis_rev(self, array: jnp.ndarray, axis: int) -> jnp.ndarray:
normalized_axis = normalize_axis_tuple(axis, array.ndim)[0]

n = array.shape[normalized_axis]
last = jnp.take(array, jnp.array([-1]), axis=normalized_axis)

sum_vals = -last * jnp.sqrt(n)
norm = sum_vals / (jnp.sqrt(n) + n)
slice_before = (slice(None, None),) * normalized_axis
return array[(*slice_before, slice(None, -1))] + norm

def extend_axis(self, array: jnp.ndarray, axis: int) -> jnp.ndarray:
n = array.shape[axis] + 1

sum_vals = array.sum(axis, keepdims=True)
norm = sum_vals / (jnp.sqrt(n) + n)
fill_val = norm - sum_vals / jnp.sqrt(n)

out = jnp.concatenate([array, fill_val], axis=axis)
return out - norm

def log_abs_det_jacobian(
self, x: jnp.ndarray, y: jnp.ndarray, intermediates: None = None
) -> jnp.ndarray:
shape = jnp.broadcast_shapes(
x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims]
)
return jnp.zeros_like(x, shape=shape)

def forward_shape(self, shape: tuple) -> tuple:
return shape[: -self.transform_ndims] + tuple(
s + 1 for s in shape[-self.transform_ndims :]
)

def inverse_shape(self, shape: tuple) -> tuple:
return shape[: -self.transform_ndims] + tuple(
s - 1 for s in shape[-self.transform_ndims :]
)

def tree_flatten(self):
aux_data = {
"transform_ndims": self.transform_ndims,
}
return (), ((), aux_data)

def __eq__(self, other):
return (
isinstance(other, ZeroSumTransform)
and self.transform_ndims == other.transform_ndims
)


##########################################################
# CONSTRAINT_REGISTRY
##########################################################
Expand Down Expand Up @@ -1530,3 +1618,8 @@ def _transform_to_softplus_lower_cholesky(constraint):
@biject_to.register(constraints.simplex)
def _transform_to_simplex(constraint):
return StickBreakingTransform()


@biject_to.register(constraints.zero_sum)
def _transform_to_zero_sum(constraint):
return ZeroSumTransform(constraint.event_dim).inv
kylejcaron marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions test/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class T(namedtuple("TestCase", ["constraint_cls", "params", "kwargs"])):
dict(),
),
"open_interval": T(constraints.open_interval, (_a(-1.0), _a(1.0)), dict()),
"zero_sum": T(constraints.zero_sum, (), dict(event_dim=1)),
}

# TODO: BijectorConstraint
Expand Down
Loading