diff --git a/aeppl/truncation.py b/aeppl/truncation.py index d4415462..830bf490 100644 --- a/aeppl/truncation.py +++ b/aeppl/truncation.py @@ -1,14 +1,15 @@ from functools import singledispatch -from typing import Tuple +from typing import Optional, Tuple import aesara.tensor as at import aesara.tensor.random.basic as arb import numpy as np -from aesara import scan, shared +from aesara import scan from aesara.compile.builders import OpFromGraph from aesara.graph.op import Op from aesara.raise_op import CheckAndRaise from aesara.scan import until +from aesara.tensor.random import RandomStream from aesara.tensor.random.op import RandomVariable from aesara.tensor.var import TensorConstant, TensorVariable @@ -68,7 +69,11 @@ def __str__(self): def truncate( - rv: TensorVariable, lower=None, upper=None, max_n_steps: int = 10_000, rng=None + rv: TensorVariable, + lower=None, + upper=None, + max_n_steps: int = 10_000, + srng: Optional[RandomStream] = None, ) -> Tuple[TensorVariable, Tuple[TensorVariable, TensorVariable]]: """Truncate a univariate `RandomVariable` between `lower` and `upper`. @@ -99,13 +104,13 @@ def truncate( lower = at.as_tensor_variable(lower) if lower is not None else at.constant(-np.inf) upper = at.as_tensor_variable(upper) if upper is not None else at.constant(np.inf) - if rng is None: - rng = shared(np.random.RandomState(), borrow=True) + if srng is None: + srng = RandomStream() # Try to use specialized Op try: truncated_rv, updates = _truncated( - rv.owner.op, lower, upper, rng, *rv.owner.inputs[1:] + rv.owner.op, lower, upper, srng, *rv.owner.inputs[1:] ) return truncated_rv, updates except NotImplementedError: @@ -116,8 +121,8 @@ def truncate( # though it would not be necessary for the icdf OpFromGraph graph_inputs = [*rv.owner.inputs[1:], lower, upper] graph_inputs_ = [inp.type() for inp in graph_inputs] - *rv_inputs_, lower_, upper_ = graph_inputs_ - rv_ = rv.owner.op.make_node(rng, *rv_inputs_).default_output() + size_, dtype_, *rv_inputs_, lower_, upper_ = graph_inputs_ + rv_ = srng.gen(rv.owner.op, *rv_inputs_, size=size_, dtype=dtype_) # Try to use inverted cdf sampling try: @@ -126,11 +131,10 @@ def truncate( lower_value = lower_ - 1 if rv.owner.op.dtype.startswith("int") else lower_ cdf_lower_ = at.exp(logcdf(rv_, lower_value)) cdf_upper_ = at.exp(logcdf(rv_, upper_)) - uniform_ = at.random.uniform( + uniform_ = srng.uniform( cdf_lower_, cdf_upper_, - rng=rng, - size=rv_inputs_[0], + size=size_, ) truncated_rv_ = icdf(rv_, uniform_) truncated_rv = TruncatedRV( @@ -146,19 +150,15 @@ def truncate( # Fallback to rejection sampling # TODO: Handle potential broadcast by lower / upper - def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs): - next_rng, new_truncated_rv = rv.owner.op.make_node(rng, *rv_inputs).outputs + def loop_fn(truncated_rv, reject_draws, lower, upper, size, dtype, *rv_inputs): + new_truncated_rv = srng.gen(rv.owner.op, *rv_inputs, size=size, dtype=dtype) # type: ignore truncated_rv = at.set_subtensor( truncated_rv[reject_draws], new_truncated_rv[reject_draws], ) reject_draws = at.or_((truncated_rv < lower), (truncated_rv > upper)) - return ( - (truncated_rv, reject_draws), - [(rng, next_rng)], - until(~at.any(reject_draws)), - ) + return (truncated_rv, reject_draws), until(~at.any(reject_draws)) (truncated_rv_, reject_draws_), updates = scan( loop_fn, @@ -166,7 +166,7 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs): at.zeros_like(rv_), at.ones_like(rv_, dtype=bool), ], - non_sequences=[lower_, upper_, rng, *rv_inputs_], + non_sequences=[lower_, upper_, size_, dtype_, *rv_inputs_], n_steps=max_n_steps, strict=True, ) @@ -180,10 +180,16 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs): truncated_rv = TruncatedRV( base_rv_op=rv.owner.op, inputs=graph_inputs_, - outputs=[truncated_rv_, tuple(updates.values())[0]], + # This will fail with `n_steps==1`, because in that case `Scan` won't return any updates + outputs=[truncated_rv_, rv_.owner.outputs[0], tuple(updates.values())[0]], inline=True, )(*graph_inputs) - updates = {truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1]} + # TODO: Is the order of multiple shared variables determnistic? + assert truncated_rv.owner.inputs[-2] is rv_.owner.inputs[0] + updates = { + truncated_rv.owner.inputs[-2]: truncated_rv.owner.outputs[-2], + truncated_rv.owner.inputs[-1]: truncated_rv.owner.outputs[-1], + } return truncated_rv, updates @@ -191,7 +197,11 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs): def truncated_logprob(op, values, *inputs, **kwargs): (value,) = values - *rv_inputs, lower_bound, upper_bound, rng = inputs + # Rejection sample graph has two rngs + if len(op.shared_inputs) == 2: + *rv_inputs, lower_bound, upper_bound, _, rng = inputs + else: + *rv_inputs, lower_bound, upper_bound, rng = inputs rv_inputs = [rng, *rv_inputs] base_rv_op = op.base_rv_op @@ -242,11 +252,11 @@ def truncated_logprob(op, values, *inputs, **kwargs): @_truncated.register(arb.UniformRV) -def uniform_truncated(op, lower, upper, rng, size, dtype, lower_orig, upper_orig): - truncated_uniform = at.random.uniform( +def uniform_truncated(op, lower, upper, srng, size, dtype, lower_orig, upper_orig): + truncated_uniform = srng.gen( + op, at.max((lower_orig, lower)), at.min((upper_orig, upper)), - rng=rng, size=size, dtype=dtype, ) diff --git a/tests/test_truncation.py b/tests/test_truncation.py index cf08928a..7c64b2bf 100644 --- a/tests/test_truncation.py +++ b/tests/test_truncation.py @@ -50,10 +50,10 @@ def _icdf_not_implemented(*args, **kwargs): def test_truncation_specialized_op(): x = at.random.uniform(0, 10, name="x", size=100) - rng = aesara.shared(np.random.RandomState()) - xt, _ = truncate(x, lower=5, upper=15, rng=rng) + srng = at.random.RandomStream() + xt, _ = truncate(x, lower=5, upper=15, srng=srng) assert isinstance(xt.owner.op, UniformRV) - assert xt.owner.inputs[0] is rng + assert xt.owner.inputs[0] is srng.updates()[0][0] lower_upper = at.stack(xt.owner.inputs[3:]) assert np.all(lower_upper.eval() == [5, 10]) @@ -68,10 +68,10 @@ def test_truncation_continuous_random(op_type, lower, upper): normal_op = icdf_normal if op_type == "icdf" else rejection_normal x = normal_op(loc, scale, name="x", size=100) - rng = aesara.shared(np.random.RandomState()) - xt, xt_update = truncate(x, lower=lower, upper=upper, rng=rng) + srng = at.random.RandomStream() + xt, xt_update = truncate(x, lower=lower, upper=upper, srng=srng) assert isinstance(xt.owner.op, TruncatedRV) - assert xt.owner.inputs[-1] is rng + assert xt.owner.inputs[-1] is srng.updates()[1 if op_type == "icdf" else 2][0] assert xt.type.dtype == x.type.dtype assert xt.type.ndim == x.type.ndim @@ -94,7 +94,7 @@ def test_truncation_continuous_random(op_type, lower, upper): assert scipy.stats.cramervonmises(xt_draws.ravel(), ref_xt.cdf).pvalue > 0.001 # Test max_n_steps - xt, xt_update = truncate(x, lower=lower, upper=upper, max_n_steps=1) + xt, xt_update = truncate(x, lower=lower, upper=upper, max_n_steps=2) xt_fn = aesara.function([], xt, updates=xt_update) if op_type == "icdf": xt_draws = xt_fn()