Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgrade python to 3.10 + use pyupgrade #4038

Merged
merged 1 commit into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/analytics/get_repo_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import json
import os
from datetime import datetime
from typing import Callable, List
from collections.abc import Callable

import matplotlib.dates as mdates
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -279,7 +279,7 @@ def _rolling_window(
last_month = _start_of_month(df.iloc[-1]['created_at'])
last_month = _shift_n_months(last_month, 1)

rows: List[pd.Series] = []
rows: list[pd.Series] = []
while end < last_month:
row = f(df[(df['created_at'] >= start) & (df['created_at'] < end)])
row['period_start'] = start
Expand Down
10 changes: 4 additions & 6 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: 3.9
python-version: '3.10'
- uses: pre-commit/action@v2.0.3
commit-count:
name: Check commit count
Expand Down Expand Up @@ -63,7 +63,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']
python-version: ['3.10', '3.11']
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -82,18 +82,16 @@ jobs:
runs-on: ubuntu-20.04-16core
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']
python-version: ['3.10', '3.11']
test-type: [doctest, pytest, pytype, mypy]
jax-version: [newest]
exclude:
- test-type: pytype
python-version: '3.9'
- test-type: pytype
python-version: '3.10'
- test-type: mypy
python-version: '3.11'
include:
- python-version: '3.9'
- python-version: '3.10'
test-type: pytest
jax-version: '0.4.27' # keep in sync with jax pin in pyproject.toml
steps:
Expand Down
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,8 @@ repos:
# Disable Ruff formatter for now
# # Run the Ruff formatter.
# - id: ruff-format
- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
hooks:
- id: pyupgrade
args: [--py310-plus]
11 changes: 5 additions & 6 deletions docs/_ext/codediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
In order to highlight a line of code, append "#!" to it.
"""

from typing import List, Optional, Tuple

import sphinx
from docutils import nodes
Expand All @@ -40,10 +39,10 @@
class CodeDiffParser:
def parse(
self,
lines: List[str],
lines: list[str],
title: str,
groups: Optional[List[str]] = None,
skip_test: Optional[str] = None,
groups: list[str] | None = None,
skip_test: str | None = None,
code_sep: str = '---',
sync: object = MISSING,
):
Expand Down Expand Up @@ -104,7 +103,7 @@ def parse(
sync = sync is not MISSING
# skip legacy code snippets in upgrade guides
if skip_test is not None:
skip_tests = set([index.strip() for index in skip_test.split(',')])
skip_tests = {index.strip() for index in skip_test.split(',')}
else:
skip_tests = set()

Expand Down Expand Up @@ -154,7 +153,7 @@ def _code_block(self, lines):
# Indent code and add empty line so the code is picked up by the directive.
return directive + [''] + list(map(lambda x: ' ' + x, code))

def _tabs(self, *contents: Tuple[str, List[str]], sync):
def _tabs(self, *contents: tuple[str, list[str]], sync):
output = ['.. tab-set::'] + [' ']

for title, content in contents:
Expand Down
28 changes: 14 additions & 14 deletions docs/conf_sphinx_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#
# We should consider sending a PR to sphinx so we can get rid of this.
# Original source: https://github.com/sphinx-doc/sphinx/blob/0aedcc9a916daa92d477226da67d33ce1831822e/sphinx/ext/autosummary/generate.py#L211-L351
from typing import Any, Dict, List, Set, Tuple
from typing import Any

import sphinx.ext.autodoc
import sphinx.ext.autosummary.generate as ag
Expand All @@ -38,7 +38,7 @@ def generate_autosummary_content(
imported_members: bool,
app: Any,
recursive: bool,
context: Dict,
context: dict,
modname: str = None,
qualname: str = None,
) -> str:
Expand All @@ -61,13 +61,13 @@ def skip_member(obj: Any, name: str, objtype: str) -> bool:
)
return False

def get_class_members(obj: Any) -> Dict[str, Any]:
def get_class_members(obj: Any) -> dict[str, Any]:
members = sphinx.ext.autodoc.get_class_members(
obj, [qualname], ag.safe_getattr
)
return {name: member.object for name, member in members.items()}

def get_module_members(obj: Any) -> Dict[str, Any]:
def get_module_members(obj: Any) -> dict[str, Any]:
members = {}
for name in ag.members_of(obj, app.config):
try:
Expand All @@ -76,7 +76,7 @@ def get_module_members(obj: Any) -> Dict[str, Any]:
continue
return members

def get_all_members(obj: Any) -> Dict[str, Any]:
def get_all_members(obj: Any) -> dict[str, Any]:
if doc.objtype == 'module':
return get_module_members(obj)
elif doc.objtype == 'class':
Expand All @@ -85,12 +85,12 @@ def get_all_members(obj: Any) -> Dict[str, Any]:

def get_members(
obj: Any,
types: Set[str],
include_public: List[str] = [],
types: set[str],
include_public: list[str] = [],
imported: bool = True,
) -> Tuple[List[str], List[str]]:
items: List[str] = []
public: List[str] = []
) -> tuple[list[str], list[str]]:
items: list[str] = []
public: list[str] = []

all_members = get_all_members(obj)
for name, value in all_members.items():
Expand All @@ -112,7 +112,7 @@ def get_members(
public.append(name)
return public, items

def get_module_attrs(members: Any) -> Tuple[List[str], List[str]]:
def get_module_attrs(members: Any) -> tuple[list[str], list[str]]:
"""Find module attributes with docstrings."""
attrs, public = [], []
try:
Expand All @@ -127,8 +127,8 @@ def get_module_attrs(members: Any) -> Tuple[List[str], List[str]]:
pass # give up if ModuleAnalyzer fails to parse code
return public, attrs

def get_modules(obj: Any) -> Tuple[List[str], List[str]]:
items: List[str] = []
def get_modules(obj: Any) -> tuple[list[str], list[str]]:
items: list[str] = []
for _, modname, _ispkg in ag.pkgutil.iter_modules(obj.__path__):
fullname = name + '.' + modname
try:
Expand All @@ -142,7 +142,7 @@ def get_modules(obj: Any) -> Tuple[List[str], List[str]]:
public = [x for x in items if not x.split('.')[-1].startswith('_')]
return public, items

ns: Dict[str, Any] = {}
ns: dict[str, Any] = {}
ns.update(context)

if doc.objtype == 'module':
Expand Down
2 changes: 1 addition & 1 deletion examples/cloud/launch_gce.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import re
import subprocess
import time
from typing import Sequence
from collections.abc import Sequence

from absl import app
from absl import flags
Expand Down
7 changes: 4 additions & 3 deletions examples/imagenet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
# pytype: disable=wrong-arg-count

from functools import partial
from typing import Any, Callable, Sequence, Tuple
from typing import Any, Tuple
from collections.abc import Callable, Sequence

from flax import linen as nn
import jax.numpy as jnp
Expand All @@ -33,7 +34,7 @@ class ResNetBlock(nn.Module):
conv: ModuleDef
norm: ModuleDef
act: Callable
strides: Tuple[int, int] = (1, 1)
strides: tuple[int, int] = (1, 1)

@nn.compact
def __call__(
Expand Down Expand Up @@ -63,7 +64,7 @@ class BottleneckResNetBlock(nn.Module):
conv: ModuleDef
norm: ModuleDef
act: Callable
strides: Tuple[int, int] = (1, 1)
strides: tuple[int, int] = (1, 1)

@nn.compact
def __call__(self, x):
Expand Down
11 changes: 6 additions & 5 deletions examples/linen_design_test/attention_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import functools
from pprint import pprint
from typing import Any, Callable, Optional, Sequence
from typing import Any, Optional
from collections.abc import Callable, Sequence
from flax.core.frozen_dict import unfreeze
from flax.linen import initializers
from flax.linen import Module, compact, vmap
Expand Down Expand Up @@ -112,8 +113,8 @@ def __call__(self, query, key, value, bias=None, dtype=jnp.float32):


class DotProductAttention(Module):
qkv_features: Optional[int] = None
out_features: Optional[int] = None
qkv_features: int | None = None
out_features: int | None = None
attn_module: Callable = SoftmaxAttn

@compact
Expand Down Expand Up @@ -154,8 +155,8 @@ def concise_vmap(module, in_axes, out_axes, axis_size=None, **var_specs):


class MultiHeadDotProductAttention(Module):
qkv_features: Optional[int] = None
out_features: Optional[int] = None
qkv_features: int | None = None
out_features: int | None = None
attn_module: Callable = SoftmaxAttn
batch_axes: Sequence[int] = (0,)
num_heads: int = 1
Expand Down
5 changes: 3 additions & 2 deletions examples/linen_design_test/autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, Tuple
from typing import Tuple
from collections.abc import Iterable

import jax
from jax import numpy as jnp, random
Expand All @@ -37,7 +38,7 @@ def __call__(self, x):
class AutoEncoder(Module):
encoder_widths: Iterable
decoder_widths: Iterable
input_shape: Tuple = None
input_shape: tuple = None

def setup(self):
# Submodules attached in `setup` get names via attribute assignment
Expand Down
2 changes: 1 addition & 1 deletion examples/linen_design_test/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from jax import lax
from flax.linen import initializers
from typing import Callable
from collections.abc import Callable
from flax.linen import Module, compact


Expand Down
2 changes: 1 addition & 1 deletion examples/linen_design_test/mlp_explicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# Add `in_features` to the built-in Dense layer that normally works
# via shape inference.
class DenseExplicit(Dense):
in_features: Optional[int] = None
in_features: int | None = None

def setup(self):
# We feed a fake batch through the module, which initialized parameters.
Expand Down
2 changes: 1 addition & 1 deletion examples/linen_design_test/mlp_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import jax
from jax import numpy as jnp
from flax import linen as nn
from typing import Iterable
from collections.abc import Iterable
from flax.linen import Module, compact
from dense import Dense

Expand Down
12 changes: 6 additions & 6 deletions examples/lm1b/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import tokenizer

AUTOTUNE = tf.data.experimental.AUTOTUNE
Features = Dict[str, tf.Tensor]
Features = dict[str, tf.Tensor]


class NormalizeFeatureNamesOp:
Expand Down Expand Up @@ -68,8 +68,8 @@ def get_raw_dataset(

def pack_dataset(
dataset: tf.data.Dataset,
key2length: Union[int, Dict[str, int]],
keys: Optional[List[str]] = None,
key2length: int | dict[str, int],
keys: list[str] | None = None,
) -> tf.data.Dataset:
"""Creates a 'packed' version of a dataset on-the-fly.

Expand Down Expand Up @@ -150,7 +150,7 @@ def my_fn(x):


def _pack_with_tf_ops(
dataset: tf.data.Dataset, keys: List[str], key2length: Dict[str, int]
dataset: tf.data.Dataset, keys: list[str], key2length: dict[str, int]
) -> tf.data.Dataset:
"""Helper-function for packing a dataset which has already been batched.

Expand Down Expand Up @@ -276,7 +276,7 @@ def true_fn():
def preprocess_data(
dataset,
shuffle: bool,
num_epochs: Optional[int] = 1,
num_epochs: int | None = 1,
pack_examples: bool = True,
shuffle_buffer_size: int = 1024,
max_length: int = 512,
Expand Down Expand Up @@ -322,7 +322,7 @@ def get_datasets(
config: ml_collections.ConfigDict,
*,
n_devices: int,
vocab_path: Optional[str] = None,
vocab_path: str | None = None,
):
"""Load and return dataset of batched examples for use during training."""
if vocab_path is None:
Expand Down
7 changes: 4 additions & 3 deletions examples/lm1b/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
# pytype: disable=wrong-keyword-args
# pytype: disable=attribute-error

from typing import Callable, Any, Optional
from typing import Any, Optional
from collections.abc import Callable

from flax import linen as nn
from flax import struct
Expand Down Expand Up @@ -53,7 +54,7 @@ class TransformerConfig:
decode: bool = False
kernel_init: Callable = nn.initializers.xavier_uniform()
bias_init: Callable = nn.initializers.normal(stddev=1e-6)
posemb_init: Optional[Callable] = None
posemb_init: Callable | None = None


def shift_right(x, axis=1):
Expand Down Expand Up @@ -176,7 +177,7 @@ class MlpBlock(nn.Module):
"""

config: TransformerConfig
out_dim: Optional[int] = None
out_dim: int | None = None

@nn.compact
def __call__(self, inputs):
Expand Down
Loading
Loading