From 895b78337169bfdc2f0fc6128adb8518aa6ace3e Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Wed, 1 May 2024 08:14:02 +0100 Subject: [PATCH 1/2] [nnx] enable pytype and mypy on CI --- flax/core/frozen_dict.py | 30 ++- flax/experimental/nnx/__init__.py | 2 +- flax/experimental/nnx/docs/why.ipynb | 2 +- flax/experimental/nnx/docs/why.md | 2 +- flax/experimental/nnx/nnx/filterlib.py | 5 +- flax/experimental/nnx/nnx/graph.py | 17 +- flax/experimental/nnx/nnx/helpers.py | 2 +- flax/experimental/nnx/nnx/module.py | 8 +- flax/experimental/nnx/nnx/nn/attention.py | 38 +-- flax/experimental/nnx/nnx/nn/dtypes.py | 3 +- flax/experimental/nnx/nnx/nn/linear.py | 1 + flax/experimental/nnx/nnx/proxy_caller.py | 4 +- flax/experimental/nnx/nnx/reprlib.py | 3 +- flax/experimental/nnx/nnx/rnglib.py | 12 +- flax/experimental/nnx/nnx/state.py | 28 +-- flax/experimental/nnx/nnx/training/metrics.py | 5 +- flax/experimental/nnx/nnx/transforms.py | 221 +++++++++--------- flax/experimental/nnx/nnx/variables.py | 7 +- flax/experimental/nnx/nnx/visualization.py | 6 +- flax/experimental/nnx/tests/nn/test_conv.py | 4 +- flax/experimental/nnx/tests/nn/test_embed.py | 1 + flax/experimental/nnx/tests/nn/test_linear.py | 2 + .../nnx/tests/test_graph_utils.py | 2 +- flax/experimental/nnx/tests/test_optimizer.py | 8 +- .../experimental/nnx/tests/test_transforms.py | 40 ++-- flax/experimental/nnx/tests/test_variable.py | 2 +- flax/typing.py | 3 +- pyproject.toml | 5 +- 28 files changed, 240 insertions(+), 223 deletions(-) diff --git a/flax/core/frozen_dict.py b/flax/core/frozen_dict.py index b78319a14..9bc241bcf 100644 --- a/flax/core/frozen_dict.py +++ b/flax/core/frozen_dict.py @@ -16,7 +16,17 @@ import collections from types import MappingProxyType -from typing import Any, Dict, Hashable, Mapping, Tuple, TypeVar, Union +from typing import ( + Any, + Dict, + Hashable, + Iterable, + Mapping, + Tuple, + TypeVar, + Union, + overload, +) import jax @@ -55,6 +65,24 @@ class FrozenDict(Mapping[K, V]): __slots__ = ('_dict', '_hash') + @overload + def __init__( + self, + mapping: Mapping[K, V] = MappingProxyType({}), + /, + __unsafe_skip_copy__=False, + **kwargs, + ): ... + + @overload + def __init__( + self, + mapping: Iterable[tuple[K, V]] = (), + /, + __unsafe_skip_copy__=False, + **kwargs, + ): ... + def __init__(self, *args, __unsafe_skip_copy__=False, **kwargs): # pylint: disable=invalid-name # make sure the dict is as xs = dict(*args, **kwargs) diff --git a/flax/experimental/nnx/__init__.py b/flax/experimental/nnx/__init__.py index 827542835..1b237768e 100644 --- a/flax/experimental/nnx/__init__.py +++ b/flax/experimental/nnx/__init__.py @@ -97,11 +97,11 @@ from .nnx.training.metrics import MultiMetric as MultiMetric from .nnx.training.optimizer import Optimizer as Optimizer from .nnx.transforms import Jit as Jit +from .nnx.transforms import jit as jit from .nnx.transforms import Remat as Remat from .nnx.transforms import Scan as Scan from .nnx.transforms import Vmap as Vmap from .nnx.transforms import grad as grad -from .nnx.transforms import jit as jit from .nnx.transforms import remat as remat from .nnx.transforms import scan as scan from .nnx.transforms import value_and_grad as value_and_grad diff --git a/flax/experimental/nnx/docs/why.ipynb b/flax/experimental/nnx/docs/why.ipynb index 4b1cfcfad..04cad17da 100644 --- a/flax/experimental/nnx/docs/why.ipynb +++ b/flax/experimental/nnx/docs/why.ipynb @@ -492,7 +492,7 @@ ], "source": [ "# class transform:\n", - "ScannedLinear = nnx.Scan(nnx.Linear, variable_axes={nnx.Param: 0}, length=4)\n", + "ScannedLinear = nnx.Scan.constructor(nnx.Linear, variable_axes={nnx.Param: 0}, length=4)\n", "\n", "scanned = ScannedLinear(2, 2, rngs=nnx.Rngs(0))\n", "scanned.get_state()" diff --git a/flax/experimental/nnx/docs/why.md b/flax/experimental/nnx/docs/why.md index c52d9b637..3dce4ad63 100644 --- a/flax/experimental/nnx/docs/why.md +++ b/flax/experimental/nnx/docs/why.md @@ -256,7 +256,7 @@ Like linen, for convenience we still provide simple lifted transforms for standa :outputId: c4800a49-efd1-4ee5-e703-6e63e18da4cb # class transform: -ScannedLinear = nnx.Scan(nnx.Linear, variable_axes={nnx.Param: 0}, length=4) +ScannedLinear = nnx.Scan.constructor(nnx.Linear, variable_axes={nnx.Param: 0}, length=4) scanned = ScannedLinear(2, 2, rngs=nnx.Rngs(0)) scanned.get_state() diff --git a/flax/experimental/nnx/nnx/filterlib.py b/flax/experimental/nnx/nnx/filterlib.py index b6c406f11..6d30d264b 100644 --- a/flax/experimental/nnx/nnx/filterlib.py +++ b/flax/experimental/nnx/nnx/filterlib.py @@ -43,7 +43,10 @@ def to_predicate(filter: Filter) -> Predicate: elif isinstance(filter, type): return OfType(filter) elif isinstance(filter, bool): - return Everything() if filter else Nothing() + if filter: + return Everything() + else: + return Nothing() elif filter is Ellipsis: return Everything() elif filter is None: diff --git a/flax/experimental/nnx/nnx/graph.py b/flax/experimental/nnx/nnx/graph.py index 71dafa259..799b361f1 100644 --- a/flax/experimental/nnx/nnx/graph.py +++ b/flax/experimental/nnx/nnx/graph.py @@ -267,7 +267,7 @@ def create( type: tp.Type[Node], index: int, attributes: tuple[Key, ...], - subgraphs: tp.Iterable[tuple[Key, tp.Union['GraphDef[tp.Any]', Index]]], + subgraphs: tp.Iterable[tuple[Key, tp.Union['NodeDef[tp.Any]', Index]]], static_fields: tp.Iterable[tuple[Key, tp.Any]], variables: tp.Iterable[tuple[Key, Index]], metadata: tp.Any, @@ -456,7 +456,7 @@ def unflatten( def _graph_unflatten( nodedef: tp.Union[NodeDef[Node], int], - state: dict[Key, StateLeaf | dict[Key, tp.Any]], + state: tp.Mapping[Key, StateLeaf | tp.Mapping[Key, tp.Any]], index_to_ref: dict[Index, tp.Any], idxmap: dict[Index, tp.Any] | None, ) -> Node: @@ -656,7 +656,7 @@ def _graph_pop( pass -def _graph_update_dynamic(node: tp.Any, state: dict[Key, tp.Any]): +def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[Key, tp.Any]): if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}') @@ -741,7 +741,7 @@ def _graph_update_static( if id(updates) in cache: if cache[id(updates)] != status: - str_path = '/'.join(path) + str_path = '/'.join(str(p) for p in path) if status is _StaticModuleStatus.NEW: raise ValueError( f'Trying to add a new node at path {str_path!r} but a' @@ -859,6 +859,7 @@ def split( ) graphdef, state, refmap = flatten(node, idxmap=self.idxmap) + states: State | tuple[State, ...] if len(filters) == 0: states = (state,) elif len(filters) == 1: @@ -938,6 +939,7 @@ def split( ) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: graphdef, state, _ = flatten(node) + states: State | tuple[State, ...] if len(filters) == 0: states = (state,) elif len(filters) == 1: @@ -995,6 +997,7 @@ def state( ) -> tp.Union[State, tuple[State, ...]]: state = flatten(node)[1] + states: State | tuple[State, ...] if len(filters) == 0: states = state elif len(filters) == 1: @@ -1181,7 +1184,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any: def _graph_node_meta_call(cls: tp.Type[G], *args, **kwargs) -> G: node = cls.__new__(cls, *args, **kwargs) vars(node)['_graph_node__state'] = ModuleState() - node.__init__(*args, **kwargs) + node.__init__(*args, **kwargs) # type: ignore[misc] return node @@ -1321,6 +1324,10 @@ def _key_path_to_key(key: tp.Any) -> Key: elif isinstance( key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey) ): + if not isinstance(key.key, Key): + raise ValueError( + f'Invalid key: {key.key}. May be due to its type not being hashable or comparable.' + ) return key.key elif isinstance(key, jax.tree_util.GetAttrKey): return key.name diff --git a/flax/experimental/nnx/nnx/helpers.py b/flax/experimental/nnx/nnx/helpers.py index 8aeca4176..d50a93271 100644 --- a/flax/experimental/nnx/nnx/helpers.py +++ b/flax/experimental/nnx/nnx/helpers.py @@ -134,7 +134,7 @@ def __call__(self, *args, rngs: tp.Optional[Rngs] = None, **kwargs) -> tp.Any: if not callable(f): raise TypeError(f'Sequence[{i}] is not callable: {f}') if i > 0: - if isinstance(output, tp.Tuple): + if isinstance(output, tuple): args = output kwargs = {} elif isinstance(output, dict): diff --git a/flax/experimental/nnx/nnx/module.py b/flax/experimental/nnx/nnx/module.py index 69dd37d75..5019a400e 100644 --- a/flax/experimental/nnx/nnx/module.py +++ b/flax/experimental/nnx/nnx/module.py @@ -30,8 +30,7 @@ CallableProxy, DelayedAccessor, ) -from flax.experimental.nnx.nnx.state import State -from flax.experimental.nnx.nnx.variables import Variable +from flax.experimental.nnx.nnx.state import State, StateLeaf from flax.typing import Path, PathParts A = tp.TypeVar('A') @@ -329,7 +328,7 @@ def __init_subclass__(cls, experimental_pytree: bool = False) -> None: jtu.register_pytree_with_keys( cls, partial(_module_flatten, with_keys=True), - _module_unflatten, + _module_unflatten, # type: ignore[arg-type] flatten_func=partial(_module_flatten, with_keys=False), ) @@ -342,6 +341,7 @@ def _module_flatten(module: Module, *, with_keys: bool): key_values = sorted(state.raw_mapping.items()) keys = tuple(key for key, _ in key_values) + children: tuple[tp.Any, ...] if with_keys: children = tuple((jtu.DictKey(key), value) for key, value in key_values) else: @@ -352,7 +352,7 @@ def _module_flatten(module: Module, *, with_keys: bool): def _module_unflatten( paths_moduledef: tuple[tuple[Path, ...], GraphDef[M]], - variables: tuple[Variable[tp.Any], ...], + variables: tuple[StateLeaf, ...], ) -> M: paths, graphdef = paths_moduledef return graph.merge(graphdef, State(zip(paths, variables))) diff --git a/flax/experimental/nnx/nnx/nn/attention.py b/flax/experimental/nnx/nnx/nn/attention.py index fa47421b0..56de71be4 100644 --- a/flax/experimental/nnx/nnx/nn/attention.py +++ b/flax/experimental/nnx/nnx/nn/attention.py @@ -17,7 +17,7 @@ from __future__ import annotations import functools -from typing import Any, Callable, Optional, overload +from typing import Any, Callable, Optional import jax import jax.numpy as jnp @@ -368,6 +368,8 @@ def __init__( self.key = linear_general(rngs=rngs) self.value = linear_general(rngs=rngs) + self.query_ln: LayerNorm | None + self.key_ln: LayerNorm | None if self.normalize_qk: # Normalizing query and key projections stabilizes training with higher # LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis. @@ -403,37 +405,7 @@ def __init__( dot_general_cls=self.out_dot_general_cls, rngs=rngs, ) - - @overload - def __call__( - self, - inputs_q: Array, - inputs_k: Optional[Array] = None, - inputs_v: Optional[Array] = None, - *, - mask: Optional[Array] = None, - deterministic: Optional[bool] = None, - dropout_rng: Optional[Array] = None, - rngs: rnglib.Rngs | None = None, - sow_weights: bool = False, - decode: bool | None = None, - ): - ... - - @overload - def __call__( - self, - inputs_q: Array, - *, - inputs_kv: Array | None = None, - mask: Array | None = None, - deterministic: bool | None = None, - dropout_rng: Array | None = None, - rngs: rnglib.Rngs | None = None, - sow_weights: bool = False, - decode: bool | None = None, - ): - ... + self.rngs = rngs if dropout_rate > 0.0 else None def __call__( self, @@ -476,6 +448,8 @@ def __call__( Returns: output of shape `[batch_sizes..., length, features]`. """ + if rngs is None: + rngs = self.rngs if inputs_k is None: if inputs_v is not None: diff --git a/flax/experimental/nnx/nnx/nn/dtypes.py b/flax/experimental/nnx/nnx/nn/dtypes.py index 9f59b10e0..c204e0426 100644 --- a/flax/experimental/nnx/nnx/nn/dtypes.py +++ b/flax/experimental/nnx/nnx/nn/dtypes.py @@ -79,4 +79,5 @@ def promote_dtype( The arguments cast to arrays of the same dtype. """ dtype = canonicalize_dtype(*args, dtype=dtype, inexact=inexact) - return tuple(jnp.asarray(x, dtype) if x is not None else None for x in args) + arrays = tuple(jnp.asarray(x, dtype) if x is not None else None for x in args) + return arrays # type: ignore[return-value] diff --git a/flax/experimental/nnx/nnx/nn/linear.py b/flax/experimental/nnx/nnx/nn/linear.py index 990f6c6f3..dffe9c742 100644 --- a/flax/experimental/nnx/nnx/nn/linear.py +++ b/flax/experimental/nnx/nnx/nn/linear.py @@ -419,6 +419,7 @@ def __init__( kernel_key = rngs.params() self.kernel = nnx.Param(kernel_init(kernel_key, kernel_shape, param_dtype)) + self.bias: nnx.Param | None if bias_shape is not None: bias_key = rngs.params() self.bias = nnx.Param(bias_init(bias_key, bias_shape, param_dtype)) diff --git a/flax/experimental/nnx/nnx/proxy_caller.py b/flax/experimental/nnx/nnx/proxy_caller.py index f3b3aab51..e1c2bd3cf 100644 --- a/flax/experimental/nnx/nnx/proxy_caller.py +++ b/flax/experimental/nnx/nnx/proxy_caller.py @@ -30,7 +30,6 @@ import dataclasses import typing as tp -import typing_extensions as tpe A = tp.TypeVar('A') @@ -53,11 +52,10 @@ def __getitem__(self, key): return DelayedAccessor(lambda x: x[key]) -class _AccessorCall(tpe.Protocol): +class _AccessorCall(tp.Protocol): def __call__(self, accessor: DelayedAccessor, /, *args, **kwargs) -> tp.Any: ... - class CallableProxy: def __init__( self, callable: _AccessorCall, accessor: DelayedAccessor | None = None diff --git a/flax/experimental/nnx/nnx/reprlib.py b/flax/experimental/nnx/nnx/reprlib.py index 5efc065ed..855a3049b 100644 --- a/flax/experimental/nnx/nnx/reprlib.py +++ b/flax/experimental/nnx/nnx/reprlib.py @@ -93,8 +93,7 @@ def _repr_elem(elem: tp.Any) -> str: return f'{config.elem_indent}{elem.start}{elem.key}{config.value_sep}{value}{elem.end}' with add_indent(config.elem_indent): - elems = list(map(_repr_elem, iterator)) - elems = ',\n'.join(elems) + elems = ',\n'.join(map(_repr_elem, iterator)) if elems: elems = '\n' + elems + '\n' diff --git a/flax/experimental/nnx/nnx/rnglib.py b/flax/experimental/nnx/nnx/rnglib.py index 8e3802bef..094f21a04 100644 --- a/flax/experimental/nnx/nnx/rnglib.py +++ b/flax/experimental/nnx/nnx/rnglib.py @@ -99,6 +99,7 @@ def fork(self, pattern: SplitPattern) -> jax.Array: # broadcast key key = self() else: + num_splits: int | tuple[int, ...] if isinstance(pattern, int): num_splits = pattern else: @@ -137,7 +138,7 @@ def __init__( ) setattr(self, name, stream) - def _get_stream(self, name: str, error_type: Exception) -> RngStream: + def _get_stream(self, name: str, error_type: type[Exception]) -> RngStream: rngs_vars = vars(self) if name not in rngs_vars: if 'default' not in rngs_vars: @@ -252,6 +253,7 @@ def _split_rng_flatten(rngs: ForkedKeys, *, with_keys: bool): items = [(name, rngs.broadcasts[name]) for name in broadcast_names] items += [(name, rngs.splits[name]) for name in split_names] + nodes: tuple[jax.Array | tuple[jax.tree_util.DictKey, jax.Array], ...] if with_keys: nodes = tuple((jax.tree_util.DictKey(name), value) for name, value in items) else: @@ -277,9 +279,9 @@ def _split_rng_unflatten( jax.tree_util.register_pytree_with_keys( ForkedKeys, - functools.partial(_split_rng_flatten, with_keys=True), - _split_rng_unflatten, - flatten_func=functools.partial(_split_rng_flatten, with_keys=False), + functools.partial(_split_rng_flatten, with_keys=True), # type: ignore + _split_rng_unflatten, # type: ignore + flatten_func=functools.partial(_split_rng_flatten, with_keys=False), # type: ignore ) def fork( @@ -289,6 +291,8 @@ def fork( ) -> tuple[State, State]: if split_pattern is None: raise RuntimeError('Split pattern cannot be None, this is a bug.') + + num_splits: int | tuple[int, ...] if isinstance(split_pattern, int): num_splits = split_pattern else: diff --git a/flax/experimental/nnx/nnx/state.py b/flax/experimental/nnx/nnx/state.py index ffe77ea76..2c9272dcf 100644 --- a/flax/experimental/nnx/nnx/state.py +++ b/flax/experimental/nnx/nnx/state.py @@ -77,10 +77,10 @@ def __init__( super().__setattr__('_mapping', dict(mapping)) @property - def raw_mapping(self) -> dict[Key, dict[str, tp.Any] | tp.Any]: - return self._mapping + def raw_mapping(self) -> tp.Mapping[Key, tp.Mapping[Key, tp.Any] | StateLeaf]: + return self._mapping # type: ignore - def __contains__(self, key: Key) -> bool: + def __contains__(self, key) -> bool: return key in self._mapping def __getitem__(self, key: Key) -> State | StateLeaf: @@ -147,7 +147,7 @@ def split( self, first: filterlib.Filter, /, *filters: filterlib.Filter ) -> tp.Union['State', tuple['State', ...]]: filters = (first, *filters) - *states, rest = _split_state(self, *filters) + *states_, rest = _split_state(self, *filters) if rest: raise ValueError( @@ -155,10 +155,11 @@ def split( f'{list(rest.keys())}.\nUse `...` to match all remaining elements.' ) - if len(states) == 1: - states = states[0] + states: State | tuple[State, ...] + if len(states_) == 1: + states = states_[0] else: - states = tuple(states) + states = tuple(states_) return states @tp.overload @@ -185,14 +186,15 @@ def filter( /, *filters: filterlib.Filter, ) -> tp.Union['State', tuple['State', ...]]: - *states, _rest = _split_state(self, first, *filters) + *states_, _rest = _split_state(self, first, *filters) - assert len(states) == len(filters) + 1 + assert len(states_) == len(filters) + 1 - if len(states) == 1: - states = states[0] + states: State | tuple[State, ...] + if len(states_) == 1: + states = states_[0] else: - states = tuple(states) + states = tuple(states_) return states @@ -241,7 +243,7 @@ def _state_unflatten( jax.tree_util.register_pytree_with_keys( State, _state_flatten_with_keys, - _state_unflatten, + _state_unflatten, # type: ignore[arg-type] ) diff --git a/flax/experimental/nnx/nnx/training/metrics.py b/flax/experimental/nnx/nnx/training/metrics.py index 9c434f040..71f6f6fd1 100644 --- a/flax/experimental/nnx/nnx/training/metrics.py +++ b/flax/experimental/nnx/nnx/training/metrics.py @@ -44,7 +44,8 @@ def __init__(self): raise NotImplementedError('Must override `__init__()` method.') def reset(self): raise NotImplementedError('Must override `reset()` method.') - def update(self): + + def update(self, **kwargs) -> None: raise NotImplementedError('Must override `update()` method.') def compute(self): raise NotImplementedError('Must override `compute()` method.') @@ -71,7 +72,7 @@ def compute(self): return self.total.value / self.count.value class Accuracy(Average): - def update(self, *, logits: jax.Array, labels: jax.Array, **_): + def update(self, *, logits: jax.Array, labels: jax.Array, **_): # type: ignore[override] if logits.ndim != labels.ndim + 1 or labels.dtype != jnp.int32: raise ValueError( f"Expected labels.dtype==jnp.int32 and logits.ndim={logits.ndim}==" diff --git a/flax/experimental/nnx/nnx/transforms.py b/flax/experimental/nnx/nnx/transforms.py index ad5a4cdc3..1bc5fe912 100644 --- a/flax/experimental/nnx/nnx/transforms.py +++ b/flax/experimental/nnx/nnx/transforms.py @@ -47,7 +47,7 @@ spmd, variables, ) -from flax.experimental.nnx.nnx.module import GraphDef, Module, ModuleMeta +from flax.experimental.nnx.nnx.module import GraphDef, Module from flax.experimental.nnx.nnx.proxy_caller import ( CallableProxy, DelayedAccessor, @@ -61,12 +61,14 @@ F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any]) G = tp.TypeVar('G', bound=tp.Callable[..., tp.Any]) M = tp.TypeVar('M', bound=Module) +MA = tp.TypeVar('MA', bound=Module) N = tp.TypeVar('N', bound=Module) StrInt = tp.TypeVar('StrInt', str, int) AxisName = tp.Hashable Leaves = tp.List[Leaf] Index = int + def _normalize_sequence( x: StrInt | tp.Iterable[StrInt] | None, / ) -> tuple[StrInt, ...]: @@ -98,7 +100,7 @@ def call(self) -> tp.Any: def check_and_call(accessor: DelayedAccessor, *args, **kwargs): return self._call(accessor, *args, **kwargs) - proxy = CallableProxy(check_and_call) + proxy = CallableProxy(check_and_call) # type: ignore[arg-type] while isinstance(module._submodule, LiftedModule): module = module._submodule @@ -113,6 +115,7 @@ def check_and_call(accessor: DelayedAccessor, *args, **kwargs): UNSPECIFIED = object() + @dataclasses.dataclass(frozen=True) class JitStaticInputs: graphdef: GraphDef[tuple[tp.Any, ...]] @@ -130,6 +133,7 @@ class JitStaticOutputs: jax.tree_util.register_static(JitStaticOutputs) + def _default_constrain_object_state(state: State) -> State: state_spec = spmd.get_partition_spec(state) state = jax.lax.with_sharding_constraint(state, state_spec) @@ -212,53 +216,6 @@ def get_jit_kwargs(self) -> dict[str, tp.Any]: return kwargs -class JITMeta(ModuleMeta): - def __call__( - self, - module_constructor: tp.Callable[..., M], - *, - in_shardings: tp.Any = UNSPECIFIED, - out_shardings: tp.Any = UNSPECIFIED, - static_argnums: int | tp.Sequence[int] | None = None, - static_argnames: str | tp.Iterable[str] | None = None, - donate_argnums: int | tp.Sequence[int] | None = None, - donate_argnames: str | tp.Iterable[str] | None = None, - keep_unused: bool = False, - device: tp.Optional[jax.Device] = None, - backend: tp.Optional[str] = None, - inline: bool = False, - abstracted_axes: tp.Optional[tp.Any] = None, - # nnx specific - donate_state: bool = False, - constrain_state: bool | tp.Callable[[State], State] = False, - ) -> tp.Callable[..., 'Jit[M]']: - super_call = super().__call__ - - def _create_jit(*args, **kwargs) -> Jit[M]: - return super_call( - module_constructor=module_constructor, - in_shardings=in_shardings, - out_shardings=out_shardings, - static_argnums=static_argnums, - static_argnames=static_argnames, - donate_argnums=donate_argnums, - donate_argnames=donate_argnames, - keep_unused=keep_unused, - device=device, - backend=backend, - inline=inline, - abstracted_axes=abstracted_axes, - # nnx specific - donate_state=donate_state, - constrain_state=constrain_state, - # submodule args - module_init_args=args, - module_init_kwargs=kwargs, - ) - - return _create_jit - - class JittedFn(tp.Protocol): def __call__( self, @@ -330,7 +287,50 @@ def jit_apply( return out -class Jit(LiftedModule[M], metaclass=JITMeta): +class Jit(LiftedModule[M]): + @staticmethod + def constructor( + module_constructor: tp.Callable[..., MA], + *, + in_shardings: tp.Any = UNSPECIFIED, + out_shardings: tp.Any = UNSPECIFIED, + static_argnums: int | tp.Sequence[int] | None = None, + static_argnames: str | tp.Iterable[str] | None = None, + donate_argnums: int | tp.Sequence[int] | None = None, + donate_argnames: str | tp.Iterable[str] | None = None, + keep_unused: bool = False, + device: tp.Optional[jax.Device] = None, + backend: tp.Optional[str] = None, + inline: bool = False, + abstracted_axes: tp.Optional[tp.Any] = None, + # nnx specific + donate_state: bool = False, + constrain_state: bool | tp.Callable[[State], State] = False, + ) -> tp.Callable[..., 'Jit[MA]']: + def _create_jit(*args, **kwargs): + return Jit( + module_constructor=module_constructor, + in_shardings=in_shardings, + out_shardings=out_shardings, + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + donate_argnames=donate_argnames, + keep_unused=keep_unused, + device=device, + backend=backend, + inline=inline, + abstracted_axes=abstracted_axes, + # nnx specific + donate_state=donate_state, + constrain_state=constrain_state, + # submodule args + module_init_args=args, + module_init_kwargs=kwargs, + ) + + return _create_jit + def __init__( self, module_constructor: tp.Callable[..., M], @@ -375,7 +375,7 @@ def jit_call_module(module, *args, **kwargs): method = self.accessor(module) return method(*args, **kwargs) - self.jitted_fn: JittedFn[M] = get_jitted_fn(jit_call_module, self.options) + self.jitted_fn: JittedFn = get_jitted_fn(jit_call_module, self.options) self.module_constructor = module_constructor self.jit_module = self.module_constructor( *module_init_args, **module_init_kwargs @@ -584,10 +584,10 @@ class GradOptions: wrt: filterlib.Filter -class GradMeta(ModuleMeta): - def __call__( - self, - module_constructor: tp.Callable[..., M], +class Grad(LiftedModule[M]): + @staticmethod + def constructor( + module_constructor: tp.Callable[..., MA], has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, @@ -595,11 +595,9 @@ def __call__( return_value: bool = False, *, wrt: filterlib.Filter = variables.Param, - ) -> tp.Callable[..., 'Grad[M]']: - super_call = super().__call__ - - def _create_grad(*args, **kwargs) -> Grad[M]: - return super_call( + ) -> tp.Callable[..., 'Grad[MA]']: + def _create_grad(*args, **kwargs): + return Grad( module_constructor=module_constructor, wrt=wrt, has_aux=has_aux, @@ -614,8 +612,6 @@ def _create_grad(*args, **kwargs) -> Grad[M]: return _create_grad - -class Grad(LiftedModule[M], metaclass=GradMeta): def __init__( self, module_constructor: tp.Callable[..., M], @@ -668,7 +664,7 @@ def grad_apply(options: GradOptions, f, args: tuple[tp.Any, ...]): if i in options.argnums and graph.is_node(arg) } - _, diff_state, _ = graph.split(diff_graph_nodes, options.wrt, ...) + _, diff_state, _ = graph.split(diff_graph_nodes, options.wrt, ...) # type: ignore[misc] for i in diff_graph_nodes: _args[i] = diff_state[i] @@ -821,8 +817,6 @@ def grad_wrapper(*args): return grad_wrapper # type: ignore - - def value_and_grad( f: tp.Callable[..., tp.Any], argnums: int | tp.Sequence[int] = 0, @@ -858,6 +852,7 @@ def value_and_grad_wrapper(*args): # scan # ------------------------------- + @dataclasses.dataclass class ScanOptions: length: int | None @@ -876,10 +871,10 @@ class ScanOptions: scan_output: bool -class ScanMeta(ModuleMeta): - def __call__( - self, - module_constructor: tp.Callable[..., M], +class Scan(LiftedModule[M]): + @staticmethod + def constructor( + module_constructor: tp.Callable[..., MA], *, length: int | None = None, reverse: bool = False, @@ -895,11 +890,9 @@ def __call__( split_rngs: filterlib.Filter = ..., transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), scan_output: bool = True, - ) -> tp.Callable[..., 'Scan[M]']: - super_call = super().__call__ - - def _create_scan(*args, **kwargs) -> Scan[M]: - return super_call( + ) -> tp.Callable[..., 'Scan[MA]']: + def _create_scan(*args, **kwargs): + return Scan( module_constructor=module_constructor, module_init_args=args, module_init_kwargs=kwargs, @@ -922,8 +915,6 @@ def _create_scan(*args, **kwargs) -> Scan[M]: return _create_scan - -class Scan(LiftedModule[M], metaclass=ScanMeta): def __init__( self, module_constructor: tp.Callable[..., M], @@ -964,7 +955,7 @@ def __init__( scan_output=scan_output, ) # use Vmap to handle initialisation - vmapped_module = Vmap( + vmapped_module = Vmap.constructor( module_constructor, in_axes=in_axes, out_axes=None, @@ -1014,7 +1005,7 @@ def scan_apply( ctx = graph.UpdateContext() # split module state filters = (*options.state_axes.keys(), ...) - graphdef, rng_state, *scan_states, carry_state = ctx.split( + graphdef, rng_state, *scan_states, carry_state = ctx.split( # type: ignore[misc] input_graph_nodes, rnglib.RngState, *filters ) @@ -1122,7 +1113,7 @@ def scan_fn( rng_state_out, *scan_states_out, carry_state_out, - ) = ctx.split( + ) = ctx.split( # type: ignore[misc] (input_graph_nodes, output_graph_nodes), rnglib.RngState, *filters, @@ -1212,8 +1203,10 @@ class FlatDef(tp.Generic[A]): treedef: jax.tree_util.PyTreeDef flat_axes: list[int | None] + jax.tree_util.register_static(FlatDef) + def _transpose_tree(tree: A, axes, /, *, axis_is_source: bool) -> A: flatdef, flat_transposes, _ = _transpose_and_split( tree, axes, allow_none=False, axis_is_source=axis_is_source @@ -1268,6 +1261,7 @@ def _transpose_and_split( return flatdef, flat_transposes, flat_broadcasts + def _unflatten_splits( flatdef: FlatDef[A], flat_transposes: list[jax.Array | None], @@ -1373,27 +1367,6 @@ def scan_apply_wrapper(*args, **kwargs) -> C | tuple[C, tp.Any]: # ------------------------------- -class RematMeta(ModuleMeta): - def __call__( - self, - module_constructor: tp.Callable[..., M], - prevent_cse: bool = True, - static_argnums: int | tuple[int, ...] = (), - policy: tp.Callable[..., bool] | None = None, - ) -> tp.Callable[..., 'Remat[M]']: - super_call = super().__call__ - - def create_remat(*args, **kwargs) -> Remat[M]: - return super_call( - module_constructor=module_constructor, - module_init_args=args, - module_init_kwargs=kwargs, - prevent_cse=prevent_cse, - static_argnums=static_argnums, - policy=policy, - ) - - return create_remat @dataclasses.dataclass @@ -1412,7 +1385,26 @@ def __post_init__(self): ) -class Remat(LiftedModule[M], metaclass=RematMeta): +class Remat(LiftedModule[M]): + @staticmethod + def constructor( + module_constructor: tp.Callable[..., MA], + prevent_cse: bool = True, + static_argnums: int | tuple[int, ...] = (), + policy: tp.Callable[..., bool] | None = None, + ) -> tp.Callable[..., 'Remat[MA]']: + def create_remat(*args, **kwargs): + return Remat( + module_constructor=module_constructor, + module_init_args=args, + module_init_kwargs=kwargs, + prevent_cse=prevent_cse, + static_argnums=static_argnums, + policy=policy, + ) + + return create_remat + def __init__( self, *, @@ -1505,6 +1497,7 @@ def remat_wrapper(*args): # vmap # ------------------------------- + @dataclasses.dataclass class VmapOptions: in_axes: int | None | tp.Sequence[tp.Any] @@ -1519,10 +1512,10 @@ class VmapOptions: transform_metadata: tp.Mapping[str, tp.Any] -class VmapMeta(ModuleMeta): - def __call__( - self, - module_constructor: tp.Callable[..., M], +class Vmap(LiftedModule[M]): + @staticmethod + def constructor( + module_constructor: tp.Callable[..., MA], *, in_axes: int | None | tp.Sequence[tp.Any] = 0, out_axes: tp.Any = 0, @@ -1534,11 +1527,9 @@ def __call__( state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), split_rngs: filterlib.Filter = ..., transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), - ) -> tp.Callable[..., 'Vmap[M]']: - super_call = super().__call__ - - def _create_vmap(*args, **kwargs) -> Scan[M]: - return super_call( + ) -> tp.Callable[..., 'Vmap[MA]']: + def _create_vmap(*args, **kwargs): + return Vmap( module_constructor=module_constructor, in_axes=in_axes, out_axes=out_axes, @@ -1557,8 +1548,6 @@ def _create_vmap(*args, **kwargs) -> Scan[M]: return _create_vmap - -class Vmap(LiftedModule[M], metaclass=VmapMeta): def __init__( self, module_constructor: tp.Callable[..., M], @@ -1625,6 +1614,7 @@ def vmap_apply_call(module, *args, **kwargs): kwargs, ) + def vmap_apply( options: VmapOptions, f: tp.Callable[..., A], @@ -1637,7 +1627,7 @@ def vmap_apply( ctx = graph.UpdateContext() # split module state filters = (*options.state_axes.keys(), ...) - graphdef, rng_state, *vectorized_states, broadcast_state = ctx.split( + graphdef, rng_state, *vectorized_states, broadcast_state = ctx.split( # type: ignore[misc] input_graph_nodes, rnglib.RngState, *filters ) @@ -1741,7 +1731,7 @@ def vmap_fn( rng_state_out, *vectorized_states_out, broadcast_state_out, - ) = ctx.split( + ) = ctx.split( # type: ignore[misc] (input_graph_nodes, output_graph_nodes), rnglib.RngState, *filters, @@ -1828,6 +1818,7 @@ def vmap_apply_wrapper(*args, **kwargs) -> tp.Any: return wrapper # type: ignore + # ------------------------------- # eval_shape # ------------------------------- @@ -1856,4 +1847,4 @@ def _eval_shape_fn(state: State, *args, **kwargs): output_nodes = graph.merge(graphdef_out, state_out) out = graph.insert_graph_nodes(out, output_nodes) - return out \ No newline at end of file + return out diff --git a/flax/experimental/nnx/nnx/variables.py b/flax/experimental/nnx/nnx/variables.py index 4d76980e5..4b7d8eae5 100644 --- a/flax/experimental/nnx/nnx/variables.py +++ b/flax/experimental/nnx/nnx/variables.py @@ -69,7 +69,8 @@ def __hash__(self): lambda _0, _1: EMPTY, ) -EMPTY = Empty() +EMPTY: Empty = Empty() + class _Missing: pass @@ -391,7 +392,7 @@ def __jax_array__(self): return self.value def __getitem__(self, key) -> tp.Any: - return self.value.__getitem__(key) + return self.value.__getitem__(key) # type: ignore def __add__(self, other) -> A: return self.value.__add__(other) # type: ignore @@ -572,7 +573,7 @@ class Intermediate(Variable[A]): class VariableState(tp.Generic[A], reprlib.Representable): def __init__( self, - type: tp.Type[Variable[A]], + type: type[Variable[tp.Any]], value: A, **metadata, ): diff --git a/flax/experimental/nnx/nnx/visualization.py b/flax/experimental/nnx/nnx/visualization.py index 94e1aae27..65c09f771 100644 --- a/flax/experimental/nnx/nnx/visualization.py +++ b/flax/experimental/nnx/nnx/visualization.py @@ -40,7 +40,7 @@ def display(*args): print(x) return - from penzai import pz + from penzai import pz # type: ignore[import-not-found] with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()): for x in args: @@ -70,7 +70,7 @@ def _to_dataclass(x, seen_nodes: set[int]): } dc_type = _make_dataclass_obj( type(x), - node_dict, + {str(key): value for key, value in node_dict.items()}, ) return dc_type elif isinstance(x, (nnx.Variable, nnx.VariableState)): @@ -109,7 +109,7 @@ def _to_dataclass_fn(x): def _make_dataclass_obj( - cls, fields: dict[str, tp.Any], penzai_dataclass: bool = True + cls, fields: tp.Mapping[str, tp.Any], penzai_dataclass: bool = True ) -> tp.Type: from penzai import pz diff --git a/flax/experimental/nnx/tests/nn/test_conv.py b/flax/experimental/nnx/tests/nn/test_conv.py index 1d6016eb4..764b8b3fa 100644 --- a/flax/experimental/nnx/tests/nn/test_conv.py +++ b/flax/experimental/nnx/tests/nn/test_conv.py @@ -58,7 +58,9 @@ def test_nnx_linen_equivalence( kernel_size = (7, 4) # Cannot use string padding specification for transpose conv - if isinstance(input_dilation, Sequence) or input_dilation > 1: + if isinstance(input_dilation, Sequence) or ( + isinstance(input_dilation, int) and input_dilation > 1 + ): padding = (4, 2) x = jax.numpy.ones(INPUT_SHAPE) diff --git a/flax/experimental/nnx/tests/nn/test_embed.py b/flax/experimental/nnx/tests/nn/test_embed.py index bf761346e..bed5ab1a8 100644 --- a/flax/experimental/nnx/tests/nn/test_embed.py +++ b/flax/experimental/nnx/tests/nn/test_embed.py @@ -67,5 +67,6 @@ def test_nnx_linen_equivalence( x = jax.numpy.ones((10,), dtype=input_dtype) * 10 out_nnx = model_nnx(x) out = model.apply(variables, x) + assert isinstance(out, jax.Array) assert_array_equal(out, out_nnx) assert_array_equal(jax.numpy.isnan(out).all(), jax.numpy.array([True])) diff --git a/flax/experimental/nnx/tests/nn/test_linear.py b/flax/experimental/nnx/tests/nn/test_linear.py index a2eee40cd..944f03b97 100644 --- a/flax/experimental/nnx/tests/nn/test_linear.py +++ b/flax/experimental/nnx/tests/nn/test_linear.py @@ -135,6 +135,7 @@ def test_nnx_einsum_equivalence( variables = model.init(key, x) variables['params']['kernel'] = model_nnx.kernel.value if bias_shape is not None: + assert model_nnx.bias is not None variables['params']['bias'] = model_nnx.bias.value out_nnx = model_nnx(x) out = model.apply(variables, x) @@ -143,6 +144,7 @@ def test_nnx_einsum_equivalence( variables = model.init(key, x) model_nnx.kernel.value = variables['params']['kernel'] if bias_shape is not None: + assert model_nnx.bias is not None model_nnx.bias.value = variables['params']['bias'] out_nnx = model_nnx(x) out = model.apply(variables, x) diff --git a/flax/experimental/nnx/tests/test_graph_utils.py b/flax/experimental/nnx/tests/test_graph_utils.py index d976a4422..c158209af 100644 --- a/flax/experimental/nnx/tests/test_graph_utils.py +++ b/flax/experimental/nnx/tests/test_graph_utils.py @@ -300,7 +300,7 @@ def __init__(self, *, rngs: nnx.Rngs): self.b = nnx.BatchNorm(2, rngs=rngs) def f(m: Foo): - m.a, m.b = m.b, m.a + m.a, m.b = m.b, m.a # type: ignore m = Foo(rngs=nnx.Rngs(0)) a = m.a diff --git a/flax/experimental/nnx/tests/test_optimizer.py b/flax/experimental/nnx/tests/test_optimizer.py index 689005690..fb749cc23 100644 --- a/flax/experimental/nnx/tests/test_optimizer.py +++ b/flax/experimental/nnx/tests/test_optimizer.py @@ -64,14 +64,14 @@ def test_jit(self, module_cls, jit_decorator, optimizer): ).mean() initial_loss = loss_fn(model_static, model_state, x, y) - def train_step(graphdef, state, x, y): + def jax_jit_train_step(graphdef, state, x, y): state = nnx.merge(graphdef, state) model_static, model_state = nnx.split(state.model) grads = jax.grad(loss_fn, argnums=1)(model_static, model_state, x, y) state.update(grads) return state.split() - graphdef, state = jit_decorator(train_step)(*state.split(), x, y) + graphdef, state = jit_decorator(jax_jit_train_step)(*state.split(), x, y) state = nnx.merge(graphdef, state) new_loss = loss_fn(*nnx.split(state.model), x, y) @@ -79,11 +79,11 @@ def train_step(graphdef, state, x, y): loss_fn = lambda model, x, y: ((model(x)-y)**2).mean() initial_loss = loss_fn(state.model, x, y) - def train_step(optimizer: nnx.Optimizer, x, y): + def nnx_jit_train_step(optimizer: nnx.Optimizer, x, y): grads = nnx.grad(loss_fn, wrt=nnx.Param)(optimizer.model, x, y) optimizer.update(grads) - jit_decorator(train_step)(state, x, y) + jit_decorator(nnx_jit_train_step)(state, x, y) new_loss = loss_fn(state.model, x, y) self.assertTrue(new_loss < initial_loss) diff --git a/flax/experimental/nnx/tests/test_transforms.py b/flax/experimental/nnx/tests/test_transforms.py index db6cb2428..67b26e46b 100644 --- a/flax/experimental/nnx/tests/test_transforms.py +++ b/flax/experimental/nnx/tests/test_transforms.py @@ -110,7 +110,7 @@ def __call__(self, x: jax.Array) -> jax.Array: n += 1 return jnp.dot(x, self.w.value) - m = nnx.Jit(Foo)(2, 3, rngs=nnx.Rngs(0)) + m = nnx.Jit.constructor(Foo)(2, 3, rngs=nnx.Rngs(0)) y = m(jnp.ones((1, 2))) assert y.shape == (1, 3) @@ -130,7 +130,7 @@ def __init__(self, *, rngs: nnx.Rngs): def f(m: Foo): nonlocal n n += 1 - m.a, m.b = m.b, m.a + m.a, m.b = m.b, m.a # type: ignore m = Foo(rngs=nnx.Rngs(0)) a = m.a @@ -533,7 +533,7 @@ def __call__(self, x: jax.Array) -> tp.Tuple[jax.Array, None]: x = nnx.gelu(x) return x, None - MLP = nnx.Scan( + MLP = nnx.Scan.constructor( Block, state_axes={nnx.Param: 0}, length=5, @@ -562,7 +562,7 @@ def __call__(self, x: jax.Array): x = nnx.gelu(x) return x - MLP = nnx.Scan( + MLP = nnx.Scan.constructor( Block, state_axes={nnx.Param: 0}, length=5, @@ -591,7 +591,7 @@ def __call__(self, x: jax.Array): x = nnx.gelu(x) return x, (x, x) - MLP = nnx.Scan( + MLP = nnx.Scan.constructor( Block, state_axes={nnx.Param: 0}, length=5, @@ -626,7 +626,7 @@ def __call__( x = nnx.gelu(x) return x, None - MLP = nnx.Scan( + MLP = nnx.Scan.constructor( Block, state_axes={nnx.Param: 0}, length=5, @@ -661,7 +661,7 @@ def __call__( x = nnx.gelu(x) return x, None - MLP = nnx.Scan( + MLP = nnx.Scan.constructor( Block, state_axes={nnx.Param: 0}, length=5, @@ -697,7 +697,7 @@ def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: x = nnx.gelu(x) return x - MLP = nnx.Scan( + MLP = nnx.Scan.constructor( Block, state_axes={nnx.Param: 0}, length=5, scan_output=False ) @@ -728,7 +728,7 @@ def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: x = nnx.gelu(x) return x - MLP = nnx.Scan( + MLP = nnx.Scan.constructor( Block, state_axes={nnx.Param: 0}, length=5, @@ -814,14 +814,14 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: # test sharding layer axes is not present inside scan state = nnx.state(self.linear) - assert state.kernel.value.shape == (3, 3) - assert state.kernel.sharding == ('din', 'dout') - assert state.bias.value.shape == (3,) - assert state.bias.sharding == ('dout',) + assert state.kernel.value.shape == (3, 3) # type: ignore + assert state.kernel.sharding == ('din', 'dout') # type: ignore + assert state.bias.value.shape == (3,) # type: ignore + assert state.bias.sharding == ('dout',) # type: ignore return x, None - MLP = nnx.Scan( + MLP = nnx.Scan.constructor( Block, state_axes={nnx.Param: 0}, length=5, @@ -873,7 +873,7 @@ def __init__(self, *, rngs: nnx.Rngs): def __call__(self): return None, None - MLP = nnx.Scan( + MLP = nnx.Scan.constructor( Block, state_axes={nnx.Param: 0}, length=5, @@ -889,7 +889,7 @@ def __call__(self): class TestRemat: def test_basic_remat(self): - RematLinear = nnx.Remat(nnx.Linear) + RematLinear = nnx.Remat.constructor(nnx.Linear) module = RematLinear(2, 3, rngs=nnx.Rngs(0)) @@ -922,9 +922,9 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: x = self.linear(x) return x, None - RematLinear = nnx.Remat(LinearBlock) + RematLinear = nnx.Remat.constructor(LinearBlock) - ScanRematLinear = nnx.Scan( + ScanRematLinear = nnx.Scan.constructor( RematLinear, state_axes={nnx.Param: 0}, length=5, @@ -1121,7 +1121,7 @@ def __call__(self, x: jax.Array) -> jax.Array: x = nnx.gelu(x) return x - MLP = nnx.Vmap(Block, state_axes={nnx.Param: 0}, axis_size=5) + MLP = nnx.Vmap.constructor(Block, state_axes={nnx.Param: 0}, axis_size=5) module = MLP(rngs=nnx.Rngs(0)) @@ -1148,7 +1148,7 @@ def __call__(self, x: jax.Array) -> jax.Array: x = nnx.gelu(x) return x - MLP = nnx.Vmap(Block, state_axes={nnx.Param: 0}, axis_size=5) + MLP = nnx.Vmap.constructor(Block, state_axes={nnx.Param: 0}, axis_size=5) module = MLP(graphdef='hello', rngs=nnx.Rngs(0)) diff --git a/flax/experimental/nnx/tests/test_variable.py b/flax/experimental/nnx/tests/test_variable.py index 604806056..af297eeae 100644 --- a/flax/experimental/nnx/tests/test_variable.py +++ b/flax/experimental/nnx/tests/test_variable.py @@ -56,7 +56,7 @@ def __init__(self, din, dout, rngs: nnx.Rngs): self.b = nnx.Param(jax.numpy.zeros((dout,))) def __call__(self, x: jax.Array): - return jnp.dot(x, self.w) + self.b + return jnp.dot(x, self.w) + self.b # type: ignore[arg-type] linear = Linear(3, 4, nnx.Rngs(0)) x = jax.numpy.ones((3,)) diff --git a/flax/typing.py b/flax/typing.py index d6ecf02ca..8d0fc5855 100644 --- a/flax/typing.py +++ b/flax/typing.py @@ -25,6 +25,7 @@ Tuple, TypeVar, Union, + runtime_checkable, ) import jax @@ -42,7 +43,7 @@ Shape = Sequence[int] K = TypeVar('K') - +@runtime_checkable class Key(Hashable, Protocol): def __lt__(self: K, value: K, /) -> bool: ... diff --git a/pyproject.toml b/pyproject.toml index 1595e1c08..9b3992306 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,9 +112,10 @@ module = [ "yaml", ] ignore_missing_imports = true -# exclude nnx +disable_error_code = "annotation-unchecked" +# exclude nnx examples [[tool.mypy.overrides]] -module = "flax.experimental.nnx.*" +module = "flax.experimental.nnx.examples.*" ignore_errors = true [tool.pytest.ini_options] From 0459456f491f906f86eb074e4123d759832e0eb7 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Thu, 2 May 2024 10:30:27 +0100 Subject: [PATCH 2/2] [nnx] enable pytype on CI --- .github/workflows/build.yml | 2 +- flax/experimental/nnx/__init__.py | 202 +++++++++--------- flax/experimental/nnx/examples/lm1b/models.py | 2 +- .../nnx/examples/lm1b/models_test.py | 17 +- .../nnx/examples/lm1b/train_test.py | 4 + flax/experimental/nnx/examples/lm1b/utils.py | 2 +- flax/experimental/nnx/nnx/graph.py | 2 +- flax/experimental/nnx/nnx/helpers.py | 4 +- flax/experimental/nnx/nnx/nn/attention.py | 4 +- flax/experimental/nnx/nnx/nn/dtypes.py | 8 +- flax/experimental/nnx/nnx/nn/linear.py | 18 +- flax/experimental/nnx/nnx/proxy_caller.py | 2 +- flax/experimental/nnx/nnx/state.py | 12 +- .../experimental/nnx/nnx/training/__init__.py | 0 flax/experimental/nnx/nnx/transforms.py | 22 +- flax/experimental/nnx/nnx/visualization.py | 4 +- flax/experimental/nnx/tests/test_optimizer.py | 3 +- flax/experimental/nnx/tests/test_spmd.py | 4 +- .../experimental/nnx/tests/test_transforms.py | 4 +- tests/run_all_tests.sh | 17 +- 20 files changed, 176 insertions(+), 157 deletions(-) create mode 100644 flax/experimental/nnx/nnx/training/__init__.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 193f7ced5..98f4c64e3 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -86,7 +86,7 @@ jobs: test-type: [doctest, pytest, pytype, mypy] exclude: - test-type: pytype - python-version: '3.11' + python-version: '3.9' - test-type: pytype python-version: '3.10' - test-type: mypy diff --git a/flax/experimental/nnx/__init__.py b/flax/experimental/nnx/__init__.py index 1b237768e..ba6384c95 100644 --- a/flax/experimental/nnx/__init__.py +++ b/flax/experimental/nnx/__init__.py @@ -18,104 +18,104 @@ from flax.linen.pooling import pool as pool from flax.typing import Initializer as Initializer -from .nnx import compatibility as compatibility -from .nnx import graph as graph -from .nnx import errors as errors -from .nnx import errors as helpers -from .nnx.filterlib import All as All -from .nnx.filterlib import Not as Not -from .nnx.graph import GraphDef as GraphDef -from .nnx.graph import GraphNode as GraphNode -from .nnx.helpers import Dict as Dict -from .nnx.helpers import List as List -from .nnx.helpers import Sequential as Sequential -from .nnx.helpers import TrainState as TrainState -from .nnx.module import M as M -from .nnx.module import Module as Module -from .nnx.graph import merge as merge -from .nnx.graph import UpdateContext as UpdateContext -from .nnx.graph import split as split -from .nnx.graph import update as update -from .nnx.graph import clone as clone -from .nnx.graph import pop as pop -from .nnx.graph import state as state -from .nnx.graph import graphdef as graphdef -from .nnx.nn import initializers as initializers -from .nnx.nn.activations import celu as celu -from .nnx.nn.activations import elu as elu -from .nnx.nn.activations import gelu as gelu -from .nnx.nn.activations import glu as glu -from .nnx.nn.activations import hard_sigmoid as hard_sigmoid -from .nnx.nn.activations import hard_silu as hard_silu -from .nnx.nn.activations import hard_swish as hard_swish -from .nnx.nn.activations import hard_tanh as hard_tanh -from .nnx.nn.activations import leaky_relu as leaky_relu -from .nnx.nn.activations import log_sigmoid as log_sigmoid -from .nnx.nn.activations import log_softmax as log_softmax -from .nnx.nn.activations import logsumexp as logsumexp -from .nnx.nn.activations import one_hot as one_hot -from .nnx.nn.activations import relu as relu -from .nnx.nn.activations import relu6 as relu6 -from .nnx.nn.activations import selu as selu -from .nnx.nn.activations import sigmoid as sigmoid -from .nnx.nn.activations import silu as silu -from .nnx.nn.activations import soft_sign as soft_sign -from .nnx.nn.activations import softmax as softmax -from .nnx.nn.activations import softplus as softplus -from .nnx.nn.activations import standardize as standardize -from .nnx.nn.activations import swish as swish -from .nnx.nn.activations import tanh as tanh -from .nnx.nn.attention import MultiHeadAttention as MultiHeadAttention -from .nnx.nn.attention import combine_masks as combine_masks -from .nnx.nn.attention import dot_product_attention as dot_product_attention -from .nnx.nn.attention import make_attention_mask as make_attention_mask -from .nnx.nn.attention import make_causal_mask as make_causal_mask -from .nnx.nn.linear import Conv as Conv -from .nnx.nn.linear import Embed as Embed -from .nnx.nn.linear import Linear as Linear -from .nnx.nn.linear import LinearGeneral as LinearGeneral -from .nnx.nn.linear import Einsum as Einsum -from .nnx.nn.normalization import BatchNorm as BatchNorm -from .nnx.nn.normalization import LayerNorm as LayerNorm -from .nnx.nn.normalization import RMSNorm as RMSNorm -from .nnx.nn.stochastic import Dropout as Dropout -from .nnx.rnglib import Rngs as Rngs -from .nnx.rnglib import RngStream as RngStream -from .nnx.rnglib import RngState as RngState -from .nnx.rnglib import RngKey as RngKey -from .nnx.rnglib import RngCount as RngCount -from .nnx.rnglib import fork as fork -from .nnx.spmd import PARTITION_NAME as PARTITION_NAME -from .nnx.spmd import get_partition_spec as get_partition_spec -from .nnx.spmd import get_named_sharding as get_named_sharding -from .nnx.spmd import with_partitioning as with_partitioning -from .nnx.spmd import with_sharding_constraint as with_sharding_constraint -from .nnx.state import State as State -from .nnx.training import metrics as metrics -from .nnx.training import optimizer as optimizer -from .nnx.training.metrics import Metric as Metric -from .nnx.training.metrics import MultiMetric as MultiMetric -from .nnx.training.optimizer import Optimizer as Optimizer -from .nnx.transforms import Jit as Jit -from .nnx.transforms import jit as jit -from .nnx.transforms import Remat as Remat -from .nnx.transforms import Scan as Scan -from .nnx.transforms import Vmap as Vmap -from .nnx.transforms import grad as grad -from .nnx.transforms import remat as remat -from .nnx.transforms import scan as scan -from .nnx.transforms import value_and_grad as value_and_grad -from .nnx.transforms import vmap as vmap -from .nnx.transforms import eval_shape as eval_shape -from .nnx.variables import EMPTY as EMPTY -from .nnx.variables import A as A -from .nnx.variables import BatchStat as BatchStat -from .nnx.variables import Cache as Cache -from .nnx.variables import Empty as Empty -from .nnx.variables import Intermediate as Intermediate -from .nnx.variables import Param as Param -from .nnx.variables import Variable as Variable -from .nnx.variables import VariableState as VariableState -from .nnx.variables import VariableMetadata as VariableMetadata -from .nnx.variables import with_metadata as with_metadata -from .nnx.visualization import display as display +from flax.experimental.nnx.nnx import compatibility as compatibility +from flax.experimental.nnx.nnx import graph as graph +from flax.experimental.nnx.nnx import errors as errors +from flax.experimental.nnx.nnx import errors as helpers +from flax.experimental.nnx.nnx.filterlib import All as All +from flax.experimental.nnx.nnx.filterlib import Not as Not +from flax.experimental.nnx.nnx.graph import GraphDef as GraphDef +from flax.experimental.nnx.nnx.graph import GraphNode as GraphNode +from flax.experimental.nnx.nnx.helpers import Dict as Dict +from flax.experimental.nnx.nnx.helpers import List as List +from flax.experimental.nnx.nnx.helpers import Sequential as Sequential +from flax.experimental.nnx.nnx.helpers import TrainState as TrainState +from flax.experimental.nnx.nnx.module import M as M +from flax.experimental.nnx.nnx.module import Module as Module +from flax.experimental.nnx.nnx.graph import merge as merge +from flax.experimental.nnx.nnx.graph import UpdateContext as UpdateContext +from flax.experimental.nnx.nnx.graph import split as split +from flax.experimental.nnx.nnx.graph import update as update +from flax.experimental.nnx.nnx.graph import clone as clone +from flax.experimental.nnx.nnx.graph import pop as pop +from flax.experimental.nnx.nnx.graph import state as state +from flax.experimental.nnx.nnx.graph import graphdef as graphdef +from flax.experimental.nnx.nnx.nn import initializers as initializers +from flax.experimental.nnx.nnx.nn.activations import celu as celu +from flax.experimental.nnx.nnx.nn.activations import elu as elu +from flax.experimental.nnx.nnx.nn.activations import gelu as gelu +from flax.experimental.nnx.nnx.nn.activations import glu as glu +from flax.experimental.nnx.nnx.nn.activations import hard_sigmoid as hard_sigmoid +from flax.experimental.nnx.nnx.nn.activations import hard_silu as hard_silu +from flax.experimental.nnx.nnx.nn.activations import hard_swish as hard_swish +from flax.experimental.nnx.nnx.nn.activations import hard_tanh as hard_tanh +from flax.experimental.nnx.nnx.nn.activations import leaky_relu as leaky_relu +from flax.experimental.nnx.nnx.nn.activations import log_sigmoid as log_sigmoid +from flax.experimental.nnx.nnx.nn.activations import log_softmax as log_softmax +from flax.experimental.nnx.nnx.nn.activations import logsumexp as logsumexp +from flax.experimental.nnx.nnx.nn.activations import one_hot as one_hot +from flax.experimental.nnx.nnx.nn.activations import relu as relu +from flax.experimental.nnx.nnx.nn.activations import relu6 as relu6 +from flax.experimental.nnx.nnx.nn.activations import selu as selu +from flax.experimental.nnx.nnx.nn.activations import sigmoid as sigmoid +from flax.experimental.nnx.nnx.nn.activations import silu as silu +from flax.experimental.nnx.nnx.nn.activations import soft_sign as soft_sign +from flax.experimental.nnx.nnx.nn.activations import softmax as softmax +from flax.experimental.nnx.nnx.nn.activations import softplus as softplus +from flax.experimental.nnx.nnx.nn.activations import standardize as standardize +from flax.experimental.nnx.nnx.nn.activations import swish as swish +from flax.experimental.nnx.nnx.nn.activations import tanh as tanh +from flax.experimental.nnx.nnx.nn.attention import MultiHeadAttention as MultiHeadAttention +from flax.experimental.nnx.nnx.nn.attention import combine_masks as combine_masks +from flax.experimental.nnx.nnx.nn.attention import dot_product_attention as dot_product_attention +from flax.experimental.nnx.nnx.nn.attention import make_attention_mask as make_attention_mask +from flax.experimental.nnx.nnx.nn.attention import make_causal_mask as make_causal_mask +from flax.experimental.nnx.nnx.nn.linear import Conv as Conv +from flax.experimental.nnx.nnx.nn.linear import Embed as Embed +from flax.experimental.nnx.nnx.nn.linear import Linear as Linear +from flax.experimental.nnx.nnx.nn.linear import LinearGeneral as LinearGeneral +from flax.experimental.nnx.nnx.nn.linear import Einsum as Einsum +from flax.experimental.nnx.nnx.nn.normalization import BatchNorm as BatchNorm +from flax.experimental.nnx.nnx.nn.normalization import LayerNorm as LayerNorm +from flax.experimental.nnx.nnx.nn.normalization import RMSNorm as RMSNorm +from flax.experimental.nnx.nnx.nn.stochastic import Dropout as Dropout +from flax.experimental.nnx.nnx.rnglib import Rngs as Rngs +from flax.experimental.nnx.nnx.rnglib import RngStream as RngStream +from flax.experimental.nnx.nnx.rnglib import RngState as RngState +from flax.experimental.nnx.nnx.rnglib import RngKey as RngKey +from flax.experimental.nnx.nnx.rnglib import RngCount as RngCount +from flax.experimental.nnx.nnx.rnglib import fork as fork +from flax.experimental.nnx.nnx.spmd import PARTITION_NAME as PARTITION_NAME +from flax.experimental.nnx.nnx.spmd import get_partition_spec as get_partition_spec +from flax.experimental.nnx.nnx.spmd import get_named_sharding as get_named_sharding +from flax.experimental.nnx.nnx.spmd import with_partitioning as with_partitioning +from flax.experimental.nnx.nnx.spmd import with_sharding_constraint as with_sharding_constraint +from flax.experimental.nnx.nnx.state import State as State +from flax.experimental.nnx.nnx.training import metrics as metrics +from flax.experimental.nnx.nnx.training import optimizer as optimizer +from flax.experimental.nnx.nnx.training.metrics import Metric as Metric +from flax.experimental.nnx.nnx.training.metrics import MultiMetric as MultiMetric +from flax.experimental.nnx.nnx.training.optimizer import Optimizer as Optimizer +from flax.experimental.nnx.nnx.transforms import Jit as Jit +from flax.experimental.nnx.nnx.transforms import jit as jit +from flax.experimental.nnx.nnx.transforms import Remat as Remat +from flax.experimental.nnx.nnx.transforms import Scan as Scan +from flax.experimental.nnx.nnx.transforms import Vmap as Vmap +from flax.experimental.nnx.nnx.transforms import grad as grad +from flax.experimental.nnx.nnx.transforms import remat as remat +from flax.experimental.nnx.nnx.transforms import scan as scan +from flax.experimental.nnx.nnx.transforms import value_and_grad as value_and_grad +from flax.experimental.nnx.nnx.transforms import vmap as vmap +from flax.experimental.nnx.nnx.transforms import eval_shape as eval_shape +from flax.experimental.nnx.nnx.variables import EMPTY as EMPTY +from flax.experimental.nnx.nnx.variables import A as A +from flax.experimental.nnx.nnx.variables import BatchStat as BatchStat +from flax.experimental.nnx.nnx.variables import Cache as Cache +from flax.experimental.nnx.nnx.variables import Empty as Empty +from flax.experimental.nnx.nnx.variables import Intermediate as Intermediate +from flax.experimental.nnx.nnx.variables import Param as Param +from flax.experimental.nnx.nnx.variables import Variable as Variable +from flax.experimental.nnx.nnx.variables import VariableState as VariableState +from flax.experimental.nnx.nnx.variables import VariableMetadata as VariableMetadata +from flax.experimental.nnx.nnx.variables import with_metadata as with_metadata +from flax.experimental.nnx.nnx.visualization import display as display diff --git a/flax/experimental/nnx/examples/lm1b/models.py b/flax/experimental/nnx/examples/lm1b/models.py index 58274de4e..1731ec7f3 100644 --- a/flax/experimental/nnx/examples/lm1b/models.py +++ b/flax/experimental/nnx/examples/lm1b/models.py @@ -33,7 +33,7 @@ from jax import lax from flax.experimental import nnx -from flax.experimental.nnx.examples.lm1b.configs import default +from configs import default Shape = tuple[int, ...] Dtype = Any diff --git a/flax/experimental/nnx/examples/lm1b/models_test.py b/flax/experimental/nnx/examples/lm1b/models_test.py index e66a2949b..76296ae50 100644 --- a/flax/experimental/nnx/examples/lm1b/models_test.py +++ b/flax/experimental/nnx/examples/lm1b/models_test.py @@ -28,19 +28,16 @@ from flax import traverse_util from flax.experimental import nnx -from flax.experimental.nnx.examples.lm1b.configs import default -from flax.experimental.nnx.examples.lm1b.models import ( - TransformerConfig, - TransformerLM, -) -from flax.experimental.nnx.examples.lm1b.utils import HasCache +from configs import default +from models import TransformerConfig, TransformerLM +from utils import HasCache jax.config.update('jax_disable_most_optimizations', True) # add project_root to import lm1b Linen model project_root = str(Path(__file__).absolute().parents[5]) sys.path.append(project_root) -from examples.lm1b.models import TransformerLM as TransformerLinen +from examples.lm1b.models import TransformerLM as TransformerLinen # type: ignore[import-error] sys.path.pop() @@ -208,6 +205,9 @@ def test_forward_eval(self): deterministic=True, decode=False, ) + # Set dropout rates to avoid create dropout states + config.dropout_rate = 0.0 + config.attention_dropout_rate = 0.0 model_nnx = nnx.eval_shape(lambda: TransformerLM(config, rngs=nnx.Rngs(0))) _, params_nnx = nnx.split(model_nnx, nnx.Param) @@ -242,6 +242,9 @@ def test_forward_decode(self): deterministic=True, decode=True, ) + # Set dropout rates to avoid create dropout states + config.dropout_rate = 0.0 + config.attention_dropout_rate = 0.0 model_nnx = nnx.eval_shape(lambda: TransformerLM(config, rngs=nnx.Rngs(0))) for _path, m in model_nnx.iter_modules(): diff --git a/flax/experimental/nnx/examples/lm1b/train_test.py b/flax/experimental/nnx/examples/lm1b/train_test.py index b279ee921..9040c4f26 100644 --- a/flax/experimental/nnx/examples/lm1b/train_test.py +++ b/flax/experimental/nnx/examples/lm1b/train_test.py @@ -52,6 +52,10 @@ def test_train_and_evaluate(self): config.max_eval_target_length = 32 config.max_predict_length = 32 + # Set dropout rates to avoid create dropout states + config.dropout_rate = 0.0 + config.attention_dropout_rate = 0.0 + workdir = tempfile.mkdtemp() # Go two directories up to the root of the flax directory. diff --git a/flax/experimental/nnx/examples/lm1b/utils.py b/flax/experimental/nnx/examples/lm1b/utils.py index 9ba2e280f..1bf2d7d8c 100644 --- a/flax/experimental/nnx/examples/lm1b/utils.py +++ b/flax/experimental/nnx/examples/lm1b/utils.py @@ -22,7 +22,7 @@ import jax.numpy as jnp import numpy as np from jax.experimental import mesh_utils -from flax.experimental.nnx.examples.lm1b.configs import default +from configs import default from models import TransformerConfig, TransformerLM from flax.experimental import nnx diff --git a/flax/experimental/nnx/nnx/graph.py b/flax/experimental/nnx/nnx/graph.py index 799b361f1..5c6e225e2 100644 --- a/flax/experimental/nnx/nnx/graph.py +++ b/flax/experimental/nnx/nnx/graph.py @@ -649,7 +649,7 @@ def _graph_pop( node_impl.pop_key(node, name) if isinstance(value, Variable): value = value.to_state() - state[node_path] = value + state[node_path] = value # type: ignore[index] # mypy is wrong here? break else: # NOTE: should we raise an error here? diff --git a/flax/experimental/nnx/nnx/helpers.py b/flax/experimental/nnx/nnx/helpers.py index d50a93271..4a2382697 100644 --- a/flax/experimental/nnx/nnx/helpers.py +++ b/flax/experimental/nnx/nnx/helpers.py @@ -161,9 +161,9 @@ def __call__( class TrainState(tp.Generic[M], struct.PyTreeNode): graphdef: GraphDef[M] params: State - tx: optax.GradientTransformation = struct.field(pytree_node=False) opt_state: optax.OptState step: jax.Array + tx: optax.GradientTransformation = struct.field(pytree_node=False) @classmethod def create( @@ -178,9 +178,9 @@ def create( return cls( graphdef=graphdef, params=params, - tx=tx, opt_state=tx.init(params), step=jnp.asarray(step), + tx=tx, **kwargs, ) diff --git a/flax/experimental/nnx/nnx/nn/attention.py b/flax/experimental/nnx/nnx/nn/attention.py index 56de71be4..2a4c6fd88 100644 --- a/flax/experimental/nnx/nnx/nn/attention.py +++ b/flax/experimental/nnx/nnx/nn/attention.py @@ -88,7 +88,7 @@ def dot_product_attention_weights( Returns: Output of shape `[batch..., num_heads, q_length, kv_length]`. """ - query, key = promote_dtype(query, key, dtype=dtype) + query, key = promote_dtype((query, key), dtype=dtype) # type: ignore[bad-unpacking] dtype = query.dtype assert query.ndim == key.ndim, 'q, k must have same rank.' @@ -184,7 +184,7 @@ def dot_product_attention( Returns: Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`. """ - query, key, value = promote_dtype(query, key, value, dtype=dtype) + query, key, value = promote_dtype((query, key, value), dtype=dtype) # type: ignore[bad-unpacking] dtype = query.dtype assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' assert ( diff --git a/flax/experimental/nnx/nnx/nn/dtypes.py b/flax/experimental/nnx/nnx/nn/dtypes.py index c204e0426..070de099f 100644 --- a/flax/experimental/nnx/nnx/nn/dtypes.py +++ b/flax/experimental/nnx/nnx/nn/dtypes.py @@ -15,9 +15,9 @@ from typing import Optional from flax.typing import Dtype from jax import numpy as jnp -from typing_extensions import TypeVarTuple, Unpack +import typing as tp -T = TypeVarTuple('T') +T = tp.TypeVar('T', bound=tuple) def canonicalize_dtype( @@ -52,9 +52,7 @@ def canonicalize_dtype( return dtype -def promote_dtype( - *args: Unpack[T], dtype=None, inexact=True -) -> tuple[Unpack[T]]: +def promote_dtype(args: T, /, dtype=None, inexact=True) -> T: """ "Promotes input arguments to a specified or inferred dtype. All args are cast to the same dtype. See ``canonicalize_dtype`` for how diff --git a/flax/experimental/nnx/nnx/nn/linear.py b/flax/experimental/nnx/nnx/nn/linear.py index dffe9c742..f775e6009 100644 --- a/flax/experimental/nnx/nnx/nn/linear.py +++ b/flax/experimental/nnx/nnx/nn/linear.py @@ -269,7 +269,7 @@ def __call__(self, inputs: Array) -> Array: contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) inputs, kernel, bias = dtypes.promote_dtype( - inputs, kernel, bias, dtype=self.dtype + (inputs, kernel, bias), dtype=self.dtype ) if self.dot_general_cls is not None: @@ -353,7 +353,7 @@ def __call__(self, inputs: Array) -> Array: bias = self.bias.value inputs, kernel, bias = dtypes.promote_dtype( - inputs, kernel, bias, dtype=self.dtype + (inputs, kernel, bias), dtype=self.dtype ) y = self.dot_general( inputs, @@ -461,9 +461,11 @@ def __call__( self._einsum_str_check(einsum_str) inputs, kernel, bias = dtypes.promote_dtype( - inputs, - self.kernel.value, - self.bias.value if self.bias is not None else self.bias, + ( + inputs, + self.kernel.value, + self.bias.value if self.bias is not None else self.bias, + ), dtype=self.dtype, ) @@ -702,7 +704,7 @@ def maybe_broadcast( bias = self.bias.value inputs, kernel, bias = dtypes.promote_dtype( - inputs, kernel, bias, dtype=self.dtype + (inputs, kernel, bias), dtype=self.dtype ) y = self.conv_general_dilated( @@ -788,7 +790,7 @@ def __call__(self, inputs: Array) -> Array: # Use take because fancy indexing numpy arrays with JAX indices does not # work correctly. (embedding,) = dtypes.promote_dtype( - self.embedding.value, dtype=self.dtype, inexact=False + (self.embedding.value,), dtype=self.dtype, inexact=False ) if self.num_embeddings == 1: return jnp.where( @@ -813,6 +815,6 @@ def attend(self, query: Array) -> Array: in NLP models. """ query, embedding = dtypes.promote_dtype( - query, self.embedding.value, dtype=self.dtype + (query, self.embedding.value), dtype=self.dtype ) return jnp.dot(query, embedding.T) diff --git a/flax/experimental/nnx/nnx/proxy_caller.py b/flax/experimental/nnx/nnx/proxy_caller.py index e1c2bd3cf..7d8719486 100644 --- a/flax/experimental/nnx/nnx/proxy_caller.py +++ b/flax/experimental/nnx/nnx/proxy_caller.py @@ -31,7 +31,7 @@ import typing as tp -A = tp.TypeVar('A') +A = tp.TypeVar('A', covariant=True) # type: ignore[not-supported-yet] def _identity(x): diff --git a/flax/experimental/nnx/nnx/state.py b/flax/experimental/nnx/nnx/state.py index 2c9272dcf..207dbb8d6 100644 --- a/flax/experimental/nnx/nnx/state.py +++ b/flax/experimental/nnx/nnx/state.py @@ -152,7 +152,7 @@ def split( if rest: raise ValueError( 'Non-exhaustive filters, got a non-empty remainder: ' - f'{list(rest.keys())}.\nUse `...` to match all remaining elements.' + f'{rest}.\nUse `...` to match all remaining elements.' ) states: State | tuple[State, ...] @@ -160,7 +160,7 @@ def split( states = states_[0] else: states = tuple(states_) - return states + return states # type: ignore[bad-return-type] @tp.overload def filter( @@ -196,7 +196,7 @@ def filter( else: states = tuple(states_) - return states + return states # type: ignore[bad-return-type] @staticmethod def merge(state: 'State', /, *states: 'State') -> 'State': @@ -208,7 +208,7 @@ def merge(state: 'State', /, *states: 'State') -> 'State': new_state: FlatState = {} for state in states: - new_state.update(state.flat_state()) + new_state.update(state.flat_state()) # type: ignore[attribute-error] # pytype is wrong here return State.from_flat_path(new_state) @@ -272,10 +272,10 @@ def _split_state( for path, value in flat_state.items(): for i, predicate in enumerate(predicates): if predicate(path, value): - flat_states[i][path] = value + flat_states[i][path] = value # type: ignore[index] # mypy is wrong here? break else: # if we didn't break, set leaf to last state - flat_states[-1][path] = value + flat_states[-1][path] = value # type: ignore[index] # mypy is wrong here? return tuple(State.from_flat_path(flat_state) for flat_state in flat_states) diff --git a/flax/experimental/nnx/nnx/training/__init__.py b/flax/experimental/nnx/nnx/training/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/flax/experimental/nnx/nnx/transforms.py b/flax/experimental/nnx/nnx/transforms.py index 1bc5fe912..b1fd3b6dc 100644 --- a/flax/experimental/nnx/nnx/transforms.py +++ b/flax/experimental/nnx/nnx/transforms.py @@ -80,15 +80,15 @@ def _normalize_sequence( return tuple(x) -class LiftedModule(Module, tp.Generic[M]): +class LiftedModule(tp.Generic[M], Module): # type: ignore[ignored-abstractmethod] @abstractmethod def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> tp.Any: - ... + pass @property @abstractmethod def _submodule(self) -> M: - ... + pass # type: ignore[bad-return-type] # why pytype? def __call__(self, *args, **kwargs) -> tp.Any: return self.call(*args, **kwargs) # type: ignore @@ -287,7 +287,7 @@ def jit_apply( return out -class Jit(LiftedModule[M]): +class Jit(tp.Generic[M], LiftedModule[M]): @staticmethod def constructor( module_constructor: tp.Callable[..., MA], @@ -584,7 +584,7 @@ class GradOptions: wrt: filterlib.Filter -class Grad(LiftedModule[M]): +class Grad(tp.Generic[M], LiftedModule[M]): @staticmethod def constructor( module_constructor: tp.Callable[..., MA], @@ -871,7 +871,7 @@ class ScanOptions: scan_output: bool -class Scan(LiftedModule[M]): +class Scan(tp.Generic[M], LiftedModule[M]): @staticmethod def constructor( module_constructor: tp.Callable[..., MA], @@ -1095,10 +1095,10 @@ def scan_fn( 'Expected a tuple of length 2 as the output of the scan function, ' f'got {out}' ) - out = tp.cast(tuple[C, B], out) + out = tp.cast(tuple[C, B], out) # type: ignore[invalid-annotation] carry_arg_out, scan_args_out = out else: - out = tp.cast(C, out) + out = tp.cast(C, out) # type: ignore[invalid-annotation] carry_arg_out = out scan_args_out = None @@ -1356,7 +1356,7 @@ def scan( ) @functools.wraps(f) - def scan_apply_wrapper(*args, **kwargs) -> C | tuple[C, tp.Any]: + def scan_apply_wrapper(*args, **kwargs) -> tp.Any: return scan_apply(options, f, args, kwargs) return scan_apply_wrapper # type: ignore @@ -1385,7 +1385,7 @@ def __post_init__(self): ) -class Remat(LiftedModule[M]): +class Remat(tp.Generic[M], LiftedModule[M]): @staticmethod def constructor( module_constructor: tp.Callable[..., MA], @@ -1512,7 +1512,7 @@ class VmapOptions: transform_metadata: tp.Mapping[str, tp.Any] -class Vmap(LiftedModule[M]): +class Vmap(tp.Generic[M], LiftedModule[M]): @staticmethod def constructor( module_constructor: tp.Callable[..., MA], diff --git a/flax/experimental/nnx/nnx/visualization.py b/flax/experimental/nnx/nnx/visualization.py index 65c09f771..0f657363c 100644 --- a/flax/experimental/nnx/nnx/visualization.py +++ b/flax/experimental/nnx/nnx/visualization.py @@ -40,7 +40,7 @@ def display(*args): print(x) return - from penzai import pz # type: ignore[import-not-found] + from penzai import pz # type: ignore[import-not-found,import-untyped] with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()): for x in args: @@ -111,7 +111,7 @@ def _to_dataclass_fn(x): def _make_dataclass_obj( cls, fields: tp.Mapping[str, tp.Any], penzai_dataclass: bool = True ) -> tp.Type: - from penzai import pz + from penzai import pz # type: ignore[import-error] dataclass = pz.pytree_dataclass if penzai_dataclass else dataclasses.dataclass base = pz.Layer if penzai_dataclass else object diff --git a/flax/experimental/nnx/tests/test_optimizer.py b/flax/experimental/nnx/tests/test_optimizer.py index fb749cc23..d1de7cc55 100644 --- a/flax/experimental/nnx/tests/test_optimizer.py +++ b/flax/experimental/nnx/tests/test_optimizer.py @@ -97,7 +97,8 @@ class TrainState(nnx.Optimizer): def __init__(self, model, tx, metrics): self.metrics = metrics super().__init__(model, tx) - def update(self, *, grads, **updates): + + def update(self, *, grads, **updates): # type: ignore[signature-mismatch] self.metrics.update(**updates) super().update(grads) diff --git a/flax/experimental/nnx/tests/test_spmd.py b/flax/experimental/nnx/tests/test_spmd.py index 8abd5827f..83e2f1b10 100644 --- a/flax/experimental/nnx/tests/test_spmd.py +++ b/flax/experimental/nnx/tests/test_spmd.py @@ -44,7 +44,7 @@ def create_module(): mesh = Mesh(mesh_utils.create_device_mesh((2, 2)), ('model', 'data')) with mesh: - m: Foo = nnx.merge(*create_module()) + m: Foo = nnx.merge(*create_module()) # type: ignore[invalid-annotation] assert m.w.shape == (8, 2) assert m.w.sharding.shard_shape(m.w.shape) == (4, 1) @@ -69,7 +69,7 @@ def create_module(): mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), ('model', 'data')) with mesh: - m: Foo = nnx.merge(*create_module()) + m: Foo = nnx.merge(*create_module()) # type: ignore[invalid-annotation] assert m.w.value.shape == (8, 2) assert m.w.value.sharding.shard_shape(m.w.value.shape) == (8, 2) diff --git a/flax/experimental/nnx/tests/test_transforms.py b/flax/experimental/nnx/tests/test_transforms.py index 67b26e46b..d386c0dc9 100644 --- a/flax/experimental/nnx/tests/test_transforms.py +++ b/flax/experimental/nnx/tests/test_transforms.py @@ -260,7 +260,7 @@ def test_cached_unflatten_add_self_reference(self): class Foo(nnx.Module): def __init__(self): - self.ref: tp.Optional[Foo] = None + self.ref: tp.Optional[Foo] = None # type: ignore[name-error] @nnx.jit def f(m: Foo): @@ -290,7 +290,7 @@ def test_cached_unflatten_ref_in_output(self): class Foo(nnx.Module): def __init__(self): - self.ref: tp.Optional[Foo] = None + self.ref: tp.Optional[Foo] = None # type: ignore[name-error] @nnx.jit def f(m: Foo): diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh index 9bc9cb853..788c0a7ad 100755 --- a/tests/run_all_tests.sh +++ b/tests/run_all_tests.sh @@ -134,20 +134,31 @@ if $RUN_PYTEST; then fi pytest $egd done - fi if $RUN_PYTYPE; then echo "=== RUNNING PYTYPE ===" + # Validate types in NNX examples. + for egd in $(find flax/experimental/nnx/examples -maxdepth 1 -mindepth 1 -type d); do + # skip if folder starts with "_" or is "toy_examples" + if [[ $egd == *"_"* ]] || [[ $egd == *"toy_examples"* ]]; then + continue + fi + # use cd to make sure pytype cache lives in example dir and doesn't name clash + # use *.py to avoid importing configs as a top-level import which leads to import errors + # because config files use relative imports (e.g. from config import ...). + (cd $egd ; pytype "*.py" --jobs auto --config ../../../../../pyproject.toml) + done # Validate types in library code. - pytype --jobs auto --config pyproject.toml flax/ --exclude flax/experimental/nnx + pytype --jobs auto --config pyproject.toml flax/ \ + --exclude flax/experimental/nnx/examples # Validate types in examples. for egd in $(find examples -maxdepth 1 -mindepth 1 -type d); do # use cd to make sure pytype cache lives in example dir and doesn't name clash # use *.py to avoid importing configs as a top-level import which leads to import errors # because config files use relative imports (e.g. from config import ...). - (cd $egd ; pytype --jobs auto --exclude flax/experimental/nnx --config ../../pyproject.toml "*.py") + (cd $egd ; pytype "*.py" --jobs auto --config ../../pyproject.toml) done fi