Skip to content

Commit

Permalink
added default params rng to .apply
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Feb 15, 2024
1 parent efbe705 commit 925fe9c
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 6 deletions.
20 changes: 15 additions & 5 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,9 +1074,19 @@ def apply(
def wrapper(
variables: VariableDict,
*args,
rngs: Optional[RNGSequences] = None,
rngs: Optional[Union[PRNGKey, RNGSequences]] = None,
**kwargs,
) -> Union[Any, Tuple[Any, Union[VariableDict, Dict[str, Any]]]]:
if rngs is not None:
if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs):
raise ValueError(
'The ``rngs`` argument passed to an apply function should be a '
'``jax.PRNGKey`` or a dictionary mapping strings to '
'``jax.PRNGKey``.'
)
if not isinstance(rngs, (dict, FrozenDict)):
rngs = {'params': rngs}

# Try to detect if user accidentally passed {'params': {'params': ...}.
if (
'params' in variables
Expand Down Expand Up @@ -1118,10 +1128,10 @@ def wrapper(rngs, *args, **kwargs) -> Tuple[Any, VariableDict]:
if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs):
raise ValueError(
'First argument passed to an init function should be a '
'`jax.PRNGKey` or a dictionary mapping strings to '
'`jax.PRNGKey`.'
'``jax.PRNGKey`` or a dictionary mapping strings to '
'``jax.PRNGKey``.'
)
if not isinstance(rngs, dict):
if not isinstance(rngs, (dict, FrozenDict)):
rngs = {'params': rngs}
init_flags = {**(flags if flags is not None else {}), 'initializing': True}
return apply(fn, mutable=mutable, flags=init_flags)(
Expand Down Expand Up @@ -1217,7 +1227,7 @@ def _is_valid_rng(rng: Array):
return True


def _is_valid_rngs(rngs: RNGSequences):
def _is_valid_rngs(rngs: Union[PRNGKey, RNGSequences]):
if not isinstance(rngs, (FrozenDict, dict)):
return False
for key, val in rngs.items():
Expand Down
10 changes: 9 additions & 1 deletion flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2041,7 +2041,7 @@ def apply(
self,
variables: VariableDict,
*args,
rngs: Optional[RNGSequences] = None,
rngs: Optional[Union[PRNGKey, RNGSequences]] = None,
method: Union[Callable[..., Any], str, None] = None,
mutable: CollectionFilter = False,
capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False,
Expand Down Expand Up @@ -2115,6 +2115,14 @@ def apply(
"""
Module._module_checks(self)

if rngs is not None and not isinstance(rngs, dict):
if not core.scope._is_valid_rng(rngs):
raise errors.InvalidRngError(
'RNGs should be of shape (2,) or PRNGKey in Module '
f'{self.__class__.__name__}, but rngs are: {rngs}'
)
rngs = {'params': rngs}

if isinstance(method, str):
attribute_name = method
method = getattr(self, attribute_name)
Expand Down
37 changes: 37 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,43 @@ def __call__(self, x):
trace = mlp.apply(variables, x)
self.assertEqual(trace, expected_trace)

def test_default_params_rng_equivalence(self):
class Model(nn.Module):
@nn.compact
def __call__(self, x, add_dropout=False, add_noise=False):
x = nn.Dense(16)(x)
x = nn.Dropout(0.5)(x, deterministic=not add_dropout)
if add_noise:
x += jax.random.normal(self.make_rng('params'))
return x

model = Model()
key0, key1, key2 = jax.random.split(jax.random.key(0), 3)
x = jax.random.normal(key0, (10, 8))

with self.assertRaisesRegex(ValueError, 'First argument passed to an init function should be a ``jax.PRNGKey``'):
model.init({'params': 'test'}, x)
with self.assertRaisesRegex(errors.InvalidRngError, 'RNGs should be of shape \\(2,\\) or PRNGKey in Module Model, but rngs are: test'):
model.init('test', x)
with self.assertRaisesRegex(errors.InvalidRngError, 'Dropout_0 needs PRNG for "dropout"'):
model.init(key1, x, add_dropout=True)

v = model.init({'params': key1}, x)
v2 = model.init(key1, x)
jax.tree_map(np.testing.assert_allclose, v, v2)

out = model.apply(v, x, add_noise=True, rngs={'params': key2})
out2 = model.apply(v, x, add_noise=True, rngs=key2)
np.testing.assert_allclose(out, out2)

with self.assertRaisesRegex(ValueError, 'The ``rngs`` argument passed to an apply function should be a ``jax.PRNGKey``'):
model.apply(v, x, rngs={'params': 'test'})
with self.assertRaisesRegex(errors.InvalidRngError, 'RNGs should be of shape \\(2,\\) or PRNGKey in Module Model, but rngs are: test'):
model.apply(v, x, rngs='test')
with self.assertRaisesRegex(errors.InvalidRngError, 'Dropout_0 needs PRNG for "dropout"'):
model.apply(v, x, add_dropout=True, rngs=key2)


def test_module_apply_method(self):
class Foo(nn.Module):
not_callable: int = 1
Expand Down

0 comments on commit 925fe9c

Please sign in to comment.