Skip to content

Commit

Permalink
Allow Truncation of CustomDists
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Mar 25, 2024
1 parent cb84c55 commit 92efe38
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 97 deletions.
201 changes: 120 additions & 81 deletions pymc/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@
import pytensor
import pytensor.tensor as pt

from pytensor import scan
from pytensor import config, graph_replace, scan
from pytensor.graph import Op
from pytensor.graph.basic import Node
from pytensor.raise_op import CheckAndRaise
from pytensor.scan import until
from pytensor.tensor import TensorConstant, TensorVariable
from pytensor.tensor.random.basic import NormalRV
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType

from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import (
CustomSymbolicDistRV,
Distribution,
SymbolicRandomVariable,
_support_point,
Expand All @@ -38,8 +40,9 @@
from pymc.distributions.transforms import _default_transform
from pymc.exceptions import TruncationError
from pymc.logprob.abstract import _logcdf, _logprob
from pymc.logprob.basic import icdf, logcdf
from pymc.logprob.basic import icdf, logcdf, logp
from pymc.math import logdiffexp
from pymc.pytensorf import collect_default_updates
from pymc.util import check_dist_not_registered


Expand All @@ -49,11 +52,17 @@ class TruncatedRV(SymbolicRandomVariable):
that represents a truncated univariate random variable.
"""

default_output = 1
base_rv_op = None
max_n_steps = None

def __init__(self, *args, base_rv_op: Op, max_n_steps: int, **kwargs):
default_output: int = 0
base_rv_op: Op
max_n_steps: int

def __init__(
self,
*args,
base_rv_op: Op,
max_n_steps: int,
**kwargs,
):
self.base_rv_op = base_rv_op
self.max_n_steps = max_n_steps
self._print_name = (
Expand All @@ -63,8 +72,13 @@ def __init__(self, *args, base_rv_op: Op, max_n_steps: int, **kwargs):
super().__init__(*args, **kwargs)

def update(self, node: Node):
"""Return the update mapping for the internal RNG."""
return {node.inputs[-1]: node.outputs[0]}
"""Return the update mapping for the internal RNGs.
TruncatedRVs are created in a way that the rng updats follow the same order as the input RNGs.
"""
rngs = [inp for inp in node.inputs if isinstance(inp.type, RandomType)]
next_rngs = [out for out in node.outputs if isinstance(out.type, RandomType)]
return dict(zip(rngs, next_rngs))


@singledispatch
Expand Down Expand Up @@ -141,10 +155,14 @@ class Truncated(Distribution):

@classmethod
def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs):
if not (isinstance(dist, TensorVariable) and isinstance(dist.owner.op, RandomVariable)):
if not (
isinstance(dist, TensorVariable)
and isinstance(dist.owner.op, (RandomVariable, CustomSymbolicDistRV))
):
if isinstance(dist.owner.op, SymbolicRandomVariable):
raise NotImplementedError(
f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}"
f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}.\n"
f"You can try wrapping the distribution inside a CustomDist instead."
)
raise ValueError(
f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}"
Expand Down Expand Up @@ -174,46 +192,54 @@ def rv_op(cls, dist, lower, upper, max_n_steps, size=None):
if size is None:
size = pt.broadcast_shape(dist, lower, upper)
dist = change_dist_size(dist, new_size=size)
rv_inputs = [
inp
if not isinstance(inp.type, RandomType)
else pytensor.shared(np.random.default_rng())
for inp in dist.owner.inputs
]
graph_inputs = [*rv_inputs, lower, upper]

# Variables with `_` suffix identify dummy inputs for the OpFromGraph
graph_inputs = [*dist.owner.inputs[1:], lower, upper]
graph_inputs_ = [inp.type() for inp in graph_inputs]
graph_inputs_ = [
inp.type() if not isinstance(inp.type, RandomType) else inp for inp in graph_inputs
]
*rv_inputs_, lower_, upper_ = graph_inputs_

# We will use a Shared RNG variable because Scan demands it, even though it
# would not be necessary for the OpFromGraph inverse cdf.
rng = pytensor.shared(np.random.default_rng())
rv_ = dist.owner.op.make_node(rng, *rv_inputs_).default_output()
rv_ = dist.owner.op.make_node(*rv_inputs_).default_output()

# Try to use inverted cdf sampling
# truncated_rv = icdf(rv, draw(uniform(cdf(lower), cdf(upper))))
try:
# For left truncated discrete RVs, we need to include the whole lower bound.
# This may result in draws below the truncation range, if any uniform == 0
lower_value = lower_ - 1 if dist.owner.op.dtype.startswith("int") else lower_
cdf_lower_ = pt.exp(logcdf(rv_, lower_value))
cdf_upper_ = pt.exp(logcdf(rv_, upper_))
# It's okay to reuse the same rng here, because the rng in rv_ will not be
# used by either the logcdf of icdf functions
logcdf_lower_, logcdf_upper_ = Truncated._create_logcdf_exprs(rv_, rv_, lower_, upper_)
# We use the first RNG from the base RV, so we don't have to introduce a new one
# This is not problematic because the RNG won't be used in the RV logcdf graph
uniform_rng_ = next(inp_ for inp_ in rv_inputs_ if isinstance(inp_.type, RandomType))
uniform_next_rng_, uniform_ = pt.random.uniform(
cdf_lower_,
cdf_upper_,
rng=rng,
size=rv_inputs_[0],
pt.exp(logcdf_lower_),
pt.exp(logcdf_upper_),
rng=uniform_rng_,
size=rv_.shape,
).owner.outputs
truncated_rv_ = icdf(rv_, uniform_)
truncated_rv_ = icdf(rv_, uniform_, warn_rvs=False)
return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=[*graph_inputs_, rng],
outputs=[uniform_next_rng_, truncated_rv_],
inputs=graph_inputs_,
outputs=[truncated_rv_, uniform_next_rng_],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs, rng)
)(*graph_inputs)
except NotImplementedError:
pass

# Fallback to rejection sampling
def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
next_rng, new_truncated_rv = dist.owner.op.make_node(rng, *rv_inputs).outputs
# truncated_rv = zeros(rv.shape)
# reject_draws = ones(rv.shape, dtype=bool)
# while any(reject_draws):
# truncated_rv[reject_draws] = draw(rv)[reject_draws]
# reject_draws = (truncated_rv < lower) | (truncated_rv > upper)
def loop_fn(truncated_rv, reject_draws, lower, upper, *rv_inputs):
new_truncated_rv = dist.owner.op.make_node(*rv_inputs_).default_output()
# Avoid scalar boolean indexing
if truncated_rv.type.ndim == 0:
truncated_rv = new_truncated_rv
Expand All @@ -226,7 +252,7 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):

return (
(truncated_rv, reject_draws),
[(rng, next_rng)],
collect_default_updates(new_truncated_rv),
until(~pt.any(reject_draws)),
)

Expand All @@ -236,7 +262,7 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
pt.zeros_like(rv_),
pt.ones_like(rv_, dtype=bool),
],
non_sequences=[lower_, upper_, rng, *rv_inputs_],
non_sequences=[lower_, upper_, *rv_inputs_],
n_steps=max_n_steps,
strict=True,
)
Expand All @@ -246,24 +272,49 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
truncated_rv_ = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")(
truncated_rv_, convergence_
)
# Sort updates of each RNG so that they show in the same order as the input RNGs

def sort_updates(update):
rng, next_rng = update
return graph_inputs.index(rng)

next_rngs = [next_rng for rng, next_rng in sorted(updates.items(), key=sort_updates)]

[next_rng] = updates.values()
return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=[*graph_inputs_, rng],
outputs=[next_rng, truncated_rv_],
inputs=graph_inputs_,
outputs=[truncated_rv_, *next_rngs],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs, rng)
)(*graph_inputs)

@staticmethod
def _create_logcdf_exprs(
base_rv: TensorVariable,
value: TensorVariable,
lower: TensorVariable,
upper: TensorVariable,
) -> tuple[TensorVariable, TensorVariable]:
"""Create lower and upper logcdf expressions for base_rv.
Uses `value` as a template for broadcasting.
"""
# For left truncated discrete RVs, we need to include the whole lower bound.
lower_value = lower - 1 if base_rv.type.dtype.startswith("int") else lower
lower_value = pt.full_like(value, lower_value, dtype=config.floatX)
upper_value = pt.full_like(value, upper, dtype=config.floatX)
lower_logcdf = logcdf(base_rv, lower_value, warn_rvs=False)
upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value})
return lower_logcdf, upper_logcdf


@_change_dist_size.register(TruncatedRV)
def change_truncated_size(op, dist, new_size, expand):
*rv_inputs, lower, upper, rng = dist.owner.inputs
# Recreate the original untruncated RV
untruncated_rv = op.base_rv_op.make_node(rng, *rv_inputs).default_output()
def change_truncated_size(op: TruncatedRV, truncated_rv, new_size, expand):
*rv_inputs, lower, upper = truncated_rv.owner.inputs
untruncated_rv = op.base_rv_op.make_node(*rv_inputs).default_output()

if expand:
new_size = to_tuple(new_size) + tuple(dist.shape)
new_size = to_tuple(new_size) + tuple(truncated_rv.shape)

return Truncated.rv_op(
untruncated_rv,
Expand All @@ -275,11 +326,11 @@ def change_truncated_size(op, dist, new_size, expand):


@_support_point.register(TruncatedRV)
def truncated_support_point(op, rv, *inputs):
*rv_inputs, lower, upper, rng = inputs
def truncated_support_point(op: TruncatedRV, truncated_rv, *inputs):
*rv_inputs, lower, upper = inputs

# recreate untruncated rv and respective support_point
untruncated_rv = op.base_rv_op.make_node(rng, *rv_inputs).default_output()
untruncated_rv = op.base_rv_op.make_node(*rv_inputs).default_output()
untruncated_support_point = support_point(untruncated_rv)

fallback_support_point = pt.switch(
Expand All @@ -300,31 +351,25 @@ def truncated_support_point(op, rv, *inputs):


@_default_transform.register(TruncatedRV)
def truncated_default_transform(op, rv):
def truncated_default_transform(op, truncated_rv):
# Don't transform discrete truncated distributions
if op.base_rv_op.dtype.startswith("int"):
if truncated_rv.type.dtype.startswith("int"):
return None
# Lower and Upper are the arguments -3 and -2
return bounded_cont_transform(op, rv, bound_args_indices=(-3, -2))
# Lower and Upper are the arguments -2 and -1
return bounded_cont_transform(op, truncated_rv, bound_args_indices=(-2, -1))


@_logprob.register(TruncatedRV)
def truncated_logprob(op, values, *inputs, **kwargs):
(value,) = values

*rv_inputs, lower, upper, rng = inputs
rv_inputs = [rng, *rv_inputs]
*rv_inputs, lower, upper = inputs

base_rv_op = op.base_rv_op
logp = _logprob(base_rv_op, (value,), *rv_inputs, **kwargs)
# For left truncated RVs, we don't want to include the lower bound in the
# normalization term
lower_value = lower - 1 if base_rv_op.dtype.startswith("int") else lower
lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs)
upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs)

base_rv = base_rv_op.make_node(*rv_inputs).default_output()
base_logp = logp(base_rv, value)
lower_logcdf, upper_logcdf = Truncated._create_logcdf_exprs(base_rv, value, lower, upper)
if base_rv_op.name:
logp.name = f"{base_rv_op}_logprob"
base_logp.name = f"{base_rv_op}_logprob"
lower_logcdf.name = f"{base_rv_op}_lower_logcdf"
upper_logcdf.name = f"{base_rv_op}_upper_logcdf"

Expand All @@ -339,37 +384,31 @@ def truncated_logprob(op, values, *inputs, **kwargs):
elif is_upper_bounded:
lognorm = upper_logcdf

logp = logp - lognorm
truncated_logp = base_logp - lognorm

if is_lower_bounded:
logp = pt.switch(value < lower, -np.inf, logp)
truncated_logp = pt.switch(value < lower, -np.inf, truncated_logp)

if is_upper_bounded:
logp = pt.switch(value <= upper, logp, -np.inf)
truncated_logp = pt.switch(value <= upper, truncated_logp, -np.inf)

if is_lower_bounded and is_upper_bounded:
logp = check_parameters(
logp,
truncated_logp = check_parameters(
truncated_logp,
pt.le(lower, upper),
msg="lower_bound <= upper_bound",
)

return logp
return truncated_logp


@_logcdf.register(TruncatedRV)
def truncated_logcdf(op, value, *inputs, **kwargs):
*rv_inputs, lower, upper, rng = inputs
rv_inputs = [rng, *rv_inputs]

base_rv_op = op.base_rv_op
logcdf = _logcdf(base_rv_op, value, *rv_inputs, **kwargs)
def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs):
*rv_inputs, lower, upper = inputs

# For left truncated discrete RVs, we don't want to include the lower bound in the
# normalization term
lower_value = lower - 1 if base_rv_op.dtype.startswith("int") else lower
lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs)
upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs)
base_rv = op.base_rv_op.make_node(*rv_inputs).default_output()
base_logcdf = logcdf(base_rv, value)
lower_logcdf, upper_logcdf = Truncated._create_logcdf_exprs(base_rv, value, lower, upper)

is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value)))
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))
Expand All @@ -382,7 +421,7 @@ def truncated_logcdf(op, value, *inputs, **kwargs):
elif is_upper_bounded:
lognorm = upper_logcdf

logcdf_numerator = logdiffexp(logcdf, lower_logcdf) if is_lower_bounded else logcdf
logcdf_numerator = logdiffexp(base_logcdf, lower_logcdf) if is_lower_bounded else base_logcdf
logcdf_trunc = logcdf_numerator - lognorm

if is_lower_bounded:
Expand Down
12 changes: 6 additions & 6 deletions tests/distributions/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,8 +1588,8 @@ def test_hurdle_negativebinomial_graph(self):
_, nonzero_dist = self.check_hurdle_mixture_graph(dist)

assert isinstance(nonzero_dist.owner.op.base_rv_op, NegativeBinomial)
assert nonzero_dist.owner.inputs[2].data == n
assert nonzero_dist.owner.inputs[3].data == p
assert nonzero_dist.owner.inputs[-4].data == n
assert nonzero_dist.owner.inputs[-3].data == p

def test_hurdle_gamma_graph(self):
psi, alpha, beta = 0.25, 3, 4
Expand All @@ -1599,17 +1599,17 @@ def test_hurdle_gamma_graph(self):
# Under the hood it uses the shape-scale parametrization of the Gamma distribution.
# So the second value is the reciprocal of the rate (i.e. 1 / beta)
assert isinstance(nonzero_dist.owner.op.base_rv_op, Gamma)
assert nonzero_dist.owner.inputs[2].data == alpha
assert nonzero_dist.owner.inputs[3].eval() == 1 / beta
assert nonzero_dist.owner.inputs[-4].data == alpha
assert nonzero_dist.owner.inputs[-3].eval() == 1 / beta

def test_hurdle_lognormal_graph(self):
psi, mu, sigma = 0.1, 2, 2.5
dist = HurdleLogNormal.dist(psi=psi, mu=mu, sigma=sigma)
_, nonzero_dist = self.check_hurdle_mixture_graph(dist)

assert isinstance(nonzero_dist.owner.op.base_rv_op, LogNormal)
assert nonzero_dist.owner.inputs[2].data == mu
assert nonzero_dist.owner.inputs[3].data == sigma
assert nonzero_dist.owner.inputs[-4].data == mu
assert nonzero_dist.owner.inputs[-3].data == sigma

@pytest.mark.parametrize(
"dist, psi, non_psi_args",
Expand Down
Loading

0 comments on commit 92efe38

Please sign in to comment.