diff --git a/flax/experimental/nnx/nnx/graph_utils.py b/flax/experimental/nnx/nnx/graph_utils.py index 974dcf7c63..d9ebb9942b 100644 --- a/flax/experimental/nnx/nnx/graph_utils.py +++ b/flax/experimental/nnx/nnx/graph_utils.py @@ -105,28 +105,33 @@ def register_node_type( def is_node(x: tp.Any) -> bool: - return type(x) in NODE_TYPES + if isinstance(x, Variable): + return False + elif type(x) in NODE_TYPES: + return True + return is_pytree_node(x) def is_node_type(x: type[tp.Any]) -> bool: return x in NODE_TYPES -@tp.overload -def get_node_impl(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]: - ... +def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]: + if isinstance(x, Variable): + raise ValueError(f'Variable is not a node: {x}') + node_type = type(x) -@tp.overload -def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]: - ... + if node_type not in NODE_TYPES: + if is_pytree_node(x): + node_type = PytreeType + else: + raise ValueError(f'Unknown node type: {x}') + + return NODE_TYPES[node_type] -def get_node_impl(x: type[Node] | Node) -> NodeImpl[Node, tp.Any, tp.Any]: - if not isinstance(x, type): - x = type(x) - if not is_node_type(x): - raise ValueError(f'Unknown node type: {x}') +def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]: return NODE_TYPES[x] @@ -147,12 +152,12 @@ def __len__(self) -> int: return len(self._mapping) def __hash__(self) -> int: - return hash(tuple(self._mapping.items())) + return hash(tuple(sorted(self._mapping.items()))) def __eq__(self, other: tp.Any) -> bool: - if not isinstance(other, tp.Mapping): - return False - return self._mapping == other + return ( + isinstance(other, _HashableMapping) and self._mapping == other._mapping + ) def __repr__(self) -> str: return repr(self._mapping) @@ -253,7 +258,7 @@ def __init__( variables: tp.Iterable[tuple[str, VariableDef | int]], metadata: tp.Any, ): - self._type = type + self._type: type[Node] = type self._index = index self._attributes = attributes self._subgraphs = _HashableMapping(subgraphs) @@ -417,7 +422,7 @@ def _graph_flatten( static_fields.append((key, value)) graphdef = GraphDef( - type=type(node), + type=node_impl.type, index=index, attributes=tuple(key for key, _ in values), subgraphs=subgraphs, @@ -449,7 +454,7 @@ def _graph_unflatten( # TODO(cgarciae): why copy here? state = state.copy() - node_impl = get_node_impl(graphdef.type) + node_impl = get_node_impl_for_type(graphdef.type) def _get_children(): new_state: dict[str, tp.Any] = {} @@ -874,3 +879,46 @@ def _unflatten_tuple( flatten=_flatten_tuple, unflatten=_unflatten_tuple, ) + + +# Pytree +class PytreeType: + pass + + +def is_pytree_node(x: tp.Any) -> bool: + return not jax.tree_util.all_leaves([x]) + + +def _key_path_to_str(key: tp.Any) -> str: + if isinstance(key, jax.tree_util.SequenceKey): + return str(key.idx) + elif isinstance( + key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey) + ): + return str(key.key) + elif isinstance(key, jax.tree_util.GetAttrKey): + return key.name + else: + return str(key) + + +def _flatten_pytree(pytree: tp.Any): + leaves, treedef = jax.tree_util.tree_flatten_with_path( + pytree, is_leaf=lambda x: x is not pytree + ) + nodes = tuple((_key_path_to_str(path[0]), value) for path, value in leaves) + + return nodes, treedef + + +def _unflatten_pytree( + nodes: tuple[tuple[str, tp.Any], ...], treedef: jax.tree_util.PyTreeDef +): + pytree = treedef.unflatten(value for _, value in nodes) + return pytree + + +register_node_type( + PytreeType, flatten=_flatten_pytree, unflatten=_unflatten_pytree +) diff --git a/flax/experimental/nnx/tests/test_graph_utils.py b/flax/experimental/nnx/tests/test_graph_utils.py index 8ee03eb0ca..826763d3e6 100644 --- a/flax/experimental/nnx/tests/test_graph_utils.py +++ b/flax/experimental/nnx/tests/test_graph_utils.py @@ -16,6 +16,7 @@ import pytest from flax.experimental import nnx +from flax import struct class TestGraphUtils: @@ -230,3 +231,49 @@ def __init__(self): assert m2.a.value == state.a.raw_value assert m2.b.value == state.a.raw_value assert m2.a is m2.b + + def test_pytree_flatten(self): + @struct.dataclass + class Tree: + a: int + b: str = struct.field(pytree_node=False) + + p = Tree(1, 'a') + + leaves, treedef = nnx.graph_utils._flatten_pytree(p) + fields = dict(leaves) + + assert 'a' in fields + assert 'b' not in fields + assert fields['a'] == 1 + + p2 = nnx.graph_utils._unflatten_pytree(leaves, treedef) + + assert isinstance(p2, Tree) + assert p2.a == 1 + + def test_pytree_node(self): + @struct.dataclass + class Tree: + a: nnx.Param[int] + b: str = struct.field(pytree_node=False) + + class Foo(nnx.Module): + def __init__(self): + self.tree = Tree(nnx.Param(1), 'a') + + m = Foo() + + state, static = m.split() + + assert 'tree' in state + assert 'a' in state.tree + assert static.subgraphs['tree'].type is nnx.graph_utils.PytreeType + + m2 = static.merge(state) + + assert isinstance(m2.tree, Tree) + assert m2.tree.a.raw_value == 1 + assert m2.tree.b == 'a' + assert m2.tree.a is not m.tree.a + assert m2.tree is not m.tree