Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ZeroSumNormal distribution #6121

Merged
merged 53 commits into from
Oct 7, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
6260c84
Use None as default value for zerosum_axes
AlexAndorra Sep 12, 2022
af96016
Add tests for ZSN
AlexAndorra Sep 12, 2022
71e5651
Reorder dispatched functions
AlexAndorra Sep 12, 2022
3cadb26
Test pylint
AlexAndorra Sep 12, 2022
a66c586
Ignore type check on normalize_axis_tuple
AlexAndorra Sep 12, 2022
e3be495
Disable mypy on import of normalize_axis_tuple
AlexAndorra Sep 12, 2022
759de36
Remove base class in tests
AlexAndorra Sep 12, 2022
a5a1e45
Use pytest parametrize
AlexAndorra Sep 15, 2022
c9eea6e
Add pm.draw to tests
AlexAndorra Sep 15, 2022
0582d7c
Test moment
AlexAndorra Sep 15, 2022
0bdcdd7
Add change size test
AlexAndorra Sep 15, 2022
854ef4c
Move ZSN to multivariate.py
AlexAndorra Sep 15, 2022
fd3aefa
Move ZSN tests to test_multivariate.py
AlexAndorra Sep 15, 2022
e94e4f1
Add check if zerosum_axes is iterable in dist method
AlexAndorra Sep 17, 2022
dec4a9f
Improve test_zsn_change_dist_size
AlexAndorra Sep 17, 2022
f7a55c5
Improve docstrings
AlexAndorra Sep 18, 2022
da6eaab
Refactor get_steps to work with multivariate support shapes
AlexAndorra Sep 27, 2022
a5ed1f0
Refactor ZSN dist and logp for rightmost zerosum_axes
AlexAndorra Sep 27, 2022
126e76b
Start writing __new__ method
AlexAndorra Sep 28, 2022
3a8d898
Handle single output and fix transform
AlexAndorra Sep 28, 2022
4c52737
Fix indexing of at.stack in get_support_shape
AlexAndorra Sep 28, 2022
7e4ed0a
Fix examples in ZSN docstrings
AlexAndorra Sep 28, 2022
44b5b91
Refactor test_zsn_dims_shape
AlexAndorra Sep 28, 2022
99dbb38
Refactor test_zsn_fail_axis
AlexAndorra Sep 28, 2022
e3dc1d4
Refactor test_zsn_change_dist_size
AlexAndorra Sep 29, 2022
09f0d91
Simplify test_zsn_dims_shape
AlexAndorra Sep 29, 2022
cf5b384
Refactor test_zsn_dims_shape
AlexAndorra Sep 29, 2022
3e86a3e
Fix get_support_shape
AlexAndorra Sep 29, 2022
ce68f02
Test support_shape handling
AlexAndorra Sep 29, 2022
09d849c
Merge branch 'main' into add-zerosumnormal
AlexAndorra Sep 29, 2022
b50909e
Remove TODO list comment
AlexAndorra Sep 29, 2022
c204131
Merge branch 'add-zerosumnormal' of https://github.com/pymc-devs/pymc…
AlexAndorra Sep 29, 2022
7ba1d0f
Add test of ZSN variance
AlexAndorra Sep 29, 2022
5ee950a
Remove unused imports
AlexAndorra Sep 30, 2022
95ffc94
Merge branch 'main' into add-zerosumnormal
AlexAndorra Sep 30, 2022
13a54e6
Replace get_steps by get_support_shape_1d in timeseries.py
AlexAndorra Sep 30, 2022
ca655bc
Split dims and shape test
AlexAndorra Sep 30, 2022
9d419ef
Fix test_get_support_shape_1d
AlexAndorra Sep 30, 2022
85da56c
Add test_get_support_shape
AlexAndorra Sep 30, 2022
f363118
Add ZSN logp test
AlexAndorra Oct 5, 2022
64eca5c
Fix test_inconsistent_steps_and_shape
AlexAndorra Oct 5, 2022
c5e76c9
Integrate review comments
AlexAndorra Oct 5, 2022
08c9df0
Solve freaking pre-commit issues
AlexAndorra Oct 5, 2022
c120f7e
Put assert_zerosum_axes at top of test class
AlexAndorra Oct 5, 2022
ba5f3a1
Improve error message of get_support_shape
AlexAndorra Oct 5, 2022
48dafe9
Nicer format for ZSN logp test
AlexAndorra Oct 5, 2022
6612a24
Increase tolerance for test_zsn_variance
AlexAndorra Oct 6, 2022
6b07a2a
Add ZSN to docs
AlexAndorra Oct 6, 2022
135ed47
Refactor ZSN docs
AlexAndorra Oct 6, 2022
cba0187
Better latex in ZSN docs
AlexAndorra Oct 6, 2022
566f308
Add ZeroSumTransform to docs
AlexAndorra Oct 7, 2022
5954e65
Remove mention of default value in ZS transform docs
AlexAndorra Oct 7, 2022
3e72922
Update pymc/distributions/transforms.py
ricardoV94 Oct 7, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
StickBreakingWeights,
Wishart,
WishartBartlett,
ZeroSumNormal,
)
from pymc.distributions.simulator import Simulator
from pymc.distributions.timeseries import (
Expand All @@ -115,8 +116,8 @@
"Uniform",
"Flat",
"HalfFlat",
"TruncatedNormal",
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved
"Normal",
"TruncatedNormal",
"Beta",
"Kumaraswamy",
"Exponential",
Expand Down Expand Up @@ -159,6 +160,7 @@
"Continuous",
"Discrete",
"MvNormal",
"ZeroSumNormal",
"MatrixNormal",
"KroneckerNormal",
"MvStudentT",
Expand Down
171 changes: 170 additions & 1 deletion pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from aesara.tensor.random.utils import broadcast_params
from aesara.tensor.slinalg import Cholesky, SolveTriangular
from aesara.tensor.type import TensorType
from numpy.core.numeric import normalize_axis_tuple
from scipy import linalg, stats

import pymc as pm
Expand Down Expand Up @@ -63,15 +64,17 @@
_change_dist_size,
broadcast_dist_samples_to,
change_dist_size,
convert_dims,
rv_size_is_none,
to_tuple,
)
from pymc.distributions.transforms import Interval, _default_transform
from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform
from pymc.math import kron_diag, kron_dot
from pymc.util import check_dist_not_registered

