Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support linen <-> nnx metadata box converging in nnx.bridge #4145

Merged
merged 1 commit into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from flax.typing import Initializer as Initializer

from .nnx.bridge import wrappers as wrappers
from .nnx.bridge.variables import (
register_variable_name_type_pair as register_variable_name_type_pair,
)
from .nnx import graph as graph
from .nnx import errors as errors
from .nnx import helpers as helpers
Expand Down Expand Up @@ -124,7 +127,6 @@
from .nnx.training import metrics as metrics
from .nnx.variables import (
Param as Param,
register_variable_name_type_pair as register_variable_name_type_pair,
)
# this needs to be imported before optimizer to prevent circular import
from .nnx.training import optimizer as optimizer
Expand Down
3 changes: 2 additions & 1 deletion flax/nnx/nnx/bridge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@
from .wrappers import ToNNX as ToNNX
from .wrappers import lazy_init as lazy_init
from .wrappers import ToLinen as ToLinen
from .wrappers import to_linen as to_linen
from .wrappers import to_linen as to_linen
from .variables import NNXMeta as NNXMeta
138 changes: 138 additions & 0 deletions flax/nnx/nnx/bridge/variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, TypeVar

import jax
from flax import struct
from flax.core import meta
from flax.nnx.nnx import variables as variableslib
import typing as tp


A = TypeVar('A')
B = TypeVar('B')


#######################################################
### Variable type <-> Linen collection name mapping ###
#######################################################
# Assumption: the mapping is 1-1 and unique.

VariableTypeCache: dict[str, tp.Type[variableslib.Variable[tp.Any]]] = {}


def variable_type(name: str) -> tp.Type[variableslib.Variable[tp.Any]]:
"""Given a Linen-style collection name, get or create its corresponding NNX Variable type."""
if name not in VariableTypeCache:
VariableTypeCache[name] = type(name, (variableslib.Variable,), {})
return VariableTypeCache[name]


def variable_type_name(typ: tp.Type[variableslib.Variable[tp.Any]]) -> str:
"""Given an NNX Variable type, get or create its Linen-style collection name.

Should output the exact inversed result of `variable_type()`."""
for name, t in VariableTypeCache.items():
if typ == t:
return name
name = typ.__name__
if name in VariableTypeCache:
raise ValueError(
'Name {name} is already registered in the registry as {VariableTypeCache[name]}. '
'It cannot be linked with this type {typ}.'
)
register_variable_name_type_pair(name, typ)
return name


def register_variable_name_type_pair(name, typ, overwrite = False):
"""Register a pair of variable type name (like Linen collections) and its NNX type."""
if not overwrite and name in VariableTypeCache:
raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. '
'To overwrite, call with `overwrite=True`.')
VariableTypeCache[name] = typ
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should check that the name does exist already

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should allow overwrite of names if user wants to. I'll add an overwrite option there.



# add known variable type names
register_variable_name_type_pair('params', variableslib.Param)
register_variable_name_type_pair('batch_stats', variableslib.BatchStat)
register_variable_name_type_pair('cache', variableslib.Cache)
register_variable_name_type_pair('intermediates', variableslib.Intermediate)


def sort_variable_types(types: tp.Iterable[type]):
def _variable_parents_count(t: type):
return sum(1 for p in t.mro() if issubclass(p, variableslib.Variable))
parent_count = {t: _variable_parents_count(t) for t in types}
return sorted(types, key=lambda t: -parent_count[t])


#############################################
### NNX Variable <-> Linen metadata boxes ###
#############################################


class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]):
"""Default Flax metadata class for `nnx.VariableState`.
"""

var_type: type[variableslib.Variable[tp.Any]] = struct.field(pytree_node=False)
value: Any = struct.field(pytree_node=True)
metadata: dict[str, tp.Any] = struct.field(pytree_node=False)

def unbox(self) -> A:
return self.value

def replace_boxed(self, val: B) -> 'NNXMeta[B]':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String types not needed in python >= 3.10

