Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add pytree support #3732

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 103 additions & 73 deletions flax/experimental/nnx/nnx/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,90 +43,90 @@


@dataclasses.dataclass(frozen=True)
class NodeImpl(tp.Generic[Node, Leaf, AuxData]):
class NodeImplBase(tp.Generic[Node, Leaf, AuxData]):
type: type
flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[str, Leaf]], AuxData]]
set_key: tp.Callable[[Node, str, Leaf], None] | None
pop_key: tp.Callable[[Node, str], Leaf] | None
unflatten: tp.Callable[[tuple[tuple[str, Leaf], ...], AuxData], Node] | None
create_empty: tp.Callable[[AuxData], Node] | None
init: tp.Callable[[Node, tuple[tuple[str, Leaf], ...]], None] | None

def node_dict(self, node: Node) -> dict[str, Leaf]:
nodes, _ = self.flatten(node)
return dict(nodes)


@tp.overload
def register_node_type(
@dataclasses.dataclass(frozen=True)
class MutableNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
set_key: tp.Callable[[Node, str, Leaf], None]
pop_key: tp.Callable[[Node, str], Leaf]
create_empty: tp.Callable[[AuxData], Node]

def init(self, node: Node, items: tuple[tuple[str, Leaf], ...]):
for key, value in items:
self.set_key(node, key, value)


@dataclasses.dataclass(frozen=True)
class ImmutableNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
unflatten: tp.Callable[[tuple[tuple[str, Leaf], ...], AuxData], Node]


NodeImpl = tp.Union[
MutableNodeImpl[Node, Leaf, AuxData], ImmutableNodeImpl[Node, Leaf, AuxData]
]


def register_immutable_node_type(
type: type,
flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[str, Leaf]], AuxData]],
*,
unflatten: tp.Callable[[tuple[tuple[str, Leaf], ...], AuxData], Node],
):
...
NODE_TYPES[type] = ImmutableNodeImpl(
type=type, flatten=flatten, unflatten=unflatten
)


@tp.overload
def register_node_type(
def register_mutable_node_type(
type: type,
flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[str, Leaf]], AuxData]],
*,
set_key: tp.Callable[[Node, str, Leaf], None],
pop_key: tp.Callable[[Node, str], Leaf],
create_empty: tp.Callable[[AuxData], Node],
init: tp.Callable[[Node, tuple[tuple[str, Leaf], ...]], None],
):
...


def register_node_type(
type: type,
flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[str, Leaf]], AuxData]],
*,
set_key: tp.Callable[[Node, str, Leaf], None] | None = None,
pop_key: tp.Callable[[Node, str], Leaf] | None = None,
unflatten: tp.Callable[[tuple[tuple[str, Leaf], ...], AuxData], Node]
| None = None,
create_empty: tp.Callable[[AuxData], Node] | None = None,
init: tp.Callable[[Node, tuple[tuple[str, Leaf], ...]], None] | None = None,
):
if type in NODE_TYPES:
raise ValueError(f"Node type '{type}' already registered.")
NODE_TYPES[type] = NodeImpl(
NODE_TYPES[type] = MutableNodeImpl(
type=type,
flatten=flatten,
set_key=set_key,
pop_key=pop_key,
unflatten=unflatten,
create_empty=create_empty,
init=init,
)


def is_node(x: tp.Any) -> bool:
return type(x) in NODE_TYPES
if isinstance(x, Variable):
return False
elif type(x) in NODE_TYPES:
return True
return is_pytree_node(x)


def is_node_type(x: type[tp.Any]) -> bool:
return x in NODE_TYPES


@tp.overload
def get_node_impl(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]:
...
def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]:
if isinstance(x, Variable):
raise ValueError(f'Variable is not a node: {x}')

node_type = type(x)

if node_type not in NODE_TYPES:
if is_pytree_node(x):
node_type = PytreeType
else:
raise ValueError(f'Unknown node type: {x}')

@tp.overload
def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]:
...
return NODE_TYPES[node_type]


def get_node_impl(x: type[Node] | Node) -> NodeImpl[Node, tp.Any, tp.Any]:
if not isinstance(x, type):
x = type(x)
if not is_node_type(x):
raise ValueError(f'Unknown node type: {x}')
def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]:
return NODE_TYPES[x]


Expand All @@ -147,12 +147,12 @@ def __len__(self) -> int:
return len(self._mapping)

def __hash__(self) -> int:
return hash(tuple(self._mapping.items()))
return hash(tuple(sorted(self._mapping.items())))

def __eq__(self, other: tp.Any) -> bool:
if not isinstance(other, tp.Mapping):
return False
return self._mapping == other
return (
isinstance(other, _HashableMapping) and self._mapping == other._mapping
)

