Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] fix mypy and pytype #3894

Merged
merged 2 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ jobs:
test-type: [doctest, pytest, pytype, mypy]
exclude:
- test-type: pytype
python-version: '3.11'
python-version: '3.9'
- test-type: pytype
python-version: '3.10'
- test-type: mypy
Expand Down
30 changes: 29 additions & 1 deletion flax/core/frozen_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,17 @@

import collections
from types import MappingProxyType
from typing import Any, Dict, Hashable, Mapping, Tuple, TypeVar, Union
from typing import (
Any,
Dict,
Hashable,
Iterable,
Mapping,
Tuple,
TypeVar,
Union,
overload,
)

import jax

Expand Down Expand Up @@ -55,6 +65,24 @@ class FrozenDict(Mapping[K, V]):

__slots__ = ('_dict', '_hash')

@overload
def __init__(
self,
mapping: Mapping[K, V] = MappingProxyType({}),
/,
__unsafe_skip_copy__=False,
**kwargs,
): ...

@overload
def __init__(
self,
mapping: Iterable[tuple[K, V]] = (),
/,
__unsafe_skip_copy__=False,
**kwargs,
): ...

def __init__(self, *args, __unsafe_skip_copy__=False, **kwargs): # pylint: disable=invalid-name
# make sure the dict is as
xs = dict(*args, **kwargs)
Expand Down
202 changes: 101 additions & 101 deletions flax/experimental/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,104 +18,104 @@
from flax.linen.pooling import pool as pool
from flax.typing import Initializer as Initializer

