diff --git a/flax/core/scope.py b/flax/core/scope.py index ff6c575983..117a63f44c 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -1074,9 +1074,19 @@ def apply( def wrapper( variables: VariableDict, *args, - rngs: Optional[RNGSequences] = None, + rngs: Optional[Union[PRNGKey, RNGSequences]] = None, **kwargs, ) -> Union[Any, Tuple[Any, Union[VariableDict, Dict[str, Any]]]]: + if rngs is not None: + if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs): + raise ValueError( + 'The ``rngs`` argument passed to an apply function should be a ' + '``jax.PRNGKey`` or a dictionary mapping strings to ' + '``jax.PRNGKey``.' + ) + if not isinstance(rngs, (dict, FrozenDict)): + rngs = {'params': rngs} + # Try to detect if user accidentally passed {'params': {'params': ...}. if ( 'params' in variables @@ -1118,10 +1128,10 @@ def wrapper(rngs, *args, **kwargs) -> Tuple[Any, VariableDict]: if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs): raise ValueError( 'First argument passed to an init function should be a ' - '`jax.PRNGKey` or a dictionary mapping strings to ' - '`jax.PRNGKey`.' + '``jax.PRNGKey`` or a dictionary mapping strings to ' + '``jax.PRNGKey``.' ) - if not isinstance(rngs, dict): + if not isinstance(rngs, (dict, FrozenDict)): rngs = {'params': rngs} init_flags = {**(flags if flags is not None else {}), 'initializing': True} return apply(fn, mutable=mutable, flags=init_flags)( @@ -1217,7 +1227,7 @@ def _is_valid_rng(rng: Array): return True -def _is_valid_rngs(rngs: RNGSequences): +def _is_valid_rngs(rngs: Union[PRNGKey, RNGSequences]): if not isinstance(rngs, (FrozenDict, dict)): return False for key, val in rngs.items(): diff --git a/flax/linen/module.py b/flax/linen/module.py index 0462d4d9b4..f0e035a1ce 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -2041,7 +2041,7 @@ def apply( self, variables: VariableDict, *args, - rngs: Optional[RNGSequences] = None, + rngs: Optional[Union[PRNGKey, RNGSequences]] = None, method: Union[Callable[..., Any], str, None] = None, mutable: CollectionFilter = False, capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False, @@ -2115,6 +2115,14 @@ def apply( """ Module._module_checks(self) + if rngs is not None and not isinstance(rngs, dict): + if not core.scope._is_valid_rng(rngs): + raise errors.InvalidRngError( + 'RNGs should be of shape (2,) or PRNGKey in Module ' + f'{self.__class__.__name__}, but rngs are: {rngs}' + ) + rngs = {'params': rngs} + if isinstance(method, str): attribute_name = method method = getattr(self, attribute_name) diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index a381209415..f6dffd049f 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -815,6 +815,43 @@ def __call__(self, x): trace = mlp.apply(variables, x) self.assertEqual(trace, expected_trace) + def test_default_params_rng_equivalence(self): + class Model(nn.Module): + @nn.compact + def __call__(self, x, add_dropout=False, add_noise=False): + x = nn.Dense(16)(x) + x = nn.Dropout(0.5)(x, deterministic=not add_dropout) + if add_noise: + x += jax.random.normal(self.make_rng('params')) + return x + + model = Model() + key0, key1, key2 = jax.random.split(jax.random.key(0), 3) + x = jax.random.normal(key0, (10, 8)) + + with self.assertRaisesRegex(ValueError, 'First argument passed to an init function should be a ``jax.PRNGKey``'): + model.init({'params': 'test'}, x) + with self.assertRaisesRegex(errors.InvalidRngError, 'RNGs should be of shape \\(2,\\) or PRNGKey in Module Model, but rngs are: test'): + model.init('test', x) + with self.assertRaisesRegex(errors.InvalidRngError, 'Dropout_0 needs PRNG for "dropout"'): + model.init(key1, x, add_dropout=True) + + v = model.init({'params': key1}, x) + v2 = model.init(key1, x) + jax.tree_map(np.testing.assert_allclose, v, v2) + + out = model.apply(v, x, add_noise=True, rngs={'params': key2}) + out2 = model.apply(v, x, add_noise=True, rngs=key2) + np.testing.assert_allclose(out, out2) + + with self.assertRaisesRegex(ValueError, 'The ``rngs`` argument passed to an apply function should be a ``jax.PRNGKey``'): + model.apply(v, x, rngs={'params': 'test'}) + with self.assertRaisesRegex(errors.InvalidRngError, 'RNGs should be of shape \\(2,\\) or PRNGKey in Module Model, but rngs are: test'): + model.apply(v, x, rngs='test') + with self.assertRaisesRegex(errors.InvalidRngError, 'Dropout_0 needs PRNG for "dropout"'): + model.apply(v, x, add_dropout=True, rngs=key2) + + def test_module_apply_method(self): class Foo(nn.Module): not_callable: int = 1