Suggested change
def replace_boxed(self, val: B) -> 'NNXMeta[B]':
def replace_boxed(self, val: B) -> NNXMeta[B]:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got error trying to do this, saying NNXMeta is not found. I'll leave it as is for now.

return self.replace(value=val) # type: ignore

def add_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]':
# TODO: implement this, supporting hooks
return self

def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]':
# TODO: implement this, supporting hooks
return self


def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata:
metadata = vs.get_metadata()
if 'linen_meta_type' in metadata:
if metadata['linen_meta_type'] is not meta.Partitioned:
raise ValueError('Not supporting Linen metadata types other than nn.Partitioned')
return meta.Partitioned(vs.value, names=metadata['sharding'], mesh=metadata['mesh'])
return NNXMeta(vs.type, vs.value, vs.get_metadata())


def get_col_name(keypath: tp.Sequence[Any]) -> str:
"""Given the keypath of a Flax variable type, return its Linen collection name."""
# Infer variable type from the leaf's path, which contains its Linen collection name
assert isinstance(keypath[0], jax.tree_util.DictKey)
return str(keypath[0].key)


def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable:
"""Convert a Linen variable to an NNX variable.
This process needs the collection name,
"""
vtype = variable_type(col)
if isinstance(x, NNXMeta):
assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}'
return x.var_type(x.value, **x.metadata)
if isinstance(x, meta.AxisMetadata):
if isinstance(x, meta.Partitioned):
return vtype(x.value, sharding=x.names, mesh=x.mesh, linen_meta_type=meta.Partitioned)
raise ValueError('Not yet supporting metadata types other than nn.Partitioned and NNXMeta')
return vtype(x)
43 changes: 29 additions & 14 deletions flax/nnx/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@

from flax import nnx
from flax import linen
from flax.core import meta
from flax.nnx.nnx import graph
from flax.nnx.nnx import variables as variableslib
from flax.nnx.nnx.bridge import variables as bv
from flax.nnx.nnx.module import GraphDef, Module
from flax.nnx.nnx.rnglib import Rngs
from flax.nnx.nnx.state import State
Expand Down Expand Up @@ -120,7 +121,7 @@ def __init__(
):
self.module = module
self.rngs = rngs
self.linen_collections: set[str] = set()
self.linen_collections: tuple[str, ...] = ()

def lazy_init(self, *args, **kwargs):
return lazy_init(self, *args, **kwargs)
Expand All @@ -143,16 +144,20 @@ def __call__(
if 'params' not in _rngs and 'default' in _rngs:
_rngs['params'] = _rngs.pop('default')
out, variables = self.module.init_with_output(_rngs, *args, method=method, **kwargs)
def nn_var_to_nnx_state(kp, v):
assert isinstance(kp[0], jtu.DictKey)
vtype = variableslib.variable_type(kp[0].key)
return vtype(v)
for col, tree in jtu.tree_map_with_path(nn_var_to_nnx_state, variables).items():
self._setattr(col, tree)
self.linen_collections.add(col)

nnx_vars = jtu.tree_map_with_path(
lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x),
variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata))
linen_collections = set()
for col, tree in nnx_vars.items():
setattr(self, col, tree)
linen_collections.add(col)
self.linen_collections = tuple(linen_collections) # make it hashable

