Skip to content

Commit

Permalink
Remove unused/deprecated variable attributes added by RandomStream
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Apr 29, 2022
1 parent b6b7aef commit a933bb3
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 19 deletions.
3 changes: 2 additions & 1 deletion aesara/compile/nanguardmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ def _is_numeric_value(arr, var):
"""
from aesara.link.c.type import _cdata_type
from aesara.tensor.random.type import RandomType

if isinstance(arr, _cdata_type):
return False
elif isinstance(arr, (np.random.mtrand.RandomState, np.random.Generator)):
return False
elif var and getattr(var.tag, "is_rng", False):
elif var and isinstance(var.type, RandomType):
return False
elif isinstance(arr, slice):
return False
Expand Down
2 changes: 0 additions & 2 deletions aesara/sandbox/rng_mrg.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,8 +926,6 @@ def uniform(
size=size,
nstreams=orig_nstreams,
)
# Add a reference to distinguish from other shared variables
node_rstate.tag.is_rng = True
r = u * (high - low) + low

if u.type.broadcastable != r.type.broadcastable:
Expand Down
21 changes: 10 additions & 11 deletions aesara/tensor/random/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,20 +251,19 @@ def gen(self, op, *args, **kwargs):

# Generate a new random state
seed = int(self.gen_seedgen.integers(2**30))
random_state_variable = shared(self.rng_ctor(seed))

# Distinguish it from other shared variables (why?)
random_state_variable.tag.is_rng = True
rng = shared(self.rng_ctor(seed), borrow=True)

# Generate the sample
out = op(*args, **kwargs, rng=random_state_variable)
out.rng = random_state_variable
out = op(*args, **kwargs, rng=rng)

# This is the value that should be used to replace the old state
# (i.e. `rng`) after `out` is sampled/evaluated.
# The updates mechanism in `aesara.function` is supposed to perform
# this replace action.
new_rng = out.owner.outputs[0]

# Update the tracked states
new_r = out.owner.outputs[0]
out.update = (random_state_variable, new_r)
self.state_updates.append(out.update)
self.state_updates.append((rng, new_rng))

random_state_variable.default_update = new_r
rng.default_update = new_rng

return out
8 changes: 3 additions & 5 deletions tests/tensor/random/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def test_tutorial(self):
assert np.all(f() != f())
assert np.all(g() == g())
assert np.all(abs(nearly_zeros()) < 1e-5)
assert isinstance(rv_u.rng.get_value(borrow=True), np.random.Generator)

@pytest.mark.parametrize("rng_ctor", [np.random.RandomState, np.random.default_rng])
def test_basics(self, rng_ctor):
Expand All @@ -109,8 +108,7 @@ def test_basics(self, rng_ctor):
with pytest.raises(AttributeError):
random.blah

# test if standard_normal is available in the namespace, See: GH issue #528
random.standard_normal
assert hasattr(random, "standard_normal")

with pytest.raises(AttributeError):
np_random = RandomStream(namespace=np, rng_ctor=rng_ctor)
Expand Down Expand Up @@ -223,7 +221,7 @@ def test_default_updates(self, rng_ctor):
# Explicit updates #2
random_c = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
out_c = random_c.uniform(0, 1, size=(2, 2))
fn_c = function([], out_c, updates=[out_c.update])
fn_c = function([], out_c, updates=random_c.state_updates)
fn_c_val0 = fn_c()
fn_c_val1 = fn_c()
assert np.all(fn_c_val0 == fn_a_val0)
Expand All @@ -241,7 +239,7 @@ def test_default_updates(self, rng_ctor):
# No updates for out
random_e = RandomStream(utt.fetch_seed(), rng_ctor=rng_ctor)
out_e = random_e.uniform(0, 1, size=(2, 2))
fn_e = function([], out_e, no_default_updates=[out_e.rng])
fn_e = function([], out_e, no_default_updates=[random_e.state_updates[0][0]])
fn_e_val0 = fn_e()
fn_e_val1 = fn_e()
assert np.all(fn_e_val0 == fn_a_val0)
Expand Down

0 comments on commit a933bb3

Please sign in to comment.