diff --git a/flax/experimental/nnx/nnx/graph_utils.py b/flax/experimental/nnx/nnx/graph_utils.py index 974dcf7c63..619ad0b03a 100644 --- a/flax/experimental/nnx/nnx/graph_utils.py +++ b/flax/experimental/nnx/nnx/graph_utils.py @@ -43,90 +43,90 @@ @dataclasses.dataclass(frozen=True) -class NodeImpl(tp.Generic[Node, Leaf, AuxData]): +class NodeImplBase(tp.Generic[Node, Leaf, AuxData]): type: type flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[str, Leaf]], AuxData]] - set_key: tp.Callable[[Node, str, Leaf], None] | None - pop_key: tp.Callable[[Node, str], Leaf] | None - unflatten: tp.Callable[[tuple[tuple[str, Leaf], ...], AuxData], Node] | None - create_empty: tp.Callable[[AuxData], Node] | None - init: tp.Callable[[Node, tuple[tuple[str, Leaf], ...]], None] | None def node_dict(self, node: Node) -> dict[str, Leaf]: nodes, _ = self.flatten(node) return dict(nodes) -@tp.overload -def register_node_type( +@dataclasses.dataclass(frozen=True) +class MutableNodeImpl(NodeImplBase[Node, Leaf, AuxData]): + set_key: tp.Callable[[Node, str, Leaf], None] + pop_key: tp.Callable[[Node, str], Leaf] + create_empty: tp.Callable[[AuxData], Node] + + def init(self, node: Node, items: tuple[tuple[str, Leaf], ...]): + for key, value in items: + self.set_key(node, key, value) + + +@dataclasses.dataclass(frozen=True) +class ImmutableNodeImpl(NodeImplBase[Node, Leaf, AuxData]): + unflatten: tp.Callable[[tuple[tuple[str, Leaf], ...], AuxData], Node] + + +NodeImpl = tp.Union[ + MutableNodeImpl[Node, Leaf, AuxData], ImmutableNodeImpl[Node, Leaf, AuxData] +] + + +def register_immutable_node_type( type: type, flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[str, Leaf]], AuxData]], - *, unflatten: tp.Callable[[tuple[tuple[str, Leaf], ...], AuxData], Node], ): - ... + NODE_TYPES[type] = ImmutableNodeImpl( + type=type, flatten=flatten, unflatten=unflatten + ) -@tp.overload -def register_node_type( +def register_mutable_node_type( type: type, flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[str, Leaf]], AuxData]], - *, set_key: tp.Callable[[Node, str, Leaf], None], pop_key: tp.Callable[[Node, str], Leaf], create_empty: tp.Callable[[AuxData], Node], - init: tp.Callable[[Node, tuple[tuple[str, Leaf], ...]], None], -): - ... - - -def register_node_type( - type: type, - flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[str, Leaf]], AuxData]], - *, - set_key: tp.Callable[[Node, str, Leaf], None] | None = None, - pop_key: tp.Callable[[Node, str], Leaf] | None = None, - unflatten: tp.Callable[[tuple[tuple[str, Leaf], ...], AuxData], Node] - | None = None, - create_empty: tp.Callable[[AuxData], Node] | None = None, - init: tp.Callable[[Node, tuple[tuple[str, Leaf], ...]], None] | None = None, ): - if type in NODE_TYPES: - raise ValueError(f"Node type '{type}' already registered.") - NODE_TYPES[type] = NodeImpl( + NODE_TYPES[type] = MutableNodeImpl( type=type, flatten=flatten, set_key=set_key, pop_key=pop_key, - unflatten=unflatten, create_empty=create_empty, - init=init, ) def is_node(x: tp.Any) -> bool: - return type(x) in NODE_TYPES + if isinstance(x, Variable): + return False + elif type(x) in NODE_TYPES: + return True + return is_pytree_node(x) def is_node_type(x: type[tp.Any]) -> bool: return x in NODE_TYPES -@tp.overload -def get_node_impl(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]: - ... +def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]: + if isinstance(x, Variable): + raise ValueError(f'Variable is not a node: {x}') + + node_type = type(x) + if node_type not in NODE_TYPES: + if is_pytree_node(x): + node_type = PytreeType + else: + raise ValueError(f'Unknown node type: {x}') -@tp.overload -def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]: - ... + return NODE_TYPES[node_type] -def get_node_impl(x: type[Node] | Node) -> NodeImpl[Node, tp.Any, tp.Any]: - if not isinstance(x, type): - x = type(x) - if not is_node_type(x): - raise ValueError(f'Unknown node type: {x}') +def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]: return NODE_TYPES[x] @@ -147,12 +147,12 @@ def __len__(self) -> int: return len(self._mapping) def __hash__(self) -> int: - return hash(tuple(self._mapping.items())) + return hash(tuple(sorted(self._mapping.items()))) def __eq__(self, other: tp.Any) -> bool: - if not isinstance(other, tp.Mapping): - return False - return self._mapping == other + return ( + isinstance(other, _HashableMapping) and self._mapping == other._mapping + ) def __repr__(self) -> str: return repr(self._mapping) @@ -253,7 +253,7 @@ def __init__( variables: tp.Iterable[tuple[str, VariableDef | int]], metadata: tp.Any, ): - self._type = type + self._type: type[Node] = type self._index = index self._attributes = attributes self._subgraphs = _HashableMapping(subgraphs) @@ -417,7 +417,7 @@ def _graph_flatten( static_fields.append((key, value)) graphdef = GraphDef( - type=type(node), + type=node_impl.type, index=index, attributes=tuple(key for key, _ in values), subgraphs=subgraphs, @@ -449,7 +449,7 @@ def _graph_unflatten( # TODO(cgarciae): why copy here? state = state.copy() - node_impl = get_node_impl(graphdef.type) + node_impl = get_node_impl_for_type(graphdef.type) def _get_children(): new_state: dict[str, tp.Any] = {} @@ -520,8 +520,7 @@ def _get_children(): return new_state - if node_impl.create_empty: - assert node_impl.init is not None + if isinstance(node_impl, MutableNodeImpl): # we create an empty node first and add it to the index # this avoids infinite recursion when there is a reference cycle node = node_impl.create_empty(graphdef.metadata) @@ -531,7 +530,6 @@ def _get_children(): else: # if the node type does not support the creation of an empty object it means # that it cannot reference itself, so we can create its children first - assert node_impl.unflatten is not None children = _get_children() node = node_impl.unflatten(tuple(children.items()), graphdef.metadata) index_to_node[graphdef.index] = node @@ -581,7 +579,7 @@ def _graph_pop( node_impl = get_node_impl(node) for state, predicate in zip(states, predicates): if predicate(path, value): - if node_impl.pop_key is None: + if isinstance(node_impl, ImmutableNodeImpl): raise ValueError( f'Cannot pop key {name!r} from node of type {type(node).__name__}' ) @@ -621,7 +619,7 @@ def _graph_update_dynamic( for key, value in state.items(): # case 1: new state is being added if key not in node_dict: - if node_impl.set_key is None: + if isinstance(node_impl, ImmutableNodeImpl): raise ValueError( f'Cannot set key {key!r} on immutable node of ' f'type {type(node).__name__}' @@ -717,7 +715,7 @@ def _graph_update_static( ) else: # case 3: adding a new subgraph - if node_impl.set_key is None: + if isinstance(node_impl, ImmutableNodeImpl): raise ValueError( f'Cannot set key {name!r} on immutable node of ' f'type {type(node).__name__}' @@ -736,7 +734,7 @@ def _graph_update_static( node_impl.set_key(node, name, value_updates) else: # static field - if node_impl.set_key is None: + if isinstance(node_impl, ImmutableNodeImpl): if name in node_dict and node_dict[name] == value_updates: # if the value is the same, skip continue @@ -799,17 +797,12 @@ def _create_empty_dict(metadata: None) -> dict[str, tp.Any]: return {} -def _init_dict(node: dict[str, tp.Any], items: tuple[tuple[str, tp.Any], ...]): - node.update(items) - - -register_node_type( +register_mutable_node_type( dict, flatten=_flatten_dict, set_key=_set_key_dict, pop_key=_pop_key_dict, create_empty=_create_empty_dict, - init=_init_dict, ) @@ -838,18 +831,12 @@ def _create_empty_list(length: int) -> list[tp.Any]: return [EMPTY] * length -def _init_list(node: list[tp.Any], items: tuple[tuple[str, tp.Any], ...]): - for key, value in items: - _set_key_list(node, key, value) - - -register_node_type( +register_mutable_node_type( type=list, flatten=_flatten_list, set_key=_set_key_list, pop_key=_pop_key_list, create_empty=_create_empty_list, - init=_init_list, ) @@ -869,8 +856,51 @@ def _unflatten_tuple( return tuple(node) -register_node_type( +register_immutable_node_type( type=tuple, flatten=_flatten_tuple, unflatten=_unflatten_tuple, ) + + +# Pytree +class PytreeType: + pass + + +def is_pytree_node(x: tp.Any) -> bool: + return not jax.tree_util.all_leaves([x]) + + +def _key_path_to_str(key: tp.Any) -> str: + if isinstance(key, jax.tree_util.SequenceKey): + return str(key.idx) + elif isinstance( + key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey) + ): + return str(key.key) + elif isinstance(key, jax.tree_util.GetAttrKey): + return key.name + else: + return str(key) + + +def _flatten_pytree(pytree: tp.Any): + leaves, treedef = jax.tree_util.tree_flatten_with_path( + pytree, is_leaf=lambda x: x is not pytree + ) + nodes = tuple((_key_path_to_str(path[0]), value) for path, value in leaves) + + return nodes, treedef + + +def _unflatten_pytree( + nodes: tuple[tuple[str, tp.Any], ...], treedef: jax.tree_util.PyTreeDef +): + pytree = treedef.unflatten(value for _, value in nodes) + return pytree + + +register_immutable_node_type( + PytreeType, flatten=_flatten_pytree, unflatten=_unflatten_pytree +) diff --git a/flax/experimental/nnx/nnx/module.py b/flax/experimental/nnx/nnx/module.py index a4d69375f1..431185a54e 100644 --- a/flax/experimental/nnx/nnx/module.py +++ b/flax/experimental/nnx/nnx/module.py @@ -455,13 +455,12 @@ def modules(self) -> tp.Iterator[tuple[Path, Module]]: def __init_subclass__(cls, experimental_pytree: bool = False) -> None: super().__init_subclass__() - graph_utils.register_node_type( + graph_utils.register_mutable_node_type( type=cls, flatten=_module_graph_flatten, set_key=_module_graph_set_key, pop_key=_module_graph_pop_key, create_empty=_module_graph_create_empty, - init=_module_graph_init, ) if experimental_pytree: @@ -532,10 +531,6 @@ def _module_graph_create_empty(cls: tp.Type[M]) -> M: return module -def _module_graph_init(node: Module, items: tuple[tuple[str, tp.Any], ...]): - vars(node).update(items) - - def first_from(*args: tp.Optional[A], error_msg: str) -> A: """Return the first non-None argument. diff --git a/flax/experimental/nnx/tests/test_graph_utils.py b/flax/experimental/nnx/tests/test_graph_utils.py index 8ee03eb0ca..826763d3e6 100644 --- a/flax/experimental/nnx/tests/test_graph_utils.py +++ b/flax/experimental/nnx/tests/test_graph_utils.py @@ -16,6 +16,7 @@ import pytest from flax.experimental import nnx +from flax import struct class TestGraphUtils: @@ -230,3 +231,49 @@ def __init__(self): assert m2.a.value == state.a.raw_value assert m2.b.value == state.a.raw_value assert m2.a is m2.b + + def test_pytree_flatten(self): + @struct.dataclass + class Tree: + a: int + b: str = struct.field(pytree_node=False) + + p = Tree(1, 'a') + + leaves, treedef = nnx.graph_utils._flatten_pytree(p) + fields = dict(leaves) + + assert 'a' in fields + assert 'b' not in fields + assert fields['a'] == 1 + + p2 = nnx.graph_utils._unflatten_pytree(leaves, treedef) + + assert isinstance(p2, Tree) + assert p2.a == 1 + + def test_pytree_node(self): + @struct.dataclass + class Tree: + a: nnx.Param[int] + b: str = struct.field(pytree_node=False) + + class Foo(nnx.Module): + def __init__(self): + self.tree = Tree(nnx.Param(1), 'a') + + m = Foo() + + state, static = m.split() + + assert 'tree' in state + assert 'a' in state.tree + assert static.subgraphs['tree'].type is nnx.graph_utils.PytreeType + + m2 = static.merge(state) + + assert isinstance(m2.tree, Tree) + assert m2.tree.a.raw_value == 1 + assert m2.tree.b == 'a' + assert m2.tree.a is not m.tree.a + assert m2.tree is not m.tree