__all__ = [
"MvNormal",
"ZeroSumNormal",
"MvStudentT",
"Dirichlet",
"Multinomial",
Expand Down Expand Up @@ -2380,3 +2383,169 @@ def logp(value, alpha, K):
K > 0,
msg="alpha > 0, K > 0",
)


class ZeroSumNormalRV(SymbolicRandomVariable):
"""ZeroSumNormal random variable"""

_print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}")
zerosum_axes = None

def __init__(self, *args, zerosum_axes, **kwargs):
self.zerosum_axes = zerosum_axes
super().__init__(*args, **kwargs)


class ZeroSumNormal(Distribution):
r"""
ZeroSumNormal distribution, i.e Normal distribution where one or
several axes are constrained to sum to zero.
By default, the last axis is constrained to sum to zero.
See `zerosum_axes` kwarg for more details.
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
sigma : tensor_like of float
Standard deviation (sigma > 0).
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved
Defaults to 1 if not specified.
For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint.
zerosum_axes: list or tuple of strings or integers
Axis (or axes) along which the zero-sum constraint is enforced.
Defaults to [-1], i.e the last axis.
If strings are passed, then ``dims`` is needed.
Otherwise, ``shape`` and ``size`` work as they do for other PyMC distributions.
dims: list or tuple of strings, optional
The dimension names of the axes.
Necessary when ``zerosum_axes`` is specified with strings.

Warnings
--------
``sigma`` has to be a scalar, to ensure the zero-sum constraint.
The ability to specifiy a vector of ``sigma`` may be added in future versions.

Examples
--------
.. code-block:: python
COORDS = {
"regions": ["a", "b", "c"],
"answers": ["yes", "no", "whatever", "don't understand question"],
}
with pm.Model(coords=COORDS) as m:
...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes="answers")

with pm.Model(coords=COORDS) as m:
...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=("regions", "answers"))

with pm.Model(coords=COORDS) as m:
...: v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=1)
"""
rv_type = ZeroSumNormalRV

