Skip to content

Commit

Permalink
[nnx] add some optimizations to graph.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 16, 2024
1 parent 7fd97dc commit ba57524
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 29 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ flaxlib_src/build
flaxlib_src/builddir
flaxlib_src/dist
flaxlib_src/subprojects

target/
flaxlib.cpython-*
# used by direnv
.envrc

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,23 @@

FLAGS = flags.FLAGS
flags.DEFINE_enum('mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in')
flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps')
flags.DEFINE_integer('total_steps', 100, 'Total number of training steps')
flags.DEFINE_integer('width', 32, 'Hidden layer size')
flags.DEFINE_integer('depth', 5, 'Depth of the model')



class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.list = [
nnx.Param(jax.random.uniform(rngs.params(), (din, dout))),
nnx.Param(jnp.zeros((dout,))),
]
self.dict = {
'w': nnx.Param(jax.random.uniform(rngs.params(), (din, dout))),
'b': nnx.Param(jnp.zeros((dout,))),
}

def __call__(self, x):
return x @ self.w + self.b


class MLP(nnx.Module):
Expand Down
File renamed without changes.
101 changes: 79 additions & 22 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,17 @@ def __str__(self) -> str:
return repr(self)


@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, slots=True)
class NodeImplBase(tp.Generic[Node, Leaf, AuxData]):
type: type
type: type[Node]
flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]]

def node_dict(self, node: Node) -> dict[Key, Leaf]:
nodes, _ = self.flatten(node)
return dict(nodes)


@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, slots=True)
class GraphNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
set_key: tp.Callable[[Node, Key, Leaf], None]
pop_key: tp.Callable[[Node, Key], Leaf]
Expand All @@ -116,7 +116,7 @@ def init(self, node: Node, items: tuple[tuple[Key, Leaf], ...]):
self.set_key(node, key, value)


@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, slots=True)
class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node]

Expand All @@ -126,7 +126,8 @@ class PytreeNodeImpl(NodeImplBase[Node, Leaf, AuxData]):
]


_node_impl_for_type: dict[type, NodeImpl[tp.Any, tp.Any, tp.Any]] = {}
GRAPH_REGISTRY: dict[type, NodeImpl[tp.Any, tp.Any, tp.Any]] = {}
PYTREE_REGISTRY: dict[type, PytreeNodeImpl[tp.Any, tp.Any, tp.Any]] = {}


def register_graph_node_type(
Expand All @@ -137,7 +138,10 @@ def register_graph_node_type(
create_empty: tp.Callable[[AuxData], Node],
clear: tp.Callable[[Node], None],
):
_node_impl_for_type[type] = GraphNodeImpl(
if type in GRAPH_REGISTRY:
raise ValueError(f'Node type {type} is already registered.')

GRAPH_REGISTRY[type] = GraphNodeImpl(
type=type,
flatten=flatten,
set_key=set_key,
Expand All @@ -146,19 +150,30 @@ def register_graph_node_type(
clear=clear,
)

def register_pytree_node_type(
type: type,
flatten: tp.Callable[[Node], tuple[tp.Sequence[tuple[Key, Leaf]], AuxData]],
unflatten: tp.Callable[[tuple[tuple[Key, Leaf], ...], AuxData], Node],
):
if type in PYTREE_REGISTRY:
raise ValueError(f'Node type {type} is already registered.')

PYTREE_REGISTRY[type] = PytreeNodeImpl(
type=type, flatten=flatten, unflatten=unflatten
)

def is_node(x: tp.Any) -> bool:
if type(x) in _node_impl_for_type:
if type(x) in GRAPH_REGISTRY:
return True
return is_pytree_node(x)


def is_graph_node(x: tp.Any) -> bool:
return type(x) in _node_impl_for_type
return type(x) in GRAPH_REGISTRY


def is_node_type(x: type[tp.Any]) -> bool:
return x in _node_impl_for_type or x is PytreeType
return x in GRAPH_REGISTRY or x in PYTREE_REGISTRY or x is GenericPytree


def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]:
Expand All @@ -167,19 +182,23 @@ def get_node_impl(x: Node) -> NodeImpl[Node, tp.Any, tp.Any]:

node_type = type(x)

if node_type not in _node_impl_for_type:
if is_pytree_node(x):
return PYTREE_NODE_IMPL
else:
raise ValueError(f'Unknown node type: {x}')

return _node_impl_for_type[node_type]
if node_type in GRAPH_REGISTRY:
return GRAPH_REGISTRY[node_type]
elif node_type in PYTREE_REGISTRY:
return PYTREE_REGISTRY[node_type]
elif is_pytree_node(x):
return PYTREE_NODE_IMPL # type: ignore
else:
raise ValueError(f'Unknown node type: {x}')


def get_node_impl_for_type(x: type[Node]) -> NodeImpl[Node, tp.Any, tp.Any]:
if x is PytreeType:
return PYTREE_NODE_IMPL
return _node_impl_for_type[x]
if x is GenericPytree:
return PYTREE_NODE_IMPL # type: ignore
elif x in PYTREE_REGISTRY:
return PYTREE_REGISTRY[x]
else:
return GRAPH_REGISTRY[x]


class HashableMapping(tp.Mapping[HA, HB], tp.Hashable):
Expand Down Expand Up @@ -1751,11 +1770,23 @@ class Static(tp.Generic[A]):
# ---------------------------------------------------------
# Pytree
# ---------------------------------------------------------
class PytreeType: ...
class GenericPytree: ...


def is_pytree_node(x: tp.Any) -> bool:
return not jax.tree_util.all_leaves((x,))
t = type(x)
if t in PYTREE_REGISTRY:
return True
elif t in GRAPH_REGISTRY:
return False
# known non-pytree types
elif isinstance(x, Variable):
return False
# knon pytree types
elif isinstance(x, (VariableState, State)):
return True
else:
return not jax.tree_util.all_leaves((x,))


def _key_path_to_key(key: tp.Any) -> Key:
Expand Down Expand Up @@ -1792,7 +1823,33 @@ def _unflatten_pytree(


PYTREE_NODE_IMPL = PytreeNodeImpl(
type=PytreeType,
type=GenericPytree,
flatten=_flatten_pytree,
unflatten=_unflatten_pytree,
)

# common pytrees
# list
register_pytree_node_type(
list,
flatten=lambda x: (list(enumerate(x)), None),
unflatten=lambda nodes, _: [value for _, value in nodes], # type: ignore
)
# tuple
register_pytree_node_type(
tuple,
flatten=lambda x: (list(enumerate(x)), None),
unflatten=lambda nodes, _: tuple(value for _, value in nodes), # type: ignore
)
# dict
register_pytree_node_type(
dict,
flatten=lambda x: (sorted(x.items()), None),
unflatten=lambda nodes, _: {key: value for key, value in nodes}, # type: ignore
)
# None
register_pytree_node_type(
type(None),
flatten=lambda x: ([], None),
unflatten=lambda _, __: None, # type: ignore
)
2 changes: 1 addition & 1 deletion tests/nnx/graph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def __init__(self):

assert 'tree' in state
assert 'a' in state.tree
assert graphdef.subgraphs['tree'].type is nnx.graph.PytreeType
assert graphdef.subgraphs['tree'].type is nnx.graph.GenericPytree

m2 = nnx.merge(graphdef, state)

Expand Down

0 comments on commit ba57524

Please sign in to comment.