Skip to content

Commit

Permalink
Use SeedSequence to seed RNG states in RandomStream
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Apr 29, 2022
1 parent a933bb3 commit b6dc523
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 44 deletions.
51 changes: 35 additions & 16 deletions aesara/tensor/random/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from collections.abc import Sequence
from functools import wraps
from itertools import zip_longest
from typing import Optional, Union
from types import ModuleType
from typing import TYPE_CHECKING, Optional, Union

import numpy as np
from typing_extensions import Literal

from aesara.compile.sharedvalue import shared
from aesara.graph.basic import Constant, Variable
Expand All @@ -13,6 +15,11 @@
from aesara.tensor.math import maximum
from aesara.tensor.shape import specify_shape
from aesara.tensor.type import int_dtypes
from aesara.tensor.var import TensorVariable


if TYPE_CHECKING:
from aesara.tensor.random.op import RandomVariable


def params_broadcast_shapes(param_shapes, ndims_params, use_aesara=True):
Expand Down Expand Up @@ -161,7 +168,14 @@ class RandomStream:
"""

def __init__(self, seed=None, namespace=None, rng_ctor=np.random.default_rng):
def __init__(
self,
seed: Optional[int] = None,
namespace: Optional[ModuleType] = None,
rng_ctor: Literal[
np.random.RandomState, np.random.Generator
] = np.random.default_rng,
):
if namespace is None:
from aesara.tensor.random import basic # pylint: disable=import-self

Expand All @@ -171,7 +185,14 @@ def __init__(self, seed=None, namespace=None, rng_ctor=np.random.default_rng):

self.default_instance_seed = seed
self.state_updates = []
self.gen_seedgen = np.random.default_rng(seed)
self.gen_seedgen = np.random.SeedSequence(seed)

if isinstance(rng_ctor, type) and issubclass(rng_ctor, np.random.RandomState):

# The legacy state does not accept `SeedSequence`s directly
def rng_ctor(seed):
return np.random.RandomState(np.random.MT19937(seed))

self.rng_ctor = rng_ctor

def __getattr__(self, obj):
Expand Down Expand Up @@ -206,7 +227,7 @@ def seed(self, seed=None):
Parameters
----------
seed : None or integer in range 0 to 2**30
seed : None or integer
Each random stream will be assigned a unique state that depends
deterministically on this value.
Expand All @@ -218,18 +239,18 @@ def seed(self, seed=None):
if seed is None:
seed = self.default_instance_seed

self.gen_seedgen = np.random.default_rng(seed)
self.gen_seedgen = np.random.SeedSequence(seed)
old_r_seeds = self.gen_seedgen.spawn(len(self.state_updates))

for old_r, new_r in self.state_updates:
old_r_seed = self.gen_seedgen.integers(2**30)
old_r.set_value(self.rng_ctor(int(old_r_seed)), borrow=True)
for (old_r, new_r), old_r_seed in zip(self.state_updates, old_r_seeds):
old_r.set_value(self.rng_ctor(old_r_seed), borrow=True)

def gen(self, op, *args, **kwargs):
"""Create a new random stream in this container.
def gen(self, op: "RandomVariable", *args, **kwargs) -> TensorVariable:
r"""Generate a draw from `op` seeded from this `RandomStream`.
Parameters
----------
op : RandomVariable
op
A `RandomVariable` instance
args
Positional arguments passed to `op`.
Expand All @@ -238,10 +259,8 @@ def gen(self, op, *args, **kwargs):
Returns
-------
TensorVariable
The symbolic random draw part of op()'s return value.
This function stores the updated `RandomGeneratorType` variable
for use at `build` time.
The symbolic random draw performed by `op`. This function stores
the updated `RandomType`\s for use at compile time.
"""
if "rng" in kwargs:
Expand All @@ -250,7 +269,7 @@ def gen(self, op, *args, **kwargs):
)

# Generate a new random state
seed = int(self.gen_seedgen.integers(2**30))
(seed,) = self.gen_seedgen.spawn(1)
rng = shared(self.rng_ctor(seed), borrow=True)

# Generate the sample
Expand Down
2 changes: 1 addition & 1 deletion tests/sandbox/test_rng_mrg.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def test_normal0():

sys.stdout.flush()

RR = RandomStream(234)
RR = RandomStream(235)

nn = RR.normal(avg, std, size=size)
ff = function(var_input, nn)
Expand Down
15 changes: 7 additions & 8 deletions tests/scan/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,8 +888,9 @@ def test_simple_shared_random(self):
)
my_f = function([], values, updates=updates, allow_input_downcast=True)