def __new__(cls, *args, zerosum_axes=None, dims=None, **kwargs):
dims = convert_dims(dims)
if zerosum_axes is None:
zerosum_axes = [-1]
if not isinstance(zerosum_axes, (list, tuple)):
zerosum_axes = [zerosum_axes]

if isinstance(zerosum_axes[0], str):
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might want to handle the case where zerosum_axes=[]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that just a Normal distribution?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, but if you write more general code that somehow produces the zerosum_axes automatically, that could easily happen I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh so you mean erroring out in that case, gotcha

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, erroring out would be ok, but can't we just handle that case correctly? I can't think of anything that should go wrong in this case. So why not just

if len(zerosum_axes) > 0 and isinstinace(zerosum_axes[0]), str):

And maybe add a test that checks if pm.ZeroSumNormal("y", zerosum_axes=[], shape=(3,)) or so works and gives us a normal dist.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh ok, I understand what you mean now. I'm curious what @ricardoV94 thinks, but I would prefer not to do that: if people wanna use a Normal they should just use pm.Normal, otherwise I find it makes code and models more confusing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I agree with you now @aseyboldt (mainly because that behavior would be consistent with other PyMC distributions' behavior).
However, this edge-case needs a bit more work right now, because at.stack in get_support_shape fails with an empty tuple.
As it's not a core feature of ZSN, I'm marking this as TODO. The code is actually already there, so if someone needs this in the future, there are already some breadcrumbs to get started.

if not dims:
raise ValueError("You need to specify dims if zerosum_axes are strings.")
else:
zerosum_axes_ = []
for axis in zerosum_axes:
zerosum_axes_.append(dims.index(axis))
zerosum_axes = zerosum_axes_
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved

return super().__new__(cls, *args, zerosum_axes=zerosum_axes, dims=dims, **kwargs)

@classmethod
def dist(cls, sigma=1, zerosum_axes=None, **kwargs):
if zerosum_axes is None:
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved
zerosum_axes = [-1]

sigma = at.as_tensor_variable(floatX(sigma))
if sigma.ndim > 0:
raise ValueError("sigma has to be a scalar")

return super().dist([sigma], zerosum_axes=zerosum_axes, **kwargs)

# TODO: This is if we want ZeroSum constraint on other dists than Normal
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved
# def dist(cls, dist, lower, upper, **kwargs):
# if not isinstance(dist, TensorVariable) or not isinstance(
# dist.owner.op, (RandomVariable, SymbolicRandomVariable)
# ):
# raise ValueError(
# f"Censoring dist must be a distribution created via the `.dist()` API, got {type(dist)}"
# )
# if dist.owner.op.ndim_supp > 0:
# raise NotImplementedError(
# "Censoring of multivariate distributions has not been implemented yet"
# )
# check_dist_not_registered(dist)
# return super().dist([dist, lower, upper], **kwargs)

@classmethod
def rv_op(cls, sigma, zerosum_axes, size=None):
if size is None:
zerosum_axes_ = np.asarray(zerosum_axes)
# just a placeholder size to infer minimum shape
size = np.ones(
max((max(np.abs(zerosum_axes_) - 1), max(zerosum_axes_))) + 1, dtype=int
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved
).tolist()

# check if zerosum_axes is valid
normalize_axis_tuple(zerosum_axes, len(size))

normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, size=size))
normal_dist_, sigma_ = normal_dist.type(), sigma.type()

# Zerosum-normaling is achieved by substracting the mean along the given zerosum_axes
zerosum_rv_ = normal_dist_
for axis in zerosum_axes:
zerosum_rv_ -= zerosum_rv_.mean(axis=axis, keepdims=True)

return ZeroSumNormalRV(
inputs=[normal_dist_, sigma_],
outputs=[zerosum_rv_],
zerosum_axes=zerosum_axes,
ndim_supp=0,
)(normal_dist, sigma)


@_change_dist_size.register(ZeroSumNormalRV)
def change_zerosum_size(op, normal_dist, new_size, expand=False):
normal_dist, sigma = normal_dist.owner.inputs
if expand:
new_size = tuple(new_size) + tuple(normal_dist.shape)
return ZeroSumNormal.rv_op(sigma=sigma, zerosum_axes=op.zerosum_axes, size=new_size)
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved


