Skip to content

Commit

Permalink
Do not set RNG updates inplace in compile_pymc
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Mar 18, 2022
1 parent 80f8195 commit 8b063f9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
16 changes: 12 additions & 4 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
Apply,
Constant,
Variable,
ancestors,
clone_get_equiv,
graph_inputs,
walk,
Expand Down Expand Up @@ -963,18 +964,19 @@ def compile_pymc(
this function is called within a model context and the model `check_bounds` flag
is set to False.
"""
# Set the default update of RandomVariable's RNG so that it is automatically
# Create an update mapping of RandomVariable's RNG so that it is automatically
# updated after every function call
# TODO: This won't work for variables with InnerGraphs (Scan and OpFromGraph)
rng_updates = {}
output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs]
for rv in (
node
for node in walk_model(output_to_list, walk_past_rvs=True)
for node in ancestors(output_to_list)
if node.owner and isinstance(node.owner.op, RandomVariable)
):
rng = rv.owner.inputs[0]
if not hasattr(rng, "default_update"):
rng.default_update = rv.owner.outputs[0]
rng_updates[rng] = rv.owner.outputs[0]

# If called inside a model context, see if check_bounds flag is set to False
try:
Expand All @@ -991,5 +993,11 @@ def compile_pymc(
mode = get_mode(mode)
opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
mode = Mode(linker=mode.linker, optimizer=opt_qry)
aesara_function = aesara.function(inputs, outputs, mode=mode, **kwargs)
aesara_function = aesara.function(
inputs,
outputs,
updates={**rng_updates, **kwargs.pop("updates", {})},
mode=mode,
**kwargs,
)
return aesara_function
14 changes: 13 additions & 1 deletion pymc/tests/test_aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,9 +576,21 @@ def test_check_bounds_flag():
assert np.all(compile_pymc([], bound)() == -np.inf)


def test_compile_pymc_sets_default_updates():
def test_compile_pymc_sets_rng_updates():
rng = aesara.shared(np.random.default_rng(0))
x = pm.Normal.dist(rng=rng)
assert x.owner.inputs[0] is rng
f = compile_pymc([], x)
assert not np.isclose(f(), f())

# Check that update was not done inplace
assert not hasattr(rng, "default_update")
f = aesara.function([], x)
assert f() == f()


def test_compile_pymc_with_updates():
x = aesara.shared(0)
f = compile_pymc([], x, updates={x: x + 1})
assert f() == 0
assert f() == 1

0 comments on commit 8b063f9

Please sign in to comment.