from .nnx import compatibility as compatibility
from .nnx import graph as graph
from .nnx import errors as errors
from .nnx import errors as helpers
from .nnx.filterlib import All as All
from .nnx.filterlib import Not as Not
from .nnx.graph import GraphDef as GraphDef
from .nnx.graph import GraphNode as GraphNode
from .nnx.helpers import Dict as Dict
from .nnx.helpers import List as List
from .nnx.helpers import Sequential as Sequential
from .nnx.helpers import TrainState as TrainState
from .nnx.module import M as M
from .nnx.module import Module as Module
from .nnx.graph import merge as merge
from .nnx.graph import UpdateContext as UpdateContext
from .nnx.graph import split as split
from .nnx.graph import update as update
from .nnx.graph import clone as clone
from .nnx.graph import pop as pop
from .nnx.graph import state as state
from .nnx.graph import graphdef as graphdef
from .nnx.nn import initializers as initializers
from .nnx.nn.activations import celu as celu
from .nnx.nn.activations import elu as elu
from .nnx.nn.activations import gelu as gelu
from .nnx.nn.activations import glu as glu
from .nnx.nn.activations import hard_sigmoid as hard_sigmoid
from .nnx.nn.activations import hard_silu as hard_silu
from .nnx.nn.activations import hard_swish as hard_swish
from .nnx.nn.activations import hard_tanh as hard_tanh
from .nnx.nn.activations import leaky_relu as leaky_relu
from .nnx.nn.activations import log_sigmoid as log_sigmoid
from .nnx.nn.activations import log_softmax as log_softmax
from .nnx.nn.activations import logsumexp as logsumexp
from .nnx.nn.activations import one_hot as one_hot
from .nnx.nn.activations import relu as relu
from .nnx.nn.activations import relu6 as relu6
from .nnx.nn.activations import selu as selu
from .nnx.nn.activations import sigmoid as sigmoid
from .nnx.nn.activations import silu as silu
from .nnx.nn.activations import soft_sign as soft_sign
from .nnx.nn.activations import softmax as softmax
from .nnx.nn.activations import softplus as softplus
from .nnx.nn.activations import standardize as standardize
from .nnx.nn.activations import swish as swish
from .nnx.nn.activations import tanh as tanh
from .nnx.nn.attention import MultiHeadAttention as MultiHeadAttention
from .nnx.nn.attention import combine_masks as combine_masks
from .nnx.nn.attention import dot_product_attention as dot_product_attention
from .nnx.nn.attention import make_attention_mask as make_attention_mask
from .nnx.nn.attention import make_causal_mask as make_causal_mask
from .nnx.nn.linear import Conv as Conv
from .nnx.nn.linear import Embed as Embed
from .nnx.nn.linear import Linear as Linear
from .nnx.nn.linear import LinearGeneral as LinearGeneral
from .nnx.nn.linear import Einsum as Einsum
from .nnx.nn.normalization import BatchNorm as BatchNorm
from .nnx.nn.normalization import LayerNorm as LayerNorm
from .nnx.nn.normalization import RMSNorm as RMSNorm
from .nnx.nn.stochastic import Dropout as Dropout
from .nnx.rnglib import Rngs as Rngs
from .nnx.rnglib import RngStream as RngStream
from .nnx.rnglib import RngState as RngState
from .nnx.rnglib import RngKey as RngKey
from .nnx.rnglib import RngCount as RngCount
from .nnx.rnglib import fork as fork
from .nnx.spmd import PARTITION_NAME as PARTITION_NAME
from .nnx.spmd import get_partition_spec as get_partition_spec
from .nnx.spmd import get_named_sharding as get_named_sharding
from .nnx.spmd import with_partitioning as with_partitioning
from .nnx.spmd import with_sharding_constraint as with_sharding_constraint
from .nnx.state import State as State
from .nnx.training import metrics as metrics
from .nnx.training import optimizer as optimizer
from .nnx.training.metrics import Metric as Metric
from .nnx.training.metrics import MultiMetric as MultiMetric
from .nnx.training.optimizer import Optimizer as Optimizer
from .nnx.transforms import Jit as Jit
from .nnx.transforms import Remat as Remat
from .nnx.transforms import Scan as Scan
from .nnx.transforms import Vmap as Vmap
from .nnx.transforms import grad as grad
from .nnx.transforms import jit as jit
from .nnx.transforms import remat as remat
from .nnx.transforms import scan as scan
from .nnx.transforms import value_and_grad as value_and_grad
from .nnx.transforms import vmap as vmap
from .nnx.transforms import eval_shape as eval_shape
from .nnx.variables import EMPTY as EMPTY
from .nnx.variables import A as A
from .nnx.variables import BatchStat as BatchStat
from .nnx.variables import Cache as Cache
from .nnx.variables import Empty as Empty
from .nnx.variables import Intermediate as Intermediate
from .nnx.variables import Param as Param
from .nnx.variables import Variable as Variable
from .nnx.variables import VariableState as VariableState
from .nnx.variables import VariableMetadata as VariableMetadata
from .nnx.variables import with_metadata as with_metadata
from .nnx.visualization import display as display
from flax.experimental.nnx.nnx import compatibility as compatibility
from flax.experimental.nnx.nnx import graph as graph
from flax.experimental.nnx.nnx import errors as errors
from flax.experimental.nnx.nnx import errors as helpers
from flax.experimental.nnx.nnx.filterlib import All as All
from flax.experimental.nnx.nnx.filterlib import Not as Not
from flax.experimental.nnx.nnx.graph import GraphDef as GraphDef
from flax.experimental.nnx.nnx.graph import GraphNode as GraphNode
from flax.experimental.nnx.nnx.helpers import Dict as Dict
from flax.experimental.nnx.nnx.helpers import List as List
from flax.experimental.nnx.nnx.helpers import Sequential as Sequential
from flax.experimental.nnx.nnx.helpers import TrainState as TrainState
from flax.experimental.nnx.nnx.module import M as M
from flax.experimental.nnx.nnx.module import Module as Module
from flax.experimental.nnx.nnx.graph import merge as merge
from flax.experimental.nnx.nnx.graph import UpdateContext as UpdateContext
from flax.experimental.nnx.nnx.graph import split as split
from flax.experimental.nnx.nnx.graph import update as update
from flax.experimental.nnx.nnx.graph import clone as clone
from flax.experimental.nnx.nnx.graph import pop as pop
from flax.experimental.nnx.nnx.graph import state as state
from flax.experimental.nnx.nnx.graph import graphdef as graphdef
from flax.experimental.nnx.nnx.nn import initializers as initializers
from flax.experimental.nnx.nnx.nn.activations import celu as celu
from flax.experimental.nnx.nnx.nn.activations import elu as elu
from flax.experimental.nnx.nnx.nn.activations import gelu as gelu
from flax.experimental.nnx.nnx.nn.activations import glu as glu
from flax.experimental.nnx.nnx.nn.activations import hard_sigmoid as hard_sigmoid
from flax.experimental.nnx.nnx.nn.activations import hard_silu as hard_silu
from flax.experimental.nnx.nnx.nn.activations import hard_swish as hard_swish
from flax.experimental.nnx.nnx.nn.activations import hard_tanh as hard_tanh
from flax.experimental.nnx.nnx.nn.activations import leaky_relu as leaky_relu
from flax.experimental.nnx.nnx.nn.activations import log_sigmoid as log_sigmoid
from flax.experimental.nnx.nnx.nn.activations import log_softmax as log_softmax
from flax.experimental.nnx.nnx.nn.activations import logsumexp as logsumexp
from flax.experimental.nnx.nnx.nn.activations import one_hot as one_hot
from flax.experimental.nnx.nnx.nn.activations import relu as relu
from flax.experimental.nnx.nnx.nn.activations import relu6 as relu6
from flax.experimental.nnx.nnx.nn.activations import selu as selu
from flax.experimental.nnx.nnx.nn.activations import sigmoid as sigmoid
from flax.experimental.nnx.nnx.nn.activations import silu as silu
from flax.experimental.nnx.nnx.nn.activations import soft_sign as soft_sign
from flax.experimental.nnx.nnx.nn.activations import softmax as softmax
from flax.experimental.nnx.nnx.nn.activations import softplus as softplus
from flax.experimental.nnx.nnx.nn.activations import standardize as standardize
from flax.experimental.nnx.nnx.nn.activations import swish as swish
from flax.experimental.nnx.nnx.nn.activations import tanh as tanh
from flax.experimental.nnx.nnx.nn.attention import MultiHeadAttention as MultiHeadAttention
from flax.experimental.nnx.nnx.nn.attention import combine_masks as combine_masks
from flax.experimental.nnx.nnx.nn.attention import dot_product_attention as dot_product_attention
from flax.experimental.nnx.nnx.nn.attention import make_attention_mask as make_attention_mask
from flax.experimental.nnx.nnx.nn.attention import make_causal_mask as make_causal_mask
from flax.experimental.nnx.nnx.nn.linear import Conv as Conv
from flax.experimental.nnx.nnx.nn.linear import Embed as Embed
from flax.experimental.nnx.nnx.nn.linear import Linear as Linear
from flax.experimental.nnx.nnx.nn.linear import LinearGeneral as LinearGeneral
from flax.experimental.nnx.nnx.nn.linear import Einsum as Einsum
from flax.experimental.nnx.nnx.nn.normalization import BatchNorm as BatchNorm
from flax.experimental.nnx.nnx.nn.normalization import LayerNorm as LayerNorm
from flax.experimental.nnx.nnx.nn.normalization import RMSNorm as RMSNorm
from flax.experimental.nnx.nnx.nn.stochastic import Dropout as Dropout
from flax.experimental.nnx.nnx.rnglib import Rngs as Rngs
from flax.experimental.nnx.nnx.rnglib import RngStream as RngStream
from flax.experimental.nnx.nnx.rnglib import RngState as RngState
from flax.experimental.nnx.nnx.rnglib import RngKey as RngKey
from flax.experimental.nnx.nnx.rnglib import RngCount as RngCount
from flax.experimental.nnx.nnx.rnglib import fork as fork
from flax.experimental.nnx.nnx.spmd import PARTITION_NAME as PARTITION_NAME
from flax.experimental.nnx.nnx.spmd import get_partition_spec as get_partition_spec
from flax.experimental.nnx.nnx.spmd import get_named_sharding as get_named_sharding
from flax.experimental.nnx.nnx.spmd import with_partitioning as with_partitioning
from flax.experimental.nnx.nnx.spmd import with_sharding_constraint as with_sharding_constraint
from flax.experimental.nnx.nnx.state import State as State
from flax.experimental.nnx.nnx.training import metrics as metrics
from flax.experimental.nnx.nnx.training import optimizer as optimizer
from flax.experimental.nnx.nnx.training.metrics import Metric as Metric
from flax.experimental.nnx.nnx.training.metrics import MultiMetric as MultiMetric
from flax.experimental.nnx.nnx.training.optimizer import Optimizer as Optimizer
from flax.experimental.nnx.nnx.transforms import Jit as Jit
from flax.experimental.nnx.nnx.transforms import jit as jit
from flax.experimental.nnx.nnx.transforms import Remat as Remat
from flax.experimental.nnx.nnx.transforms import Scan as Scan
from flax.experimental.nnx.nnx.transforms import Vmap as Vmap
from flax.experimental.nnx.nnx.transforms import grad as grad
from flax.experimental.nnx.nnx.transforms import remat as remat
from flax.experimental.nnx.nnx.transforms import scan as scan
from flax.experimental.nnx.nnx.transforms import value_and_grad as value_and_grad
from flax.experimental.nnx.nnx.transforms import vmap as vmap
from flax.experimental.nnx.nnx.transforms import eval_shape as eval_shape
from flax.experimental.nnx.nnx.variables import EMPTY as EMPTY
from flax.experimental.nnx.nnx.variables import A as A
from flax.experimental.nnx.nnx.variables import BatchStat as BatchStat
from flax.experimental.nnx.nnx.variables import Cache as Cache
from flax.experimental.nnx.nnx.variables import Empty as Empty
from flax.experimental.nnx.nnx.variables import Intermediate as Intermediate
from flax.experimental.nnx.nnx.variables import Param as Param
from flax.experimental.nnx.nnx.variables import Variable as Variable
from flax.experimental.nnx.nnx.variables import VariableState as VariableState
from flax.experimental.nnx.nnx.variables import VariableMetadata as VariableMetadata
from flax.experimental.nnx.nnx.variables import with_metadata as with_metadata
from flax.experimental.nnx.nnx.visualization import display as display
2 changes: 1 addition & 1 deletion flax/experimental/nnx/docs/why.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@
],
"source": [
"# class transform:\n",
"ScannedLinear = nnx.Scan(nnx.Linear, variable_axes={nnx.Param: 0}, length=4)\n",
"ScannedLinear = nnx.Scan.constructor(nnx.Linear, variable_axes={nnx.Param: 0}, length=4)\n",
"\n",
"scanned = ScannedLinear(2, 2, rngs=nnx.Rngs(0))\n",
"scanned.get_state()"
Expand Down
2 changes: 1 addition & 1 deletion flax/experimental/nnx/docs/why.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ Like linen, for convenience we still provide simple lifted transforms for standa
:outputId: c4800a49-efd1-4ee5-e703-6e63e18da4cb

# class transform:
ScannedLinear = nnx.Scan(nnx.Linear, variable_axes={nnx.Param: 0}, length=4)
ScannedLinear = nnx.Scan.constructor(nnx.Linear, variable_axes={nnx.Param: 0}, length=4)

scanned = ScannedLinear(2, 2, rngs=nnx.Rngs(0))
scanned.get_state()
Expand Down
2 changes: 1 addition & 1 deletion flax/experimental/nnx/examples/lm1b/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from jax import lax

from flax.experimental import nnx
from flax.experimental.nnx.examples.lm1b.configs import default
from configs import default

Shape = tuple[int, ...]
Dtype = Any
Expand Down
17 changes: 10 additions & 7 deletions flax/experimental/nnx/examples/lm1b/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,16 @@

from flax import traverse_util
from flax.experimental import nnx
from flax.experimental.nnx.examples.lm1b.configs import default
from flax.experimental.nnx.examples.lm1b.models import (
TransformerConfig,
TransformerLM,
)
from flax.experimental.nnx.examples.lm1b.utils import HasCache
from configs import default
from models import TransformerConfig, TransformerLM
from utils import HasCache

jax.config.update('jax_disable_most_optimizations', True)

# add project_root to import lm1b Linen model
project_root = str(Path(__file__).absolute().parents[5])
sys.path.append(project_root)
from examples.lm1b.models import TransformerLM as TransformerLinen
from examples.lm1b.models import TransformerLM as TransformerLinen # type: ignore[import-error]

sys.path.pop()

Expand Down Expand Up @@ -208,6 +205,9 @@ def test_forward_eval(self):
deterministic=True,
decode=False,
)
# Set dropout rates to avoid create dropout states
config.dropout_rate = 0.0
config.attention_dropout_rate = 0.0

