Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add submodule iterator #3581

Merged
merged 1 commit into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flax/experimental/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from flax.linen.pooling import pool as pool

from .nnx import compatibility as compatibility
from .nnx import graph_utils
from .nnx.dataclasses import dataclass as dataclass
from .nnx.dataclasses import field as field
from .nnx import graph_utils as graph_utils
from .nnx.dataclasses import param_field as param_field
from .nnx.dataclasses import treenode_field as treenode_field
from .nnx.dataclasses import variable_field as variable_field
Expand Down
22 changes: 22 additions & 0 deletions flax/experimental/nnx/nnx/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ def _graph_unflatten(
if graphdef.index in index_to_node:
raise RuntimeError(f'GraphDef index {graphdef.index} already used.')

state = state.copy()
node_impl = get_node_impl(graphdef.type)

def _get_children():
Expand Down Expand Up @@ -721,6 +722,27 @@ def clone(node: Node) -> Node:
return static.merge(state)


def iter_nodes(node: tp.Any) -> tp.Iterator[tuple[Path, tp.Any]]:
visited: set[int] = set()
path_parts: PathParts = ()
yield from _iter_nodes(node, visited, path_parts)


def _iter_nodes(
node: tp.Any, visited: set[int], path_parts: PathParts
Comment on lines +725 to +732
Copy link
Collaborator

@chiamp chiamp Dec 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can the node args have a more specific type annotation than tp.Any?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

problem is that a node could be of any type that is registered.

) -> tp.Iterator[tuple[Path, tp.Any]]:
if not is_node(node):
return
if id(node) in visited:
return
visited.add(id(node))
path = '/'.join(path_parts)
yield path, node
node_impl = get_node_impl(node)
for key, value in node_impl.items(node):
yield from _iter_nodes(value, visited, (*path_parts, key))


# -----------------------------
# register node types
# -----------------------------
Expand Down
27 changes: 4 additions & 23 deletions flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,29 +482,10 @@ def sow(
reduced_value = reduce_fn(init_fn(), value)
setattr(self, name, variable_type(reduced_value))

def for_each(
Copy link
Member

@superbobry superbobry Dec 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be called submodules and not module? Or better yet iter_submodules?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Called it modules() as it would be familiar to Pytorch users. Wondering if we should try to follow their conventions when possible?

self, module_type: tp.Type[M], fn: tp.Callable[[M], None]
) -> None:
visited: tp.Set[ids.UUID] = set()
self._on_all(module_type, fn, visited)

def _on_all(
self,
module_type: tp.Type[M],
fn: tp.Callable[[M], None],
visited: tp.Set[ids.UUID],
) -> None:
if self._module__state.id in visited:
return

visited.add(self._module__state.id)

if isinstance(self, module_type):
fn(self)

for value in vars(self).values():
def modules(self) -> tp.Iterator[tuple[Path, Module]]:
for path, value in graph_utils.iter_nodes(self):
if isinstance(value, Module):
value._on_all(module_type, fn, visited)
yield path, value

def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
super().__init_subclass__()
Expand Down Expand Up @@ -603,7 +584,7 @@ def first_from(arg_name: str, *args: tp.Optional[A]) -> A:


def merge(
state_and_def: tuple[tpe.Unpack[tuple[State, ...]], GraphDef[M]]
state_and_def: tuple[tpe.Unpack[tuple[State, ...]], GraphDef[M]],
) -> M:
*states, graphdef = state_and_def
return graphdef.merge(*states)
20 changes: 20 additions & 0 deletions flax/experimental/nnx/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,3 +615,23 @@ def __call__(self, x, *, rngs: nnx.Rngs):
y, (state, graphdef) = graphdef.apply(state)(x=2.0, rngs=nnx.Rngs(e=1))

assert isinstance(y, jax.Array)

def test_modules_iterator(self):
class Foo(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.submodules = [
{'a': nnx.Linear(1, 1, rngs=rngs)},
{'b': nnx.Conv(1, 1, 1, rngs=rngs)},
]

module = Foo(rngs=nnx.Rngs(0))

modules = list(module.modules())

assert len(modules) == 3
assert modules[0][0] == ''
assert isinstance(modules[0][1], Foo)
assert modules[1][0] == 'submodules/0/a'
assert isinstance(modules[1][1], nnx.Linear)
assert modules[2][0] == 'submodules/1/b'
assert isinstance(modules[2][1], nnx.Conv)
Loading