diff --git a/.gitignore b/.gitignore index 0bc7f3cbe6..a1de4b1dda 100644 --- a/.gitignore +++ b/.gitignore @@ -17,7 +17,8 @@ flaxlib_src/build flaxlib_src/builddir flaxlib_src/dist flaxlib_src/subprojects - +target/ +flaxlib.cpython-* # used by direnv .envrc diff --git a/flax/nnx/benchmarks/graph_overhead.py b/benchmarks/nnx_graph_overhead.py similarity index 89% rename from flax/nnx/benchmarks/graph_overhead.py rename to benchmarks/nnx_graph_overhead.py index 19d908e751..73cff6d6d6 100644 --- a/flax/nnx/benchmarks/graph_overhead.py +++ b/benchmarks/nnx_graph_overhead.py @@ -26,7 +26,7 @@ FLAGS = flags.FLAGS flags.DEFINE_enum('mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in') -flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps') +flags.DEFINE_integer('total_steps', 100, 'Total number of training steps') flags.DEFINE_integer('width', 32, 'Hidden layer size') flags.DEFINE_integer('depth', 5, 'Depth of the model') @@ -34,11 +34,15 @@ class Linear(nnx.Module): def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): - self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) - self.b = nnx.Param(jnp.zeros((dout,))) + self.list = [ + nnx.Param(jax.random.uniform(rngs.params(), (din, dout))), + nnx.Param(jnp.zeros((dout,))), + ] + self.dict = { + 'w': nnx.Param(jax.random.uniform(rngs.params(), (din, dout))), + 'b': nnx.Param(jnp.zeros((dout,))), + } - def __call__(self, x): - return x @ self.w + self.b class MLP(nnx.Module): diff --git a/flax/nnx/benchmarks/simple_training.py b/benchmarks/nnx_simple_training.py similarity index 100% rename from flax/nnx/benchmarks/simple_training.py rename to benchmarks/nnx_simple_training.py diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py index 2a92b8b5ad..fec21add20 100644 --- a/flax/nnx/graph.py +++ b/flax/nnx/graph.py @@ -94,9 +94,9 @@ def __str__(self) -> str: return repr(self) -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, slots=True) class NodeImplBase(tp.Generic[Node, Leaf, AuxData]): - type: type + type: type[Node] flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]] def node_dict(self, node: Node) -> dict[Key, Leaf]: @@ -104,7 +104,7 @@ def node_dict(self, node: Node) -> dict[Key, Leaf]: return dict(nodes) -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, slots=True) class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]): set_key: tp.Callable[[Node, Key, Leaf], None] pop_key: tp.Callable[[Node, Key], Leaf] @@ -116,7 +116,7 @@ def init(self, node: Node, items: tuple[tuple[Key, Leaf], ...]): self.set_key(node, key, value) -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, slots=True) class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]): unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node] @@ -126,7 +126,8 @@ class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]): ] -_node_impl_for_type: dict[type, NodeImpl[tp.Any, tp.Any, tp.Any]] = {} +GRAPH_REGISTRY: dict[type, NodeImpl[tp.Any, tp.Any, tp.Any]] = {} +PYTREE_REGISTRY: dict[type, PytreeNodeImpl[tp.Any, tp.Any, tp.Any]] = {} def register_graph_node_type( @@ -137,7 +138,10 @@ def register_graph_node_type( create_empty: tp.Callable[[AuxData], Node], clear: tp.Callable[[Node], None], ): - _node_impl_for_type[type] = GraphNodeImpl( + if type in GRAPH_REGISTRY: + raise ValueError(f'Node type {type} is already registered.') + + GRAPH_REGISTRY[type] = GraphNodeImpl( type=type, flatten=flatten, set_key=set_key, @@ -146,19 +150,30 @@ def register_graph_node_type( clear=clear, ) +def register_pytree_node_type( + type: type, + flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]], + unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node], +): + if type in PYTREE_REGISTRY: + raise ValueError(f'Node type {type} is already registered.') + + PYTREE_REGISTRY[type] = PytreeNodeImpl( + type=type, flatten=flatten, unflatten=unflatten + ) def is_node(x: tp.Any) -> bool: - if type(x) in _node_impl_for_type: + if type(x) in GRAPH_REGISTRY: return True return is_pytree_node(x) def is_graph_node(x: tp.Any) -> bool: - return type(x) in _node_impl_for_type + return type(x) in GRAPH_REGISTRY def is_node_type(x: type[tp.Any]) -> bool: - return x in _node_impl_for_type or x is PytreeType + return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]: @@ -167,19 +182,23 @@ def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]: node_type = type(x) - if node_type not in _node_impl_for_type: - if is_pytree_node(x): - return PYTREE_NODE_IMPL - else: - raise ValueError(f'Unknown node type: {x}') - - return _node_impl_for_type[node_type] + if node_type in GRAPH_REGISTRY: + return GRAPH_REGISTRY[node_type] + elif node_type in PYTREE_REGISTRY: + return PYTREE_REGISTRY[node_type] + elif is_pytree_node(x): + return PYTREE_NODE_IMPL # type: ignore + else: + raise ValueError(f'Unknown node type: {x}') def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]: - if x is PytreeType: - return PYTREE_NODE_IMPL - return _node_impl_for_type[x] + if x is GenericPytree: + return PYTREE_NODE_IMPL # type: ignore + elif x in PYTREE_REGISTRY: + return PYTREE_REGISTRY[x] + else: + return GRAPH_REGISTRY[x] class HashableMapping(tp.Mapping[HA, HB], tp.Hashable): @@ -1751,11 +1770,23 @@ class Static(tp.Generic[A]): # --------------------------------------------------------- # Pytree # --------------------------------------------------------- -class PytreeType: ... +class GenericPytree: ... def is_pytree_node(x: tp.Any) -> bool: - return not jax.tree_util.all_leaves((x,)) + t = type(x) + if t in PYTREE_REGISTRY: + return True + elif t in GRAPH_REGISTRY: + return False + # known non-pytree types + elif isinstance(x, Variable): + return False + # knon pytree types + elif isinstance(x, (VariableState, State)): + return True + else: + return not jax.tree_util.all_leaves((x,)) def _key_path_to_key(key: tp.Any) -> Key: @@ -1792,7 +1823,33 @@ def _unflatten_pytree( PYTREE_NODE_IMPL = PytreeNodeImpl( - type=PytreeType, + type=GenericPytree, flatten=_flatten_pytree, unflatten=_unflatten_pytree, ) + +# common pytrees +# list +register_pytree_node_type( + list, + flatten=lambda x: (list(enumerate(x)), None), + unflatten=lambda nodes, _: [value for _, value in nodes], # type: ignore +) +# tuple +register_pytree_node_type( + tuple, + flatten=lambda x: (list(enumerate(x)), None), + unflatten=lambda nodes, _: tuple(value for _, value in nodes), # type: ignore +) +# dict +register_pytree_node_type( + dict, + flatten=lambda x: (sorted(x.items()), None), + unflatten=lambda nodes, _: {key: value for key, value in nodes}, # type: ignore +) +# None +register_pytree_node_type( + type(None), + flatten=lambda x: ([], None), + unflatten=lambda _, __: None, # type: ignore +) \ No newline at end of file diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index 8983acbe7f..fb0496e07a 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -303,7 +303,7 @@ def __init__(self): assert 'tree' in state assert 'a' in state.tree - assert graphdef.subgraphs['tree'].type is nnx.graph.PytreeType + assert graphdef.subgraphs['tree'].type is nnx.graph.GenericPytree m2 = nnx.merge(graphdef, state)