From 92a58668b10bccf9ccafd5b7c3997660f0f483b2 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 13 Jan 2022 13:54:34 +0100 Subject: [PATCH 1/3] Add docstrings to transforms.simplex --- pymc/distributions/transforms.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index ff287635e18..d02fa210feb 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -165,6 +165,9 @@ def log_jac_det(self, value, *inputs): simplex = Simplex() +simplex.__doc__ = """ +Instantiation of :class:`aeppl.transforms.Simplex` +for use in the ``transform`` argument of a random variable.""" logodds = LogOddsTransform() logodds.__doc__ = """ From 9e4fcdb17b7bd200ba80d1f23dd738a8d5f0478c Mon Sep 17 00:00:00 2001 From: Ricardo Date: Thu, 13 Jan 2022 13:55:02 +0100 Subject: [PATCH 2/3] Create helper wrapper around Aeppl IntervalTransform --- docs/source/api/distributions/transforms.rst | 20 +++-- pymc/distributions/continuous.py | 4 +- pymc/distributions/multivariate.py | 4 +- pymc/distributions/transforms.py | 88 ++++++++++++++++++-- pymc/tests/test_distributions.py | 2 +- pymc/tests/test_sampling.py | 6 +- pymc/tests/test_transforms.py | 31 +++---- 7 files changed, 119 insertions(+), 36 deletions(-) diff --git a/docs/source/api/distributions/transforms.rst b/docs/source/api/distributions/transforms.rst index 8a08baca4fd..ffccd979889 100644 --- a/docs/source/api/distributions/transforms.rst +++ b/docs/source/api/distributions/transforms.rst @@ -15,21 +15,12 @@ Transform instances are the entities that should be used in the simplex logodds - interval log_exp_m1 ordered log sum_to_1 circular -Transform Composition Classes -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autosummary:: - :toctree: generated - - Chain - CholeskyCovPacked Specific Transform Classes ~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -37,6 +28,17 @@ Specific Transform Classes .. autosummary:: :toctree: generated + CholeskyCovPacked + Interval LogExpM1 Ordered SumTo1 + + +Transform Composition Classes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: generated + + Chain diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 03c4a5a2e36..567db3ef5e0 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -195,7 +195,7 @@ def transform_params(*args): return lower, upper - return transforms.interval(transform_params) + return transforms.Interval(bounds_fn=transform_params) def assert_negative_support(var, label, distname, value=-1e-6): @@ -3796,7 +3796,7 @@ def transform_params(*params): _, _, _, x_points, _, _ = params return floatX(x_points[0]), floatX(x_points[-1]) - kwargs["transform"] = transforms.interval(transform_params) + kwargs["transform"] = transforms.Interval(bounds_fn=transform_params) return super().__new__(cls, *args, **kwargs) @classmethod diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index d8d56453378..9ee2b083b52 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -62,7 +62,7 @@ rv_size_is_none, to_tuple, ) -from pymc.distributions.transforms import interval +from pymc.distributions.transforms import Interval from pymc.math import kron_diag, kron_dot from pymc.util import UNSET, check_dist_not_registered @@ -1554,7 +1554,7 @@ class LKJCorr(BoundedContinuous): def __new__(cls, *args, **kwargs): transform = kwargs.get("transform", UNSET) if transform is UNSET: - kwargs["transform"] = interval(lambda *args: (floatX(-1.0), floatX(1.0))) + kwargs["transform"] = Interval(floatX(-1.0), floatX(1.0)) return super().__new__(cls, *args, **kwargs) @classmethod diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index d02fa210feb..48f943fdffb 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -13,6 +13,7 @@ # limitations under the License. import aesara.tensor as at +import numpy as np from aeppl.transforms import ( CircularTransform, @@ -27,7 +28,7 @@ "RVTransform", "simplex", "logodds", - "interval", + "Interval", "log_exp_m1", "ordered", "log", @@ -174,10 +175,87 @@ def log_jac_det(self, value, *inputs): Instantiation of :class:`aeppl.transforms.LogOddsTransform` for use in the ``transform`` argument of a random variable.""" -interval = IntervalTransform -interval.__doc__ = """ -Instantiation of :class:`aeppl.transforms.IntervalTransform` -for use in the ``transform`` argument of a random variable.""" + +class Interval(IntervalTransform): + """Wrapper around :class:`aeppl.transforms.IntervalTransform` for use in the + ``transform`` argument of a random variable. + + Parameters + ---------- + lower : int, float, or None + Lower bound of the interval transform. Must be a constant value. If ``None``, the + interval is not bounded below. + upper : int, float or None + Upper bound of the interval transfrom. Must be a finite value. If ``None``, the + interval is not bounded above. + bounds_fn : callable + Alternative to lower and upper. Must return a tuple of lower and upper bounds + as a symbolic function of the respective distribution inputs. If lower or + upper is ``None``, the interval is unbounded on that edge. + + .. warning:: Expressions returned by `bounds_fn` should depend only on the + distribution inputs or other constants. Expressions that depend on other + symbolic variables, including nonlocal variables defined in the model + context will likely break sampling. + + + Examples + -------- + .. code-block:: python + + # Create an interval transform between -1 and +1 + with pm.Model(): + interval = pm.distributions.transforms.Interval(lower=-1, upper=1) + x = pm.Normal("x", transform=interval) + + .. code-block:: python + + # Create an interval transform between -1 and +1 using a callable + def get_bounds(rng, size, dtype, loc, scale): + return 0, None + + with pm.Model(): + interval = pm.distributions.transforms.Interval(bouns_fn=get_bounds) + x = pm.Normal("x", transform=interval) + + .. code-block:: python + + # Create a lower bounded interval transform based on a distribution parameter + def get_bounds(rng, size, dtype, loc, scale): + return loc, None + + interval = pm.distributions.transforms.Interval(bounds_fn=get_bounds) + + with pm.Model(): + loc = pm.Normal("loc") + x = pm.Normal("x", mu=loc, sigma=2, transform=interval) + """ + + def __init__(self, lower=None, upper=None, *, bounds_fn=None): + if bounds_fn is None: + try: + bounds = tuple( + None if bound is None else at.constant(bound, ndim=0).data + for bound in (lower, upper) + ) + except (ValueError, TypeError): + raise ValueError( + "Interval bounds must be constant values. If you need expressions that " + "depend on symbolic variables use `args_fn`" + ) + + lower, upper = ( + None if (bound is None or np.isinf(bound)) else bound for bound in bounds + ) + + if lower is None and upper is None: + raise ValueError("Lower and upper interval bounds cannot both be None") + + def bounds_fn(*rv_inputs): + return lower, upper + + super().__init__(args_fn=bounds_fn) + log_exp_m1 = LogExpM1() log_exp_m1.__doc__ = """ diff --git a/pymc/tests/test_distributions.py b/pymc/tests/test_distributions.py index 0e4e65baf40..94ab4e7fb7e 100644 --- a/pymc/tests/test_distributions.py +++ b/pymc/tests/test_distributions.py @@ -2710,7 +2710,7 @@ def test_arguments_checks(self): with pm.Model() as m: x = pm.Poisson.dist(0.5) with pytest.raises(ValueError, match=msg): - pm.Bound("bound", x, transform=pm.transforms.interval) + pm.Bound("bound", x, transform=pm.distributions.transforms.log) msg = "Given dims do not exist in model coordinates." with pm.Model() as m: diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index 9090189a7b7..78ee3ef5337 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -327,11 +327,13 @@ def test_deterministic_of_unobserved(self): np.testing.assert_allclose(idata.posterior["y"], idata.posterior["x"] + 100) - def test_transform_with_rv_depenency(self): + def test_transform_with_rv_dependency(self): # Test that untransformed variables that depend on upstream variables are properly handled with pm.Model() as m: x = pm.HalfNormal("x", observed=1) - transform = pm.transforms.IntervalTransform(lambda *inputs: (inputs[-2], inputs[-1])) + transform = pm.distributions.transforms.Interval( + bounds_fn=lambda *inputs: (inputs[-2], inputs[-1]) + ) y = pm.Uniform("y", lower=0, upper=x, transform=transform) trace = pm.sample(tune=10, draws=50, return_inferencedata=False, random_seed=336) diff --git a/pymc/tests/test_transforms.py b/pymc/tests/test_transforms.py index 6931b3d668c..04e1a588b24 100644 --- a/pymc/tests/test_transforms.py +++ b/pymc/tests/test_transforms.py @@ -177,10 +177,7 @@ def test_logodds(): def test_lowerbound(): - def transform_params(*inputs): - return 0.0, None - - trans = tr.interval(transform_params) + trans = tr.Interval(0.0, None) check_transform(trans, Rplusbig) check_jacobian_det(trans, Rplusbig, elemwise=True) @@ -191,10 +188,7 @@ def transform_params(*inputs): def test_upperbound(): - def transform_params(*inputs): - return None, 0.0 - - trans = tr.interval(transform_params) + trans = tr.Interval(None, 0.0) check_transform(trans, Rminusbig) check_jacobian_det(trans, Rminusbig, elemwise=True) @@ -208,10 +202,7 @@ def test_interval(): for a, b in [(-4, 5.5), (0.1, 0.7), (-10, 4.3)]: domain = Unit * np.float64(b - a) + np.float64(a) - def transform_params(z=a, y=b): - return z, y - - trans = tr.interval(transform_params) + trans = tr.Interval(a, b) check_transform(trans, domain) check_jacobian_det(trans, domain, elemwise=True) @@ -375,7 +366,7 @@ def transform_params(*inputs): upper = at.as_tensor_variable(upper) if upper is not None else None return lower, upper - interval = tr.interval(transform_params) + interval = tr.Interval(bounds_fn=transform_params) model = self.build_model( pm.Uniform, {"lower": lower, "upper": upper}, size=size, transform=interval ) @@ -396,7 +387,7 @@ def transform_params(*inputs): upper = at.as_tensor_variable(upper) if upper is not None else None return lower, upper - interval = tr.interval(transform_params) + interval = tr.Interval(bounds_fn=transform_params) model = self.build_model( pm.Triangular, {"lower": lower, "c": c, "upper": upper}, size=size, transform=interval ) @@ -491,7 +482,7 @@ def transform_params(*inputs): upper = at.as_tensor_variable(upper) if upper is not None else None return lower, upper - interval = tr.interval(transform_params) + interval = tr.Interval(bounds_fn=transform_params) initval = np.sort(np.abs(np.random.rand(*size))) model = self.build_model( @@ -556,3 +547,13 @@ def test_triangular_transform(): transform = x.tag.value_var.tag.transform assert np.isclose(transform.backward(-np.inf, *x.owner.inputs).eval(), 0) assert np.isclose(transform.backward(np.inf, *x.owner.inputs).eval(), 2) + + +def test_interval_transform_raises(): + with pytest.raises(ValueError, match="Lower and upper interval bounds cannot both be None"): + tr.Interval(None, None) + + with pytest.raises(ValueError, match="Interval bounds must be constant values"): + tr.Interval(at.constant(5) + 1, None) + + assert tr.Interval(at.constant(5), None) From d0d592963cafa63552fd5c99e0e487d245b43317 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Fri, 18 Mar 2022 13:40:09 +0100 Subject: [PATCH 3/3] Do not make transforms module accessible at root level --- RELEASE-NOTES.md | 1 + docs/source/api/distributions/transforms.rst | 2 +- pymc/__init__.py | 1 - pymc/distributions/transforms.py | 6 +++--- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 3548518785c..09326a4fe47 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -38,6 +38,7 @@ All of the above apply to: - `pm.sample_prior_predictive`, `pm.sample_posterior_predictive` and `pm.sample_posterior_predictive_w` now return an `InferenceData` object by default, instead of a dictionary (see [#5073](https://github.com/pymc-devs/pymc/pull/5073)). - `pm.sample_prior_predictive` no longer returns transformed variable values by default. Pass them by name in `var_names` if you want to obtain these draws (see [4769](https://github.com/pymc-devs/pymc/pull/4769)). - `pm.sample(trace=...)` no longer accepts `MultiTrace` or `len(.) > 0` traces ([see 5019#](https://github.com/pymc-devs/pymc/pull/5019)). +- `transforms` module is no longer accessible ta the root level. It is accessible at `pymc.distributions.transforms` (see[#5347](https://github.com/pymc-devs/pymc/pull/5347)). - `logpt`, `logpt_sum`, `logp_elemwiset` and `nojac` variations were removed. Use `Model.logpt(jacobian=True/False, sum=True/False)` instead. - `dlogp_nojact` and `d2logp_nojact` were removed. Use `Model.dlogpt` and `d2logpt` with `jacobian=False` instead. - `logp`, `dlogp`, and `d2logp` and `nojac` variations were removed. Use `Model.compile_logp`, `compile_dlgop` and `compile_d2logp` with `jacobian` keyword instead. diff --git a/docs/source/api/distributions/transforms.rst b/docs/source/api/distributions/transforms.rst index ffccd979889..904ee19ea5c 100644 --- a/docs/source/api/distributions/transforms.rst +++ b/docs/source/api/distributions/transforms.rst @@ -2,7 +2,7 @@ Transformations *************** -.. currentmodule:: pymc.transforms +.. currentmodule:: pymc.distributions.transforms Transform Instances ~~~~~~~~~~~~~~~~~~~ diff --git a/pymc/__init__.py b/pymc/__init__.py index 9c4419c623c..0982b3edc6e 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -55,7 +55,6 @@ def __set_compiler_flags(): from pymc.blocking import * from pymc.data import * from pymc.distributions import * -from pymc.distributions import transforms from pymc.exceptions import * from pymc.func_utils import find_constrained_prior from pymc.math import ( diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 48f943fdffb..7a5d9750a8b 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -259,12 +259,12 @@ def bounds_fn(*rv_inputs): log_exp_m1 = LogExpM1() log_exp_m1.__doc__ = """ -Instantiation of :class:`pymc.transforms.LogExpM1` +Instantiation of :class:`pymc.distributions.transforms.LogExpM1` for use in the ``transform`` argument of a random variable.""" ordered = Ordered() ordered.__doc__ = """ -Instantiation of :class:`pymc.transforms.Ordered` +Instantiation of :class:`pymc.distributions.transforms.Ordered` for use in the ``transform`` argument of a random variable.""" log = LogTransform() @@ -274,7 +274,7 @@ def bounds_fn(*rv_inputs): sum_to_1 = SumTo1() sum_to_1.__doc__ = """ -Instantiation of :class:`pymc.transforms.SumTo1` +Instantiation of :class:`pymc.distributions.transforms.SumTo1` for use in the ``transform`` argument of a random variable.""" circular = CircularTransform()