def __repr__(self) -> str:
return repr(self._mapping)
Expand Down Expand Up @@ -253,7 +253,7 @@ def __init__(
variables: tp.Iterable[tuple[str, VariableDef | int]],
metadata: tp.Any,
):
self._type = type
self._type: type[Node] = type
self._index = index
self._attributes = attributes
self._subgraphs = _HashableMapping(subgraphs)
Expand Down Expand Up @@ -417,7 +417,7 @@ def _graph_flatten(
static_fields.append((key, value))

graphdef = GraphDef(
type=type(node),
type=node_impl.type,
index=index,
attributes=tuple(key for key, _ in values),
subgraphs=subgraphs,
Expand Down Expand Up @@ -449,7 +449,7 @@ def _graph_unflatten(

# TODO(cgarciae): why copy here?
state = state.copy()
node_impl = get_node_impl(graphdef.type)
node_impl = get_node_impl_for_type(graphdef.type)

def _get_children():
new_state: dict[str, tp.Any] = {}
Expand Down Expand Up @@ -520,8 +520,7 @@ def _get_children():

return new_state

if node_impl.create_empty:
assert node_impl.init is not None
if isinstance(node_impl, MutableNodeImpl):
# we create an empty node first and add it to the index
# this avoids infinite recursion when there is a reference cycle
node = node_impl.create_empty(graphdef.metadata)
Expand All @@ -531,7 +530,6 @@ def _get_children():
else:
# if the node type does not support the creation of an empty object it means
# that it cannot reference itself, so we can create its children first
assert node_impl.unflatten is not None
children = _get_children()
node = node_impl.unflatten(tuple(children.items()), graphdef.metadata)
index_to_node[graphdef.index] = node
Expand Down Expand Up @@ -581,7 +579,7 @@ def _graph_pop(
node_impl = get_node_impl(node)
for state, predicate in zip(states, predicates):
if predicate(path, value):
if node_impl.pop_key is None:
if isinstance(node_impl, ImmutableNodeImpl):
raise ValueError(
f'Cannot pop key {name!r} from node of type {type(node).__name__}'
)
Expand Down Expand Up @@ -621,7 +619,7 @@ def _graph_update_dynamic(
for key, value in state.items():
# case 1: new state is being added
if key not in node_dict:
if node_impl.set_key is None:
if isinstance(node_impl, ImmutableNodeImpl):
raise ValueError(
f'Cannot set key {key!r} on immutable node of '
f'type {type(node).__name__}'
Expand Down Expand Up @@ -717,7 +715,7 @@ def _graph_update_static(
)
else:
# case 3: adding a new subgraph
if node_impl.set_key is None:
if isinstance(node_impl, ImmutableNodeImpl):
raise ValueError(
f'Cannot set key {name!r} on immutable node of '
f'type {type(node).__name__}'
Expand All @@ -736,7 +734,7 @@ def _graph_update_static(

node_impl.set_key(node, name, value_updates)
else: # static field
if node_impl.set_key is None:
if isinstance(node_impl, ImmutableNodeImpl):
if name in node_dict and node_dict[name] == value_updates:
# if the value is the same, skip
continue
Expand Down Expand Up @@ -799,17 +797,12 @@ def _create_empty_dict(metadata: None) -> dict[str, tp.Any]:
return {}


def _init_dict(node: dict[str, tp.Any], items: tuple[tuple[str, tp.Any], ...]):
node.update(items)


register_node_type(
register_mutable_node_type(
dict,
flatten=_flatten_dict,
set_key=_set_key_dict,
pop_key=_pop_key_dict,
create_empty=_create_empty_dict,
init=_init_dict,
)


Expand Down Expand Up @@ -838,18 +831,12 @@ def _create_empty_list(length: int) -> list[tp.Any]:
return [EMPTY] * length


def _init_list(node: list[tp.Any], items: tuple[tuple[str, tp.Any], ...]):
for key, value in items:
_set_key_list(node, key, value)


register_node_type(
register_mutable_node_type(
type=list,
flatten=_flatten_list,
set_key=_set_key_list,
pop_key=_pop_key_list,
create_empty=_create_empty_list,
init=_init_list,
)


Expand All @@ -869,8 +856,51 @@ def _unflatten_tuple(
return tuple(node)


register_node_type(
register_immutable_node_type(
type=tuple,
flatten=_flatten_tuple,
unflatten=_unflatten_tuple,
)


# Pytree
class PytreeType:
pass


def is_pytree_node(x: tp.Any) -> bool:
return not jax.tree_util.all_leaves([x])


def _key_path_to_str(key: tp.Any) -> str:
if isinstance(key, jax.tree_util.SequenceKey):
return str(key.idx)
elif isinstance(
key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey)
):
return str(key.key)
elif isinstance(key, jax.tree_util.GetAttrKey):
return key.name
else:
return str(key)


def _flatten_pytree(pytree: tp.Any):
leaves, treedef = jax.tree_util.tree_flatten_with_path(
pytree, is_leaf=lambda x: x is not pytree
)
nodes = tuple((_key_path_to_str(path[0]), value) for path, value in leaves)

return nodes, treedef


def _unflatten_pytree(
nodes: tuple[tuple[str, tp.Any], ...], treedef: jax.tree_util.PyTreeDef
):
pytree = treedef.unflatten(value for _, value in nodes)
return pytree


register_immutable_node_type(
PytreeType, flatten=_flatten_pytree, unflatten=_unflatten_pytree
)
7 changes: 1 addition & 6 deletions flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,13 +455,12 @@ def modules(self) -> tp.Iterator[tuple[Path, Module]]:
def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
super().__init_subclass__()

graph_utils.register_node_type(
graph_utils.register_mutable_node_type(
type=cls,
flatten=_module_graph_flatten,
set_key=_module_graph_set_key,
pop_key=_module_graph_pop_key,
create_empty=_module_graph_create_empty,
init=_module_graph_init,
)

if experimental_pytree:
Expand Down Expand Up @@ -532,10 +531,6 @@ def _module_graph_create_empty(cls: tp.Type[M]) -> M:
return module


def _module_graph_init(node: Module, items: tuple[tuple[str, tp.Any], ...]):
vars(node).update(items)


def first_from(*args: tp.Optional[A], error_msg: str) -> A:
"""Return the first non-None argument.

Expand Down
Loading
Loading