else:
variables = {col: jax.tree.map(lambda v: v.value, getattr(self, col))
variables = {col: jax.tree.map(lambda x: bv.to_linen_var(x.to_state()),
getattr(self, col),
is_leaf=lambda x: isinstance(x, nnx.Variable))
for col in self.linen_collections}
_rngs = (
{name: stream() for name, stream in rngs.items()} if rngs else {}
Expand All @@ -162,8 +167,11 @@ def nn_var_to_nnx_state(kp, v):
# Split out the updates if `mutable` is passed into the Flax module
if kwargs.get('mutable', False) != False:
out, updates = out
updates = jtu.tree_map_with_path(
lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x),
updates, is_leaf=lambda x: isinstance(x, meta.AxisMetadata))
for collection, value in updates.items():
self._setattr(collection, jax.tree.map(variableslib.variable_type(collection), value))
setattr(self, collection, value)

return out

Expand Down Expand Up @@ -214,6 +222,7 @@ class ToLinen(linen.Module):
args: tp.Sequence = ()
kwargs: tp.Mapping = dataclasses.field(default_factory=dict)
skip_rng: bool = False
metadata_type: tp.Type = bv.NNXMeta

def update_variables(self, module):
"""Store the NNX module's graph def and state inside Linen module variables."""
Expand All @@ -225,14 +234,16 @@ def update_variables(self, module):
types = set(jax.tree.leaves(
jax.tree.map(lambda x: x.type, state,
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
types = variableslib.sort_variable_types(types)
types = bv.sort_variable_types(types)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sort_variable_types expects type list but we can change the type to Iterable

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch! Done.

_, *state_by_types = nnx.split(module, *types)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use State.split:

Suggested change
_, *state_by_types = nnx.split(module, *types)
*state_by_types = state.split(*types)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks but nnx.split is more convenient because it always return a tuple...

# Each variable type goes to its own linen collection, and
# each attribute goes to its own linen variable
for typ, state in zip(types, state_by_types):
collection = variableslib.variable_type_name(typ)
collection = bv.variable_type_name(typ)
if self.is_mutable_collection(collection):
for k, v in state.raw_mapping.items():
v = jax.tree.map(bv.to_linen_var, v,
is_leaf=lambda x: isinstance(x, nnx.VariableState))
self.put_variable(collection, k, v)

@linen.compact
Expand All @@ -250,7 +261,11 @@ def __call__(self, *args, **kwargs):
# apply codepath
gdef = self.get_variable('nnx', 'graphdef')
assert gdef, 'GraphDef not found in variables. Was the collection "nnx" dropped somewhere?'
states = [State(state) for col, state in self.variables.items() if col != 'nnx']
variables = {col: v for col, v in self.variables.items() if col != 'nnx'}
states = jtu.tree_map_with_path(
lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x).to_state(),
variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata))
states = [State(v) for v in states.values()]
nnx_state = nnx.GraphState.merge(*states) if states else nnx.GraphState({})
module = nnx.merge(gdef, nnx_state)
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call.
Expand Down
59 changes: 0 additions & 59 deletions flax/nnx/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pytype: skip-file
from __future__ import annotations

Expand Down Expand Up @@ -998,48 +984,3 @@ def wrapper(*args):

return wrapper # type: ignore


### Variable type <-> name mapping ###
# Assumption: the mapping is 1-1 and unique.

def variable_type(name: str) -> tp.Type[Variable[tp.Any]]:
"""Given a Linen-style collection name, get or create its corresponding NNX Variable type."""
if name not in VariableTypeCache:
VariableTypeCache[name] = type(name, (Variable,), {})
return VariableTypeCache[name]


def variable_type_name(typ: tp.Type[Variable[tp.Any]]) -> str:
"""Given an NNX Variable type, get or create its Linen-style collection name.

Should output the exact inversed result of `variable_type()`."""
for name, t in VariableTypeCache.items():
if typ == t:
return name
name = typ.__name__
if name in VariableTypeCache:
raise ValueError(
'Name {name} is already registered in the registry as {VariableTypeCache[name]}. '
'It cannot be linked with this type {typ}.'
)
register_variable_name_type_pair(name, typ)
return name


def register_variable_name_type_pair(name, typ):
"""Register a pair of variable type name (like Linen collections) and its NNX type."""
VariableTypeCache[name] = typ


# add known variable type names
register_variable_name_type_pair('params', Param)
register_variable_name_type_pair('batch_stats', BatchStat)
register_variable_name_type_pair('cache', Cache)
register_variable_name_type_pair('intermediates', Intermediate)


def sort_variable_types(types: list[type]):
def _variable_parents_count(t: type):
return sum(1 for p in t.mro() if issubclass(p, nnx.Variable))
parent_count = {t: _variable_parents_count(t) for t in types}
return sorted(types, key=lambda t: -parent_count[t])
Loading
Loading