Skip to content

Commit

Permalink
Merge pull request #3889 from google:nnx-fix-iter-nodes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 632184511
  • Loading branch information
Flax Authors committed May 9, 2024
2 parents b000c75 + 90c2687 commit a680767
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 8 deletions.
1 change: 1 addition & 0 deletions flax/experimental/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from .nnx.graph import pop as pop
from .nnx.graph import state as state
from .nnx.graph import graphdef as graphdef
from .nnx.graph import iter_nodes as iter_nodes
from .nnx.nn import initializers as initializers
from .nnx.nn.activations import celu as celu
from .nnx.nn.activations import elu as elu
Expand Down
2 changes: 1 addition & 1 deletion flax/experimental/nnx/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,11 +1072,11 @@ def _iter_nodes(
if id(node) in visited:
return
visited.add(id(node))
yield path_parts, node
node_impl = get_node_impl(node)
node_dict = node_impl.node_dict(node)
for key, value in node_dict.items():
yield from _iter_nodes(value, visited, (*path_parts, key))
yield path_parts, node


def compose_mapping(
Expand Down
2 changes: 1 addition & 1 deletion flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,10 @@ def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]:
>>> for path, module in model.iter_modules():
... print(path, type(module).__name__)
...
() Block
('batch_norm',) BatchNorm
('dropout',) Dropout
('linear',) Linear
() Block
"""
for path, value in graph.iter_nodes(self):
if isinstance(value, Module):
Expand Down
12 changes: 6 additions & 6 deletions flax/experimental/nnx/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,12 +639,12 @@ def __init__(self, *, rngs: nnx.Rngs):
modules = list(module.iter_modules())

assert len(modules) == 3
assert modules[0][0] == ()
assert isinstance(modules[0][1], Foo)
assert modules[1][0] == ('submodules', 0, 'a')
assert isinstance(modules[1][1], nnx.Linear)
assert modules[2][0] == ('submodules', 1, 'b')
assert isinstance(modules[2][1], nnx.Conv)
assert modules[0][0] == ('submodules', 0, 'a')
assert isinstance(modules[0][1], nnx.Linear)
assert modules[1][0] == ('submodules', 1, 'b')
assert isinstance(modules[1][1], nnx.Conv)
assert modules[2][0] == ()
assert isinstance(modules[2][1], Foo)

def test_array_in_module(self):
class Foo(nnx.Module):
Expand Down

0 comments on commit a680767

Please sign in to comment.