diff --git a/docs/source/api/distributions/multivariate.rst b/docs/source/api/distributions/multivariate.rst index 1156f04a255..ac401b9944a 100644 --- a/docs/source/api/distributions/multivariate.rst +++ b/docs/source/api/distributions/multivariate.rst @@ -8,6 +8,7 @@ Multivariate MvNormal MvStudentT + ZeroSumNormal Dirichlet Multinomial DirichletMultinomial diff --git a/docs/source/api/distributions/transforms.rst b/docs/source/api/distributions/transforms.rst index 904ee19ea5c..434e2065c5d 100644 --- a/docs/source/api/distributions/transforms.rst +++ b/docs/source/api/distributions/transforms.rst @@ -33,6 +33,7 @@ Specific Transform Classes LogExpM1 Ordered SumTo1 + ZeroSumTransform Transform Composition Classes diff --git a/pymc/distributions/__init__.py b/pymc/distributions/__init__.py index 6887321c786..5ada4d67ce1 100644 --- a/pymc/distributions/__init__.py +++ b/pymc/distributions/__init__.py @@ -99,6 +99,7 @@ StickBreakingWeights, Wishart, WishartBartlett, + ZeroSumNormal, ) from pymc.distributions.simulator import Simulator from pymc.distributions.timeseries import ( @@ -116,8 +117,8 @@ "Uniform", "Flat", "HalfFlat", - "TruncatedNormal", "Normal", + "TruncatedNormal", "Beta", "Kumaraswamy", "Exponential", @@ -160,6 +161,7 @@ "Continuous", "Discrete", "MvNormal", + "ZeroSumNormal", "MatrixNormal", "KroneckerNormal", "MvStudentT", diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 288822fad81..e92a479d156 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -18,6 +18,7 @@ import warnings from functools import reduce +from typing import Optional import aesara import aesara.tensor as at @@ -63,15 +64,17 @@ _change_dist_size, broadcast_dist_samples_to, change_dist_size, + get_support_shape, rv_size_is_none, to_tuple, ) -from pymc.distributions.transforms import Interval, _default_transform +from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform from pymc.math import kron_diag, kron_dot from pymc.util import check_dist_not_registered __all__ = [ "MvNormal", + "ZeroSumNormal", "MvStudentT", "Dirichlet", "Multinomial", @@ -2380,3 +2383,205 @@ def logp(value, alpha, K): K > 0, msg="alpha > 0, K > 0", ) + + +class ZeroSumNormalRV(SymbolicRandomVariable): + """ZeroSumNormal random variable""" + + _print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}") + default_output = 0 + + +class ZeroSumNormal(Distribution): + r""" + ZeroSumNormal distribution, i.e Normal distribution where one or + several axes are constrained to sum to zero. + By default, the last axis is constrained to sum to zero. + See `zerosum_axes` kwarg for more details. + + .. math:: + + \begin{align*} + ZSN(\sigma) = N \Big( 0, \sigma^2 (I - \tfrac{1}{n}J) \Big) \\ + \text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\ + n = \text{nbr of zero-sum axes} + \end{align*} + + Parameters + ---------- + sigma : tensor_like of float + Scale parameter (sigma > 0). + It's actually the standard deviation of the underlying, unconstrained Normal distribution. + Defaults to 1 if not specified. + For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint. + zerosum_axes: int, defaults to 1 + Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position. + Defaults to 1, i.e the rightmost axis. + dims: sequence of strings, optional + Dimension names of the distribution. Works the same as for other PyMC distributions. + Necessary if ``shape`` is not passed. + shape: tuple of integers, optional + Shape of the distribution. Works the same as for other PyMC distributions. + Necessary if ``dims`` or ``observed`` is not passed. + + Warnings + -------- + ``sigma`` has to be a scalar, to ensure the zero-sum constraint. + The ability to specifiy a vector of ``sigma`` may be added in future versions. + + ``zerosum_axes`` has to be > 0. If you want the behavior of ``zerosum_axes = 0``, + just use ``pm.Normal``. + + Examples + -------- + Define a `ZeroSumNormal` variable, with `sigma=1` and + `zerosum_axes=1` by default:: + + COORDS = { + "regions": ["a", "b", "c"], + "answers": ["yes", "no", "whatever", "don't understand question"], + } + with pm.Model(coords=COORDS) as m: + # the zero sum axis will be 'answers' + v = pm.ZeroSumNormal("v", dims=("regions", "answers")) + + with pm.Model(coords=COORDS) as m: + # the zero sum axes will be 'answers' and 'regions' + v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=2) + + with pm.Model(coords=COORDS) as m: + # the zero sum axes will be the last two + v = pm.ZeroSumNormal("v", shape=(3, 4, 5), zerosum_axes=2) + """ + rv_type = ZeroSumNormalRV + + def __new__(cls, *args, zerosum_axes=None, support_shape=None, dims=None, **kwargs): + if dims is not None or kwargs.get("observed") is not None: + zerosum_axes = cls.check_zerosum_axes(zerosum_axes) + + support_shape = get_support_shape( + support_shape=support_shape, + shape=None, # Shape will be checked in `cls.dist` + dims=dims, + observed=kwargs.get("observed", None), + ndim_supp=zerosum_axes, + ) + + return super().__new__( + cls, *args, zerosum_axes=zerosum_axes, support_shape=support_shape, dims=dims, **kwargs + ) + + @classmethod + def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs): + zerosum_axes = cls.check_zerosum_axes(zerosum_axes) + + sigma = at.as_tensor_variable(floatX(sigma)) + if sigma.ndim > 0: + raise ValueError("sigma has to be a scalar") + + support_shape = get_support_shape( + support_shape=support_shape, + shape=kwargs.get("shape"), + ndim_supp=zerosum_axes, + ) + + if support_shape is None: + if zerosum_axes > 0: + raise ValueError("You must specify dims, shape or support_shape parameter") + # TODO: edge-case doesn't work for now, because at.stack in get_support_shape fails + # else: + # support_shape = () # because it's just a Normal in that case + support_shape = at.as_tensor_variable(intX(support_shape)) + + assert zerosum_axes == at.get_vector_length( + support_shape + ), "support_shape has to be as long as zerosum_axes" + + return super().dist( + [sigma], zerosum_axes=zerosum_axes, support_shape=support_shape, **kwargs + ) + + @classmethod + def check_zerosum_axes(cls, zerosum_axes: Optional[int]) -> int: + if zerosum_axes is None: + zerosum_axes = 1 + if not isinstance(zerosum_axes, int): + raise TypeError("zerosum_axes has to be an integer") + if not zerosum_axes > 0: + raise ValueError("zerosum_axes has to be > 0") + return zerosum_axes + + @classmethod + def rv_op(cls, sigma, zerosum_axes, support_shape, size=None): + + shape = to_tuple(size) + tuple(support_shape) + normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, shape=shape)) + + if zerosum_axes > normal_dist.ndim: + raise ValueError("Shape of distribution is too small for the number of zerosum axes") + + normal_dist_, sigma_, support_shape_ = ( + normal_dist.type(), + sigma.type(), + support_shape.type(), + ) + + # Zerosum-normaling is achieved by substracting the mean along the given zerosum_axes + zerosum_rv_ = normal_dist_ + for axis in range(zerosum_axes): + zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True) + + return ZeroSumNormalRV( + inputs=[normal_dist_, sigma_, support_shape_], + outputs=[zerosum_rv_, support_shape_], + ndim_supp=zerosum_axes, + )(normal_dist, sigma, support_shape) + + +@_change_dist_size.register(ZeroSumNormalRV) +def change_zerosum_size(op, normal_dist, new_size, expand=False): + + normal_dist, sigma, support_shape = normal_dist.owner.inputs + + if expand: + original_shape = tuple(normal_dist.shape) + old_size = original_shape[: len(original_shape) - op.ndim_supp] + new_size = tuple(new_size) + old_size + + return ZeroSumNormal.rv_op( + sigma=sigma, zerosum_axes=op.ndim_supp, support_shape=support_shape, size=new_size + ) + + +@_moment.register(ZeroSumNormalRV) +def zerosumnormal_moment(op, rv, *rv_inputs): + return at.zeros_like(rv) + + +@_default_transform.register(ZeroSumNormalRV) +def zerosum_default_transform(op, rv): + zerosum_axes = tuple(np.arange(-op.ndim_supp, 0)) + return ZeroSumTransform(zerosum_axes) + + +@_logprob.register(ZeroSumNormalRV) +def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs): + (value,) = values + shape = value.shape + zerosum_axes = op.ndim_supp + + _deg_free_support_shape = at.inc_subtensor(shape[-zerosum_axes:], -1) + _full_size = at.prod(shape) + _degrees_of_freedom = at.prod(_deg_free_support_shape) + + zerosums = [ + at.all(at.isclose(at.mean(value, axis=-axis - 1), 0, atol=1e-9)) + for axis in range(zerosum_axes) + ] + + out = at.sum( + pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size, + axis=tuple(np.arange(-zerosum_axes, 0)), + ) + + return check_parameters(out, *zerosums, msg="at.mean(value, axis=zerosum_axes) == 0") diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index 59c6666249b..2bcc85a89a9 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -20,7 +20,7 @@ import warnings from functools import singledispatch -from typing import Optional, Sequence, Tuple, Union +from typing import Any, Optional, Sequence, Tuple, Union import numpy as np @@ -28,11 +28,15 @@ from aesara import tensor as at from aesara.graph.basic import Variable from aesara.graph.op import Op, compute_test_value +from aesara.raise_op import Assert from aesara.tensor.random.op import RandomVariable from aesara.tensor.shape import SpecifyShape from aesara.tensor.var import TensorVariable from typing_extensions import TypeAlias +from pymc.aesaraf import convert_observed_data +from pymc.model import modelcontext + __all__ = [ "to_tuple", "shapes_broadcasting", @@ -657,3 +661,111 @@ def change_specify_shape_size(op, ss, new_size, expand) -> TensorVariable: # specify_shape has a wrong signature https://github.com/aesara-devs/aesara/issues/1164 return at.specify_shape(new_var, new_shapes) # type: ignore + + +def get_support_shape( + support_shape: Optional[Sequence[Union[int, np.ndarray, TensorVariable]]], + *, + shape: Optional[Shape] = None, + dims: Optional[Dims] = None, + observed: Optional[Any] = None, + support_shape_offset: Sequence[int] = None, + ndim_supp: int = 1, +): + """Extract the support shapes from shape / dims / observed information + + Parameters + ---------- + support_shape: + User-specified support shape for multivariate distribution + shape: + User-specified shape for multivariate distribution + dims: + User-specified dims for multivariate distribution + observed: + User-specified observed data from multivariate distribution + support_shape_offset: + Difference between last shape dimensions and the length of + explicit support shapes in multivariate distribution, defaults to 0. + For timeseries, this is shape[-1] = support_shape[-1] + 1 + ndim_supp: + Number of support dimensions of the given multivariate distribution, defaults to 1 + + Returns + ------- + support_shape + Support shape, if specified directly by user, or inferred from the last dimensions of + shape / dims / observed. When two sources of support shape information are provided, + a symbolic Assert is added to ensure they are consistent. + """ + if ndim_supp < 1: + raise NotImplementedError("ndim_supp must be bigger than 0") + if support_shape_offset is None: + support_shape_offset = [0] * ndim_supp + inferred_support_shape = None + + if shape is not None: + shape = to_tuple(shape) + assert isinstance(shape, tuple) + inferred_support_shape = at.stack( + [shape[i] - support_shape_offset[i] for i in np.arange(-ndim_supp, 0)] + ) + + if inferred_support_shape is None and dims is not None: + dims = convert_dims(dims) + assert isinstance(dims, tuple) + model = modelcontext(None) + inferred_support_shape = at.stack( + [ + model.dim_lengths[dims[i]] - support_shape_offset[i] # type: ignore + for i in np.arange(-ndim_supp, 0) + ] + ) + + if inferred_support_shape is None and observed is not None: + observed = convert_observed_data(observed) + inferred_support_shape = at.stack( + [observed.shape[i] - support_shape_offset[i] for i in np.arange(-ndim_supp, 0)] + ) + + if inferred_support_shape is None: + inferred_support_shape = support_shape + # If there are two sources of information for the support shapes, assert they are consistent: + elif support_shape is not None: + inferred_support_shape = at.stack( + [ + Assert(msg="support_shape does not match respective shape dimension")( + inferred, at.eq(inferred, explicit) + ) + for inferred, explicit in zip(inferred_support_shape, support_shape) + ] + ) + + return inferred_support_shape + + +def get_support_shape_1d( + support_shape: Optional[Union[int, np.ndarray, TensorVariable]], + *, + shape: Optional[Shape] = None, + dims: Optional[Dims] = None, + observed: Optional[Any] = None, + support_shape_offset: int = 0, +): + """Helper function for cases when you just care about one dimension.""" + if support_shape is not None: + support_shape_tuple = (support_shape,) + else: + support_shape_tuple = None + + support_shape_tuple = get_support_shape( + support_shape_tuple, + shape=shape, + dims=dims, + observed=observed, + support_shape_offset=(support_shape_offset,), + ) + if support_shape_tuple is not None: + (support_shape,) = support_shape_tuple + + return support_shape diff --git a/pymc/distributions/timeseries.py b/pymc/distributions/timeseries.py index c99732471b3..376479d5f06 100644 --- a/pymc/distributions/timeseries.py +++ b/pymc/distributions/timeseries.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings -from typing import Any, Optional, Union +from typing import Optional import aesara import aesara.tensor as at @@ -21,11 +21,10 @@ from aeppl.logprob import _logprob from aesara.graph.basic import Node, clone_replace -from aesara.raise_op import Assert from aesara.tensor import TensorVariable from aesara.tensor.random.op import RandomVariable -from pymc.aesaraf import constant_fold, convert_observed_data, floatX, intX +from pymc.aesaraf import constant_fold, floatX, intX from pymc.distributions import distribution, multivariate from pymc.distributions.continuous import Flat, Normal, get_tau_sigma from pymc.distributions.distribution import ( @@ -36,15 +35,12 @@ ) from pymc.distributions.logprob import ignore_logprob, logp from pymc.distributions.shape_utils import ( - Dims, - Shape, _change_dist_size, change_dist_size, - convert_dims, + get_support_shape_1d, to_tuple, ) from pymc.exceptions import NotConstantValueError -from pymc.model import modelcontext from pymc.util import check_dist_not_registered __all__ = [ @@ -58,61 +54,6 @@ ] -def get_steps( - steps: Optional[Union[int, np.ndarray, TensorVariable]], - *, - shape: Optional[Shape] = None, - dims: Optional[Dims] = None, - observed: Optional[Any] = None, - step_shape_offset: int = 0, -): - """Extract number of steps from shape / dims / observed information - - Parameters - ---------- - steps: - User specified steps for timeseries distribution - shape: - User specified shape for timeseries distribution - dims: - User specified dims for timeseries distribution - observed: - User specified observed data from timeseries distribution - step_shape_offset: - Difference between last shape dimension and number of steps in timeseries - distribution, defaults to 0 - - Returns - ------- - steps - Steps, if specified directly by user, or inferred from the last dimension of - shape / dims / observed. When two sources of step information are provided, - a symbolic Assert is added to ensure they are consistent. - """ - inferred_steps = None - if shape is not None: - shape = to_tuple(shape) - inferred_steps = shape[-1] - step_shape_offset - - if inferred_steps is None and dims is not None: - dims = convert_dims(dims) - model = modelcontext(None) - inferred_steps = model.dim_lengths[dims[-1]] - step_shape_offset - - if inferred_steps is None and observed is not None: - observed = convert_observed_data(observed) - inferred_steps = observed.shape[-1] - step_shape_offset - - if inferred_steps is None: - inferred_steps = steps - # If there are two sources of information for the steps, assert they are consistent - elif steps is not None: - inferred_steps = Assert(msg="Steps do not match last shape dimension")( - inferred_steps, at.eq(inferred_steps, steps) - ) - return inferred_steps - - class RandomWalkRV(SymbolicRandomVariable): """RandomWalk Variable""" @@ -129,21 +70,21 @@ class RandomWalk(Distribution): rv_type = RandomWalkRV def __new__(cls, *args, steps=None, **kwargs): - steps = get_steps( - steps=steps, + steps = get_support_shape_1d( + support_shape=steps, shape=None, # Shape will be checked in `cls.dist` dims=kwargs.get("dims", None), observed=kwargs.get("observed", None), - step_shape_offset=1, + support_shape_offset=1, ) return super().__new__(cls, *args, steps=steps, **kwargs) @classmethod def dist(cls, init_dist, innovation_dist, steps=None, **kwargs) -> at.TensorVariable: - steps = get_steps( - steps=steps, + steps = get_support_shape_1d( + support_shape=steps, shape=kwargs.get("shape"), - step_shape_offset=1, + support_shape_offset=1, ) if steps is None: raise ValueError("Must specify steps or shape parameter") @@ -381,12 +322,12 @@ class AR(Distribution): def __new__(cls, name, rho, *args, steps=None, constant=False, ar_order=None, **kwargs): rhos = at.atleast_1d(at.as_tensor_variable(floatX(rho))) ar_order = cls._get_ar_order(rhos=rhos, constant=constant, ar_order=ar_order) - steps = get_steps( - steps=steps, + steps = get_support_shape_1d( + support_shape=steps, shape=None, # Shape will be checked in `cls.dist` dims=kwargs.get("dims", None), observed=kwargs.get("observed", None), - step_shape_offset=ar_order, + support_shape_offset=ar_order, ) return super().__new__( cls, name, rhos, *args, steps=steps, constant=constant, ar_order=ar_order, **kwargs @@ -417,7 +358,9 @@ def dist( init_dist = kwargs.pop("init") ar_order = cls._get_ar_order(rhos=rhos, constant=constant, ar_order=ar_order) - steps = get_steps(steps=steps, shape=kwargs.get("shape", None), step_shape_offset=ar_order) + steps = get_support_shape_1d( + support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=ar_order + ) if steps is None: raise ValueError("Must specify steps or shape parameter") steps = at.as_tensor_variable(intX(steps), ndim=0) @@ -640,18 +583,20 @@ class GARCH11(Distribution): rv_type = GARCH11RV def __new__(cls, *args, steps=None, **kwargs): - steps = get_steps( - steps=steps, + steps = get_support_shape_1d( + support_shape=steps, shape=None, # Shape will be checked in `cls.dist` dims=kwargs.get("dims", None), observed=kwargs.get("observed", None), - step_shape_offset=1, + support_shape_offset=1, ) return super().__new__(cls, *args, steps=steps, **kwargs) @classmethod def dist(cls, omega, alpha_1, beta_1, initial_vol, *, steps=None, **kwargs): - steps = get_steps(steps=steps, shape=kwargs.get("shape", None), step_shape_offset=1) + steps = get_support_shape_1d( + support_shape=steps, shape=kwargs.get("shape", None), support_shape_offset=1 + ) if steps is None: raise ValueError("Must specify steps or shape parameter") steps = at.as_tensor_variable(intX(steps), ndim=0) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 2f65902eddc..ee142b46fe1 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -27,6 +27,10 @@ from aesara.graph import Op from aesara.tensor import TensorVariable +# ignore mypy error because it somehow considers that +# "numpy.core.numeric has no attribute normalize_axis_tuple" +from numpy.core.numeric import normalize_axis_tuple # type: ignore + __all__ = [ "RVTransform", "simplex", @@ -39,6 +43,7 @@ "circular", "CholeskyCovPacked", "Chain", + "ZeroSumTransform", ] @@ -266,6 +271,60 @@ def bounds_fn(*rv_inputs): super().__init__(args_fn=bounds_fn) +class ZeroSumTransform(RVTransform): + """ + Constrains any random samples to sum to zero along the user-provided ``zerosum_axes``. + + Parameters + ---------- + zerosum_axes : list of ints + Must be a list of integers (positive or negative). + """ + + name = "zerosum" + + __props__ = ("zerosum_axes",) + + def __init__(self, zerosum_axes): + self.zerosum_axes = tuple(int(axis) for axis in zerosum_axes) + + def forward(self, value, *rv_inputs): + for axis in self.zerosum_axes: + value = extend_axis_rev(value, axis=axis) + return value + + def backward(self, value, *rv_inputs): + for axis in self.zerosum_axes: + value = extend_axis(value, axis=axis) + return value + + def log_jac_det(self, value, *rv_inputs): + return at.constant(0.0) + + +def extend_axis(array, axis): + n = array.shape[axis] + 1 + sum_vals = array.sum(axis, keepdims=True) + norm = sum_vals / (np.sqrt(n) + n) + fill_val = norm - sum_vals / np.sqrt(n) + + out = at.concatenate([array, fill_val], axis=axis) + return out - norm + + +def extend_axis_rev(array, axis): + normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] + + n = array.shape[normalized_axis] + last = at.take(array, [-1], axis=normalized_axis) + + sum_vals = -last * np.sqrt(n) + norm = sum_vals / (np.sqrt(n) + n) + slice_before = (slice(None, None),) * normalized_axis + + return array[slice_before + (slice(None, -1),)] + norm + + log_exp_m1 = LogExpM1() log_exp_m1.__doc__ = """ Instantiation of :class:`pymc.distributions.transforms.LogExpM1` diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index 2dde22c915f..263b7205959 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -1008,6 +1008,19 @@ def test_mv_normal_moment(self, mu, cov, size, expected): # MvNormal logp is only implemented for up to 2D variables assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3) + @pytest.mark.parametrize( + "shape, zerosum_axes, expected", + [ + ((2, 5), None, np.zeros((2, 5))), + ((2, 5, 6), 2, np.zeros((2, 5, 6))), + ((2, 5, 6), 3, np.zeros((2, 5, 6))), + ], + ) + def test_zerosum_normal_moment(self, shape, zerosum_axes, expected): + with pm.Model() as model: + pm.ZeroSumNormal("x", shape=shape, zerosum_axes=zerosum_axes) + assert_moment_is_expected(model, expected) + @pytest.mark.parametrize( "mu, size, expected", [ @@ -1368,6 +1381,210 @@ def test_issue_3706(self): assert prior_pred["X"].shape == (1, N, 2) +class TestZeroSumNormal: + coords = { + "regions": ["a", "b", "c"], + "answers": ["yes", "no", "whatever", "don't understand question"], + } + + def assert_zerosum_axes(self, random_samples, axes_to_check, check_zerosum_axes=True): + if check_zerosum_axes: + for ax in axes_to_check: + assert np.isclose( + random_samples.mean(axis=ax), 0 + ).all(), f"{ax} is a zerosum_axis but is not summing to 0 across all samples." + else: + for ax in axes_to_check: + assert not np.isclose( + random_samples.mean(axis=ax), 0 + ).all(), f"{ax} is not a zerosum_axis, but is nonetheless summing to 0 across all samples." + + @pytest.mark.parametrize( + "dims, zerosum_axes", + [ + (("regions", "answers"), None), + (("regions", "answers"), 1), + (("regions", "answers"), 2), + ], + ) + def test_zsn_dims(self, dims, zerosum_axes): + with pm.Model(coords=self.coords) as m: + v = pm.ZeroSumNormal("v", dims=dims, zerosum_axes=zerosum_axes) + s = pm.sample(10, chains=1, tune=100) + + # to test forward graph + random_samples = pm.draw(v, draws=10) + + assert s.posterior.v.shape == ( + 1, + 10, + len(self.coords["regions"]), + len(self.coords["answers"]), + ) + + ndim_supp = v.owner.op.ndim_supp + zerosum_axes = np.arange(-ndim_supp, 0) + nonzero_axes = np.arange(v.ndim - ndim_supp) + for samples in [ + s.posterior.v, + random_samples, + ]: + self.assert_zerosum_axes(samples, zerosum_axes) + self.assert_zerosum_axes(samples, nonzero_axes, check_zerosum_axes=False) + + @pytest.mark.parametrize( + "zerosum_axes", + (None, 1, 2), + ) + def test_zsn_shape(self, zerosum_axes): + shape = (len(self.coords["regions"]), len(self.coords["answers"])) + + with pm.Model(coords=self.coords) as m: + v = pm.ZeroSumNormal("v", shape=shape, zerosum_axes=zerosum_axes) + s = pm.sample(10, chains=1, tune=100) + + # to test forward graph + random_samples = pm.draw(v, draws=10) + + assert s.posterior.v.shape == ( + 1, + 10, + len(self.coords["regions"]), + len(self.coords["answers"]), + ) + + ndim_supp = v.owner.op.ndim_supp + zerosum_axes = np.arange(-ndim_supp, 0) + nonzero_axes = np.arange(v.ndim - ndim_supp) + for samples in [ + s.posterior.v, + random_samples, + ]: + self.assert_zerosum_axes(samples, zerosum_axes) + self.assert_zerosum_axes(samples, nonzero_axes, check_zerosum_axes=False) + + @pytest.mark.parametrize( + "error, match, shape, support_shape, zerosum_axes", + [ + (IndexError, "index out of range", (3, 4, 5), None, 4), + (AssertionError, "does not match", (3, 4), (3,), None), # support_shape should be 4 + ( + AssertionError, + "does not match", + (3, 4), + (3, 4), + None, + ), # doesn't work because zerosum_axes = 1 by default + ], + ) + def test_zsn_fail_axis(self, error, match, shape, support_shape, zerosum_axes): + with pytest.raises(error, match=match): + with pm.Model() as m: + _ = pm.ZeroSumNormal( + "v", shape=shape, support_shape=support_shape, zerosum_axes=zerosum_axes + ) + + @pytest.mark.parametrize( + "shape, support_shape", + [ + (None, (3, 4)), + ((3, 4), (3, 4)), + ], + ) + def test_zsn_support_shape(self, shape, support_shape): + with pm.Model() as m: + v = pm.ZeroSumNormal("v", shape=shape, support_shape=support_shape, zerosum_axes=2) + + random_samples = pm.draw(v, draws=10) + zerosum_axes = np.arange(-2, 0) + self.assert_zerosum_axes(random_samples, zerosum_axes) + + @pytest.mark.parametrize( + "zerosum_axes", + [1, 2], + ) + def test_zsn_change_dist_size(self, zerosum_axes): + base_dist = pm.ZeroSumNormal.dist(shape=(4, 9), zerosum_axes=zerosum_axes) + random_samples = pm.draw(base_dist, draws=100) + + zerosum_axes = np.arange(-zerosum_axes, 0) + self.assert_zerosum_axes(random_samples, zerosum_axes) + + new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=False) + try: + assert new_dist.eval().shape == (5, 3, 9) + except AssertionError: + assert new_dist.eval().shape == (5, 3, 4, 9) + random_samples = pm.draw(new_dist, draws=100) + self.assert_zerosum_axes(random_samples, zerosum_axes) + + new_dist = change_dist_size(base_dist, new_size=(5, 3), expand=True) + assert new_dist.eval().shape == (5, 3, 4, 9) + random_samples = pm.draw(new_dist, draws=100) + self.assert_zerosum_axes(random_samples, zerosum_axes) + + @pytest.mark.parametrize( + "sigma, n", + [ + (5, 3), + (2, 6), + ], + ) + def test_zsn_variance(self, sigma, n): + + dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=(100_000, n)) + random_samples = pm.draw(dist) + + empirical_var = random_samples.var(axis=0) + theoretical_var = sigma**2 * (n - 1) / n + + np.testing.assert_allclose(empirical_var, theoretical_var, atol=0.4) + + @pytest.mark.parametrize( + "sigma, shape, zerosum_axes, mvn_axes", + [ + (5, 3, None, [-1]), + (2, 6, None, [-1]), + (5, (7, 3), None, [-1]), + (5, (2, 7, 3), 2, [1, 2]), + ], + ) + def test_zsn_logp(self, sigma, shape, zerosum_axes, mvn_axes): + def logp_norm(value, sigma, axes): + """ + Special case of the MvNormal, that's equivalent to the ZSN. + Only to test the ZSN logp + """ + axes = [ax if ax >= 0 else value.ndim + ax for ax in axes] + if len(set(axes)) < len(axes): + raise ValueError("Must specify unique zero sum axes") + other_axes = [ax for ax in range(value.ndim) if ax not in axes] + new_order = other_axes + axes + reshaped_value = np.reshape( + np.transpose(value, new_order), [value.shape[ax] for ax in other_axes] + [-1] + ) + + degrees_of_freedom = np.prod([value.shape[ax] - 1 for ax in axes]) + full_size = np.prod([value.shape[ax] for ax in axes]) + + psdet = (0.5 * np.log(2 * np.pi) + np.log(sigma)) * degrees_of_freedom / full_size + exp = 0.5 * (reshaped_value / sigma) ** 2 + inds = np.ones_like(value, dtype="bool") + for ax in axes: + inds = np.logical_and(inds, np.abs(np.mean(value, axis=ax, keepdims=True)) < 1e-9) + inds = np.reshape( + np.transpose(inds, new_order), [value.shape[ax] for ax in other_axes] + [-1] + )[..., 0] + + return np.where(inds, np.sum(-psdet - exp, axis=-1), -np.inf) + + zsn_dist = pm.ZeroSumNormal.dist(sigma=sigma, shape=shape, zerosum_axes=zerosum_axes) + zsn_logp = pm.logp(zsn_dist, value=np.zeros(shape)).eval() + mvn_logp = logp_norm(value=np.zeros(shape), sigma=sigma, axes=mvn_axes) + + np.testing.assert_allclose(zsn_logp, mvn_logp) + + class TestMvStudentTCov(BaseTestDistributionRandom): def mvstudentt_rng_fn(self, size, nu, mu, cov, rng): mv_samples = rng.multivariate_normal(np.zeros_like(mu), cov, size=size) diff --git a/pymc/tests/distributions/test_shape_utils.py b/pymc/tests/distributions/test_shape_utils.py index 623a3d306d0..3a9f4bb8dee 100644 --- a/pymc/tests/distributions/test_shape_utils.py +++ b/pymc/tests/distributions/test_shape_utils.py @@ -18,9 +18,10 @@ import numpy as np import pytest -from aesara import Mode from aesara import tensor as at +from aesara.compile.mode import Mode from aesara.graph import Constant, ancestors +from aesara.tensor import TensorVariable from aesara.tensor.random import normal from aesara.tensor.shape import SpecifyShape @@ -36,10 +37,13 @@ convert_shape, convert_size, get_broadcastable_dist_samples, + get_support_shape, + get_support_shape_1d, rv_size_is_none, shapes_broadcasting, to_tuple, ) +from pymc.model import Model test_shapes = [ (tuple(), (1,), (4,), (5, 4)), @@ -599,3 +603,141 @@ def test_change_specify_shape_size_multivariate(): new_x.eval({batch: 5, supp: 3}).shape == (10, 5, 5, 3) with pytest.raises(AssertionError, match=re.escape("expected (None, None, 5, 3)")): new_x.eval({batch: 6, supp: 3}).shape == (10, 5, 5, 3) + + +@pytest.mark.parametrize( + "support_shape, shape, support_shape_offset, expected_support_shape, consistent", + [ + (10, None, 0, 10, True), + (10, None, 1, 10, True), + (None, (10,), 0, 10, True), + (None, (10,), 1, 9, True), + (None, (10, 5), 0, 5, True), + (None, None, 0, None, True), + (10, (10,), 0, 10, True), + (10, (11,), 1, 10, True), + (10, (5, 5), 0, 5, False), + (10, (5, 10), 1, 9, False), + ], +) +@pytest.mark.parametrize("info_source", ("shape", "dims", "observed")) +def test_get_support_shape_1d( + info_source, support_shape, shape, support_shape_offset, expected_support_shape, consistent +): + if info_source == "shape": + inferred_support_shape = get_support_shape_1d( + support_shape=support_shape, shape=shape, support_shape_offset=support_shape_offset + ) + + elif info_source == "dims": + if shape is None: + dims = None + coords = {} + else: + dims = tuple(str(i) for i, _ in enumerate(shape)) + coords = {str(i): range(shape) for i, shape in enumerate(shape)} + with Model(coords=coords): + inferred_support_shape = get_support_shape_1d( + support_shape=support_shape, dims=dims, support_shape_offset=support_shape_offset + ) + + elif info_source == "observed": + if shape is None: + observed = None + else: + observed = np.zeros(shape) + inferred_support_shape = get_support_shape_1d( + support_shape=support_shape, + observed=observed, + support_shape_offset=support_shape_offset, + ) + + if not isinstance(inferred_support_shape, TensorVariable): + assert inferred_support_shape == expected_support_shape + else: + if consistent: + assert inferred_support_shape.eval() == expected_support_shape + else: + # check that inferred steps is still correct by ignoring the assert + f = aesara.function( + [], inferred_support_shape, mode=Mode().including("local_remove_all_assert") + ) + assert f() == expected_support_shape + with pytest.raises(AssertionError, match="support_shape does not match"): + inferred_support_shape.eval() + + +@pytest.mark.parametrize( + "support_shape, shape, support_shape_offset, expected_support_shape, ndim_supp, consistent", + [ + ((10, 5), None, (0,), (10, 5), 1, True), + ((10, 5), None, (1, 1), (10, 5), 1, True), + (None, (10, 5), (0,), 5, 1, True), + (None, (10, 5), (1,), 4, 1, True), + (None, (10, 5, 2), (0,), 2, 1, True), + (None, None, None, None, 1, True), + ((10, 5), (10, 5), None, (10, 5), 2, True), + ((10, 5), (11, 10, 5), None, (10, 5), 2, True), + (None, (11, 10, 5), (0, 1, 0), (11, 9, 5), 3, True), + ((10, 5), (10, 5, 5), (0,), (5,), 1, False), + ((10, 5), (10, 5), (1, 1), (9, 4), 2, False), + ], +) +@pytest.mark.parametrize("info_source", ("shape", "dims", "observed")) +def test_get_support_shape( + info_source, + support_shape, + shape, + support_shape_offset, + expected_support_shape, + ndim_supp, + consistent, +): + if info_source == "shape": + inferred_support_shape = get_support_shape( + support_shape=support_shape, + shape=shape, + support_shape_offset=support_shape_offset, + ndim_supp=ndim_supp, + ) + + elif info_source == "dims": + if shape is None: + dims = None + coords = {} + else: + dims = tuple(str(i) for i, _ in enumerate(shape)) + coords = {str(i): range(shape) for i, shape in enumerate(shape)} + with Model(coords=coords): + inferred_support_shape = get_support_shape( + support_shape=support_shape, + dims=dims, + support_shape_offset=support_shape_offset, + ndim_supp=ndim_supp, + ) + + elif info_source == "observed": + if shape is None: + observed = None + else: + observed = np.zeros(shape) + inferred_support_shape = get_support_shape( + support_shape=support_shape, + observed=observed, + support_shape_offset=support_shape_offset, + ndim_supp=ndim_supp, + ) + + if not isinstance(inferred_support_shape, TensorVariable): + assert inferred_support_shape == expected_support_shape + else: + if consistent: + assert (inferred_support_shape.eval() == expected_support_shape).all() + else: + # check that inferred support shape is still correct by ignoring the assert + f = aesara.function( + [], inferred_support_shape, mode=Mode().including("local_remove_all_assert") + ) + assert (f() == expected_support_shape).all() + with pytest.raises(AssertionError, match="support_shape does not match"): + inferred_support_shape.eval() diff --git a/pymc/tests/distributions/test_timeseries.py b/pymc/tests/distributions/test_timeseries.py index 9832515e3b9..f7f6a7227fe 100644 --- a/pymc/tests/distributions/test_timeseries.py +++ b/pymc/tests/distributions/test_timeseries.py @@ -19,8 +19,6 @@ import pytest import scipy.stats as st -from aesara.tensor import TensorVariable - import pymc as pm from pymc.aesaraf import floatX @@ -28,14 +26,8 @@ from pymc.distributions.discrete import DiracDelta from pymc.distributions.logprob import logp from pymc.distributions.multivariate import Dirichlet -from pymc.distributions.shape_utils import change_dist_size -from pymc.distributions.timeseries import ( - AR, - GARCH11, - EulerMaruyama, - GaussianRandomWalk, - get_steps, -) +from pymc.distributions.shape_utils import change_dist_size, to_tuple +from pymc.distributions.timeseries import AR, GARCH11, EulerMaruyama, GaussianRandomWalk from pymc.model import Model from pymc.sampling import draw, sample, sample_posterior_predictive from pymc.tests.distributions.util import ( @@ -48,56 +40,6 @@ from pymc.tests.helpers import SeededTest, select_by_precision -@pytest.mark.parametrize( - "steps, shape, step_shape_offset, expected_steps, consistent", - [ - (10, None, 0, 10, True), - (10, None, 1, 10, True), - (None, (10,), 0, 10, True), - (None, (10,), 1, 9, True), - (None, (10, 5), 0, 5, True), - (None, None, 0, None, True), - (10, (10,), 0, 10, True), - (10, (11,), 1, 10, True), - (10, (5, 5), 0, 5, False), - (10, (5, 10), 1, 9, False), - ], -) -@pytest.mark.parametrize("info_source", ("shape", "dims", "observed")) -def test_get_steps(info_source, steps, shape, step_shape_offset, expected_steps, consistent): - if info_source == "shape": - inferred_steps = get_steps(steps=steps, shape=shape, step_shape_offset=step_shape_offset) - - elif info_source == "dims": - if shape is None: - dims = None - coords = {} - else: - dims = tuple(str(i) for i, shape in enumerate(shape)) - coords = {str(i): range(shape) for i, shape in enumerate(shape)} - with Model(coords=coords): - inferred_steps = get_steps(steps=steps, dims=dims, step_shape_offset=step_shape_offset) - - elif info_source == "observed": - if shape is None: - observed = None - else: - observed = np.zeros(shape) - inferred_steps = get_steps( - steps=steps, observed=observed, step_shape_offset=step_shape_offset - ) - - if not isinstance(inferred_steps, TensorVariable): - assert inferred_steps == expected_steps - else: - if consistent: - assert inferred_steps.eval() == expected_steps - else: - assert inferred_steps.owner.inputs[0].eval() == expected_steps - with pytest.raises(AssertionError, match="Steps do not match"): - inferred_steps.eval() - - class TestGaussianRandomWalk: def test_logp(self): def ref_logp(value, mu, sigma): @@ -180,7 +122,9 @@ def test_missing_steps(self): GaussianRandomWalk.dist(shape=None, init_dist=Normal.dist(0, 100)) def test_inconsistent_steps_and_shape(self): - with pytest.raises(AssertionError, match="Steps do not match last shape dimension"): + with pytest.raises( + AssertionError, match="support_shape does not match respective shape dimension" + ): x = GaussianRandomWalk.dist(steps=12, shape=45, init_dist=Normal.dist(0, 100)) def test_inferred_steps_from_dims(self):