From 7639c7f122104f54814ad96533fcbf731ac5008f Mon Sep 17 00:00:00 2001 From: Ricardo Date: Wed, 10 Nov 2021 12:10:53 +0100 Subject: [PATCH] Add Censored distributions --- pymc/distributions/__init__.py | 2 + pymc/distributions/censored.py | 68 +++++++++ pymc/distributions/distribution.py | 228 ++++++++++++++++++++++++++++- pymc/tests/test_distributions.py | 41 ++++++ 4 files changed, 338 insertions(+), 1 deletion(-) create mode 100644 pymc/distributions/censored.py diff --git a/pymc/distributions/__init__.py b/pymc/distributions/__init__.py index bb91ea23c2a..ce0fe041d0f 100644 --- a/pymc/distributions/__init__.py +++ b/pymc/distributions/__init__.py @@ -21,6 +21,7 @@ ) from pymc.distributions.bound import Bound +from pymc.distributions.censored import Censored from pymc.distributions.continuous import ( AsymmetricLaplace, Beta, @@ -194,4 +195,5 @@ "logp_transform", "logcdf", "logpt_sum", + "Censored", ] diff --git a/pymc/distributions/censored.py b/pymc/distributions/censored.py new file mode 100644 index 00000000000..43577a4c497 --- /dev/null +++ b/pymc/distributions/censored.py @@ -0,0 +1,68 @@ +# Copyright 2020 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from aesara.tensor import clip + +from pymc.distributions.distribution import DerivedDistribution + + +class Censored(DerivedDistribution): + @classmethod + def dist(cls, distribution, lower, upper, **kwargs): + # TODO: Assert distribution is a RandomVariable + if distribution.owner.op.ndim_supp > 0: + raise NotImplemented( + "Censoring of multivariate distributions has not been implemented yet" + ) + return super().dist([distribution, lower, upper], **kwargs) + + @classmethod + def rv_op(cls, dist, lower=None, upper=None, size=None, rngs=None): + if lower is None: + lower = -np.inf + if upper is None: + upper = np.inf + + rv_out = clip(dist, lower, upper) + if size is not None: + rv_out = cls.change_size(rv_out, size) + if rngs is not None: + rv_out = cls.change_rngs(rv_out, rngs) + return rv_out + + @classmethod + def ndim_supp(cls, *dist_params): + return 0 + + @classmethod + def change_size(cls, rv, new_size): + dist, lower, upper = rv.owner.inputs + dist_node = dist.owner + rng, old_size, dtype, *dist_params = dist_node.inputs + new_dist = dist_node.op.make_node(rng, new_size, dtype, *dist_params).default_output() + return cls.rv_op(new_dist, lower, upper) + + @classmethod + def change_rngs(cls, rv, new_rngs): + (new_rng,) = new_rngs + dist, lower, upper = rv.owner.inputs + dist_node = dist.owner + olg_rng, size, dtype, *dist_params = dist_node.inputs + new_dist = dist_node.op.make_node(new_rng, size, dtype, *dist_params).default_output() + return cls.rv_op(new_dist, lower, upper) + + @classmethod + def graph_rvs(cls, dist, *bounds): + return (dist,) diff --git a/pymc/distributions/distribution.py b/pymc/distributions/distribution.py index 37b5483297b..1463358fe91 100644 --- a/pymc/distributions/distribution.py +++ b/pymc/distributions/distribution.py @@ -19,7 +19,7 @@ from abc import ABCMeta from functools import singledispatch -from typing import Callable, Optional, Sequence +from typing import Callable, Iterable, Optional, Sequence import aesara @@ -353,6 +353,232 @@ def dist( return rv_out +class DerivedDistribution: + def __new__( + cls, + name: str, + *args, + rngs: Optional[Iterable] = None, + dims: Optional[Dims] = None, + initval=None, + observed=None, + total_size=None, + transform=UNSET, + **kwargs, + ) -> TensorVariable: + """Adds a TensorVariable corresponding to a PyMC derived distribution to the current model. + + Note that all remaining kwargs must be compatible with ``.dist()`` + + Parameters + ---------- + cls : type + A PyMC distribution. + name : str + Name for the new model variable. + rngs : optional + Random number generator to use with the RandomVariable. + dims : tuple, optional + A tuple of dimension names known to the model. + initval : optional + Numeric or symbolic untransformed initial value of matching shape, + or one of the following initial value strategies: "moment", "prior". + Depending on the sampler's settings, a random jitter may be added to numeric, symbolic + or moment-based initial values in the transformed space. + observed : optional + Observed data to be passed when registering the random variable in the model. + See ``Model.register_rv``. + total_size : float, optional + See ``Model.register_rv``. + transform : optional + See ``Model.register_rv``. + **kwargs + Keyword arguments that will be forwarded to ``.dist()``. + Most prominently: ``shape`` and ``size`` + + Returns + ------- + var : TensorVariable + The created variable, registered in the Model. + """ + + try: + from pymc.model import Model + + model = Model.get_context() + except TypeError: + raise TypeError( + "No model on context stack, which is needed to " + "instantiate distributions. Add variable inside " + "a 'with model:' block, or use the '.dist' syntax " + "for a standalone distribution." + ) + + if "testval" in kwargs: + initval = kwargs.pop("testval") + warnings.warn( + "The `testval` argument is deprecated; use `initval`.", + FutureWarning, + stacklevel=2, + ) + + if not isinstance(name, string_types): + raise TypeError(f"Name needs to be a string but got: {name}") + + if dims is not None and "shape" in kwargs: + raise ValueError( + f"Passing both `dims` ({dims}) and `shape` ({kwargs['shape']}) is not supported!" + ) + if dims is not None and "size" in kwargs: + raise ValueError( + f"Passing both `dims` ({dims}) and `size` ({kwargs['size']}) is not supported!" + ) + dims = convert_dims(dims) + + if rngs is None: + rngs = [model.next_rng() for _ in cls.graph_rvs(args)] + + # Create the RV without dims information, because that's not something tracked at the Aesara level. + # If necessary we'll later replicate to a different size implied by already known dims. + rv_out = cls.dist(*args, rngs=rngs, **kwargs) + ndim_actual = rv_out.ndim + resize_shape = None + + # # `dims` are only available with this API, because `.dist()` can be used + # # without a modelcontext and dims are not tracked at the Aesara level. + if dims is not None: + ndim_resize, resize_shape, dims = resize_from_dims(dims, ndim_actual, model) + elif observed is not None: + ndim_resize, resize_shape, observed = resize_from_observed(observed, ndim_actual) + + if resize_shape: + # A batch size was specified through `dims`, or implied by `observed`. + rv_out = cls.change_size( + rv=rv_out, + new_size=resize_shape, + ) + + rv_out = model.register_rv( + rv_out, + name, + observed, + total_size, + dims=dims, + transform=transform, + initval=initval, + ) + + # TODO: Refactor this + # add in pretty-printing support + rv_out.str_repr = lambda *args, **kwargs: name + rv_out._repr_latex_ = f"\\text{name}" + # rv_out.str_repr = types.MethodType(str_for_dist, rv_out) + # rv_out._repr_latex_ = types.MethodType( + # functools.partial(str_for_dist, formatting="latex"), rv_out + # ) + + return rv_out + + @classmethod + def dist( + cls, + dist_params, + *, + shape: Optional[Shape] = None, + size: Optional[Size] = None, + **kwargs, + ) -> TensorVariable: + """Creates a TensorVariable corresponding to the `cls` derived distribution. + + Parameters + ---------- + dist_params : array-like + The inputs to the `RandomVariable` `Op`. + shape : int, tuple, Variable, optional + A tuple of sizes for each dimension of the new RV. + + An Ellipsis (...) may be inserted in the last position to short-hand refer to + all the dimensions that the RV would get if no shape/size/dims were passed at all. + size : int, tuple, Variable, optional + For creating the RV like in Aesara/NumPy. + + Returns + ------- + var : TensorVariable + """ + + if "testval" in kwargs: + kwargs.pop("testval") + warnings.warn( + "The `.dist(testval=...)` argument is deprecated and has no effect. " + "Initial values for sampling/optimization can be specified with `initval` in a modelcontext. " + "For using Aesara's test value features, you must assign the `.tag.test_value` yourself.", + FutureWarning, + stacklevel=2, + ) + if "initval" in kwargs: + raise TypeError( + "Unexpected keyword argument `initval`. " + "This argument is not available for the `.dist()` API." + ) + + if "dims" in kwargs: + raise NotImplementedError("The use of a `.dist(dims=...)` API is not supported.") + if shape is not None and size is not None: + raise ValueError( + f"Passing both `shape` ({shape}) and `size` ({size}) is not supported!" + ) + + shape = convert_shape(shape) + size = convert_size(size) + + create_size, ndim_expected, ndim_batch, ndim_supp = find_size( + shape=shape, size=size, ndim_supp=cls.ndim_supp(*dist_params) + ) + # Create the RV with a `size` right away. + # This is not necessarily the final result. + graph = cls.rv_op(*dist_params, size=create_size, **kwargs) + # TODO: Refactor this branch + # graph = maybe_resize( + # graph, + # cls.rv_op, + # dist_params, + # ndim_expected, + # ndim_batch, + # ndim_supp, + # shape, + # size, + # **kwargs, + # ) + + rngs = kwargs.pop("rngs", None) + if rngs is not None: + graph_rvs = cls.graph_rvs(*graph.owner.inputs) + assert len(rngs) == len(graph_rvs) + for rng, rv_out in zip(rngs, graph_rvs): + if ( + rv_out.owner + and isinstance(rv_out.owner.op, RandomVariable) + and isinstance(rng, RandomStateSharedVariable) + and not getattr(rng, "default_update", None) + ): + # This tells `aesara.function` that the shared RNG variable + # is mutable, which--in turn--tells the `FunctionGraph` + # `Supervisor` feature to allow in-place updates on the variable. + # Without it, the `RandomVariable`s could not be optimized to allow + # in-place RNG updates, forcing all sample results from compiled + # functions to be the same on repeated evaluations. + new_rng = rv_out.owner.outputs[0] + rv_out.update = (rng, new_rng) + rng.default_update = new_rng + + # TODO: Create new attr error stating that these are not available for DerivedDistribution + # rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)") + # rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)") + # rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()") + return graph + + @singledispatch def _get_moment(op, rv, size, *rv_inputs) -> TensorVariable: raise NotImplementedError( diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index 68822b264f8..74d63ac791b 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -3349,3 +3349,44 @@ def logp(value, mu): ).shape == to_tuple(size) ) + + +class TestCensored: + @pytest.mark.parametrize("censored", (False, True)) + def test_censored_workflow(self, censored): + # Based on pymc-examples/censored_data + rng = np.random.default_rng(1234) + size = 500 + true_mu = 13.0 + true_sigma = 5.0 + + # Set censoring limits + low = 3.0 + high = 16.0 + + # Draw censored samples + data = rng.normal(true_mu, true_sigma, size) + data[data <= low] = low + data[data >= high] = high + + with pm.Model(rng_seeder=17092021) as m: + mu = pm.Normal("mu", mu=((high - low) / 2) + low, sigma=(high - low) / 2.0) + sigma = pm.HalfNormal("sigma", sigma=(high - low) / 2.0) + observed = pm.Censored( + "observed", + pm.Normal.dist(mu=mu, sigma=sigma), + lower=low if censored else None, + upper=high if censored else None, + observed=data, + ) + + prior_pred = pm.sample_prior_predictive() + # TODO: Log-likelihood in bakend/arviz is failing + posterior = pm.sample(idata_kwargs=dict(log_likelihood=False), tune=500, draws=500) + posterior_pred = pm.sample_posterior_predictive(posterior) + + expected = True if censored else False + assert (9 < prior_pred.prior_predictive.mean() < 10) == expected + assert (13 < posterior.posterior["mu"].mean() < 14) == expected + assert (4.5 < posterior.posterior["sigma"].mean() < 5.5) == expected + assert (12 < posterior_pred.posterior_predictive.mean() < 13) == expected