Skip to content

Commit f14fb11

Browse files
authored
[Lint] Enable pyupgrade linter in ruff (#963)
* update rules * ruff check * other fixes * fmt * do not touch examples * fmt
1 parent 4f3523d commit f14fb11

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

99 files changed

+836
-829
lines changed

docs/conf.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
# -*- coding: utf-8 -*-
2-
31
# General information about the project.
42
project = "Tile Language <br>"
53
author = "Tile Lang Contributors"
6-
copyright = "2025-2025, %s" % author
4+
copyright = f"2025-2025, {author}"
75

86
# Version information.
9-
with open("../VERSION", "r") as f:
7+
with open("../VERSION") as f:
108
version = f.read().strip()
119
release = version
1210

pyproject.toml

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,25 @@ target-version = "py38"
8787
line-length = 100
8888
output-format = "full"
8989

90+
exclude = [
91+
"3rdparty",
92+
"examples/deepseek_v32/inference",
93+
]
94+
95+
[tool.ruff.lint.per-file-ignores]
96+
# Do not upgrade type hint in testing and examples.
97+
# See https://github.com/tile-ai/tilelang/issues/1079 for more information.
98+
"testing/**.py" = ["UP", "FA"]
99+
"examples/**.py" = ["UP", "FA"]
100+
90101
[tool.ruff.lint]
91102
select = [
92103
# pycodestyle
93104
"E", "W",
94105
# Pyflakes
95106
"F",
96107
# pyupgrade
97-
# "UP",
108+
"UP", "FA",
98109
# flake8-bugbear
99110
"B",
100111
# flake8-simplify
@@ -115,16 +126,15 @@ ignore = [
115126
"SIM108",
116127
# key in dict.keys()
117128
"SIM118",
129+
# open file w.o. ctx manager
130+
"SIM115",
118131
# memory leaks
119132
"B019",
120133
# zip without explicit strict
121134
"B905",
122135
# No such file or directory
123136
"E902",
124137
]
125-
[tool.ruff.lint.per-file-ignores]
126-
"3rdparty/**/*" = ["ALL"]
127-
"examples/deepseek_v32/inference/**/*" = ["ALL"]
128138

129139
[tool.pytest.ini_options]
130140
verbosity_assertions = 3

tilelang/autotuner/capture.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from __future__ import annotations
12
import threading
2-
from typing import List, Any, Optional
3+
from typing import Any
34

45
# Use thread local to store the stack
56
# This is to avoid the cross-thread interference
@@ -87,7 +88,7 @@ class AutotuneInputsCapture:
8788

8889
__slots__ = ("tensors")
8990

90-
def __init__(self, tensors: List[Any]):
91+
def __init__(self, tensors: list[Any]):
9192
self.tensors = tensors
9293

9394
def __enter__(self) -> None:
@@ -118,7 +119,7 @@ def set_autotune_inputs(*args) -> AutotuneInputsCapture:
118119
return AutotuneInputsCapture(tensors)
119120

120121

121-
def get_autotune_inputs() -> Optional[List[Any]]:
122+
def get_autotune_inputs() -> list[Any] | None:
122123
"""
123124
Get the current autotune inputs from the stack.
124125
"""

tilelang/autotuner/param.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""The auto-tune parameters.
22
"""
3+
from __future__ import annotations
34

45
import tilelang
56
from tilelang import tvm as tvm
67
from tvm.tir import PrimFunc
78
from tvm.target import Target
8-
from typing import Callable, List, Literal, Any, Optional, Union, Dict
9+
from typing import Callable, Literal, Any
910
from dataclasses import dataclass
1011
from pathlib import Path
1112

@@ -40,12 +41,12 @@ class CompileArgs:
4041
Refer to `tilelang.PassConfigKey` for supported options.
4142
"""
4243

43-
out_idx: Optional[Union[List[int], int]] = None
44+
out_idx: list[int] | int | None = None
4445
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython"
4546
target: Literal['auto', 'cuda', 'hip'] = 'auto'
46-
target_host: Union[str, Target] = None
47+
target_host: str | Target = None
4748
verbose: bool = False
48-
pass_configs: Optional[Dict[str, Any]] = None
49+
pass_configs: dict[str, Any] | None = None
4950

5051
def compile_program(self, program: PrimFunc):
5152
return tilelang.compile(
@@ -135,12 +136,12 @@ class AutotuneResult:
135136
func: Optimized function.
136137
kernel: Compiled kernel function.
137138
"""
138-
latency: Optional[float] = None
139-
config: Optional[dict] = None
140-
ref_latency: Optional[float] = None
141-
libcode: Optional[str] = None
142-
func: Optional[Callable] = None
143-
kernel: Optional[Callable] = None
139+
latency: float | None = None
140+
config: dict | None = None
141+
ref_latency: float | None = None
142+
libcode: str | None = None
143+
func: Callable | None = None
144+
kernel: Callable | None = None
144145

145146
def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: bool = False):
146147
"""
@@ -204,9 +205,9 @@ def _save_kernel_to_disk(self, cache_path: Path, kernel: JITKernel, verbose: boo
204205
def _load_kernel_from_disk(
205206
self,
206207
cache_path: Path,
207-
target: Union[str, Target] = "auto",
208-
target_host: Union[str, Target] = None,
209-
out_idx: Optional[Union[List[int], int]] = None,
208+
target: str | Target = "auto",
209+
target_host: str | Target = None,
210+
out_idx: list[int] | int | None = None,
210211
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
211212
pass_configs: dict = None,
212213
func: Callable = None,
@@ -232,14 +233,14 @@ def _load_kernel_from_disk(
232233
if not os.path.exists(cache_path):
233234
return None
234235

235-
kernel_global_source: Optional[str] = None
236-
kernel_params: Optional[List[KernelParam]] = None
236+
kernel_global_source: str | None = None
237+
kernel_params: list[KernelParam] | None = None
237238

238239
try:
239240
wrapped_kernel_path = os.path.join(cache_path, WRAPPED_KERNEL_PATH)
240241
if verbose:
241242
logger.debug(f"Loading wrapped kernel source code from file: {wrapped_kernel_path}")
242-
with open(wrapped_kernel_path, "r") as f:
243+
with open(wrapped_kernel_path) as f:
243244
kernel_global_source = f.read()
244245
except Exception as e:
245246
logger.error(f"Error loading wrapped kernel source code from disk: {e}")
@@ -300,15 +301,15 @@ def save_to_disk(self, path: Path, verbose: bool = False):
300301
self._save_kernel_to_disk(path, self.kernel)
301302

302303
@classmethod
303-
def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> 'AutotuneResult':
304+
def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> AutotuneResult:
304305
if not os.path.exists(path):
305306
return None
306307

307308
verbose = compile_args.verbose
308309
# load best config
309310
if verbose:
310311
logger.debug(f"Loading best config from file: {path / BEST_CONFIG_PATH}")
311-
with open(path / BEST_CONFIG_PATH, "r") as f:
312+
with open(path / BEST_CONFIG_PATH) as f:
312313
config = json.load(f)
313314

314315
# load function
@@ -320,7 +321,7 @@ def load_from_disk(cls, path: Path, compile_args: CompileArgs) -> 'AutotuneResul
320321
# load latency
321322
if verbose:
322323
logger.debug(f"Loading latency from file: {path / LATENCY_PATH}")
323-
with open(path / LATENCY_PATH, "r") as f:
324+
with open(path / LATENCY_PATH) as f:
324325
latency = json.load(f)
325326
latency, ref_latency = latency["latency"], latency["ref_latency"]
326327

tilelang/autotuner/tuner.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
This module provides functionality for auto-tuning tilelang programs, including JIT compilation
44
and performance optimization through configuration search.
55
"""
6+
from __future__ import annotations
67

78
import tilelang
89
from tilelang import tvm as tvm
910
from tvm.tir import PrimFunc, Var
1011
from tvm.target import Target
1112
import inspect
1213
from functools import partial
13-
from typing import (Callable, List, Literal, Any, Optional, Union, Dict, overload, Tuple)
14+
from typing import (Callable, Literal, Any, overload)
1415
from tqdm import tqdm
1516
import logging
1617
import functools
@@ -103,8 +104,8 @@ class AutoTuner:
103104
compile_args = CompileArgs()
104105
profile_args = ProfileArgs()
105106

106-
_kernel_parameters: Optional[Tuple[str, ...]] = None
107-
_function_parameters: Optional[Dict[str, Any]] = None
107+
_kernel_parameters: tuple[str, ...] | None = None
108+
_function_parameters: dict[str, Any] | None = None
108109
_lock = threading.Lock() # For thread safety
109110
_memory_cache = {} # In-memory cache dictionary
110111
cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner"
@@ -131,12 +132,12 @@ def from_kernel(cls, kernel: Callable, configs):
131132
return cls(kernel, configs)
132133

133134
def set_compile_args(self,
134-
out_idx: Union[List[int], int, None] = None,
135+
out_idx: list[int] | int | None = None,
135136
target: Literal['auto', 'cuda', 'hip'] = 'auto',
136137
execution_backend: Literal["dlpack", "ctypes", "cython"] = "cython",
137-
target_host: Union[str, Target] = None,
138+
target_host: str | Target = None,
138139
verbose: bool = False,
139-
pass_configs: Optional[Dict[str, Any]] = None):
140+
pass_configs: dict[str, Any] | None = None):
140141
"""Set compilation arguments for the auto-tuner.
141142
142143
Args:
@@ -223,12 +224,12 @@ def set_profile_args(self,
223224

224225
return self
225226

226-
def set_kernel_parameters(self, k_parameters: Tuple[str, ...], f_parameters: Dict[str, Any]):
227+
def set_kernel_parameters(self, k_parameters: tuple[str, ...], f_parameters: dict[str, Any]):
227228
# for cache key generation
228229
self._kernel_parameters = k_parameters
229230
self._function_parameters = f_parameters
230231

231-
def generate_cache_key(self, parameters: Dict[str, Any]) -> Optional[AutotuneResult]:
232+
def generate_cache_key(self, parameters: dict[str, Any]) -> AutotuneResult | None:
232233
"""Generate a cache key for the auto-tuning process.
233234
"""
234235

@@ -307,8 +308,8 @@ def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30):
307308
return result
308309

309310
best_latency: float = 1e8
310-
best_config: Optional[Dict[str, Any]] = None
311-
best_kernel: Optional[tilelang.JITKernel] = None
311+
best_config: dict[str, Any] | None = None
312+
best_kernel: tilelang.JITKernel | None = None
312313

313314
def _compile(**config_arg) -> tilelang.JITKernel:
314315
compile_args = self.compile_args
@@ -591,7 +592,7 @@ class _AutoTunerImplementation:
591592
warmup: int = 25
592593
rep: int = 100
593594
timeout: int = 100
594-
configs: Union[Dict, Callable] = None
595+
configs: dict | Callable = None
595596
supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Auto
596597
ref_prog: Callable = None
597598
supply_prog: Callable = None
@@ -603,7 +604,7 @@ class _AutoTunerImplementation:
603604
cache_input_tensors: bool = False
604605

605606
def __init__(self,
606-
configs: Union[Dict, Callable],
607+
configs: dict | Callable,
607608
warmup: int = 25,
608609
rep: int = 100,
609610
timeout: int = 100,
@@ -653,12 +654,12 @@ def __init__(self,
653654
self.cache_input_tensors = cache_input_tensors # Reuse inputs
654655

655656
# Cache for storing tuned kernel implementations
656-
self._tuner_cache: Dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel
657+
self._tuner_cache: dict[tuple, tilelang.JITKernel] = {} # (args, kwargs) -> compiled kernel
657658

658659
# This tells the type checker what the *wrapper* function will return.
659660
# this is for linting, please do not remove it.
660661
@overload
661-
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, Tuple[_RProg, AutotuneResult]]:
662+
def __call__(self, fn: Callable[_P, _RProg]) -> Callable[_P, tuple[_RProg, AutotuneResult]]:
662663
...
663664

664665
@overload
@@ -720,9 +721,9 @@ def jit_compile(**config_arg):
720721

721722

722723
def autotune( # This is the new public interface
723-
func: Union[Callable[_P, _RProg], PrimFunc, None] = None,
724+
func: Callable[_P, _RProg] | PrimFunc | None = None,
724725
*, # Indicates subsequent arguments are keyword-only
725-
configs: Union[Dict, Callable],
726+
configs: dict | Callable,
726727
# profile arguments
727728
warmup: int = 25,
728729
rep: int = 100,

tilelang/cache/__init__.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""The cache utils with class and database persistence - Init file"""
2+
from __future__ import annotations
23

3-
from typing import List, Union, Literal, Optional
4+
from typing import Literal
45
from tvm.target import Target
56
from tvm.tir import PrimFunc
67
from tilelang.jit import JITKernel
@@ -13,14 +14,14 @@
1314

1415
def cached(
1516
func: PrimFunc = None,
16-
out_idx: List[int] = None,
17+
out_idx: list[int] = None,
1718
*args,
18-
target: Union[str, Target] = "auto",
19-
target_host: Union[str, Target] = None,
20-
execution_backend: Optional[Literal["dlpack", "ctypes", "cython", "nvrtc"]] = "cython",
21-
verbose: Optional[bool] = False,
22-
pass_configs: Optional[dict] = None,
23-
compile_flags: Optional[Union[List[str], str]] = None,
19+
target: str | Target = "auto",
20+
target_host: str | Target = None,
21+
execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] | None = "cython",
22+
verbose: bool | None = False,
23+
pass_configs: dict | None = None,
24+
compile_flags: list[str] | str | None = None,
2425
) -> JITKernel:
2526
"""
2627
Caches and reuses compiled kernels (using KernelCache class).

0 commit comments

Comments
 (0)