diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 5cf1c16672..b30f7d8161 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -19,6 +19,9 @@ from flax.typing import Initializer as Initializer from .nnx.bridge import wrappers as wrappers +from .nnx.bridge.variables import ( + register_variable_name_type_pair as register_variable_name_type_pair, +) from .nnx import graph as graph from .nnx import errors as errors from .nnx import helpers as helpers @@ -124,7 +127,6 @@ from .nnx.training import metrics as metrics from .nnx.variables import ( Param as Param, - register_variable_name_type_pair as register_variable_name_type_pair, ) # this needs to be imported before optimizer to prevent circular import from .nnx.training import optimizer as optimizer diff --git a/flax/nnx/nnx/bridge/__init__.py b/flax/nnx/nnx/bridge/__init__.py index a5924d0bd8..4f6757fc83 100644 --- a/flax/nnx/nnx/bridge/__init__.py +++ b/flax/nnx/nnx/bridge/__init__.py @@ -22,4 +22,5 @@ from .wrappers import ToNNX as ToNNX from .wrappers import lazy_init as lazy_init from .wrappers import ToLinen as ToLinen -from .wrappers import to_linen as to_linen \ No newline at end of file +from .wrappers import to_linen as to_linen +from .variables import NNXMeta as NNXMeta \ No newline at end of file diff --git a/flax/nnx/nnx/bridge/variables.py b/flax/nnx/nnx/bridge/variables.py new file mode 100644 index 0000000000..57dd37a796 --- /dev/null +++ b/flax/nnx/nnx/bridge/variables.py @@ -0,0 +1,138 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, TypeVar + +import jax +from flax import struct +from flax.core import meta +from flax.nnx.nnx import variables as variableslib +import typing as tp + + +A = TypeVar('A') +B = TypeVar('B') + + +####################################################### +### Variable type <-> Linen collection name mapping ### +####################################################### +# Assumption: the mapping is 1-1 and unique. + +VariableTypeCache: dict[str, tp.Type[variableslib.Variable[tp.Any]]] = {} + + +def variable_type(name: str) -> tp.Type[variableslib.Variable[tp.Any]]: + """Given a Linen-style collection name, get or create its corresponding NNX Variable type.""" + if name not in VariableTypeCache: + VariableTypeCache[name] = type(name, (variableslib.Variable,), {}) + return VariableTypeCache[name] + + +def variable_type_name(typ: tp.Type[variableslib.Variable[tp.Any]]) -> str: + """Given an NNX Variable type, get or create its Linen-style collection name. + + Should output the exact inversed result of `variable_type()`.""" + for name, t in VariableTypeCache.items(): + if typ == t: + return name + name = typ.__name__ + if name in VariableTypeCache: + raise ValueError( + 'Name {name} is already registered in the registry as {VariableTypeCache[name]}. ' + 'It cannot be linked with this type {typ}.' + ) + register_variable_name_type_pair(name, typ) + return name + + +def register_variable_name_type_pair(name, typ, overwrite = False): + """Register a pair of variable type name (like Linen collections) and its NNX type.""" + if not overwrite and name in VariableTypeCache: + raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. ' + 'To overwrite, call with `overwrite=True`.') + VariableTypeCache[name] = typ + + +# add known variable type names +register_variable_name_type_pair('params', variableslib.Param) +register_variable_name_type_pair('batch_stats', variableslib.BatchStat) +register_variable_name_type_pair('cache', variableslib.Cache) +register_variable_name_type_pair('intermediates', variableslib.Intermediate) + + +def sort_variable_types(types: tp.Iterable[type]): + def _variable_parents_count(t: type): + return sum(1 for p in t.mro() if issubclass(p, variableslib.Variable)) + parent_count = {t: _variable_parents_count(t) for t in types} + return sorted(types, key=lambda t: -parent_count[t]) + + +############################################# +### NNX Variable <-> Linen metadata boxes ### +############################################# + + +class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]): + """Default Flax metadata class for `nnx.VariableState`. + """ + + var_type: type[variableslib.Variable[tp.Any]] = struct.field(pytree_node=False) + value: Any = struct.field(pytree_node=True) + metadata: dict[str, tp.Any] = struct.field(pytree_node=False) + + def unbox(self) -> A: + return self.value + + def replace_boxed(self, val: B) -> 'NNXMeta[B]': + return self.replace(value=val) # type: ignore + + def add_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]': + # TODO: implement this, supporting hooks + return self + + def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]': + # TODO: implement this, supporting hooks + return self + + +def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata: + metadata = vs.get_metadata() + if 'linen_meta_type' in metadata: + if metadata['linen_meta_type'] is not meta.Partitioned: + raise ValueError('Not supporting Linen metadata types other than nn.Partitioned') + return meta.Partitioned(vs.value, names=metadata['sharding'], mesh=metadata['mesh']) + return NNXMeta(vs.type, vs.value, vs.get_metadata()) + + +def get_col_name(keypath: tp.Sequence[Any]) -> str: + """Given the keypath of a Flax variable type, return its Linen collection name.""" + # Infer variable type from the leaf's path, which contains its Linen collection name + assert isinstance(keypath[0], jax.tree_util.DictKey) + return str(keypath[0].key) + + +def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable: + """Convert a Linen variable to an NNX variable. + This process needs the collection name, + """ + vtype = variable_type(col) + if isinstance(x, NNXMeta): + assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}' + return x.var_type(x.value, **x.metadata) + if isinstance(x, meta.AxisMetadata): + if isinstance(x, meta.Partitioned): + return vtype(x.value, sharding=x.names, mesh=x.mesh, linen_meta_type=meta.Partitioned) + raise ValueError('Not yet supporting metadata types other than nn.Partitioned and NNXMeta') + return vtype(x) diff --git a/flax/nnx/nnx/bridge/wrappers.py b/flax/nnx/nnx/bridge/wrappers.py index 8ef6628099..6878be0a19 100644 --- a/flax/nnx/nnx/bridge/wrappers.py +++ b/flax/nnx/nnx/bridge/wrappers.py @@ -18,8 +18,9 @@ from flax import nnx from flax import linen +from flax.core import meta from flax.nnx.nnx import graph -from flax.nnx.nnx import variables as variableslib +from flax.nnx.nnx.bridge import variables as bv from flax.nnx.nnx.module import GraphDef, Module from flax.nnx.nnx.rnglib import Rngs from flax.nnx.nnx.state import State @@ -120,7 +121,7 @@ def __init__( ): self.module = module self.rngs = rngs - self.linen_collections: set[str] = set() + self.linen_collections: tuple[str, ...] = () def lazy_init(self, *args, **kwargs): return lazy_init(self, *args, **kwargs) @@ -143,16 +144,20 @@ def __call__( if 'params' not in _rngs and 'default' in _rngs: _rngs['params'] = _rngs.pop('default') out, variables = self.module.init_with_output(_rngs, *args, method=method, **kwargs) - def nn_var_to_nnx_state(kp, v): - assert isinstance(kp[0], jtu.DictKey) - vtype = variableslib.variable_type(kp[0].key) - return vtype(v) - for col, tree in jtu.tree_map_with_path(nn_var_to_nnx_state, variables).items(): - self._setattr(col, tree) - self.linen_collections.add(col) + + nnx_vars = jtu.tree_map_with_path( + lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x), + variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata)) + linen_collections = set() + for col, tree in nnx_vars.items(): + setattr(self, col, tree) + linen_collections.add(col) + self.linen_collections = tuple(linen_collections) # make it hashable else: - variables = {col: jax.tree.map(lambda v: v.value, getattr(self, col)) + variables = {col: jax.tree.map(lambda x: bv.to_linen_var(x.to_state()), + getattr(self, col), + is_leaf=lambda x: isinstance(x, nnx.Variable)) for col in self.linen_collections} _rngs = ( {name: stream() for name, stream in rngs.items()} if rngs else {} @@ -162,8 +167,11 @@ def nn_var_to_nnx_state(kp, v): # Split out the updates if `mutable` is passed into the Flax module if kwargs.get('mutable', False) != False: out, updates = out + updates = jtu.tree_map_with_path( + lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x), + updates, is_leaf=lambda x: isinstance(x, meta.AxisMetadata)) for collection, value in updates.items(): - self._setattr(collection, jax.tree.map(variableslib.variable_type(collection), value)) + setattr(self, collection, value) return out @@ -214,6 +222,7 @@ class ToLinen(linen.Module): args: tp.Sequence = () kwargs: tp.Mapping = dataclasses.field(default_factory=dict) skip_rng: bool = False + metadata_type: tp.Type = bv.NNXMeta def update_variables(self, module): """Store the NNX module's graph def and state inside Linen module variables.""" @@ -225,14 +234,16 @@ def update_variables(self, module): types = set(jax.tree.leaves( jax.tree.map(lambda x: x.type, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)))) - types = variableslib.sort_variable_types(types) + types = bv.sort_variable_types(types) _, *state_by_types = nnx.split(module, *types) # Each variable type goes to its own linen collection, and # each attribute goes to its own linen variable for typ, state in zip(types, state_by_types): - collection = variableslib.variable_type_name(typ) + collection = bv.variable_type_name(typ) if self.is_mutable_collection(collection): for k, v in state.raw_mapping.items(): + v = jax.tree.map(bv.to_linen_var, v, + is_leaf=lambda x: isinstance(x, nnx.VariableState)) self.put_variable(collection, k, v) @linen.compact @@ -250,7 +261,11 @@ def __call__(self, *args, **kwargs): # apply codepath gdef = self.get_variable('nnx', 'graphdef') assert gdef, 'GraphDef not found in variables. Was the collection "nnx" dropped somewhere?' - states = [State(state) for col, state in self.variables.items() if col != 'nnx'] + variables = {col: v for col, v in self.variables.items() if col != 'nnx'} + states = jtu.tree_map_with_path( + lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x).to_state(), + variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata)) + states = [State(v) for v in states.values()] nnx_state = nnx.GraphState.merge(*states) if states else nnx.GraphState({}) module = nnx.merge(gdef, nnx_state) nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call. diff --git a/flax/nnx/nnx/variables.py b/flax/nnx/nnx/variables.py index a88ba0cc00..8c402b8ff0 100644 --- a/flax/nnx/nnx/variables.py +++ b/flax/nnx/nnx/variables.py @@ -11,20 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Copyright 2023 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the 'License'); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. # pytype: skip-file from __future__ import annotations @@ -998,48 +984,3 @@ def wrapper(*args): return wrapper # type: ignore - -### Variable type <-> name mapping ### -# Assumption: the mapping is 1-1 and unique. - -def variable_type(name: str) -> tp.Type[Variable[tp.Any]]: - """Given a Linen-style collection name, get or create its corresponding NNX Variable type.""" - if name not in VariableTypeCache: - VariableTypeCache[name] = type(name, (Variable,), {}) - return VariableTypeCache[name] - - -def variable_type_name(typ: tp.Type[Variable[tp.Any]]) -> str: - """Given an NNX Variable type, get or create its Linen-style collection name. - - Should output the exact inversed result of `variable_type()`.""" - for name, t in VariableTypeCache.items(): - if typ == t: - return name - name = typ.__name__ - if name in VariableTypeCache: - raise ValueError( - 'Name {name} is already registered in the registry as {VariableTypeCache[name]}. ' - 'It cannot be linked with this type {typ}.' - ) - register_variable_name_type_pair(name, typ) - return name - - -def register_variable_name_type_pair(name, typ): - """Register a pair of variable type name (like Linen collections) and its NNX type.""" - VariableTypeCache[name] = typ - - -# add known variable type names -register_variable_name_type_pair('params', Param) -register_variable_name_type_pair('batch_stats', BatchStat) -register_variable_name_type_pair('cache', Cache) -register_variable_name_type_pair('intermediates', Intermediate) - - -def sort_variable_types(types: list[type]): - def _variable_parents_count(t: type): - return sum(1 for p in t.mro() if issubclass(p, nnx.Variable)) - parent_count = {t: _variable_parents_count(t) for t in types} - return sorted(types, key=lambda t: -parent_count[t]) diff --git a/flax/nnx/tests/bridge/wrappers_test.py b/flax/nnx/tests/bridge/wrappers_test.py index d28d84ec18..72d42eb6d4 100644 --- a/flax/nnx/tests/bridge/wrappers_test.py +++ b/flax/nnx/tests/bridge/wrappers_test.py @@ -42,6 +42,11 @@ def test_linen_to_nnx(self): model = bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x) # like linen init y = model(x) # like linen apply assert y.shape == (1, 64) + self.assertIsInstance(model.params['kernel'], nnx.Variable) + # NNX automatically adds metadata box regardless of original Linen module. + linen_vars = linen_module.init(jax.random.key(0), x) + np.testing.assert_array_equal(linen_vars['params']['kernel'], + model.params['kernel'].value) def test_linen_to_nnx_submodule(self): class NNXOuter(nnx.Module): @@ -127,6 +132,26 @@ def vmap_fn(inner, x): self.assertEqual(model.inner.params['kernel'].shape, (5, 4, 3)) self.assertEqual(model.inner.params['bias'].shape, (5, 3)) + def test_linen_to_nnx_metadata(self): + linen_module = nn.Dense( + features=64, + kernel_init=nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out'))) + x = jax.numpy.ones((1, 32)) + linen_vars = linen_module.init(jax.random.key(0), x) + nnx_model = bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x) + # nn.Partitioned metadata box is translated into a valid nnx.Variable / VariableState box. + self.assertIsInstance(linen_vars['params']['kernel'], nn.Partitioned) + self.assertIsInstance(nnx_model.params['kernel'], nnx.Variable) + np.testing.assert_array_equal(linen_vars['params']['kernel'].value, + nnx_model.params['kernel'].value) + assert nnx_model.params['kernel'].sharding == ('in', 'out') + _, nnx_state = nnx.split(nnx_model) + self.assertIsInstance(nnx_state['params']['kernel'], nnx.VariableState) + np.testing.assert_array_equal(linen_vars['params']['kernel'].value, + nnx_state['params']['kernel'].value) + assert nnx_state['params']['kernel'].sharding == ('in', 'out') + + ################## ### NNXToLinen ### ################## @@ -181,7 +206,7 @@ def __call__(self, x): def test_nnx_to_linen_mutable(self): class Count(nnx.Variable): pass - nnx.register_variable_name_type_pair('Count', Count) + nnx.register_variable_name_type_pair('Count', Count, overwrite=True) class Counter(nnx.Module): def __init__(self): @@ -199,7 +224,7 @@ def __call__(self): def test_nnx_to_linen_mutated_static_data(self): class Count(nnx.Variable): pass - nnx.register_variable_name_type_pair('Count', Count) + nnx.register_variable_name_type_pair('Count', Count, overwrite=True) class Counter(nnx.Module): def __init__(self): @@ -246,35 +271,56 @@ def __call__(self, x): np.testing.assert_allclose(y, jnp.einsum('ab,abc->ac', x, k)) assert 'nnx' in var + def test_nnx_to_linen_metadata(self): + model = bridge.to_linen( + nnx.Linear, 32, 64, + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ('in', 'out'))) + x = jax.numpy.ones((1, 32)) + y, variables = model.init_with_output(jax.random.key(0), x) + assert y.shape == (1, 64) + self.assertIsInstance(variables['params']['kernel'], nnx.bridge.NNXMeta) + assert variables['params']['kernel'].metadata['sharding'] == ('in', 'out') + np.testing.assert_allclose(y, x @ variables['params']['kernel'].value) + + def test_nnx_to_linen_metadata_transform(self): + # TODO: add support and testing after axis add/remove in transform is fixed. + pass + ############################ ### Hybrid mix-and-match ### ############################ def test_nnx_linen_nnx(self): class NNXInner(nnx.Module): - def __init__(self, din, dout, rngs): - self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout))) - self.dropout = nnx.Dropout(rate=0.5, rngs=rngs) + def __init__(self, din, dout, dropout_rate, rngs): + self.w = nnx.Param( + nnx.with_partitioning(nnx.initializers.lecun_normal(), sharding=('in', 'out') + )(rngs.params(), (din, dout))) + self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs) def __call__(self, x): return self.dropout(x @ self.w.value) class LinenMiddle(nn.Module): dout: int + dropout_rate: float @nn.compact def __call__(self, x): - dot = bridge.to_linen(NNXInner, x.shape[-1], self.dout, name='linen') - b = self.param('b', nn.zeros_init(), (1, self.dout)) + dot = bridge.to_linen(NNXInner, x.shape[-1], self.dout, self.dropout_rate, name='dot') + b = self.param('b', nn.initializers.lecun_normal(), (1, self.dout)) return dot(x) + b class NNXOuter(nnx.Module): - def __init__(self, dout: int, *, rngs: nnx.Rngs): - self.inner = bridge.ToNNX(LinenMiddle(dout), rngs=rngs) + def __init__(self, dout: int, dropout_rate: float, *, rngs: nnx.Rngs): + self.inner = bridge.ToNNX(LinenMiddle(dout, dropout_rate), rngs=rngs) self.rngs = rngs def __call__(self, x): return self.inner(x) x = jax.random.normal(jax.random.key(0), (2, 4)) - model = bridge.lazy_init(NNXOuter(3, rngs=nnx.Rngs(default=1, dropout=2)), x) + + # Test the RNG + model = bridge.lazy_init(NNXOuter(dout=3, dropout_rate=0.5, + rngs=nnx.Rngs(default=1, dropout=2)), x) y1, y2 = model(x), model(x) # The dropout key of lowest NNX level still changes over stateful calls assert not jnp.allclose(y1, y2) @@ -282,5 +328,19 @@ def __call__(self, x): nnx.reseed(model, dropout=2) np.testing.assert_array_equal(y1, model(x)) + # Test the param value with disabled dropout + model = bridge.lazy_init(NNXOuter(dout=3, dropout_rate=0., + rngs=nnx.Rngs(default=1, dropout=2)), x) + w, b = model.inner.params['dot']['w'], model.inner.params['b'] + self.assertIsInstance(w, nnx.Param) + np.testing.assert_allclose(model(x), x @ w + b) + assert hasattr(w, 'sharding') and w.sharding == ('in', 'out') + + def test_linen_nnx_linen(self): + # TODO: add when we can safely `lazy_init` the NNX module inside `ToLinen` without + # messing up the stateful part of the NNX module. + pass + + if __name__ == '__main__': absltest.main()