From 7aa520cf747abbac37af9dfeefa31ca1ffafad69 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Wed, 23 Mar 2022 12:59:31 +0100 Subject: [PATCH] Add test for interactions between missing, default and explicit updates in `compile_pymc` --- pymc/tests/test_aesaraf.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/pymc/tests/test_aesaraf.py b/pymc/tests/test_aesaraf.py index 74622a16a3..c5b54b0eaf 100644 --- a/pymc/tests/test_aesaraf.py +++ b/pymc/tests/test_aesaraf.py @@ -604,3 +604,26 @@ def test_compile_pymc_with_updates(): f = compile_pymc([], x, updates={x: x + 1}) assert f() == 0 assert f() == 1 + + +def test_compile_pymc_missing_default_explicit_updates(): + rng = aesara.shared(np.random.default_rng(0)) + x = pm.Normal.dist(rng=rng) + + # By default, compile_pymc should update the rng of x + f = compile_pymc([], x) + assert f() != f() + + # An explicit update should override the default_update, like aesara.function does + # For testing purposes, we use an update that leaves the rng unchanged + f = compile_pymc([], x, updates={rng: rng}) + assert f() == f() + + # If we specify a custom default_update directly it should use that instead. + rng.default_update = rng + f = compile_pymc([], x) + assert f() == f() + + # And again, it should be overridden by an explicit update + f = compile_pymc([], x, updates={rng: x.owner.outputs[0]}) + assert f() != f()