Skip to content

Commit

Permalink
Reverting unsafe_pytree #4030
Browse files Browse the repository at this point in the history
Exploring better alternatives that will not clash with the assumption that graph nodes are not pytrees present in many functions.

PiperOrigin-RevId: 651347331
  • Loading branch information
Cristian Garcia authored and Flax Authors committed Jul 11, 2024
1 parent 5c97143 commit 6eb272b
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 174 deletions.
171 changes: 66 additions & 105 deletions docs/nnx/nnx_basics.ipynb

Large diffs are not rendered by default.

53 changes: 1 addition & 52 deletions docs/nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ that have allowed Linen to scale effectively to large codebases.
```{code-cell} ipython3
:tags: [skip-execution]
# ! pip install -U flax penzai
! pip install -U flax penzai
```

```{code-cell} ipython3
Expand Down Expand Up @@ -378,54 +378,3 @@ model = nnx.merge(graphdef, params, counts)
# update with multiple States
nnx.update(model, params, counts)
```

## Using Modules as Pytrees

Even though `nnx.split` and `nnx.merge` can be used to interact with any JAX
API, they are not always the most convenient way to do so as they introduce
some syntactic overhead. `Module`s and other `Object`-derived types can be
registered as PyTrees via the `unsafe_pytree` class argument for convenience.
This allows you to pass Modules directly to JAX functions without having to
split them first.

```{code-cell} ipython3
class Block(nnx.Module, unsafe_pytree=True): # <== 👀 unsafe_pytree
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.linear = Linear(din, dout, rngs=rngs)
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
def __call__(self, x: jax.Array):
return nnx.gelu(self.dropout(self.linear(x)))
model = Block(3, 5, rngs=nnx.Rngs(0))
@jax.jit # regular jax.jit!
def forward(model: Block, x: jax.Array):
y = model(x)
return y, model # manually propagate state updates
y, model = forward(model, jnp.ones((1, 3)))
```

**WARNING**: The reason the features is called `unsafe` is because NNX's
reference semantics are broken by JAX's referential transparency, this
is specially problematic when there is shared state between NNX graph nodes
as reference identity is lost. Use `unsafe_pytree` only when there's only
a single top-level object or when top-level object have no shared state
between them.

```{code-cell} ipython3
class Foo(nnx.Module, unsafe_pytree=True):
def __init__(self, shared):
self.shared = shared
shared = nnx.Linear(3, 5, rngs=nnx.Rngs(0))
ma, mb = Foo(shared), Foo(shared)
print(f'Before: {ma.shared is mb.shared = }')
# flatten + unflatten
ma, mb = jax.tree.map(lambda x: x, (ma, mb))
print(f'After: {ma.shared is mb.shared = }')
```
4 changes: 2 additions & 2 deletions flax/nnx/nnx/bridge/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ def is_initializing(self) -> bool:

return self._object__state._initializing

def __init_subclass__(cls, unsafe_pytree: bool = False) -> None:
super().__init_subclass__(unsafe_pytree=unsafe_pytree)
def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
super().__init_subclass__(experimental_pytree=experimental_pytree)

cls = dataclasses.dataclass(repr=False)(cls)

Expand Down
12 changes: 2 additions & 10 deletions flax/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,18 +392,10 @@ def eval(self, **attributes):
raise_if_not_found=False,
)

def __init_subclass__(cls, unsafe_pytree: bool = False) -> None:
"""
Args:
unsafe_pytree: If True, the Module subclass will be
registered as a pytree node with JAX. This breaks reference
semantics and should be used with caution, however it can be
useful to use Modules with vanillay JAX transformations. See
`Using Modules as PyTrees <https://flax.readthedocs.io/en/latest/nnx/nnx_basics.html#using-modules-as-pytrees>`__.
"""
def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
super().__init_subclass__()

if unsafe_pytree:
if experimental_pytree:
jtu.register_pytree_with_keys(
cls,
partial(_module_flatten, with_keys=True),
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/tests/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ class SimpleModule(nnx.Module):
pass


class SimplePyTreeModule(nnx.Module, unsafe_pytree=True):
class SimplePyTreeModule(nnx.Module, experimental_pytree=True):
pass


Expand Down
4 changes: 2 additions & 2 deletions flax/nnx/tests/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs):

class TestModulePytree:
def test_tree_map(self):
class Foo(nnx.Module, unsafe_pytree=True):
class Foo(nnx.Module, experimental_pytree=True):
def __init__(self):
self.node = nnx.Param(1)
self.graphdef = 1
Expand All @@ -490,7 +490,7 @@ def __init__(self):
assert m.graphdef == 1

def test_static(self):
class C(nnx.Module, unsafe_pytree=True):
class C(nnx.Module, experimental_pytree=True):
def __init__(self, x):
self.x = x

Expand Down
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,6 @@ filterwarnings = [
"ignore:.*Deprecated call to.*pkg_resources.declare_namespace.*:DeprecationWarning",
# jax.xla_computation is deprecated but TF still uses it.
"ignore:.*jax.xla_computation is deprecated.*:DeprecationWarning",
# FutureWarning: The key path API is deprecated and will be removed in a future version
"ignore:.*The key path API is deprecated and will be removed in a future version.*:FutureWarning",
]

[tool.coverage.report]
Expand Down

0 comments on commit 6eb272b

Please sign in to comment.