Skip to content

Commit

Permalink
Add ZeroSumNormal distribution (#6121)
Browse files Browse the repository at this point in the history
Also:
* Refactor get_steps to work with multivariate support shapes
* Replace get_steps by get_support_shape_1d in timeseries.py

Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
  • Loading branch information
AlexAndorra and ricardoV94 authored Oct 7, 2022
1 parent faebc60 commit 9aeb6b5
Show file tree
Hide file tree
Showing 10 changed files with 769 additions and 141 deletions.
1 change: 1 addition & 0 deletions docs/source/api/distributions/multivariate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Multivariate

MvNormal
MvStudentT
ZeroSumNormal
Dirichlet
Multinomial
DirichletMultinomial
Expand Down
1 change: 1 addition & 0 deletions docs/source/api/distributions/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Specific Transform Classes
LogExpM1
Ordered
SumTo1
ZeroSumTransform


Transform Composition Classes
Expand Down
4 changes: 3 additions & 1 deletion pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
StickBreakingWeights,
Wishart,
WishartBartlett,
ZeroSumNormal,
)
from pymc.distributions.simulator import Simulator
from pymc.distributions.timeseries import (
Expand All @@ -116,8 +117,8 @@
"Uniform",
"Flat",
"HalfFlat",
"TruncatedNormal",
"Normal",
"TruncatedNormal",
"Beta",
"Kumaraswamy",
"Exponential",
Expand Down Expand Up @@ -160,6 +161,7 @@
"Continuous",
"Discrete",
"MvNormal",
"ZeroSumNormal",
"MatrixNormal",
"KroneckerNormal",
"MvStudentT",
Expand Down
207 changes: 206 additions & 1 deletion pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import warnings

from functools import reduce
from typing import Optional

import aesara
import aesara.tensor as at
Expand Down Expand Up @@ -63,15 +64,17 @@
_change_dist_size,
broadcast_dist_samples_to,
change_dist_size,
get_support_shape,
rv_size_is_none,
to_tuple,
)
from pymc.distributions.transforms import Interval, _default_transform
from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform
from pymc.math import kron_diag, kron_dot
from pymc.util import check_dist_not_registered

__all__ = [
"MvNormal",
"ZeroSumNormal",
"MvStudentT",
"Dirichlet",
"Multinomial",
Expand Down Expand Up @@ -2380,3 +2383,205 @@ def logp(value, alpha, K):
K > 0,
msg="alpha > 0, K > 0",
)


class ZeroSumNormalRV(SymbolicRandomVariable):
"""ZeroSumNormal random variable"""

_print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}")
default_output = 0


class ZeroSumNormal(Distribution):
r"""
ZeroSumNormal distribution, i.e Normal distribution where one or
several axes are constrained to sum to zero.
By default, the last axis is constrained to sum to zero.
See `zerosum_axes` kwarg for more details.
.. math::
\begin{align*}
ZSN(\sigma) = N \Big( 0, \sigma^2 (I - \tfrac{1}{n}J) \Big) \\
\text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\
n = \text{nbr of zero-sum axes}
\end{align*}
Parameters
----------
sigma : tensor_like of float
Scale parameter (sigma > 0).
It's actually the standard deviation of the underlying, unconstrained Normal distribution.
Defaults to 1 if not specified.
For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint.
zerosum_axes: int, defaults to 1
Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position.
Defaults to 1, i.e the rightmost axis.
dims: sequence of strings, optional
Dimension names of the distribution. Works the same as for other PyMC distributions.
Necessary if ``shape`` is not passed.
shape: tuple of integers, optional
Shape of the distribution. Works the same as for other PyMC distributions.
Necessary if ``dims`` or ``observed`` is not passed.
Warnings
--------
``sigma`` has to be a scalar, to ensure the zero-sum constraint.
The ability to specifiy a vector of ``sigma`` may be added in future versions.
``zerosum_axes`` has to be > 0. If you want the behavior of ``zerosum_axes = 0``,
just use ``pm.Normal``.
Examples
--------
Define a `ZeroSumNormal` variable, with `sigma=1` and
`zerosum_axes=1` by default::
COORDS = {
"regions": ["a", "b", "c"],
"answers": ["yes", "no", "whatever", "don't understand question"],
}
with pm.Model(coords=COORDS) as m:
# the zero sum axis will be 'answers'
v = pm.ZeroSumNormal("v", dims=("regions", "answers"))
with pm.Model(coords=COORDS) as m:
# the zero sum axes will be 'answers' and 'regions'
v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=2)
with pm.Model(coords=COORDS) as m:
# the zero sum axes will be the last two
v = pm.ZeroSumNormal("v", shape=(3, 4, 5), zerosum_axes=2)
"""
rv_type = ZeroSumNormalRV