rng_seed = np.random.default_rng(utt.fetch_seed()).integers(2**30)
rng = np.random.default_rng(int(rng_seed)) # int() is for 32bit
rng_seed = np.random.SeedSequence(utt.fetch_seed())
(rng_seed,) = rng_seed.spawn(1)
rng = aesara_rng.rng_ctor(rng_seed)

numpy_v = np.zeros((10, 2))
for i in range(10):
Expand Down Expand Up @@ -2698,12 +2699,10 @@ def f(vsample_tm1):
[vsample], aesara_vsamples[-1], updates=updates, allow_input_downcast=True
)

_rng = np.random.default_rng(utt.fetch_seed())
rng_seed = _rng.integers(2**30)
nrng1 = np.random.default_rng(int(rng_seed)) # int() is for 32bit

rng_seed = _rng.integers(2**30)
nrng2 = np.random.default_rng(int(rng_seed)) # int() is for 32bit
rng_seed = np.random.SeedSequence(utt.fetch_seed())
(rng_seed_1, rng_seed_2) = rng_seed.spawn(2)
nrng1 = trng.rng_ctor(rng_seed_1)
nrng2 = trng.rng_ctor(rng_seed_2)

def numpy_implementation(vsample):
for idx in range(10):
Expand Down
34 changes: 15 additions & 19 deletions tests/tensor/random/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ def test_basics(self, rng_ctor):
fn_val0 = fn()
fn_val1 = fn()

rng_seed = np.random.default_rng(utt.fetch_seed()).integers(2**30)
rng = rng_ctor(int(rng_seed)) # int() is for 32bit
rng_seed = np.random.SeedSequence(utt.fetch_seed())
(rng_seed,) = rng_seed.spawn(1)
rng = random.rng_ctor(rng_seed)

numpy_val0 = rng.uniform(0, 1, size=(2, 2))
numpy_val1 = rng.uniform(0, 1, size=(2, 2))
Expand All @@ -133,26 +134,18 @@ def test_seed(self, rng_ctor):
init_seed = 234
random = RandomStream(init_seed, rng_ctor=rng_ctor)

ref_state = np.random.default_rng(init_seed).__getstate__()
random_state = random.gen_seedgen.__getstate__()
assert random.default_instance_seed == init_seed
assert random_state["bit_generator"] == ref_state["bit_generator"]
assert random_state["state"] == ref_state["state"]

new_seed = 43298
random.seed(new_seed)

ref_state = np.random.default_rng(new_seed).__getstate__()
random_state = random.gen_seedgen.__getstate__()
assert random_state["bit_generator"] == ref_state["bit_generator"]
assert random_state["state"] == ref_state["state"]
rng_seed = np.random.SeedSequence(new_seed)
assert random.gen_seedgen.entropy == rng_seed.entropy

random.seed()
ref_state = np.random.default_rng(init_seed).__getstate__()
random_state = random.gen_seedgen.__getstate__()
assert random.default_instance_seed == init_seed
assert random_state["bit_generator"] == ref_state["bit_generator"]
assert random_state["state"] == ref_state["state"]

rng_seed = np.random.SeedSequence(init_seed)
assert random.gen_seedgen.entropy == rng_seed.entropy

# Reset the seed
random.seed(new_seed)
Expand All @@ -163,8 +156,9 @@ def test_seed(self, rng_ctor):
# Now, change the seed when there are state updates
random.seed(new_seed)

update_seed = np.random.default_rng(new_seed).integers(2**30)
ref_rng = rng_ctor(update_seed)
update_seed = np.random.SeedSequence(new_seed)
(update_seed,) = update_seed.spawn(1)
ref_rng = random.rng_ctor(update_seed)
state_rng = random.state_updates[0][0].get_value(borrow=True)

if hasattr(state_rng, "get_state"):
Expand All @@ -188,8 +182,10 @@ def test_uniform(self, rng_ctor):
fn_val0 = fn()
fn_val1 = fn()

rng_seed = np.random.default_rng(utt.fetch_seed()).integers(2**30)
rng = rng_ctor(int(rng_seed)) # int() is for 32bit
rng_seed = np.random.SeedSequence(utt.fetch_seed())
(rng_seed,) = rng_seed.spawn(1)

rng = random.rng_ctor(rng_seed)
numpy_val0 = rng.uniform(-1, 1, size=(2, 2))
numpy_val1 = rng.uniform(-1, 1, size=(2, 2))

Expand Down

0 comments on commit b6dc523

Please sign in to comment.