model_nnx = nnx.eval_shape(lambda: TransformerLM(config, rngs=nnx.Rngs(0)))
_, params_nnx = nnx.split(model_nnx, nnx.Param)
Expand Down Expand Up @@ -242,6 +242,9 @@ def test_forward_decode(self):
deterministic=True,
decode=True,
)
# Set dropout rates to avoid create dropout states
config.dropout_rate = 0.0
config.attention_dropout_rate = 0.0

model_nnx = nnx.eval_shape(lambda: TransformerLM(config, rngs=nnx.Rngs(0)))
for _path, m in model_nnx.iter_modules():
Expand Down
4 changes: 4 additions & 0 deletions flax/experimental/nnx/examples/lm1b/train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def test_train_and_evaluate(self):
config.max_eval_target_length = 32
config.max_predict_length = 32

# Set dropout rates to avoid create dropout states
config.dropout_rate = 0.0
config.attention_dropout_rate = 0.0

workdir = tempfile.mkdtemp()

# Go two directories up to the root of the flax directory.
Expand Down
2 changes: 1 addition & 1 deletion flax/experimental/nnx/examples/lm1b/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import jax.numpy as jnp
import numpy as np
from jax.experimental import mesh_utils
from flax.experimental.nnx.examples.lm1b.configs import default
from configs import default
from models import TransformerConfig, TransformerLM

from flax.experimental import nnx
Expand Down
5 changes: 4 additions & 1 deletion flax/experimental/nnx/nnx/filterlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ def to_predicate(filter: Filter) -> Predicate:
elif isinstance(filter, type):
return OfType(filter)
elif isinstance(filter, bool):
return Everything() if filter else Nothing()
if filter:
return Everything()
else:
return Nothing()
elif filter is Ellipsis:
return Everything()
elif filter is None:
Expand Down
Loading
Loading