From e7793af17df1a96f6b1bf3d0b0efb82b17450f4a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 11 Oct 2023 11:07:59 +0100 Subject: [PATCH] Allow Truncation of CustomDists --- pymc/distributions/truncated.py | 188 +++++++++++++++----------- tests/distributions/test_mixture.py | 12 +- tests/distributions/test_truncated.py | 117 ++++++++++++++-- 3 files changed, 224 insertions(+), 93 deletions(-) diff --git a/pymc/distributions/truncated.py b/pymc/distributions/truncated.py index 2665656f2b8..646246cdb1f 100644 --- a/pymc/distributions/truncated.py +++ b/pymc/distributions/truncated.py @@ -17,7 +17,7 @@ import pytensor import pytensor.tensor as pt -from pytensor import scan +from pytensor import config, graph_replace, scan from pytensor.graph import Op from pytensor.graph.basic import Node from pytensor.raise_op import CheckAndRaise @@ -25,10 +25,12 @@ from pytensor.tensor import TensorConstant, TensorVariable from pytensor.tensor.random.basic import NormalRV from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.type import RandomType from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import ( + CustomSymbolicDistRV, Distribution, SymbolicRandomVariable, _moment, @@ -38,8 +40,9 @@ from pymc.distributions.transforms import _default_transform from pymc.exceptions import TruncationError from pymc.logprob.abstract import _logcdf, _logprob -from pymc.logprob.basic import icdf, logcdf +from pymc.logprob.basic import icdf, logcdf, logp from pymc.math import logdiffexp +from pymc.pytensorf import collect_default_updates from pymc.util import check_dist_not_registered @@ -49,7 +52,7 @@ class TruncatedRV(SymbolicRandomVariable): that represents a truncated univariate random variable. """ - default_output = 1 + default_output = 0 def __init__( self, @@ -63,8 +66,13 @@ def __init__( super().__init__(*args, **kwargs) def update(self, node: Node): - """Return the update mapping for the internal RNG.""" - return {node.inputs[-1]: node.outputs[0]} + """Return the update mapping for the internal RNGs. + + TruncatedRVs are created in a way that the rng updats follow the same order as the input RNGs. + """ + rngs = [inp for inp in node.inputs if isinstance(inp.type, RandomType)] + next_rngs = [out for out in node.outputs if isinstance(out.type, RandomType)] + return dict(zip(rngs, next_rngs)) @singledispatch @@ -141,10 +149,14 @@ class Truncated(Distribution): @classmethod def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs): - if not (isinstance(dist, TensorVariable) and isinstance(dist.owner.op, RandomVariable)): + if not ( + isinstance(dist, TensorVariable) + and isinstance(dist.owner.op, (RandomVariable, CustomSymbolicDistRV)) + ): if isinstance(dist.owner.op, SymbolicRandomVariable): raise NotImplementedError( - f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}" + f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}.\n" + f"You can try wrapping the distribution inside a CustomDist instead." ) raise ValueError( f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}" @@ -174,46 +186,54 @@ def rv_op(cls, dist, lower, upper, max_n_steps, size=None): if size is None: size = pt.broadcast_shape(dist, lower, upper) dist = change_dist_size(dist, new_size=size) + rv_inputs = [ + inp + if not isinstance(inp.type, RandomType) + else pytensor.shared(np.random.default_rng()) + for inp in dist.owner.inputs + ] + graph_inputs = [*rv_inputs, lower, upper] # Variables with `_` suffix identify dummy inputs for the OpFromGraph - graph_inputs = [*dist.owner.inputs[1:], lower, upper] - graph_inputs_ = [inp.type() for inp in graph_inputs] + graph_inputs_ = [ + inp.type() if not isinstance(inp.type, RandomType) else inp for inp in graph_inputs + ] *rv_inputs_, lower_, upper_ = graph_inputs_ - # We will use a Shared RNG variable because Scan demands it, even though it - # would not be necessary for the OpFromGraph inverse cdf. - rng = pytensor.shared(np.random.default_rng()) - rv_ = dist.owner.op.make_node(rng, *rv_inputs_).default_output() + rv_ = dist.owner.op.make_node(*rv_inputs_).default_output() # Try to use inverted cdf sampling + # truncated_rv = icdf(rv, draw(uniform(lower, upper))) try: - # For left truncated discrete RVs, we need to include the whole lower bound. - # This may result in draws below the truncation range, if any uniform == 0 - lower_value = lower_ - 1 if dist.owner.op.dtype.startswith("int") else lower_ - cdf_lower_ = pt.exp(logcdf(rv_, lower_value)) - cdf_upper_ = pt.exp(logcdf(rv_, upper_)) - # It's okay to reuse the same rng here, because the rng in rv_ will not be - # used by either the logcdf of icdf functions + logcdf_lower_, logcdf_upper_ = Truncated._create_logcdf_exprs(rv_, rv_, lower_, upper_) + # We use the first RNG from the base RV, so we don't have to introduce a new one + # This is not problematic because the RNG won't be used in the RV logcdf graph + uniform_rng_ = next(inp_ for inp_ in rv_inputs_ if isinstance(inp_.type, RandomType)) uniform_next_rng_, uniform_ = pt.random.uniform( - cdf_lower_, - cdf_upper_, - rng=rng, - size=rv_inputs_[0], + pt.exp(logcdf_lower_), + pt.exp(logcdf_upper_), + rng=uniform_rng_, + size=rv_.shape, ).owner.outputs - truncated_rv_ = icdf(rv_, uniform_) + truncated_rv_ = icdf(rv_, uniform_, warn_rvs=False) return TruncatedRV( base_rv_op=dist.owner.op, - inputs=[*graph_inputs_, rng], - outputs=[uniform_next_rng_, truncated_rv_], + inputs=graph_inputs_, + outputs=[truncated_rv_, uniform_next_rng_], ndim_supp=0, max_n_steps=max_n_steps, - )(*graph_inputs, rng) + )(*graph_inputs) except NotImplementedError: pass # Fallback to rejection sampling - def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs): - next_rng, new_truncated_rv = dist.owner.op.make_node(rng, *rv_inputs).outputs + # truncated_rv = zeros(rv.shape) + # reject_draws = ones(rv.shape, dtype=bool) + # while any(reject_draws): + # truncated_rv[reject_draws] = draw(rv)[reject_draws] + # reject_draws = (truncated_rv < lower) | (truncated_rv > upper) + def loop_fn(truncated_rv, reject_draws, lower, upper, *rv_inputs): + new_truncated_rv = dist.owner.op.make_node(*rv_inputs_).default_output() # Avoid scalar boolean indexing if truncated_rv.type.ndim == 0: truncated_rv = new_truncated_rv @@ -226,7 +246,7 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs): return ( (truncated_rv, reject_draws), - [(rng, next_rng)], + collect_default_updates(new_truncated_rv), until(~pt.any(reject_draws)), ) @@ -236,7 +256,7 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs): pt.zeros_like(rv_), pt.ones_like(rv_, dtype=bool), ], - non_sequences=[lower_, upper_, rng, *rv_inputs_], + non_sequences=[lower_, upper_, *rv_inputs_], n_steps=max_n_steps, strict=True, ) @@ -246,23 +266,49 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs): truncated_rv_ = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")( truncated_rv_, convergence_ ) + # Sort updates of each RNG so that they show in the same order as the input RNGs + + def sort_updates(update): + rng, next_rng = update + return graph_inputs.index(rng) + + next_rngs = [next_rng for rng, next_rng in sorted(updates.items(), key=sort_updates)] return TruncatedRV( base_rv_op=dist.owner.op, - inputs=[*graph_inputs_, rng], - outputs=[tuple(updates.values())[0], truncated_rv_], + inputs=graph_inputs_, + outputs=[truncated_rv_, *next_rngs], ndim_supp=0, max_n_steps=max_n_steps, - )(*graph_inputs, rng) + )(*graph_inputs) + + @staticmethod + def _create_logcdf_exprs( + base_rv: TensorVariable, + value: TensorVariable, + lower: TensorVariable, + upper: TensorVariable, + ) -> tuple[TensorVariable, TensorVariable]: + """Create lower and upper logcdf expressions for base_rv. + + Uses `value` as a template for broadcasting. + """ + # For left truncated discrete RVs, we need to include the whole lower bound. + lower_value = lower - 1 if base_rv.type.dtype.startswith("int") else lower + lower_value = pt.full_like(value, lower_value, dtype=config.floatX) + upper_value = pt.full_like(value, upper, dtype=config.floatX) + lower_logcdf = logcdf(base_rv, lower_value, warn_rvs=False) + upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value}) + return lower_logcdf, upper_logcdf @_change_dist_size.register(TruncatedRV) -def change_truncated_size(op, dist, new_size, expand): - *rv_inputs, lower, upper, rng = dist.owner.inputs - # Recreate the original untruncated RV - untruncated_rv = op.base_rv_op.make_node(rng, *rv_inputs).default_output() +def change_truncated_size(op: TruncatedRV, truncated_rv, new_size, expand): + *rv_inputs, lower, upper = truncated_rv.owner.inputs + untruncated_rv = op.base_rv_op.make_node(*rv_inputs).default_output() + if expand: - new_size = to_tuple(new_size) + tuple(dist.shape) + new_size = to_tuple(new_size) + tuple(truncated_rv.shape) return Truncated.rv_op( untruncated_rv, @@ -274,11 +320,9 @@ def change_truncated_size(op, dist, new_size, expand): @_moment.register(TruncatedRV) -def truncated_moment(op, rv, *inputs): - *rv_inputs, lower, upper, rng = inputs - - # recreate untruncated rv and respective moment - untruncated_rv = op.base_rv_op.make_node(rng, *rv_inputs).default_output() +def truncated_moment(op: TruncatedRV, truncated_rv, *inputs): + *rv_inputs, lower, upper = inputs + untruncated_rv = op.base_rv_op.make_node(*rv_inputs).default_output() untruncated_moment = moment(untruncated_rv) fallback_moment = pt.switch( @@ -299,31 +343,25 @@ def truncated_moment(op, rv, *inputs): @_default_transform.register(TruncatedRV) -def truncated_default_transform(op, rv): +def truncated_default_transform(op, truncated_rv): # Don't transform discrete truncated distributions - if op.base_rv_op.dtype.startswith("int"): + if truncated_rv.type.dtype.startswith("int"): return None - # Lower and Upper are the arguments -3 and -2 - return bounded_cont_transform(op, rv, bound_args_indices=(-3, -2)) + # Lower and Upper are the arguments -2 and -1 + return bounded_cont_transform(op, truncated_rv, bound_args_indices=(-2, -1)) @_logprob.register(TruncatedRV) def truncated_logprob(op, values, *inputs, **kwargs): (value,) = values - - *rv_inputs, lower, upper, rng = inputs - rv_inputs = [rng, *rv_inputs] + *rv_inputs, lower, upper = inputs base_rv_op = op.base_rv_op - logp = _logprob(base_rv_op, (value,), *rv_inputs, **kwargs) - # For left truncated RVs, we don't want to include the lower bound in the - # normalization term - lower_value = lower - 1 if base_rv_op.dtype.startswith("int") else lower - lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs) - upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs) - + base_rv = base_rv_op.make_node(*rv_inputs).default_output() + base_logp = logp(base_rv, value) + lower_logcdf, upper_logcdf = Truncated._create_logcdf_exprs(base_rv, value, lower, upper) if base_rv_op.name: - logp.name = f"{base_rv_op}_logprob" + base_logp.name = f"{base_rv_op}_logprob" lower_logcdf.name = f"{base_rv_op}_lower_logcdf" upper_logcdf.name = f"{base_rv_op}_upper_logcdf" @@ -338,37 +376,31 @@ def truncated_logprob(op, values, *inputs, **kwargs): elif is_upper_bounded: lognorm = upper_logcdf - logp = logp - lognorm + truncated_logp = base_logp - lognorm if is_lower_bounded: - logp = pt.switch(value < lower, -np.inf, logp) + truncated_logp = pt.switch(value < lower, -np.inf, truncated_logp) if is_upper_bounded: - logp = pt.switch(value <= upper, logp, -np.inf) + truncated_logp = pt.switch(value <= upper, truncated_logp, -np.inf) if is_lower_bounded and is_upper_bounded: - logp = check_parameters( - logp, + truncated_logp = check_parameters( + truncated_logp, pt.le(lower, upper), msg="lower_bound <= upper_bound", ) - return logp + return truncated_logp @_logcdf.register(TruncatedRV) -def truncated_logcdf(op, value, *inputs, **kwargs): - *rv_inputs, lower, upper, rng = inputs - rv_inputs = [rng, *rv_inputs] - - base_rv_op = op.base_rv_op - logcdf = _logcdf(base_rv_op, value, *rv_inputs, **kwargs) +def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs): + *rv_inputs, lower, upper = inputs - # For left truncated discrete RVs, we don't want to include the lower bound in the - # normalization term - lower_value = lower - 1 if base_rv_op.dtype.startswith("int") else lower - lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs) - upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs) + base_rv = op.base_rv_op.make_node(*rv_inputs).default_output() + base_logcdf = logcdf(base_rv, value) + lower_logcdf, upper_logcdf = Truncated._create_logcdf_exprs(base_rv, value, lower, upper) is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value))) is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value))) @@ -381,7 +413,7 @@ def truncated_logcdf(op, value, *inputs, **kwargs): elif is_upper_bounded: lognorm = upper_logcdf - logcdf_numerator = logdiffexp(logcdf, lower_logcdf) if is_lower_bounded else logcdf + logcdf_numerator = logdiffexp(base_logcdf, lower_logcdf) if is_lower_bounded else base_logcdf logcdf_trunc = logcdf_numerator - lognorm if is_lower_bounded: diff --git a/tests/distributions/test_mixture.py b/tests/distributions/test_mixture.py index 9632efd859f..2288a001ed0 100644 --- a/tests/distributions/test_mixture.py +++ b/tests/distributions/test_mixture.py @@ -1602,8 +1602,8 @@ def test_hurdle_negativebinomial_graph(self): _, nonzero_dist = self.check_hurdle_mixture_graph(dist) assert isinstance(nonzero_dist.owner.op.base_rv_op, NegativeBinomial) - assert nonzero_dist.owner.inputs[2].data == n - assert nonzero_dist.owner.inputs[3].data == p + assert nonzero_dist.owner.inputs[-4].data == n + assert nonzero_dist.owner.inputs[-3].data == p def test_hurdle_gamma_graph(self): psi, alpha, beta = 0.25, 3, 4 @@ -1613,8 +1613,8 @@ def test_hurdle_gamma_graph(self): # Under the hood it uses the shape-scale parametrization of the Gamma distribution. # So the second value is the reciprocal of the rate (i.e. 1 / beta) assert isinstance(nonzero_dist.owner.op.base_rv_op, Gamma) - assert nonzero_dist.owner.inputs[2].data == alpha - assert nonzero_dist.owner.inputs[3].eval() == 1 / beta + assert nonzero_dist.owner.inputs[-4].data == alpha + assert nonzero_dist.owner.inputs[-3].eval() == 1 / beta def test_hurdle_lognormal_graph(self): psi, mu, sigma = 0.1, 2, 2.5 @@ -1622,8 +1622,8 @@ def test_hurdle_lognormal_graph(self): _, nonzero_dist = self.check_hurdle_mixture_graph(dist) assert isinstance(nonzero_dist.owner.op.base_rv_op, LogNormal) - assert nonzero_dist.owner.inputs[2].data == mu - assert nonzero_dist.owner.inputs[3].data == sigma + assert nonzero_dist.owner.inputs[-4].data == mu + assert nonzero_dist.owner.inputs[-3].data == sigma @pytest.mark.parametrize( "dist, psi, non_psi_args", diff --git a/tests/distributions/test_truncated.py b/tests/distributions/test_truncated.py index d9d007c51f1..ac0811eb0f9 100644 --- a/tests/distributions/test_truncated.py +++ b/tests/distributions/test_truncated.py @@ -18,9 +18,17 @@ import scipy from pytensor.tensor.random.basic import GeometricRV, NormalRV - -from pymc import Censored, Model, draw, find_MAP -from pymc.distributions.continuous import Exponential, Gamma, TruncatedNormalRV +from pytensor.tensor.random.type import RandomType + +from pymc import Censored, CustomDist, Mixture, Model, draw, find_MAP +from pymc.distributions.continuous import ( + ChiSquared, + Exponential, + Gamma, + HalfNormal, + LogNormal, + TruncatedNormalRV, +) from pymc.distributions.shape_utils import change_dist_size from pymc.distributions.transforms import _default_transform from pymc.distributions.truncated import Truncated, TruncatedRV, _truncated @@ -29,6 +37,11 @@ from pymc.logprob.basic import logcdf, logp from pymc.logprob.transforms import IntervalTransform from pymc.logprob.utils import ParameterValueError +from pymc.pytensorf import ( + collect_default_updates, + collect_default_updates_inner_fgraph, + constant_fold, +) from pymc.testing import assert_moment_is_expected @@ -54,6 +67,24 @@ class RejectionGeometricRV(GeometricRV): rejection_geometric = RejectionGeometricRV() +def icdf_normal_customdist(loc, scale, name=None, size=None): + def dist(loc, scale, size): + return loc + icdf_normal(size=size) * scale + + x = CustomDist.dist(loc, scale, dist=dist, size=size) + x.name = name + return x + + +def rejection_normal_customdist(loc, scale, name=None, size=None): + def dist(loc, scale, size): + return loc + rejection_normal(size=size) * scale + + x = CustomDist.dist(loc, scale, dist=dist, size=size) + x.name = name + return x + + @_truncated.register(IcdfNormalRV) @_truncated.register(RejectionNormalRV) @_truncated.register(IcdfGeometricRV) @@ -102,10 +133,14 @@ def test_truncation_specialized_op(shape_info): @pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)]) @pytest.mark.parametrize("op_type", ["icdf", "rejection"]) @pytest.mark.parametrize("scalar", [True, False]) -def test_truncation_continuous_random(op_type, lower, upper, scalar): +@pytest.mark.parametrize("custom_dist", [False, True]) +def test_truncation_continuous_random(op_type, lower, upper, scalar, custom_dist): loc = 0.15 scale = 10 - normal_op = icdf_normal if op_type == "icdf" else rejection_normal + if custom_dist: + normal_op = icdf_normal_customdist if op_type == "icdf" else rejection_normal_customdist + else: + normal_op = icdf_normal if op_type == "icdf" else rejection_normal x = normal_op(loc, scale, name="x", size=() if scalar else (100,)) xt = Truncated.dist(x, lower=lower, upper=upper) @@ -140,10 +175,14 @@ def test_truncation_continuous_random(op_type, lower, upper, scalar): @pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)]) @pytest.mark.parametrize("op_type", ["icdf", "rejection"]) -def test_truncation_continuous_logp(op_type, lower, upper): +@pytest.mark.parametrize("custom_dist", [False, True]) +def test_truncation_continuous_logp(op_type, lower, upper, custom_dist): loc = 0.15 scale = 10 - op = icdf_normal if op_type == "icdf" else rejection_normal + if custom_dist: + op = icdf_normal_customdist if op_type == "icdf" else rejection_normal_customdist + else: + op = icdf_normal if op_type == "icdf" else rejection_normal x = op(loc, scale, name="x") xt = Truncated.dist(x, lower=lower, upper=upper) @@ -168,10 +207,14 @@ def test_truncation_continuous_logp(op_type, lower, upper): @pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)]) @pytest.mark.parametrize("op_type", ["icdf", "rejection"]) -def test_truncation_continuous_logcdf(op_type, lower, upper): +@pytest.mark.parametrize("custom_dist", [False, True]) +def test_truncation_continuous_logcdf(op_type, lower, upper, custom_dist): loc = 0.15 scale = 10 - op = icdf_normal if op_type == "icdf" else rejection_normal + if custom_dist: + op = icdf_normal_customdist if op_type == "icdf" else rejection_normal_customdist + else: + op = icdf_normal if op_type == "icdf" else rejection_normal x = op(loc, scale, name="x") xt = Truncated.dist(x, lower=lower, upper=upper) @@ -423,3 +466,59 @@ def test_truncated_gamma(): logp_resized_pymc, logp_scipy, ) + + +def test_truncated_multiple_rngs(): + def mix_dist_fn(size): + return Mixture.dist( + w=[0.3, 0.7], comp_dists=[HalfNormal.dist(), LogNormal.dist()], shape=size + ) + + upper = 0.1 + x = CustomDist.dist(dist=mix_dist_fn) + x_trunc = Truncated.dist(x, lower=0, upper=upper, shape=(5,)) + + # Mixture doesn't have an icdf method, so TruncatedRV uses a RejectionSampling representation + # Check that RNGs updates are correct + # TODO: Find out way of testing updates were not mixed + rngs = [inp for inp in x_trunc.owner.inputs if isinstance(inp.type, RandomType)] + next_rngs = [out for out in x_trunc.owner.outputs if isinstance(out.type, RandomType)] + assert len(set(rngs)) == len(set(next_rngs)) == 3 + + draws1 = draw(x_trunc, random_seed=1) + draws2 = draw(x_trunc, random_seed=1) + draws3 = draw(x_trunc, random_seed=2) + assert np.unique(draws1).size == 5 + assert np.unique(draws3).size == 5 + assert np.all(draws1 == draws2) + assert np.all(draws1 != draws3) + + test_x = np.array([-1, 0, 1, 2, 3]) + mix_rv = mix_dist_fn((5,)) + expected_logp = logp(mix_rv, test_x) - logcdf(mix_rv, upper) + expected_logp = pt.where(test_x <= upper, expected_logp, -np.inf) + np.testing.assert_allclose( + logp(x_trunc, test_x).eval(), + expected_logp.eval(), + ) + + +def test_truncated_maxwell_dist(): + def maxwell_dist(scale, size): + return pt.sqrt(ChiSquared.dist(nu=3, size=size)) * scale + + scale = 5.0 + upper = 2.0 + x = CustomDist.dist(scale, dist=maxwell_dist) + trunc_x = Truncated.dist(x, lower=None, upper=upper, size=(5,)) + assert np.all(draw(trunc_x, draws=20) < 2) + + test_value = np.array([-0.5, 0.0, 0.5, 1.5, 2.5]) + expected_logp = scipy.stats.maxwell.logpdf( + test_value, scale=scale + ) - scipy.stats.maxwell.logcdf(upper, scale=scale) + expected_logp[(test_value <= 0) | (test_value > upper)] = -np.inf + np.testing.assert_allclose( + logp(trunc_x, test_value).eval(), + expected_logp, + )