Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make make_rng default to 'params' #3699

Merged
merged 1 commit into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,11 +743,11 @@ def has_rng(self, name: str) -> bool:
"""Returns true if a PRNGSequence with name `name` exists."""
return name in self.rngs

def make_rng(self, name: str = 'default') -> PRNGKey:
def make_rng(self, name: str = 'params') -> PRNGKey:
"""Generates A PRNGKey from a PRNGSequence with name `name`."""
if not self.has_rng(name):
if self.has_rng('default'):
name = 'default'
if self.has_rng('params'):
name = 'params'
else:
raise errors.InvalidRngError(f'{self.name} needs PRNG for "{name}"')
self._check_valid()
Expand Down
6 changes: 5 additions & 1 deletion flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1883,13 +1883,17 @@ def has_rng(self, name: str) -> bool:
raise ValueError("Can't query for RNGs on unbound modules")
return self.scope.has_rng(name)

def make_rng(self, name: str = 'default') -> PRNGKey:
def make_rng(self, name: str = 'params') -> PRNGKey:
"""Returns a new RNG key from a given RNG sequence for this Module.

The new RNG key is split from the previous one. Thus, every call to
``make_rng`` returns a new RNG key, while still guaranteeing full
reproducibility.

NOTE: if an invalid name is passed (i.e. no RNG key was passed by
the user in ``.init`` or ``.apply`` for this name), then ``name``
will default to ``'params'``.

TODO: Link to Flax RNG design note.

Args:
Expand Down
76 changes: 46 additions & 30 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,28 +829,46 @@ def __call__(self, x, add_dropout=False, add_noise=False):
key0, key1, key2 = jax.random.split(jax.random.key(0), 3)
x = jax.random.normal(key0, (10, 8))

with self.assertRaisesRegex(ValueError, 'First argument passed to an init function should be a ``jax.PRNGKey``'):
with self.assertRaisesRegex(
ValueError,
'First argument passed to an init function should be a ``jax.PRNGKey``',
):
model.init({'params': 'test'}, x)
with self.assertRaisesRegex(errors.InvalidRngError, 'RNGs should be of shape \\(2,\\) or PRNGKey in Module Model, but rngs are: test'):
with self.assertRaisesRegex(
errors.InvalidRngError,
'RNGs should be of shape \\(2,\\) or PRNGKey in Module Model, but rngs are: test',
):
model.init('test', x)
with self.assertRaisesRegex(errors.InvalidRngError, 'Dropout_0 needs PRNG for "dropout"'):
model.init(key1, x, add_dropout=True)
# should not throw an error, since nn.Dropout will get an RNG key from the 'params' stream
model.init(key1, x, add_dropout=True)

v = model.init({'params': key1}, x)
v2 = model.init(key1, x)
jax.tree_map(np.testing.assert_allclose, v, v2)

out = model.apply(v, x, add_noise=True, rngs={'params': key2})
out2 = model.apply(v, x, add_noise=True, rngs=key2)
np.testing.assert_allclose(out, out2)
for add_dropout, add_noise in [[True, False], [False, True], [True, True]]:
out = model.apply(
v,
x,
add_dropout=add_dropout,
add_noise=add_noise,
rngs={'params': key2},
)
out2 = model.apply(
v, x, add_dropout=add_dropout, add_noise=add_noise, rngs=key2
)
np.testing.assert_allclose(out, out2)

with self.assertRaisesRegex(ValueError, 'The ``rngs`` argument passed to an apply function should be a ``jax.PRNGKey``'):
with self.assertRaisesRegex(
ValueError,
'The ``rngs`` argument passed to an apply function should be a ``jax.PRNGKey``',
):
model.apply(v, x, rngs={'params': 'test'})
with self.assertRaisesRegex(errors.InvalidRngError, 'RNGs should be of shape \\(2,\\) or PRNGKey in Module Model, but rngs are: test'):
with self.assertRaisesRegex(
errors.InvalidRngError,
'RNGs should be of shape \\(2,\\) or PRNGKey in Module Model, but rngs are: test',
):
model.apply(v, x, rngs='test')
with self.assertRaisesRegex(errors.InvalidRngError, 'Dropout_0 needs PRNG for "dropout"'):
model.apply(v, x, add_dropout=True, rngs=key2)


def test_module_apply_method(self):
class Foo(nn.Module):
Expand Down Expand Up @@ -1491,6 +1509,7 @@ def test_perturb_setup(self):
class Foo(nn.Module):
def setup(self):
self.a = nn.Dense(10)

def __call__(self, x):
x = self.a(x)
x = self.perturb('before_multiply', x)
Expand Down Expand Up @@ -1534,9 +1553,7 @@ def __call__(self, x):
module.apply({'params': params}, x)

# check errors if perturbations is passed but empty
with self.assertRaisesRegex(
ValueError, 'Perturbation collection'
):
with self.assertRaisesRegex(ValueError, 'Perturbation collection'):
module.apply({'params': params, 'perturbations': {}}, x)

# check no error if perturbations is passed and not empty
Expand Down Expand Up @@ -2619,9 +2636,8 @@ def __call__(self, x, apply_dropout):
model = Model()

# test init equality
default_variables = model.init({'default': key1}, x, apply_dropout=False)
# adding 'default' rng shouldn't change anything
rngs = {'params': key1, 'var_rng': key1, 'noise': key1, 'default': key0}
default_variables = model.init({'params': key1}, x, apply_dropout=False)
rngs = {'params': key1, 'var_rng': key1, 'noise': key1}
explicit_variables = model.init(rngs, x, apply_dropout=False)
self.assertTrue(
jax.tree_util.tree_all(
Expand All @@ -2632,7 +2648,6 @@ def __call__(self, x, apply_dropout):
)

# test init inequality
rngs['default'] = key1 # adding 'default' rng shouldn't change anything
for rng_name in ('params', 'var_rng'):
rngs[rng_name] = key2
explicit_variables = model.init(rngs, x, apply_dropout=False)
Expand All @@ -2649,17 +2664,15 @@ def __call__(self, x, apply_dropout):

# test apply equality
default_out = model.apply(
default_variables, x, apply_dropout=True, rngs={'default': key1}
default_variables, x, apply_dropout=True, rngs={'params': key1}
)
# adding 'default' rng shouldn't change anything
rngs = {'dropout': key1, 'noise': key1, 'default': key0}
rngs = {'dropout': key1, 'noise': key1}
explicit_out = model.apply(
default_variables, x, apply_dropout=True, rngs=rngs
)
np.testing.assert_allclose(default_out, explicit_out)

# test apply inequality
rngs['default'] = key1 # adding 'default' rng shouldn't change anything
for rng_name in ('dropout', 'noise'):
rngs[rng_name] = key2
explicit_out = model.apply(
Expand All @@ -2685,22 +2698,22 @@ def __call__(self, x):

key0, key1 = jax.random.split(jax.random.key(0), 2)
x = jax.random.normal(key0, (10, 4))
default_out = Model().apply({}, x, rngs={'default': key1})
default_out = Model().apply({}, x, rngs={'params': key1})

class SubModel(nn.Module):
@nn.compact
def __call__(self, x):
noise = jax.random.normal(self.make_rng('default'), x.shape)
noise = jax.random.normal(self.make_rng('params'), x.shape)
return x + noise

class Model(nn.Module):
@nn.compact
def __call__(self, x):
x = SubModel()(x)
noise = jax.random.normal(self.make_rng('default'), x.shape)
noise = jax.random.normal(self.make_rng('params'), x.shape)
return x + noise

explicit_out = Model().apply({}, x, rngs={'default': key1})
explicit_out = Model().apply({}, x, rngs={'params': key1})
np.testing.assert_allclose(default_out, explicit_out)

def test_default_rng_error(self):
Expand All @@ -2718,13 +2731,13 @@ def __call__(self, x):
class Model(nn.Module):
@nn.compact
def __call__(self, x):
return x + jax.random.normal(self.make_rng('default'), x.shape)
return x + jax.random.normal(self.make_rng(), x.shape)

model = Model()
with self.assertRaisesRegex(
errors.InvalidRngError, 'None needs PRNG for "default"'
errors.InvalidRngError, 'None needs PRNG for "params"'
):
model.init(jax.random.key(0), jnp.ones((1, 3)))
model.init({'other_rng_stream': jax.random.key(0)}, jnp.ones((1, 3)))

def test_compact_name_scope(self):
class Foo(nn.Module):
Expand Down Expand Up @@ -3106,10 +3119,13 @@ def test_nonstring_keys_in_dict_on_module(self):
class MyEnum(str, enum.Enum):
a = 'a'
b = 'b'

class MyModule(nn.Module):
config: dict[MyEnum, int]

def __call__(self, inputs):
return inputs

module = MyModule(config={MyEnum.a: 1, MyEnum.b: 2})
variables = module.init(jax.random.key(0), jnp.zeros([0]))

Expand Down
Loading