diff --git a/aeppl/truncation.py b/aeppl/truncation.py new file mode 100644 index 00000000..830bf490 --- /dev/null +++ b/aeppl/truncation.py @@ -0,0 +1,265 @@ +from functools import singledispatch +from typing import Optional, Tuple + +import aesara.tensor as at +import aesara.tensor.random.basic as arb +import numpy as np +from aesara import scan +from aesara.compile.builders import OpFromGraph +from aesara.graph.op import Op +from aesara.raise_op import CheckAndRaise +from aesara.scan import until +from aesara.tensor.random import RandomStream +from aesara.tensor.random.op import RandomVariable +from aesara.tensor.var import TensorConstant, TensorVariable + +from aeppl.abstract import MeasurableVariable, _get_measurable_outputs +from aeppl.logprob import ( + CheckParameterValue, + _logcdf, + _logprob, + icdf, + logcdf, + logdiffexp, +) + + +class TruncatedRV(OpFromGraph): + """An `Op` constructed from an Aesara graph that represents a truncated univariate random variable.""" + + default_output = 0 + base_rv_op = None + + def __init__(self, base_rv_op: Op, *args, **kwargs): + self.base_rv_op = base_rv_op + super().__init__(*args, **kwargs) + + +MeasurableVariable.register(TruncatedRV) + + +@_get_measurable_outputs.register(TruncatedRV) +def _get_measurable_outputs_TruncatedRV(op, node): + return [node.outputs[0]] + + +@singledispatch +def _truncated(op: Op, lower, upper, *params): + """Return the truncated equivalent of another `RandomVariable`.""" + raise NotImplementedError( + f"{op} does not have an equivalent truncated version implemented" + ) + + +class TruncationError(Exception): + """Exception for errors generated from truncated graphs""" + + +class TruncationCheck(CheckAndRaise): + """Implements a check in truncated graphs. + + Raises `TruncationError` if the check is not True. + """ + + def __init__(self, msg=""): + super().__init__(TruncationError, msg) + + def __str__(self): + return f"TruncationCheck{{{self.msg}}}" + + +def truncate( + rv: TensorVariable, + lower=None, + upper=None, + max_n_steps: int = 10_000, + srng: Optional[RandomStream] = None, +) -> Tuple[TensorVariable, Tuple[TensorVariable, TensorVariable]]: + """Truncate a univariate `RandomVariable` between `lower` and `upper`. + + If `lower` or `upper` is ``None``, the variable is not truncated on that side. + + Depending on whether or not a dispatch implementation is available, this + function returns either a specialized `Op`, or an equivalent graph + representing the truncation process via inverse CDF or rejection + sampling. + + The argument `max_n_steps` controls the maximum number of resamples that are + attempted when performing rejection sampling. A `TruncationError` is raised if + convergence is not reached after that many steps. + + Returns + ======= + `TensorVariable` graph representing the truncated `RandomVariable` and respective updates + """ + + if lower is None and upper is None: + raise ValueError("lower and upper cannot both be None") + + if not (isinstance(rv.owner.op, RandomVariable) and rv.owner.op.ndim_supp == 0): + raise NotImplementedError( + f"Truncation is only implemented for univariate random variables, got {rv.owner.op}" + ) + + lower = at.as_tensor_variable(lower) if lower is not None else at.constant(-np.inf) + upper = at.as_tensor_variable(upper) if upper is not None else at.constant(np.inf) + + if srng is None: + srng = RandomStream() + + # Try to use specialized Op + try: + truncated_rv, updates = _truncated( + rv.owner.op, lower, upper, srng, *rv.owner.inputs[1:] + ) + return truncated_rv, updates + except NotImplementedError: + pass + + # Variables with `_` suffix identify dummy inputs for the OpFromGraph + # We will use the Shared RNG variable directly because Scan demands it, even + # though it would not be necessary for the icdf OpFromGraph + graph_inputs = [*rv.owner.inputs[1:], lower, upper] + graph_inputs_ = [inp.type() for inp in graph_inputs] + size_, dtype_, *rv_inputs_, lower_, upper_ = graph_inputs_ + rv_ = srng.gen(rv.owner.op, *rv_inputs_, size=size_, dtype=dtype_) + + # Try to use inverted cdf sampling + try: + # For left truncated discrete RVs, we need to include the whole lower bound. + # This may result in draws below the truncation range, if any uniform == 0 + lower_value = lower_ - 1 if rv.owner.op.dtype.startswith("int") else lower_ + cdf_lower_ = at.exp(logcdf(rv_, lower_value)) + cdf_upper_ = at.exp(logcdf(rv_, upper_)) + uniform_ = srng.uniform( + cdf_lower_, + cdf_upper_, + size=size_, + ) + truncated_rv_ = icdf(rv_, uniform_) + truncated_rv = TruncatedRV( + base_rv_op=rv.owner.op, + inputs=graph_inputs_, + outputs=[truncated_rv_, uniform_.owner.outputs[0]], + inline=True, + )(*graph_inputs) + updates = {truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1]} + return truncated_rv, updates + except NotImplementedError: + pass + + # Fallback to rejection sampling + # TODO: Handle potential broadcast by lower / upper + def loop_fn(truncated_rv, reject_draws, lower, upper, size, dtype, *rv_inputs): + new_truncated_rv = srng.gen(rv.owner.op, *rv_inputs, size=size, dtype=dtype) # type: ignore + truncated_rv = at.set_subtensor( + truncated_rv[reject_draws], + new_truncated_rv[reject_draws], + ) + reject_draws = at.or_((truncated_rv < lower), (truncated_rv > upper)) + + return (truncated_rv, reject_draws), until(~at.any(reject_draws)) + + (truncated_rv_, reject_draws_), updates = scan( + loop_fn, + outputs_info=[ + at.zeros_like(rv_), + at.ones_like(rv_, dtype=bool), + ], + non_sequences=[lower_, upper_, size_, dtype_, *rv_inputs_], + n_steps=max_n_steps, + strict=True, + ) + + truncated_rv_ = truncated_rv_[-1] + convergence_ = ~at.any(reject_draws_[-1]) + truncated_rv_ = TruncationCheck( + f"Truncation did not converge in {max_n_steps} steps" + )(truncated_rv_, convergence_) + + truncated_rv = TruncatedRV( + base_rv_op=rv.owner.op, + inputs=graph_inputs_, + # This will fail with `n_steps==1`, because in that case `Scan` won't return any updates + outputs=[truncated_rv_, rv_.owner.outputs[0], tuple(updates.values())[0]], + inline=True, + )(*graph_inputs) + # TODO: Is the order of multiple shared variables determnistic? + assert truncated_rv.owner.inputs[-2] is rv_.owner.inputs[0] + updates = { + truncated_rv.owner.inputs[-2]: truncated_rv.owner.outputs[-2], + truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1], + } + return truncated_rv, updates + + +@_logprob.register(TruncatedRV) +def truncated_logprob(op, values, *inputs, **kwargs): + (value,) = values + + # Rejection sample graph has two rngs + if len(op.shared_inputs) == 2: + *rv_inputs, lower_bound, upper_bound, _, rng = inputs + else: + *rv_inputs, lower_bound, upper_bound, rng = inputs + rv_inputs = [rng, *rv_inputs] + + base_rv_op = op.base_rv_op + logp = _logprob(base_rv_op, (value,), *rv_inputs, **kwargs) + # For left truncated RVs, we don't want to include the lower bound in the + # normalization term + lower_bound_value = ( + lower_bound - 1 if base_rv_op.dtype.startswith("int") else lower_bound + ) + lower_logcdf = _logcdf(base_rv_op, lower_bound_value, *rv_inputs, **kwargs) + upper_logcdf = _logcdf(base_rv_op, upper_bound, *rv_inputs, **kwargs) + + if base_rv_op.name: + logp.name = f"{base_rv_op}_logprob" + lower_logcdf.name = f"{base_rv_op}_lower_logcdf" + upper_logcdf.name = f"{base_rv_op}_upper_logcdf" + + is_lower_bounded = not ( + isinstance(lower_bound, TensorConstant) + and np.all(np.isneginf(lower_bound.value)) + ) + is_upper_bounded = not ( + isinstance(upper_bound, TensorConstant) and np.all(np.isinf(upper_bound.value)) + ) + + lognorm = 0 + if is_lower_bounded and is_upper_bounded: + lognorm = logdiffexp(upper_logcdf, lower_logcdf) + elif is_lower_bounded: + lognorm = at.log1mexp(lower_logcdf) + elif is_upper_bounded: + lognorm = upper_logcdf + + logp = logp - lognorm + + if is_lower_bounded: + logp = at.switch(value < lower_bound, -np.inf, logp) + + if is_upper_bounded: + logp = at.switch(value <= upper_bound, logp, -np.inf) + + if is_lower_bounded and is_upper_bounded: + logp = CheckParameterValue("lower_bound <= upper_bound")( + logp, at.all(at.le(lower_bound, upper_bound)) + ) + + return logp + + +@_truncated.register(arb.UniformRV) +def uniform_truncated(op, lower, upper, srng, size, dtype, lower_orig, upper_orig): + truncated_uniform = srng.gen( + op, + at.max((lower_orig, lower)), + at.min((upper_orig, upper)), + size=size, + dtype=dtype, + ) + return truncated_uniform, { + truncated_uniform.owner.inputs[0]: truncated_uniform.owner.outputs[0] + } diff --git a/tests/test_truncation.py b/tests/test_truncation.py new file mode 100644 index 00000000..7c64b2bf --- /dev/null +++ b/tests/test_truncation.py @@ -0,0 +1,221 @@ +import aesara +import aesara.tensor as at +import numpy as np +import pytest +import scipy.stats +import scipy.stats as st +from aesara.tensor.random.basic import GeometricRV, NormalRV, UniformRV + +from aeppl import joint_logprob, logprob +from aeppl.logprob import ParameterValueError, _icdf +from aeppl.truncation import TruncatedRV, TruncationError, _truncated, truncate + + +class IcdfNormalRV(NormalRV): + """Normal RV that has icdf but not truncated dispatching""" + + +class RejectionNormalRV(NormalRV): + """Normal RV that has neither icdf nor truncated dispatching.""" + + +class IcdfGeometricRV(GeometricRV): + """Geometric RV that has neither icdf nor truncated dispatching.""" + + +class RejectionGeometricRV(GeometricRV): + """Geometric RV that has neither icdf nor truncated dispatching.""" + + +icdf_normal = IcdfNormalRV() +rejection_normal = RejectionNormalRV() +icdf_geometric = IcdfGeometricRV() +rejection_geometric = RejectionGeometricRV() + + +@_truncated.register(IcdfNormalRV) +@_truncated.register(RejectionNormalRV) +@_truncated.register(IcdfGeometricRV) +@_truncated.register(RejectionGeometricRV) +def _truncated_not_implemented(*args, **kwargs): + raise NotImplementedError() + + +@_icdf.register(RejectionNormalRV) +@_icdf.register(RejectionGeometricRV) +def _icdf_not_implemented(*args, **kwargs): + raise NotImplementedError() + + +def test_truncation_specialized_op(): + x = at.random.uniform(0, 10, name="x", size=100) + + srng = at.random.RandomStream() + xt, _ = truncate(x, lower=5, upper=15, srng=srng) + assert isinstance(xt.owner.op, UniformRV) + assert xt.owner.inputs[0] is srng.updates()[0][0] + + lower_upper = at.stack(xt.owner.inputs[3:]) + assert np.all(lower_upper.eval() == [5, 10]) + + +@pytest.mark.filterwarnings("ignore:Rewrite warning") +@pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)]) +@pytest.mark.parametrize("op_type", ["icdf", "rejection"]) +def test_truncation_continuous_random(op_type, lower, upper): + loc = 0.15 + scale = 10 + normal_op = icdf_normal if op_type == "icdf" else rejection_normal + x = normal_op(loc, scale, name="x", size=100) + + srng = at.random.RandomStream() + xt, xt_update = truncate(x, lower=lower, upper=upper, srng=srng) + assert isinstance(xt.owner.op, TruncatedRV) + assert xt.owner.inputs[-1] is srng.updates()[1 if op_type == "icdf" else 2][0] + assert xt.type.dtype == x.type.dtype + assert xt.type.ndim == x.type.ndim + + # Check that original op can be used on its own + assert x.eval().shape == (100,) + + xt_fn = aesara.function([], xt, updates=xt_update) + xt_draws = np.array([xt_fn() for _ in range(5)]) + assert np.all(xt_draws >= lower) + assert np.all(xt_draws <= upper) + assert np.unique(xt_draws).size == xt_draws.size + + # Compare with reference + ref_xt = scipy.stats.truncnorm( + (lower - loc) / scale, + (upper - loc) / scale, + loc, + scale, + ) + assert scipy.stats.cramervonmises(xt_draws.ravel(), ref_xt.cdf).pvalue > 0.001 + + # Test max_n_steps + xt, xt_update = truncate(x, lower=lower, upper=upper, max_n_steps=2) + xt_fn = aesara.function([], xt, updates=xt_update) + if op_type == "icdf": + xt_draws = xt_fn() + assert np.all(xt_draws >= lower) + assert np.all(xt_draws <= upper) + assert np.unique(xt_draws).size == xt_draws.size + else: + with pytest.raises(TruncationError, match="^Truncation did not converge"): + xt_fn() + + +@pytest.mark.parametrize("lower, upper", [(-1, np.inf), (-1, 1.5), (-np.inf, 1.5)]) +@pytest.mark.parametrize("op_type", ["icdf", "rejection"]) +def test_truncation_continuous_logp(op_type, lower, upper): + loc = 0.15 + scale = 10 + op = icdf_normal if op_type == "icdf" else rejection_normal + + x = op(loc, scale, name="x") + xt, _ = truncate(x, lower=lower, upper=upper) + assert isinstance(xt.owner.op, TruncatedRV) + + logp, (xt_vv,) = joint_logprob(xt) + xt_logp_fn = aesara.function([xt_vv], logp) + + ref_xt = scipy.stats.truncnorm( + (lower - loc) / scale, + (upper - loc) / scale, + loc, + scale, + ) + for bound in (lower, upper): + if np.isinf(bound): + return + for offset in (-1, 0, 1): + test_xt_v = bound + offset + assert np.isclose(xt_logp_fn(test_xt_v), ref_xt.logpdf(test_xt_v)) + + +@pytest.mark.parametrize("lower, upper", [(2, np.inf), (2, 5), (-np.inf, 5)]) +@pytest.mark.parametrize("op_type", ["icdf", "rejection"]) +def test_truncation_discrete_random(op_type, lower, upper): + p = 0.2 + geometric_op = icdf_geometric if op_type == "icdf" else rejection_geometric + + x = geometric_op(p, name="x", size=500) + xt, xt_update = truncate(x, lower=lower, upper=upper) + assert isinstance(xt.owner.op, TruncatedRV) + assert xt.type.is_super(x.type) + + xt_draws = aesara.function([], xt, updates=xt_update)() + assert np.all(xt_draws >= lower) + assert np.all(xt_draws <= upper) + assert np.any(xt_draws == (max(1, lower))) + if upper != np.inf: + assert np.any(xt_draws == upper) + + # Test max_n_steps + xt, xt_update = truncate(x, lower=lower, upper=upper, max_n_steps=3) + xt_fn = aesara.function([], xt, updates=xt_update) + if op_type == "icdf": + xt_draws = xt_fn() + assert np.all(xt_draws >= lower) + assert np.all(xt_draws <= upper) + assert np.any(xt_draws == (max(1, lower))) + if upper != np.inf: + assert np.any(xt_draws == upper) + else: + with pytest.raises(TruncationError, match="^Truncation did not converge"): + xt_fn() + + +@pytest.mark.parametrize("lower, upper", [(2, np.inf), (2, 5), (-np.inf, 5)]) +@pytest.mark.parametrize("op_type", ["icdf", "rejection"]) +def test_truncation_discrete_logp(op_type, lower, upper): + p = 0.7 + op = icdf_geometric if op_type == "icdf" else rejection_geometric + + x = op(p, name="x") + xt, _ = truncate(x, lower=lower, upper=upper) + assert isinstance(xt.owner.op, TruncatedRV) + + xt_vv = xt.clone() + xt_logp_fn = aesara.function([xt_vv], logprob(xt, xt_vv)) + + ref_xt = st.geom(p) + log_norm = np.log(ref_xt.cdf(upper) - ref_xt.cdf(lower - 1)) + + def ref_xt_logpmf(value): + if value < lower or value > upper: + return -np.inf + return ref_xt.logpmf(value) - log_norm + + for bound in (lower, upper): + if np.isinf(bound): + continue + for offset in (-1, 0, 1): + test_xt_v = bound + offset + assert np.isclose(xt_logp_fn(test_xt_v), ref_xt_logpmf(test_xt_v)) + + # Check that it integrates to 1 + log_integral = scipy.special.logsumexp( + [xt_logp_fn(v) for v in range(min(upper + 1, 20))] + ) + assert np.isclose(log_integral, 0.0, atol=1e-5) + + +def test_truncation_exceptions(): + with pytest.raises(ValueError, match="lower and upper cannot both be None"): + truncate(at.random.normal()) + + with pytest.raises(NotImplementedError, match="Truncation is only implemented for"): + truncate(at.clip(at.random.normal(), -1, 1), -1, 1) + + with pytest.raises(NotImplementedError, match="Truncation is only implemented for"): + truncate(at.random.dirichlet([1, 1, 1]), -1, 1) + + +def test_truncation_bound_check(): + x = at.random.normal(name="x") + xt, _ = truncate(x, lower=5, upper=-5) + xt_vv = xt.clone() + with pytest.raises(ParameterValueError): + logprob(xt, xt_vv).eval({xt_vv: 0})