Skip to content

Commit

Permalink
Merge pull request #4099 from google:nnx-reseed
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 654933687
  • Loading branch information
Flax Authors committed Jul 22, 2024
2 parents d8bc194 + 3f8c62f commit a7bdadb
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/api_reference/flax.nnx/rnglib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ rnglib
:members: __init__
.. autoclass:: RngStream
:members:
.. autofunction:: reseed
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
from .nnx.rnglib import RngCount as RngCount
from .nnx.rnglib import ForkStates as ForkStates
from .nnx.rnglib import fork as fork
from .nnx.rnglib import reseed as reseed
from .nnx.spmd import PARTITION_NAME as PARTITION_NAME
from .nnx.spmd import get_partition_spec as get_partition_spec
from .nnx.spmd import get_named_sharding as get_named_sharding
Expand Down
53 changes: 52 additions & 1 deletion flax/nnx/nnx/rnglib.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,4 +288,55 @@ def backup_keys(node: tp.Any, /):

def restore_keys(backups: list[tuple[RngStream, jax.Array]], /):
for stream, key in backups:
stream.key.value = key
stream.key.value = key


def reseed(node, /, **stream_keys: RngValue):
"""Update the keys of the specified RNG streams with new keys.
Args:
node: the node to reseed the RNG streams in.
**stream_keys: a mapping of stream names to new keys. The keys can be
either integers or jax arrays. If an integer is passed in, then the
key will be generated using ``jax.random.key``.
Raises:
ValueError: if an existing stream key is not a scalar.
Example::
>>> from flax import nnx
>>> import jax.numpy as jnp
...
>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.linear = nnx.Linear(2, 3, rngs=rngs)
... self.dropout = nnx.Dropout(0.5, rngs=rngs)
... def __call__(self, x):
... return self.dropout(self.linear(x))
...
>>> model = Model(nnx.Rngs(params=0, dropout=42))
>>> x = jnp.ones((1, 2))
...
>>> y1 = model(x)
...
>>> # reset the ``dropout`` stream key to 42
>>> nnx.reseed(model, dropout=42)
>>> y2 = model(x)
...
>>> jnp.allclose(y1, y2)
Array(True, dtype=bool)
"""
for _, stream in graph.iter_graph(node):
if isinstance(stream, RngStream):
if stream.key.tag in stream_keys:
if stream.key.shape != ():
raise ValueError(
f'Cannot reseed stream {stream.key.tag!r} with a non-scalar key, '
f' found key with shape {stream.key.shape}.'
)
key = stream_keys[stream.key.tag]
if isinstance(key, int):
key = jax.random.key(key)
stream.key.value = key
stream.count.value = jnp.array(0, dtype=jnp.uint32)
24 changes: 23 additions & 1 deletion flax/nnx/tests/rngs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@

import jax
import jax.numpy as jnp
import numpy as np
import pytest
from absl.testing import absltest

from flax import nnx


class TestRngs:
class TestRngs(absltest.TestCase):
def test_call(self):
rngs = nnx.Rngs(0)
key = rngs()
Expand Down Expand Up @@ -235,3 +237,23 @@ def test_state_fork_multidimensional_split_mixed(self):
assert broadcast_keys.dropout.key.value.shape == ()
assert split_counts.params.count.value == 0
assert broadcast_counts.dropout.count.value == 0

def test_reseed(self):
class Model(nnx.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(2, 3, rngs=rngs)
self.dropout = nnx.Dropout(0.5, rngs=rngs)

def __call__(self, x):
return self.dropout(self.linear(x))

model = Model(nnx.Rngs(params=0, dropout=42))
x = jnp.ones((1, 2))

y1 = model(x)

# reset the ``dropout`` stream key to 42
nnx.reseed(model, dropout=42)
y2 = model(x)

np.testing.assert_allclose(y1, y2)

0 comments on commit a7bdadb

Please sign in to comment.