Skip to content

Commit

Permalink
Manage SharedVariables explicitly in SymbolicRandomVariable
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Apr 8, 2024
1 parent f0390d4 commit 1195261
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 41 deletions.
35 changes: 23 additions & 12 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def __init__(
raise ValueError("ndim_supp or gufunc_signature must be provided")

kwargs.setdefault("inline", True)
kwargs.setdefault("strict", True)
super().__init__(*args, **kwargs)

def update(self, node: Node) -> dict[Variable, Variable]:
Expand Down Expand Up @@ -702,7 +703,7 @@ class CustomSymbolicDistRV(SymbolicRandomVariable):
symbolic random methods.
"""

default_output = -1
default_output = 0

_print_name = ("CustomSymbolicDist", "\\operatorname{CustomSymbolicDist}")

Expand Down Expand Up @@ -805,14 +806,17 @@ def rv_op(
if logp is not None:

@_logprob.register(rv_type)
def custom_dist_logp(op, values, size, *params, **kwargs):
return logp(values[0], *params[: len(dist_params)])
def custom_dist_logp(op, values, size, *inputs, **kwargs):
[value] = values
rv_params = inputs[: len(dist_params)]
return logp(value, *rv_params)

if logcdf is not None:

@_logcdf.register(rv_type)
def custom_dist_logcdf(op, value, size, *params, **kwargs):
return logcdf(value, *params[: len(dist_params)])
def custom_dist_logcdf(op, value, size, *inputs, **kwargs):
rv_params = inputs[: len(dist_params)]
return logcdf(value, *rv_params)

if support_point is not None:

Expand Down Expand Up @@ -845,22 +849,29 @@ def change_custom_symbolic_dist_size(op, rv, new_size, expand):
dummy_dist_params = [dist_param.type() for dist_param in old_dist_params]
dummy_rv = dist(*dummy_dist_params, dummy_size_param)
dummy_params = [dummy_size_param, *dummy_dist_params]
dummy_updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
rngs = updates_dict.keys()
rngs_updates = updates_dict.values()
new_rv_op = rv_type(
inputs=dummy_params,
outputs=[*dummy_updates_dict.values(), dummy_rv],
inputs=[*dummy_params, *rngs],
outputs=[dummy_rv, *rngs_updates],
signature=signature,
)
new_rv = new_rv_op(new_size, *dist_params)
new_rv = new_rv_op(new_size, *dist_params, *rngs)

return new_rv

# RNGs are not passed as explicit inputs (because we usually don't know how many are needed)
# We retrieve them here
updates_dict = collect_default_updates(inputs=dummy_params, outputs=(dummy_rv,))
rngs = updates_dict.keys()
rngs_updates = updates_dict.values()
rv_op = rv_type(
inputs=dummy_params,
outputs=[*dummy_updates_dict.values(), dummy_rv],
inputs=[*dummy_params, *rngs],
outputs=[dummy_rv, *rngs_updates],
signature=signature,
)
return rv_op(size, *dist_params)
return rv_op(size, *dist_params, *rngs)

@staticmethod
def _infer_final_signature(signature: str, n_inputs, n_updates) -> str:
Expand Down
20 changes: 8 additions & 12 deletions pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,6 @@ def __init__(self, *args, ar_order, constant_term, **kwargs):

def update(self, node: Node):
"""Return the update mapping for the noise RV."""
# Since noise is a shared variable it shows up as the last node input
return {node.inputs[-1]: node.outputs[0]}


Expand Down Expand Up @@ -658,13 +657,13 @@ def step(*args):
ar_ = pt.concatenate([init_, innov_.T], axis=-1)

ar_op = AutoRegressiveRV(
inputs=[rhos_, sigma_, init_, steps_],
inputs=[rhos_, sigma_, init_, steps_, noise_rng],
outputs=[noise_next_rng, ar_],
ar_order=ar_order,
constant_term=constant_term,
)

ar = ar_op(rhos, sigma, init_dist, steps)
ar = ar_op(rhos, sigma, init_dist, steps, noise_rng)
return ar


Expand Down Expand Up @@ -731,7 +730,6 @@ class GARCH11RV(SymbolicRandomVariable):

def update(self, node: Node):
"""Return the update mapping for the noise RV."""
# Since noise is a shared variable it shows up as the last node input
return {node.inputs[-1]: node.outputs[0]}


Expand Down Expand Up @@ -797,7 +795,6 @@ def rv_op(cls, omega, alpha_1, beta_1, initial_vol, init_dist, steps, size=None)
# In this case the size of the init_dist depends on the parameters shape
batch_size = pt.broadcast_shape(omega, alpha_1, beta_1, initial_vol)
init_dist = change_dist_size(init_dist, batch_size)
# initial_vol = initial_vol * pt.ones(batch_size)

# Create OpFromGraph representing random draws from GARCH11 process
# Variables with underscore suffix are dummy inputs into the OpFromGraph
Expand All @@ -819,7 +816,7 @@ def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng):

(y_t, _), innov_updates_ = pytensor.scan(
fn=step,
outputs_info=[init_, initial_vol_ * pt.ones(batch_size)],
outputs_info=[init_, pt.broadcast_to(initial_vol_.astype("floatX"), init_.shape)],
non_sequences=[omega_, alpha_1_, beta_1_, noise_rng],
n_steps=steps_,
strict=True,
Expand All @@ -831,11 +828,11 @@ def step(prev_y, prev_sigma, omega, alpha_1, beta_1, rng):
)

garch11_op = GARCH11RV(
inputs=[omega_, alpha_1_, beta_1_, initial_vol_, init_, steps_],
inputs=[omega_, alpha_1_, beta_1_, initial_vol_, init_, steps_, noise_rng],
outputs=[noise_next_rng, garch11_],
)

garch11 = garch11_op(omega, alpha_1, beta_1, initial_vol, init_dist, steps)
garch11 = garch11_op(omega, alpha_1, beta_1, initial_vol, init_dist, steps, noise_rng)
return garch11


Expand Down Expand Up @@ -891,14 +888,13 @@ class EulerMaruyamaRV(SymbolicRandomVariable):
ndim_supp = 1
_print_name = ("EulerMaruyama", "\\operatorname{EulerMaruyama}")

def __init__(self, *args, dt, sde_fn, **kwargs):
def __init__(self, *args, dt: float, sde_fn: Callable, **kwargs):
self.dt = dt
self.sde_fn = sde_fn
super().__init__(*args, **kwargs)

def update(self, node: Node):
"""Return the update mapping for the noise RV."""
# Since noise is a shared variable it shows up as the last node input
return {node.inputs[-1]: node.outputs[0]}


Expand Down Expand Up @@ -1010,14 +1006,14 @@ def step(*prev_args):
)

eulermaruyama_op = EulerMaruyamaRV(
inputs=[init_, steps_, *sde_pars_],
inputs=[init_, steps_, *sde_pars_, noise_rng],
outputs=[noise_next_rng, sde_out_],
dt=dt,
sde_fn=sde_fn,
signature=f"(),(s),{','.join('()' for _ in sde_pars_)}->(),(t)",
)

eulermaruyama = eulermaruyama_op(init_dist, steps, *sde_pars)
eulermaruyama = eulermaruyama_op(init_dist, steps, *sde_pars, noise_rng)
return eulermaruyama


Expand Down
20 changes: 10 additions & 10 deletions pymc/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def __init__(self, *args, base_rv_op: Op, max_n_steps: int, **kwargs):
super().__init__(*args, **kwargs)

def update(self, node: Node):
"""Return the update mapping for the noise RV."""
# Since RNG is a shared variable it shows up as the last node input
"""Return the update mapping for the internal RNG."""
return {node.inputs[-1]: node.outputs[0]}


Expand Down Expand Up @@ -195,20 +194,20 @@ def rv_op(cls, dist, lower, upper, max_n_steps, size=None):
cdf_upper_ = pt.exp(logcdf(rv_, upper_))
# It's okay to reuse the same rng here, because the rng in rv_ will not be
# used by either the logcdf of icdf functions
uniform_ = pt.random.uniform(
uniform_next_rng_, uniform_ = pt.random.uniform(
cdf_lower_,
cdf_upper_,
rng=rng,
size=rv_inputs_[0],
)
).owner.outputs
truncated_rv_ = icdf(rv_, uniform_)
return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=graph_inputs_,
outputs=[uniform_.owner.outputs[0], truncated_rv_],
inputs=[*graph_inputs_, rng],
outputs=[uniform_next_rng_, truncated_rv_],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs)
)(*graph_inputs, rng)
except NotImplementedError:
pass

Expand Down Expand Up @@ -248,13 +247,14 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
truncated_rv_, convergence_
)

[next_rng] = updates.values()
return TruncatedRV(
base_rv_op=dist.owner.op,
inputs=graph_inputs_,
outputs=[next(iter(updates.values())), truncated_rv_],
inputs=[*graph_inputs_, rng],
outputs=[next_rng, truncated_rv_],
ndim_supp=0,
max_n_steps=max_n_steps,
)(*graph_inputs)
)(*graph_inputs, rng)


@_change_dist_size.register(TruncatedRV)
Expand Down
2 changes: 1 addition & 1 deletion pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,7 @@ def collect_default_updates_inner_fgraph(node: Node) -> dict[Variable, Variable]


def collect_default_updates(
outputs: Sequence[Variable],
outputs: Variable | Sequence[Variable],
*,
inputs: Sequence[Variable] | None = None,
must_be_shared: bool = True,
Expand Down
38 changes: 32 additions & 6 deletions tests/distributions/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,14 @@
CustomDist,
CustomDistRV,
CustomSymbolicDistRV,
DiracDelta,
PartialObservedRV,
SymbolicRandomVariable,
_support_point,
create_partial_observed_rv,
support_point,
)
from pymc.distributions.shape_utils import change_dist_size, rv_size_is_none, to_tuple
from pymc.distributions.shape_utils import change_dist_size, to_tuple
from pymc.distributions.transforms import log
from pymc.exceptions import BlockModelAccessError
from pymc.logprob.basic import conditional_logp, logcdf, logp
Expand Down Expand Up @@ -584,9 +585,7 @@ def custom_dist(p, sigma, size):

def test_custom_methods(self):
def custom_dist(mu, size):
if rv_size_is_none(size):
return mu
return pt.full(size, mu)
return DiracDelta.dist(mu, size=size)

def custom_support_point(rv, size, mu):
return pt.full_like(rv, mu + 1)
Expand Down Expand Up @@ -778,7 +777,8 @@ def test_inline(self):
class TestSymbolicRV(SymbolicRandomVariable):
pass

x = TestSymbolicRV([], [Flat.dist()], ndim_supp=0)()
rng = pytensor.shared(np.random.default_rng())
x = TestSymbolicRV([rng], [Flat.dist(rng=rng)], ndim_supp=0)(rng)

# By default, the SymbolicRandomVariable will not be inlined. Because we did not
# dispatch a custom logprob function it will raise next
Expand All @@ -788,7 +788,7 @@ class TestSymbolicRV(SymbolicRandomVariable):
class TestInlinedSymbolicRV(SymbolicRandomVariable):
inline_logprob = True

x_inline = TestInlinedSymbolicRV([], [Flat.dist()], ndim_supp=0)()
x_inline = TestInlinedSymbolicRV([rng], [Flat.dist(rng=rng)], ndim_supp=0)(rng)
assert np.isclose(logp(x_inline, 0).eval(), 0)

def test_default_update(self):
Expand Down Expand Up @@ -826,6 +826,32 @@ def update(self, node):
):
compile_pymc(inputs=[], outputs=x, random_seed=431)

def test_recreate_with_different_rng_inputs(self):
"""Test that we can recreate a SymbolicRandomVariable with new RNG inputs.
Related to https://github.com/pymc-devs/pytensor/issues/473
"""
rng = pytensor.shared(np.random.default_rng())

dummy_rng = rng.type()
dummy_next_rng, dummy_x = pt.random.normal(rng=dummy_rng).owner.outputs

op = SymbolicRandomVariable(
[dummy_rng],
[dummy_next_rng, dummy_x],
ndim_supp=0,
)

next_rng, x = op(rng)
assert op.update(x.owner) == {rng: next_rng}

new_rng = pytensor.shared(np.random.default_rng())
inputs = x.owner.inputs.copy()
inputs[0] = new_rng
# This would fail with the default OpFromGraph.__call__()
new_next_rng, new_x = x.owner.op(*inputs)
assert op.update(new_x.owner) == {new_rng: new_next_rng}


def test_tag_future_warning_dist():
# Test no unexpected warnings
Expand Down

0 comments on commit 1195261

Please sign in to comment.