From a933bb3a0b88e47721fee75320be4023ca4932c5 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Thu, 28 Apr 2022 21:32:07 -0500 Subject: [PATCH] Remove unused/deprecated variable attributes added by RandomStream --- aesara/compile/nanguardmode.py | 3 ++- aesara/sandbox/rng_mrg.py | 2 -- aesara/tensor/random/utils.py | 21 ++++++++++----------- tests/tensor/random/test_utils.py | 8 +++----- 4 files changed, 15 insertions(+), 19 deletions(-) diff --git a/aesara/compile/nanguardmode.py b/aesara/compile/nanguardmode.py index 06e85a2cfd..3e0cbf719a 100644 --- a/aesara/compile/nanguardmode.py +++ b/aesara/compile/nanguardmode.py @@ -30,12 +30,13 @@ def _is_numeric_value(arr, var): """ from aesara.link.c.type import _cdata_type + from aesara.tensor.random.type import RandomType if isinstance(arr, _cdata_type): return False elif isinstance(arr, (np.random.mtrand.RandomState, np.random.Generator)): return False - elif var and getattr(var.tag, "is_rng", False): + elif var and isinstance(var.type, RandomType): return False elif isinstance(arr, slice): return False diff --git a/aesara/sandbox/rng_mrg.py b/aesara/sandbox/rng_mrg.py index b660a0ee6b..0c4a290971 100644 --- a/aesara/sandbox/rng_mrg.py +++ b/aesara/sandbox/rng_mrg.py @@ -926,8 +926,6 @@ def uniform( size=size, nstreams=orig_nstreams, ) - # Add a reference to distinguish from other shared variables - node_rstate.tag.is_rng = True r = u * (high - low) + low if u.type.broadcastable != r.type.broadcastable: diff --git a/aesara/tensor/random/utils.py b/aesara/tensor/random/utils.py index 0e4aa840b7..0e06741f3e 100644 --- a/aesara/tensor/random/utils.py +++ b/aesara/tensor/random/utils.py @@ -251,20 +251,19 @@ def gen(self, op, *args, **kwargs): # Generate a new random state seed = int(self.gen_seedgen.integers(2**30)) - random_state_variable = shared(self.rng_ctor(seed)) - - # Distinguish it from other shared variables (why?) - random_state_variable.tag.is_rng = True + rng = shared(self.rng_ctor(seed), borrow=True) # Generate the sample - out = op(*args, **kwargs, rng=random_state_variable) - out.rng = random_state_variable + out = op(*args, **kwargs, rng=rng) + + # This is the value that should be used to replace the old state + # (i.e. `rng`) after `out` is sampled/evaluated. + # The updates mechanism in `aesara.function` is supposed to perform + # this replace action. + new_rng = out.owner.outputs[0] - # Update the tracked states - new_r = out.owner.outputs[0] - out.update = (random_state_variable, new_r) - self.state_updates.append(out.update) + self.state_updates.append((rng, new_rng)) - random_state_variable.default_update = new_r + rng.default_update = new_rng return out diff --git a/tests/tensor/random/test_utils.py b/tests/tensor/random/test_utils.py index 763d204889..8347ab6659 100644 --- a/tests/tensor/random/test_utils.py +++ b/tests/tensor/random/test_utils.py @@ -97,7 +97,6 @@ def test_tutorial(self): assert np.all(f() != f()) assert np.all(g() == g()) assert np.all(abs(nearly_zeros()) < 1e-5) - assert isinstance(rv_u.rng.get_value(borrow=True), np.random.Generator) @pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng]) def test_basics(self, rng_ctor): @@ -109,8 +108,7 @@ def test_basics(self, rng_ctor): with pytest.raises(AttributeError): random.blah - # test if standard_normal is available in the namespace, See: GH issue #528 - random.standard_normal + assert hasattr(random, "standard_normal") with pytest.raises(AttributeError): np_random = RandomStream(namespace=np, rng_ctor=rng_ctor) @@ -223,7 +221,7 @@ def test_default_updates(self, rng_ctor): # Explicit updates #2 random_c = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor) out_c = random_c.uniform(0, 1, size=(2, 2)) - fn_c = function([], out_c, updates=[out_c.update]) + fn_c = function([], out_c, updates=random_c.state_updates) fn_c_val0 = fn_c() fn_c_val1 = fn_c() assert np.all(fn_c_val0 == fn_a_val0) @@ -241,7 +239,7 @@ def test_default_updates(self, rng_ctor): # No updates for out random_e = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor) out_e = random_e.uniform(0, 1, size=(2, 2)) - fn_e = function([], out_e, no_default_updates=[out_e.rng]) + fn_e = function([], out_e, no_default_updates=[random_e.state_updates[0][0]]) fn_e_val0 = fn_e() fn_e_val1 = fn_e() assert np.all(fn_e_val0 == fn_a_val0)