@_moment.register(ZeroSumNormalRV)
def zerosumnormal_moment(op, rv, *rv_inputs):
return at.zeros_like(rv)


@_default_transform.register(ZeroSumNormalRV)
def zerosum_default_transform(op, rv):
return ZeroSumTransform(op.zerosum_axes)


@_logprob.register(ZeroSumNormalRV)
def zerosumnormal_logp(op, values, normal_dist, sigma, **kwargs):
(value,) = values
shape = value.shape
_deg_free_shape = at.inc_subtensor(shape[at.as_tensor_variable(op.zerosum_axes)], -1)
_full_size = at.prod(shape)
_degrees_of_freedom = at.prod(_deg_free_shape)
zerosums = [
at.all(at.isclose(at.mean(value, axis=axis), 0, atol=1e-9)) for axis in op.zerosum_axes
]
# out = at.sum(
# pm.logp(dist, value) * _degrees_of_freedom / _full_size,
# axis=op.zerosum_axes,
# )
# figure out how dimensionality should be handled for logp
# for now, we assume ZSN is a scalar distribut, which is not correct
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved
out = pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size
return check_parameters(out, *zerosums, msg="at.mean(value, axis=zerosum_axes) == 0")
65 changes: 65 additions & 0 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
from aesara.graph import Op
from aesara.tensor import TensorVariable

# ignore mypy error because it somehow considers that
# "numpy.core.numeric has no attribute normalize_axis_tuple"
from numpy.core.numeric import normalize_axis_tuple # type: ignore

__all__ = [
"RVTransform",
"simplex",
Expand All @@ -39,6 +43,7 @@
"circular",
"CholeskyCovPacked",
"Chain",
"ZeroSumTransform",
AlexAndorra marked this conversation as resolved.
Show resolved Hide resolved
]


Expand Down Expand Up @@ -266,6 +271,66 @@ def bounds_fn(*rv_inputs):
super().__init__(args_fn=bounds_fn)


class ZeroSumTransform(RVTransform):
"""
Constrains the samples of a Normal distribution to sum to zero
twiecki marked this conversation as resolved.
Show resolved Hide resolved
along the user-provided ``zerosum_axes``.
By default (``zerosum_axes=[-1]``), the sum-to-zero constraint is imposed
on the last axis.
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
"""

name = "zerosum"

__props__ = ("zerosum_axes",)

def __init__(self, zerosum_axes):
"""
Parameters
----------
zerosum_axes : list of ints
Must be a list of integers (positive or negative).
By default (``zerosum_axes=[-1]``), the sum-to-zero constraint is imposed
on the last axis.
"""
self.zerosum_axes = zerosum_axes

def forward(self, value, *rv_inputs):
for axis in self.zerosum_axes:
value = extend_axis_rev(value, axis=axis)
return value

def backward(self, value, *rv_inputs):
for axis in self.zerosum_axes:
value = extend_axis(value, axis=axis)
return value

def log_jac_det(self, value, *rv_inputs):
return at.constant(0.0)


def extend_axis(array, axis):
n = array.shape[axis] + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could maybe add a comment here saying that this is using a householder reflection plus a projection operator to move forward from the constrained space onto the zero sum manifold. I’ll look up our notes and write something here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you find your notes @lucianopaz ?

sum_vals = array.sum(axis, keepdims=True)
norm = sum_vals / (np.sqrt(n) + n)
fill_val = norm - sum_vals / np.sqrt(n)

out = at.concatenate([array, fill_val], axis=axis)
return out - norm


def extend_axis_rev(array, axis):
normalized_axis = normalize_axis_tuple(axis, array.ndim)[0]

n = array.shape[normalized_axis]
last = at.take(array, [-1], axis=normalized_axis)

sum_vals = -last * np.sqrt(n)
norm = sum_vals / (np.sqrt(n) + n)
slice_before = (slice(None, None),) * normalized_axis

return array[slice_before + (slice(None, -1),)] + norm


log_exp_m1 = LogExpM1()
log_exp_m1.__doc__ = """
Instantiation of :class:`pymc.distributions.transforms.LogExpM1`
Expand Down
Loading