def __new__(cls, *args, zerosum_axes=None, support_shape=None, dims=None, **kwargs):
if dims is not None or kwargs.get("observed") is not None:
zerosum_axes = cls.check_zerosum_axes(zerosum_axes)

support_shape = get_support_shape(
support_shape=support_shape,
shape=None, # Shape will be checked in `cls.dist`
dims=dims,
observed=kwargs.get("observed", None),
ndim_supp=zerosum_axes,
)

return super().__new__(
cls, *args, zerosum_axes=zerosum_axes, support_shape=support_shape, dims=dims, **kwargs
)

@classmethod
def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
zerosum_axes = cls.check_zerosum_axes(zerosum_axes)

sigma = at.as_tensor_variable(floatX(sigma))
if sigma.ndim > 0:
raise ValueError("sigma has to be a scalar")

support_shape = get_support_shape(
support_shape=support_shape,
shape=kwargs.get("shape"),
ndim_supp=zerosum_axes,
)

if support_shape is None:
if zerosum_axes > 0:
raise ValueError("You must specify dims, shape or support_shape parameter")
# TODO: edge-case doesn't work for now, because at.stack in get_support_shape fails
# else:
# support_shape = () # because it's just a Normal in that case
support_shape = at.as_tensor_variable(intX(support_shape))

assert zerosum_axes == at.get_vector_length(
support_shape
), "support_shape has to be as long as zerosum_axes"

return super().dist(
[sigma], zerosum_axes=zerosum_axes, support_shape=support_shape, **kwargs
)

@classmethod
def check_zerosum_axes(cls, zerosum_axes: Optional[int]) -> int:
if zerosum_axes is None:
zerosum_axes = 1
if not isinstance(zerosum_axes, int):
raise TypeError("zerosum_axes has to be an integer")
if not zerosum_axes > 0:
raise ValueError("zerosum_axes has to be > 0")
return zerosum_axes

@classmethod
def rv_op(cls, sigma, zerosum_axes, support_shape, size=None):

shape = to_tuple(size) + tuple(support_shape)
normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, shape=shape))

if zerosum_axes > normal_dist.ndim:
raise ValueError("Shape of distribution is too small for the number of zerosum axes")

normal_dist_, sigma_, support_shape_ = (
normal_dist.type(),
sigma.type(),
support_shape.type(),
)

# Zerosum-normaling is achieved by substracting the mean along the given zerosum_axes
zerosum_rv_ = normal_dist_
for axis in range(zerosum_axes):
zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True)

return ZeroSumNormalRV(
inputs=[normal_dist_, sigma_, support_shape_],
outputs=[zerosum_rv_, support_shape_],
ndim_supp=zerosum_axes,
)(normal_dist, sigma, support_shape)


@_change_dist_size.register(ZeroSumNormalRV)
def change_zerosum_size(op, normal_dist, new_size, expand=False):

normal_dist, sigma, support_shape = normal_dist.owner.inputs

if expand:
original_shape = tuple(normal_dist.shape)
old_size = original_shape[: len(original_shape) - op.ndim_supp]
new_size = tuple(new_size) + old_size

return ZeroSumNormal.rv_op(
sigma=sigma, zerosum_axes=op.ndim_supp, support_shape=support_shape, size=new_size
)


@_moment.register(ZeroSumNormalRV)
def zerosumnormal_moment(op, rv, *rv_inputs):
return at.zeros_like(rv)


@_default_transform.register(ZeroSumNormalRV)
def zerosum_default_transform(op, rv):
zerosum_axes = tuple(np.arange(-op.ndim_supp, 0))
return ZeroSumTransform(zerosum_axes)


@_logprob.register(ZeroSumNormalRV)
def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs):
(value,) = values
shape = value.shape
zerosum_axes = op.ndim_supp

_deg_free_support_shape = at.inc_subtensor(shape[-zerosum_axes:], -1)
_full_size = at.prod(shape)
_degrees_of_freedom = at.prod(_deg_free_support_shape)

zerosums = [
at.all(at.isclose(at.mean(value, axis=-axis - 1), 0, atol=1e-9))
for axis in range(zerosum_axes)
]

out = at.sum(
pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size,
axis=tuple(np.arange(-zerosum_axes, 0)),
)

return check_parameters(out, *zerosums, msg="at.mean(value, axis=zerosum_axes) == 0")
Loading

0 comments on commit 9aeb6b5

Please sign in to comment.