diff --git a/flax/experimental/nnx/nnx/module.py b/flax/experimental/nnx/nnx/module.py index 0156fd9222..cb1c54dd99 100644 --- a/flax/experimental/nnx/nnx/module.py +++ b/flax/experimental/nnx/nnx/module.py @@ -68,16 +68,6 @@ def _module_meta_call(cls: tp.Type[M], *args, **kwargs) -> M: if isinstance(module, _HasSetup): module.setup() - assert isinstance(module, Module) - - for field in dataclasses.fields(module): - if not field.init: - continue - value = vars(module)[field.name] - # set Rngs instances to None - if isinstance(value, Rngs): - vars(module)[field.name] = None - return module diff --git a/flax/experimental/nnx/nnx/nn/stochastic.py b/flax/experimental/nnx/nnx/nn/stochastic.py index 3bad4e0aec..fa2d7ce79d 100644 --- a/flax/experimental/nnx/nnx/nn/stochastic.py +++ b/flax/experimental/nnx/nnx/nn/stochastic.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations -from typing import Optional, Sequence +import dataclasses +from typing import Sequence import jax import jax.numpy as jnp @@ -20,7 +22,6 @@ from flax.experimental.nnx.nnx import rnglib from flax.experimental.nnx.nnx.module import Module, first_from -import dataclasses @dataclasses.dataclass @@ -38,15 +39,16 @@ class Dropout(Module): rate: float broadcast_dims: Sequence[int] = () - deterministic: Optional[bool] = None + deterministic: bool | None = None rng_collection: str = 'dropout' + rngs: rnglib.Rngs | None = None def __call__( self, inputs, *, - deterministic: Optional[bool] = None, - rngs: Optional[rnglib.Rngs] = None, + deterministic: bool | None = None, + rngs: rnglib.Rngs | None = None, ) -> jax.Array: """Applies a random dropout mask to the input. @@ -59,6 +61,7 @@ def __call__( Returns: The masked inputs reweighted to preserve mean. """ + rngs = rngs or self.rngs deterministic = first_from( deterministic, self.deterministic, diff --git a/flax/experimental/nnx/tests/nn/test_stochastic.py b/flax/experimental/nnx/tests/nn/test_stochastic.py new file mode 100644 index 0000000000..e9b8e1dc14 --- /dev/null +++ b/flax/experimental/nnx/tests/nn/test_stochastic.py @@ -0,0 +1,27 @@ + +import jax.numpy as jnp + +from flax.experimental import nnx + + +class TestStochastic: + def test_dropout_internal_rngs(self): + n = 0 + m = nnx.Dropout(rate=0.5, deterministic=False, rngs=nnx.Rngs(dropout=0)) + + @nnx.jit + def f(m, x): + nonlocal n + n += 1 + return m(x) + + x = jnp.ones((1, 10)) + assert m.rngs is not None and m.rngs.dropout.count.value == 0 + + y = f(m, x) + assert n == 1 + assert m.rngs.dropout.count.value == 1 + + y = f(m, x) + assert n == 1 + assert m.rngs.dropout.count.value == 2 diff --git a/flax/experimental/nnx/tests/test_module.py b/flax/experimental/nnx/tests/test_module.py index d4d3d4ac23..bb348d13e9 100644 --- a/flax/experimental/nnx/tests/test_module.py +++ b/flax/experimental/nnx/tests/test_module.py @@ -550,7 +550,7 @@ class Foo(nnx.Module): assert state.d == nnx.Variable(4) assert state.e == nnx.BatchStat(5) - def test_context_none_after_init(self): + def test_post_init(self): @dataclasses.dataclass class DFoo(nnx.Module): din: int @@ -566,7 +566,6 @@ def __call__(self, x): m = DFoo(1, 1, rngs=nnx.Rngs(0)) assert hasattr(m, 'bar') - assert m.rngs is None def test_setup_is_called(self): @dataclasses.dataclass @@ -584,7 +583,6 @@ def __call__(self, x): m = DFoo(1, 1, rngs=nnx.Rngs(0)) assert hasattr(m, 'bar') - assert m.rngs is None class TestModuleDef: