Skip to content

Commit

Permalink
[nnx] add pytree support
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Mar 2, 2024
1 parent f3b57e1 commit 2850c08
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 19 deletions.
86 changes: 67 additions & 19 deletions flax/experimental/nnx/nnx/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,28 +105,33 @@ def register_node_type(


def is_node(x: tp.Any) -> bool:
return type(x) in NODE_TYPES
if isinstance(x, Variable):
return False
elif type(x) in NODE_TYPES:
return True
return is_pytree_node(x)


def is_node_type(x: type[tp.Any]) -> bool:
return x in NODE_TYPES


@tp.overload
def get_node_impl(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]:
...
def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]:
if isinstance(x, Variable):
raise ValueError(f'Variable is not a node: {x}')

node_type = type(x)

@tp.overload
def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]:
...
if node_type not in NODE_TYPES:
if is_pytree_node(x):
node_type = PytreeType
else:
raise ValueError(f'Unknown node type: {x}')

return NODE_TYPES[node_type]


def get_node_impl(x: type[Node] | Node) -> NodeImpl[Node, tp.Any, tp.Any]:
if not isinstance(x, type):
x = type(x)
if not is_node_type(x):
raise ValueError(f'Unknown node type: {x}')
def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]:
return NODE_TYPES[x]


Expand All @@ -147,12 +152,12 @@ def __len__(self) -> int:
return len(self._mapping)

def __hash__(self) -> int:
return hash(tuple(self._mapping.items()))
return hash(tuple(sorted(self._mapping.items())))

def __eq__(self, other: tp.Any) -> bool:
if not isinstance(other, tp.Mapping):
return False
return self._mapping == other
return (
isinstance(other, _HashableMapping) and self._mapping == other._mapping
)

def __repr__(self) -> str:
return repr(self._mapping)
Expand Down Expand Up @@ -253,7 +258,7 @@ def __init__(
variables: tp.Iterable[tuple[str, VariableDef | int]],
metadata: tp.Any,
):
self._type = type
self._type: type[Node] = type
self._index = index
self._attributes = attributes
self._subgraphs = _HashableMapping(subgraphs)
Expand Down Expand Up @@ -417,7 +422,7 @@ def _graph_flatten(
static_fields.append((key, value))

graphdef = GraphDef(
type=type(node),
type=node_impl.type,
index=index,
attributes=tuple(key for key, _ in values),
subgraphs=subgraphs,
Expand Down Expand Up @@ -449,7 +454,7 @@ def _graph_unflatten(

# TODO(cgarciae): why copy here?
state = state.copy()
node_impl = get_node_impl(graphdef.type)
node_impl = get_node_impl_for_type(graphdef.type)

def _get_children():
new_state: dict[str, tp.Any] = {}
Expand Down Expand Up @@ -874,3 +879,46 @@ def _unflatten_tuple(
flatten=_flatten_tuple,
unflatten=_unflatten_tuple,
)


# Pytree
class PytreeType:
pass


def is_pytree_node(x: tp.Any) -> bool:
return not jax.tree_util.all_leaves([x])


def _key_path_to_str(key: tp.Any) -> str:
if isinstance(key, jax.tree_util.SequenceKey):
return str(key.idx)
elif isinstance(
key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey)
):
return str(key.key)
elif isinstance(key, jax.tree_util.GetAttrKey):
return key.name
else:
return str(key)


def _flatten_pytree(pytree: tp.Any):
leaves, treedef = jax.tree_util.tree_flatten_with_path(
pytree, is_leaf=lambda x: x is not pytree
)
nodes = tuple((_key_path_to_str(path[0]), value) for path, value in leaves)

return nodes, treedef


def _unflatten_pytree(
nodes: tuple[tuple[str, tp.Any], ...], treedef: jax.tree_util.PyTreeDef
):
pytree = treedef.unflatten(value for _, value in nodes)
return pytree


register_node_type(
PytreeType, flatten=_flatten_pytree, unflatten=_unflatten_pytree
)
47 changes: 47 additions & 0 deletions flax/experimental/nnx/tests/test_graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pytest

from flax.experimental import nnx
from flax import struct


class TestGraphUtils:
Expand Down Expand Up @@ -230,3 +231,49 @@ def __init__(self):
assert m2.a.value == state.a.raw_value
assert m2.b.value == state.a.raw_value
assert m2.a is m2.b

def test_pytree_flatten(self):
@struct.dataclass
class Tree:
a: int
b: str = struct.field(pytree_node=False)

p = Tree(1, 'a')

leaves, treedef = nnx.graph_utils._flatten_pytree(p)
fields = dict(leaves)

assert 'a' in fields
assert 'b' not in fields
assert fields['a'] == 1

p2 = nnx.graph_utils._unflatten_pytree(leaves, treedef)

assert isinstance(p2, Tree)
assert p2.a == 1

def test_pytree_node(self):
@struct.dataclass
class Tree:
a: nnx.Param[int]
b: str = struct.field(pytree_node=False)

class Foo(nnx.Module):
def __init__(self):
self.tree = Tree(nnx.Param(1), 'a')

m = Foo()

state, static = m.split()

assert 'tree' in state
assert 'a' in state.tree
assert static.subgraphs['tree'].type is nnx.graph_utils.PytreeType

m2 = static.merge(state)

assert isinstance(m2.tree, Tree)
assert m2.tree.a.raw_value == 1
assert m2.tree.b == 'a'
assert m2.tree.a is not m.tree.a
assert m2.tree is not m.tree

0 comments on commit 2850c08

Please sign in to comment.