-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support linen <-> nnx metadata box converging in nnx.bridge
#4145
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,138 @@ | ||||||
# Copyright 2024 The Flax Authors. | ||||||
# | ||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
# you may not use this file except in compliance with the License. | ||||||
# You may obtain a copy of the License at | ||||||
# | ||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||
# | ||||||
# Unless required by applicable law or agreed to in writing, software | ||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
|
||||||
from typing import Any, TypeVar | ||||||
|
||||||
import jax | ||||||
from flax import struct | ||||||
from flax.core import meta | ||||||
from flax.nnx.nnx import variables as variableslib | ||||||
import typing as tp | ||||||
|
||||||
|
||||||
A = TypeVar('A') | ||||||
B = TypeVar('B') | ||||||
|
||||||
|
||||||
####################################################### | ||||||
### Variable type <-> Linen collection name mapping ### | ||||||
####################################################### | ||||||
# Assumption: the mapping is 1-1 and unique. | ||||||
|
||||||
VariableTypeCache: dict[str, tp.Type[variableslib.Variable[tp.Any]]] = {} | ||||||
|
||||||
|
||||||
def variable_type(name: str) -> tp.Type[variableslib.Variable[tp.Any]]: | ||||||
"""Given a Linen-style collection name, get or create its corresponding NNX Variable type.""" | ||||||
if name not in VariableTypeCache: | ||||||
VariableTypeCache[name] = type(name, (variableslib.Variable,), {}) | ||||||
return VariableTypeCache[name] | ||||||
|
||||||
|
||||||
def variable_type_name(typ: tp.Type[variableslib.Variable[tp.Any]]) -> str: | ||||||
"""Given an NNX Variable type, get or create its Linen-style collection name. | ||||||
|
||||||
Should output the exact inversed result of `variable_type()`.""" | ||||||
for name, t in VariableTypeCache.items(): | ||||||
if typ == t: | ||||||
return name | ||||||
name = typ.__name__ | ||||||
if name in VariableTypeCache: | ||||||
raise ValueError( | ||||||
'Name {name} is already registered in the registry as {VariableTypeCache[name]}. ' | ||||||
'It cannot be linked with this type {typ}.' | ||||||
) | ||||||
register_variable_name_type_pair(name, typ) | ||||||
return name | ||||||
|
||||||
|
||||||
def register_variable_name_type_pair(name, typ, overwrite = False): | ||||||
"""Register a pair of variable type name (like Linen collections) and its NNX type.""" | ||||||
if not overwrite and name in VariableTypeCache: | ||||||
raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. ' | ||||||
'To overwrite, call with `overwrite=True`.') | ||||||
VariableTypeCache[name] = typ | ||||||
|
||||||
|
||||||
# add known variable type names | ||||||
register_variable_name_type_pair('params', variableslib.Param) | ||||||
register_variable_name_type_pair('batch_stats', variableslib.BatchStat) | ||||||
register_variable_name_type_pair('cache', variableslib.Cache) | ||||||
register_variable_name_type_pair('intermediates', variableslib.Intermediate) | ||||||
|
||||||
|
||||||
def sort_variable_types(types: tp.Iterable[type]): | ||||||
def _variable_parents_count(t: type): | ||||||
return sum(1 for p in t.mro() if issubclass(p, variableslib.Variable)) | ||||||
parent_count = {t: _variable_parents_count(t) for t in types} | ||||||
return sorted(types, key=lambda t: -parent_count[t]) | ||||||
|
||||||
|
||||||
############################################# | ||||||
### NNX Variable <-> Linen metadata boxes ### | ||||||
############################################# | ||||||
|
||||||
|
||||||
class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]): | ||||||
"""Default Flax metadata class for `nnx.VariableState`. | ||||||
""" | ||||||
|
||||||
var_type: type[variableslib.Variable[tp.Any]] = struct.field(pytree_node=False) | ||||||
value: Any = struct.field(pytree_node=True) | ||||||
metadata: dict[str, tp.Any] = struct.field(pytree_node=False) | ||||||
|
||||||
def unbox(self) -> A: | ||||||
return self.value | ||||||
|
||||||
def replace_boxed(self, val: B) -> 'NNXMeta[B]': | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. String types not needed in python >= 3.10
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I got error trying to do this, saying |
||||||
return self.replace(value=val) # type: ignore | ||||||
|
||||||
def add_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]': | ||||||
# TODO: implement this, supporting hooks | ||||||
return self | ||||||
|
||||||
def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]': | ||||||
# TODO: implement this, supporting hooks | ||||||
return self | ||||||
|
||||||
|
||||||
def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata: | ||||||
metadata = vs.get_metadata() | ||||||
if 'linen_meta_type' in metadata: | ||||||
if metadata['linen_meta_type'] is not meta.Partitioned: | ||||||
raise ValueError('Not supporting Linen metadata types other than nn.Partitioned') | ||||||
return meta.Partitioned(vs.value, names=metadata['sharding'], mesh=metadata['mesh']) | ||||||
return NNXMeta(vs.type, vs.value, vs.get_metadata()) | ||||||
|
||||||
|
||||||
def get_col_name(keypath: tp.Sequence[Any]) -> str: | ||||||
"""Given the keypath of a Flax variable type, return its Linen collection name.""" | ||||||
# Infer variable type from the leaf's path, which contains its Linen collection name | ||||||
assert isinstance(keypath[0], jax.tree_util.DictKey) | ||||||
return str(keypath[0].key) | ||||||
|
||||||
|
||||||
def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable: | ||||||
"""Convert a Linen variable to an NNX variable. | ||||||
This process needs the collection name, | ||||||
""" | ||||||
vtype = variable_type(col) | ||||||
if isinstance(x, NNXMeta): | ||||||
assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}' | ||||||
return x.var_type(x.value, **x.metadata) | ||||||
if isinstance(x, meta.AxisMetadata): | ||||||
if isinstance(x, meta.Partitioned): | ||||||
return vtype(x.value, sharding=x.names, mesh=x.mesh, linen_meta_type=meta.Partitioned) | ||||||
raise ValueError('Not yet supporting metadata types other than nn.Partitioned and NNXMeta') | ||||||
return vtype(x) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -18,8 +18,9 @@ | |||||
|
||||||
from flax import nnx | ||||||
from flax import linen | ||||||
from flax.core import meta | ||||||
from flax.nnx.nnx import graph | ||||||
from flax.nnx.nnx import variables as variableslib | ||||||
from flax.nnx.nnx.bridge import variables as bv | ||||||
from flax.nnx.nnx.module import GraphDef, Module | ||||||
from flax.nnx.nnx.rnglib import Rngs | ||||||
from flax.nnx.nnx.state import State | ||||||
|
@@ -120,7 +121,7 @@ def __init__( | |||||
): | ||||||
self.module = module | ||||||
self.rngs = rngs | ||||||
self.linen_collections: set[str] = set() | ||||||
self.linen_collections: tuple[str, ...] = () | ||||||
|
||||||
def lazy_init(self, *args, **kwargs): | ||||||
return lazy_init(self, *args, **kwargs) | ||||||
|
@@ -143,16 +144,20 @@ def __call__( | |||||
if 'params' not in _rngs and 'default' in _rngs: | ||||||
_rngs['params'] = _rngs.pop('default') | ||||||
out, variables = self.module.init_with_output(_rngs, *args, method=method, **kwargs) | ||||||
def nn_var_to_nnx_state(kp, v): | ||||||
assert isinstance(kp[0], jtu.DictKey) | ||||||
vtype = variableslib.variable_type(kp[0].key) | ||||||
return vtype(v) | ||||||
for col, tree in jtu.tree_map_with_path(nn_var_to_nnx_state, variables).items(): | ||||||
self._setattr(col, tree) | ||||||
self.linen_collections.add(col) | ||||||
|
||||||
nnx_vars = jtu.tree_map_with_path( | ||||||
lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x), | ||||||
variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata)) | ||||||
linen_collections = set() | ||||||
for col, tree in nnx_vars.items(): | ||||||
setattr(self, col, tree) | ||||||
linen_collections.add(col) | ||||||
self.linen_collections = tuple(linen_collections) # make it hashable | ||||||
|
||||||
else: | ||||||
variables = {col: jax.tree.map(lambda v: v.value, getattr(self, col)) | ||||||
variables = {col: jax.tree.map(lambda x: bv.to_linen_var(x.to_state()), | ||||||
getattr(self, col), | ||||||
is_leaf=lambda x: isinstance(x, nnx.Variable)) | ||||||
for col in self.linen_collections} | ||||||
_rngs = ( | ||||||
{name: stream() for name, stream in rngs.items()} if rngs else {} | ||||||
|
@@ -162,8 +167,11 @@ def nn_var_to_nnx_state(kp, v): | |||||
# Split out the updates if `mutable` is passed into the Flax module | ||||||
if kwargs.get('mutable', False) != False: | ||||||
out, updates = out | ||||||
updates = jtu.tree_map_with_path( | ||||||
lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x), | ||||||
updates, is_leaf=lambda x: isinstance(x, meta.AxisMetadata)) | ||||||
for collection, value in updates.items(): | ||||||
self._setattr(collection, jax.tree.map(variableslib.variable_type(collection), value)) | ||||||
setattr(self, collection, value) | ||||||
|
||||||
return out | ||||||
|
||||||
|
@@ -214,6 +222,7 @@ class ToLinen(linen.Module): | |||||
args: tp.Sequence = () | ||||||
kwargs: tp.Mapping = dataclasses.field(default_factory=dict) | ||||||
skip_rng: bool = False | ||||||
metadata_type: tp.Type = bv.NNXMeta | ||||||
|
||||||
def update_variables(self, module): | ||||||
"""Store the NNX module's graph def and state inside Linen module variables.""" | ||||||
|
@@ -225,14 +234,16 @@ def update_variables(self, module): | |||||
types = set(jax.tree.leaves( | ||||||
jax.tree.map(lambda x: x.type, state, | ||||||
is_leaf=lambda x: isinstance(x, nnx.VariableState)))) | ||||||
types = variableslib.sort_variable_types(types) | ||||||
types = bv.sort_variable_types(types) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the catch! Done. |
||||||
_, *state_by_types = nnx.split(module, *types) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks but |
||||||
# Each variable type goes to its own linen collection, and | ||||||
# each attribute goes to its own linen variable | ||||||
for typ, state in zip(types, state_by_types): | ||||||
collection = variableslib.variable_type_name(typ) | ||||||
collection = bv.variable_type_name(typ) | ||||||
if self.is_mutable_collection(collection): | ||||||
for k, v in state.raw_mapping.items(): | ||||||
v = jax.tree.map(bv.to_linen_var, v, | ||||||
is_leaf=lambda x: isinstance(x, nnx.VariableState)) | ||||||
self.put_variable(collection, k, v) | ||||||
|
||||||
@linen.compact | ||||||
|
@@ -250,7 +261,11 @@ def __call__(self, *args, **kwargs): | |||||
# apply codepath | ||||||
gdef = self.get_variable('nnx', 'graphdef') | ||||||
assert gdef, 'GraphDef not found in variables. Was the collection "nnx" dropped somewhere?' | ||||||
states = [State(state) for col, state in self.variables.items() if col != 'nnx'] | ||||||
variables = {col: v for col, v in self.variables.items() if col != 'nnx'} | ||||||
states = jtu.tree_map_with_path( | ||||||
lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x).to_state(), | ||||||
variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata)) | ||||||
states = [State(v) for v in states.values()] | ||||||
nnx_state = nnx.GraphState.merge(*states) if states else nnx.GraphState({}) | ||||||
module = nnx.merge(gdef, nnx_state) | ||||||
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call. | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should check that the
name
does exist alreadyThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should allow overwrite of names if user wants to. I'll add an overwrite option there.