Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZeroSum bijector and ZeroSumNormal distribution #1980

Open
jeffpollock9 opened this issue Dec 12, 2024 · 0 comments
Open

ZeroSum bijector and ZeroSumNormal distribution #1980

jeffpollock9 opened this issue Dec 12, 2024 · 0 comments

Comments

@jeffpollock9
Copy link
Contributor

numpyro and pymc have a zero sum normal distribution based on a zero sum bijector, (see e.g. numpyro zero sum normal and numpyro zero sum transform)).

I was wondering if there is any appetite in adding this to TFP? I have already got a simple port working (needs some changes, in particular maybe allowing variable number of axes to be constrained to sum to zero):

"""ZeroSum bijector."""

import tensorflow as tf
from tensorflow_probability.python.bijectors import bijector
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import tensorshape_util


class ZeroSum(bijector.AutoCompositeTensorBijector):

    def __init__(self, validate_args=False, name="zero_sum"):
        parameters = dict(locals())
        super(ZeroSum, self).__init__(
            is_constant_jacobian=True,
            forward_min_event_ndims=1,
            validate_args=validate_args,
            parameters=parameters,
            name=name,
        )

    @classmethod
    def _parameter_properties(cls, dtype):
        return dict()

    def _forward(self, x):
        n = ps.cast(ps.shape(x)[-1], x.dtype) + 1
        sum_vals = tf.reduce_sum(x, axis=-1, keepdims=True)
        norm = sum_vals / (ps.sqrt(n) + n)
        fill_val = norm - sum_vals / ps.sqrt(n)
        out = tf.concat([x, fill_val], axis=-1)
        return out - norm

    def _inverse(self, y):
        normalized_axis = ps.rank(y) - 1
        n = ps.cast(ps.shape(y)[normalized_axis], y.dtype)
        last = y[..., -1]
        sum_vals = -last * ps.sqrt(n)
        norm = sum_vals / (ps.sqrt(n) + n)
        slice_before = (slice(None, None),) * normalized_axis
        return y[(*slice_before, slice(None, -1))] + norm

    def _inverse_log_det_jacobian(self, y):
        return tf.zeros([], dtype=y.dtype)

    def _forward_log_det_jacobian(self, x):
        return tf.zeros([], dtype=x.dtype)

    def _forward_event_shape(self, input_shape):
        return tensorshape_util.concatenate(input_shape[:-1], input_shape[-1] + 1)

    def _forward_event_shape_tensor(self, input_shape):
        n = ps.shape(input_shape)[-1]
        return ps.tensor_scatter_nd_add(input_shape, [[n - 1]], [1])

    def _inverse_event_shape(self, input_shape):
        return tensorshape_util.concatenate(input_shape[:-1], input_shape[-1] + 1)

    def _inverse_event_shape_tensor(self, input_shape):
        n = ps.shape(input_shape)[-1]
        return ps.tensor_scatter_nd_sub(input_shape, [[n - 1]], [1])

usage:

import numpy as np
import tensorflow_probability as tfp

tfd = tfp.distributions

zero_sum_normal = tfd.TransformedDistribution(
    distribution=tfd.MultivariateNormalDiag(loc=0.0, scale_diag=[1.0, 1.0]),
    bijector=ZeroSum(),
)
zero_sum_normal
# <tfp.distributions.TransformedDistribution 'zero_sumMultivariateNormalDiag' batch_shape=[] event_shape=[3] dtype=float32>

samples = zero_sum_normal.sample(int(1e7))

np.max(np.abs(np.sum(samples, axis=-1)))
# 4.7683716e-07

np.mean(samples, axis=0)
# array([ 4.8274879e-04, -5.6865485e-04,  8.5900021e-05], dtype=float32)

np.std(samples, axis=0)
# array([0.8100124, 0.8101918, 0.8100392], dtype=float32)

compare to numpyro:

import jax.numpy as jnp
import jax.random as jr
import numpyro.distributions as dist

zero_sum_normal = dist.ZeroSumNormal(scale=jnp.array(1.0), event_shape=[3])

rng = jr.key(123)

samples = zero_sum_normal.sample(rng, sample_shape=(int(1e7),))

jnp.max(jnp.abs(jnp.sum(samples, axis=1)))
# Array(5.9604645e-07, dtype=float32)

jnp.mean(samples, axis=0)
# Array([-1.7739683e-04, -6.7088688e-05,  2.4448565e-04], dtype=float32)

jnp.std(samples, axis=0)
# Array([0.8164292 , 0.81622946, 0.8164195 ], dtype=float32)

If this is something useful, I can work on bits of it over the next couple of weeks, or if someone else wants to take it over, that's great too.

Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant