Skip to content

Commit

Permalink
Implement truncated variables
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Apr 11, 2022
1 parent a330723 commit 77343ca
Show file tree
Hide file tree
Showing 3 changed files with 433 additions and 3 deletions.
6 changes: 6 additions & 0 deletions aeppl/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def xlogy0(m, x):
return at.switch(at.eq(x, 0), at.switch(at.eq(m, 0), 0.0, -np.inf), m * at.log(x))


def logdiffexp(a, b):
"""log(exp(a) - exp(b))"""
# TODO: This should be a basic Aesara stabilization
return a + at.log1mexp(b - a)


def logprob(rv_var, *rv_values, **kwargs):
"""Create a graph for the log-probability of a ``RandomVariable``."""
logprob = _logprob(rv_var.owner.op, rv_values, *rv_var.owner.inputs, **kwargs)
Expand Down
239 changes: 237 additions & 2 deletions aeppl/truncation.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,33 @@
import warnings
from functools import singledispatch
from typing import List, Optional

import aesara.tensor as at
import aesara.tensor.random.basic as arb
import numpy as np
from aesara import scan, shared
from aesara.compile.builders import OpFromGraph
from aesara.graph import Op
from aesara.graph.basic import Node
from aesara.graph.fg import FunctionGraph
from aesara.graph.opt import local_optimizer
from aesara.raise_op import Assert
from aesara.scalar.basic import Clip
from aesara.scalar.basic import clip as scalar_clip
from aesara.scan import until
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.var import TensorConstant
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.var import TensorConstant, TensorVariable

from aeppl.abstract import MeasurableVariable, assign_custom_measurable_outputs
from aeppl.logprob import CheckParameterValue, _logcdf, _logprob
from aeppl.logprob import (
CheckParameterValue,
_logcdf,
_logprob,
icdf,
logcdf,
logdiffexp,
)
from aeppl.opt import rv_sinking_db


Expand Down Expand Up @@ -123,3 +138,223 @@ def censor_logprob(op, values, base_rv, lower_bound, upper_bound, **kwargs):
)

return logprob


class TruncatedRV(OpFromGraph):
"""An `Op` constructed from an Aesara graph that represents a truncated univariate RV."""

default_output = 1
base_rv_op = None

def __init__(self, base_rv_op: Op, *args, **kwargs):
self.base_rv_op = base_rv_op
super().__init__(*args, **kwargs)


@singledispatch
def _truncated(op: Op, lower, upper, *params):
"""Return the truncated equivalent of another ``RandomVariable``."""
raise NotImplementedError(
f"{op} does not have an equivalent truncated version implemented"
)


def truncate(
rv: TensorVariable, lower=None, upper=None, max_n_steps: int = 10_000, rng=None
):
"""Truncate a univariate RandomVariable between lower and upper.
If lower or upper is ``None``, the variable is not truncated on that side.
Depending on dispatched implementations, this function returns either a specialized
`Op`, or equivalent graph representing the truncation process, via inverse CDF
sampling, or rejection sampling.
The argument `max_n_steps` controls the maximum number of resamples that are
attempted when performing rejection sampling. An Error is raised if convergence is
not reached after that many steps.
TODO: Add Note about updates
"""

lower = at.as_tensor_variable(lower) if lower is not None else at.constant(-np.inf)
upper = at.as_tensor_variable(upper) if upper is not None else at.constant(np.inf)
if lower is None and upper is None:
raise ValueError("lower and upper cannot both be None")

if not isinstance(rv.owner.op, RandomVariable):
raise ValueError(f"truncation not implemented for Op {rv.owner.op}")

if rv.owner.op.ndim_supp > 0:
raise NotImplementedError(
"truncation not implemented for multivariate variables"
)

if rng is None:
rng = shared(np.random.RandomState(), borrow=True)

# Try to use specialized Op
try:
truncated_rv = _truncated(rv.owner.op, lower, upper, rng, *rv.owner.inputs[1:])
truncated_rv.update = (
truncated_rv.owner.inputs[0],
truncated_rv.owner.outputs[0],
)
return truncated_rv
except NotImplementedError:
pass

# Variables with `_` suffix identify dummy inputs for the OpFromGraph
graph_inputs = [rng, *rv.owner.inputs[1:], lower, upper]
graph_inputs_ = [inp.type() for inp in graph_inputs]
*rv_inputs_, lower_, upper_ = graph_inputs_

# Try to use inverted cdf sampling
try:
rv_ = rv.owner.op.make_node(*rv_inputs_).default_output()
# For left truncated discrete RVs, we need to include the whole lower bound.
# This may result in draws below the truncation range, if any uniform == 0
lower_value = lower_ - 1 if rv.owner.op.dtype.startswith("int") else lower_
cdf_lower_ = at.exp(logcdf(rv_, lower_value))
cdf_upper_ = at.exp(logcdf(rv_, upper_))
uniform_ = at.random.uniform(
cdf_lower_,
cdf_upper_,
rng=rv_inputs_[0],
size=rv_inputs_[1],
)
truncated_rv_ = icdf(rv_, uniform_)
truncated_rv = TruncatedRV(
base_rv_op=rv.owner.op,
inputs=graph_inputs_,
outputs=[uniform_.owner.outputs[0], truncated_rv_],
inline=True,
)(*graph_inputs)
truncated_rv.update = (
truncated_rv.owner.inputs[0],
truncated_rv.owner.outputs[0],
)
return truncated_rv
except NotImplementedError:
pass

# Fallback to rejection sampling
# TODO: Handle potential broadcast by lower / upper

# Scan forces us to use a shared variable for the RNG
graph_inputs = graph_inputs[1:]
graph_inputs_ = graph_inputs_[1:]
*rv_inputs_, lower_, upper_ = (rng, *graph_inputs_)

rv_ = rv.owner.op.make_node(*rv_inputs_).default_output()

def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
# We need to set default_update for scan to generate updates
next_rng, new_truncated_rv = rv.owner.op.make_node(rng, *rv_inputs).outputs
rng.default_update = next_rng

truncated_rv = at.set_subtensor(
truncated_rv[reject_draws],
new_truncated_rv[reject_draws],
)
reject_draws = at.or_((truncated_rv < lower), (truncated_rv > upper))

return (truncated_rv, reject_draws), until(~at.any(reject_draws))

(truncated_rv_, reject_draws_), updates = scan(
loop_fn,
outputs_info=[
at.empty_like(rv_),
at.ones_like(rv_, dtype=bool),
],
non_sequences=[lower_, upper_, *rv_inputs_],
n_steps=max_n_steps,
strict=True,
)

truncated_rv_ = truncated_rv_[-1]
convergence_ = ~at.any(reject_draws_[-1])
truncated_rv_ = Assert(
f"truncation did not converge in predefined {max_n_steps} steps"
)(truncated_rv_, convergence_)

# TODO: Scan does not return updates when a single step is performed, so this
# will fail with max_n_steps = 1
truncated_rv = TruncatedRV(
base_rv_op=rv.owner.op,
inputs=graph_inputs_,
outputs=[tuple(updates.values())[0], truncated_rv_],
inline=True,
)(*graph_inputs)
truncated_rv.update = (truncated_rv.owner.inputs[-1], truncated_rv.owner.outputs[0])
return truncated_rv


@_logprob.register(TruncatedRV)
def truncated_logprob(op, values, *inputs, **kwargs):
(value,) = values

# rng shows up as the last input when using rejection sampling
if op.shared_inputs:
*rv_inputs, lower_bound, upper_bound, rng = inputs
rv_inputs = [rng, *rv_inputs]
else:
*rv_inputs, lower_bound, upper_bound = inputs

base_rv_op = op.base_rv_op
logp = _logprob(base_rv_op, (value,), *rv_inputs, **kwargs)
# For left truncated RVs, we don't want to include the lower bound in the
# normalization term
lower_bound_value = (
lower_bound - 1 if base_rv_op.dtype.startswith("int") else lower_bound
)
lower_logcdf = _logcdf(base_rv_op, lower_bound_value, *rv_inputs, **kwargs)
upper_logcdf = _logcdf(base_rv_op, upper_bound, *rv_inputs, **kwargs)

if base_rv_op.name:
logp.name = f"{base_rv_op}_logprob"
lower_logcdf.name = f"{base_rv_op}_lower_logcdf"
upper_logcdf.name = f"{base_rv_op}_upper_logcdf"

is_lower_bounded = not (
isinstance(lower_bound, TensorConstant)
and np.all(np.isneginf(lower_bound.value))
)
is_upper_bounded = not (
isinstance(upper_bound, TensorConstant) and np.all(np.isinf(upper_bound.value))
)

if is_lower_bounded and is_upper_bounded:
lognorm = logdiffexp(upper_logcdf, lower_logcdf)
elif is_lower_bounded:
lognorm = at.log1mexp(lower_logcdf)
elif is_upper_bounded:
lognorm = upper_logcdf
else:
lognorm = 0

logp = logp - lognorm

if is_lower_bounded:
logp = at.switch(value < lower_bound, -np.inf, logp)

if is_upper_bounded:
logp = at.switch(value <= upper_bound, logp, -np.inf)

if is_lower_bounded and is_upper_bounded:
logp = CheckParameterValue("lower_bound <= upper_bound")(
logp, at.all(at.le(lower_bound, upper_bound))
)

return logp


@_truncated.register(arb.UniformRV)
def uniform_truncated(op, lower, upper, rng, size, dtype, lower_orig, upper_orig):
return at.random.uniform(
at.max((lower_orig, lower)),
at.min((upper_orig, upper)),
rng=rng,
size=size,
dtype=dtype,
)
Loading

0 comments on commit 77343ca

Please sign in to comment.