Skip to content

Commit 8fcaaf6

Browse files
authored
Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
1 parent 9bb3813 commit 8fcaaf6

File tree

944 files changed

+9491
-10122
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

944 files changed

+9491
-10122
lines changed

benchmarks/backend_request_func.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import time
99
import traceback
1010
from dataclasses import dataclass, field
11-
from typing import Optional, Union
1211

1312
import aiohttp
1413
import huggingface_hub.constants
@@ -28,13 +27,13 @@ class RequestFuncInput:
2827
prompt_len: int
2928
output_len: int
3029
model: str
31-
model_name: Optional[str] = None
32-
logprobs: Optional[int] = None
33-
extra_body: Optional[dict] = None
34-
multi_modal_content: Optional[dict | list[dict]] = None
30+
model_name: str | None = None
31+
logprobs: int | None = None
32+
extra_body: dict | None = None
33+
multi_modal_content: dict | list[dict] | None = None
3534
ignore_eos: bool = False
36-
language: Optional[str] = None
37-
request_id: Optional[str] = None
35+
language: str | None = None
36+
request_id: str | None = None
3837

3938

4039
@dataclass
@@ -52,7 +51,7 @@ class RequestFuncOutput:
5251

5352
async def async_request_tgi(
5453
request_func_input: RequestFuncInput,
55-
pbar: Optional[tqdm] = None,
54+
pbar: tqdm | None = None,
5655
) -> RequestFuncOutput:
5756
api_url = request_func_input.api_url
5857
assert api_url.endswith("generate_stream")
@@ -133,7 +132,7 @@ async def async_request_tgi(
133132

134133
async def async_request_trt_llm(
135134
request_func_input: RequestFuncInput,
136-
pbar: Optional[tqdm] = None,
135+
pbar: tqdm | None = None,
137136
) -> RequestFuncOutput:
138137
api_url = request_func_input.api_url
139138
assert api_url.endswith("generate_stream")
@@ -204,7 +203,7 @@ async def async_request_trt_llm(
204203

205204
async def async_request_deepspeed_mii(
206205
request_func_input: RequestFuncInput,
207-
pbar: Optional[tqdm] = None,
206+
pbar: tqdm | None = None,
208207
) -> RequestFuncOutput:
209208
api_url = request_func_input.api_url
210209
assert api_url.endswith(("completions", "profile")), (
@@ -267,7 +266,7 @@ async def async_request_deepspeed_mii(
267266

268267
async def async_request_openai_completions(
269268
request_func_input: RequestFuncInput,
270-
pbar: Optional[tqdm] = None,
269+
pbar: tqdm | None = None,
271270
) -> RequestFuncOutput:
272271
api_url = request_func_input.api_url
273272
assert api_url.endswith(("completions", "profile")), (
@@ -367,7 +366,7 @@ async def async_request_openai_completions(
367366

368367
async def async_request_openai_chat_completions(
369368
request_func_input: RequestFuncInput,
370-
pbar: Optional[tqdm] = None,
369+
pbar: tqdm | None = None,
371370
) -> RequestFuncOutput:
372371
api_url = request_func_input.api_url
373372
assert api_url.endswith(("chat/completions", "profile")), (
@@ -476,7 +475,7 @@ async def async_request_openai_chat_completions(
476475

477476
async def async_request_openai_audio(
478477
request_func_input: RequestFuncInput,
479-
pbar: Optional[tqdm] = None,
478+
pbar: tqdm | None = None,
480479
) -> RequestFuncOutput:
481480
# Lazy import without PlaceholderModule to avoid vllm dep.
482481
import soundfile
@@ -610,7 +609,7 @@ def get_tokenizer(
610609
tokenizer_mode: str = "auto",
611610
trust_remote_code: bool = False,
612611
**kwargs,
613-
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
612+
) -> PreTrainedTokenizer | PreTrainedTokenizerFast:
614613
if pretrained_model_name_or_path is not None and not os.path.exists(
615614
pretrained_model_name_or_path
616615
):

benchmarks/benchmark_prefix_caching.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import json
3333
import random
3434
import time
35-
from typing import Optional
3635

3736
from transformers import PreTrainedTokenizerBase
3837

@@ -80,7 +79,7 @@ def sample_requests_from_dataset(
8079
num_requests: int,
8180
tokenizer: PreTrainedTokenizerBase,
8281
input_length_range: tuple[int, int],
83-
fixed_output_len: Optional[int],
82+
fixed_output_len: int | None,
8483
) -> list[Request]:
8584
if fixed_output_len is not None and fixed_output_len < 4:
8685
raise ValueError("output_len too small")
@@ -128,7 +127,7 @@ def sample_requests_from_random(
128127
num_requests: int,
129128
tokenizer: PreTrainedTokenizerBase,
130129
input_length_range: tuple[int, int],
131-
fixed_output_len: Optional[int],
130+
fixed_output_len: int | None,
132131
prefix_len: int,
133132
) -> list[Request]:
134133
requests = []

benchmarks/benchmark_prioritization.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import json
88
import random
99
import time
10-
from typing import Optional
1110

1211
from transformers import AutoTokenizer, PreTrainedTokenizerBase
1312

@@ -24,7 +23,7 @@ def sample_requests(
2423
dataset_path: str,
2524
num_requests: int,
2625
tokenizer: PreTrainedTokenizerBase,
27-
fixed_output_len: Optional[int],
26+
fixed_output_len: int | None,
2827
) -> list[tuple[str, int, int, int]]:
2928
if fixed_output_len is not None and fixed_output_len < 4:
3029
raise ValueError("output_len too small")

benchmarks/benchmark_serving_structured_output.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import warnings
3333
from collections.abc import AsyncGenerator
3434
from dataclasses import dataclass
35-
from typing import Optional
3635

3736
import datasets
3837
import numpy as np
@@ -316,7 +315,7 @@ def calculate_metrics(
316315
tokenizer: PreTrainedTokenizerBase,
317316
selected_percentile_metrics: list[str],
318317
selected_percentiles: list[float],
319-
goodput_config_dict: Optional[dict[str, float]] = None,
318+
goodput_config_dict: dict[str, float] | None = None,
320319
) -> tuple[BenchmarkMetrics, list[int]]:
321320
actual_output_lens: list[int] = []
322321
total_input = 0
@@ -436,9 +435,9 @@ async def benchmark(
436435
selected_percentile_metrics: list[str],
437436
selected_percentiles: list[str],
438437
ignore_eos: bool,
439-
max_concurrency: Optional[int],
438+
max_concurrency: int | None,
440439
structured_output_ratio: float,
441-
goodput_config_dict: Optional[dict[str, float]] = None,
440+
goodput_config_dict: dict[str, float] | None = None,
442441
):
443442
if backend in ASYNC_REQUEST_FUNCS:
444443
request_func = ASYNC_REQUEST_FUNCS[backend]

benchmarks/benchmark_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
import time
88
from types import TracebackType
9-
from typing import Any, Optional, Union
9+
from typing import Any
1010

1111

1212
def convert_to_pytorch_benchmark_format(
@@ -92,7 +92,7 @@ class TimeCollector:
9292
def __init__(self, scale: int) -> None:
9393
self.cnt: int = 0
9494
self._sum: int = 0
95-
self._max: Optional[int] = None
95+
self._max: int | None = None
9696
self.scale = scale
9797
self.start_time: int = time.monotonic_ns()
9898

@@ -104,22 +104,22 @@ def collect(self, v: int) -> None:
104104
else:
105105
self._max = max(self._max, v)
106106

107-
def avg(self) -> Union[float, str]:
107+
def avg(self) -> float | str:
108108
return self._sum * 1.0 / self.cnt / self.scale if self.cnt > 0 else "N/A"
109109

110-
def max(self) -> Union[float, str]:
110+
def max(self) -> float | str:
111111
return self._max / self.scale if self._max else "N/A"
112112

113-
def dump_avg_max(self) -> list[Union[float, str]]:
113+
def dump_avg_max(self) -> list[float | str]:
114114
return [self.avg(), self.max()]
115115

116116
def __enter__(self) -> None:
117117
self.start_time = time.monotonic_ns()
118118

119119
def __exit__(
120120
self,
121-
exc_type: Optional[type[BaseException]],
122-
exc_value: Optional[BaseException],
123-
exc_traceback: Optional[TracebackType],
121+
exc_type: type[BaseException] | None,
122+
exc_value: BaseException | None,
123+
exc_traceback: TracebackType | None,
124124
) -> None:
125125
self.collect(time.monotonic_ns() - self.start_time)

benchmarks/cutlass_benchmarks/sparse_benchmarks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import itertools
77
import pickle as pkl
88
import time
9-
from collections.abc import Iterable
10-
from typing import Callable
9+
from collections.abc import Callable, Iterable
1110

1211
import torch
1312
import torch.utils.benchmark as TBenchmark

benchmarks/cutlass_benchmarks/w8a8_benchmarks.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
import itertools
77
import pickle as pkl
88
import time
9-
from collections.abc import Iterable
10-
from typing import Callable, Optional
9+
from collections.abc import Callable, Iterable
1110

1211
import torch
1312
import torch.utils.benchmark as TBenchmark
@@ -53,7 +52,7 @@ def bench_int8(
5352
n: int,
5453
label: str,
5554
sub_label: str,
56-
bench_kernels: Optional[list[str]] = None,
55+
bench_kernels: list[str] | None = None,
5756
) -> Iterable[TMeasurement]:
5857
"""Benchmark INT8-based kernels."""
5958
assert dtype == torch.int8
@@ -108,7 +107,7 @@ def bench_fp8(
108107
n: int,
109108
label: str,
110109
sub_label: str,
111-
bench_kernels: Optional[list[str]] = None,
110+
bench_kernels: list[str] | None = None,
112111
) -> Iterable[TMeasurement]:
113112
"""Benchmark FP8-based kernels."""
114113
assert dtype == torch.float8_e4m3fn
@@ -183,7 +182,7 @@ def bench(
183182
n: int,
184183
label: str,
185184
sub_label: str,
186-
bench_kernels: Optional[list[str]] = None,
185+
bench_kernels: list[str] | None = None,
187186
) -> Iterable[TMeasurement]:
188187
if dtype == torch.int8:
189188
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
@@ -201,7 +200,7 @@ def print_timers(timers: Iterable[TMeasurement]):
201200
def run(
202201
dtype: torch.dtype,
203202
MKNs: Iterable[tuple[int, int, int]],
204-
bench_kernels: Optional[list[str]] = None,
203+
bench_kernels: list[str] | None = None,
205204
) -> Iterable[TMeasurement]:
206205
results = []
207206
for m, k, n in MKNs:

benchmarks/fused_kernels/layernorm_rms_benchmarks.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33

44
import pickle as pkl
55
import time
6-
from collections.abc import Iterable
6+
from collections.abc import Callable, Iterable
77
from dataclasses import dataclass
88
from itertools import product
9-
from typing import Callable, Optional
109

1110
import torch
1211
import torch.utils.benchmark as TBenchmark
@@ -51,7 +50,7 @@ def get_bench_params() -> list[bench_params_t]:
5150
def unfused_int8_impl(
5251
rms_norm_layer: RMSNorm,
5352
x: torch.Tensor,
54-
residual: Optional[torch.Tensor],
53+
residual: torch.Tensor | None,
5554
quant_dtype: torch.dtype,
5655
):
5756
# Norm
@@ -68,7 +67,7 @@ def unfused_int8_impl(
6867
def unfused_fp8_impl(
6968
rms_norm_layer: RMSNorm,
7069
x: torch.Tensor,
71-
residual: Optional[torch.Tensor],
70+
residual: torch.Tensor | None,
7271
quant_dtype: torch.dtype,
7372
):
7473
# Norm
@@ -85,7 +84,7 @@ def unfused_fp8_impl(
8584
def fused_impl(
8685
rms_norm_layer: RMSNorm, # this stores the weights
8786
x: torch.Tensor,
88-
residual: Optional[torch.Tensor],
87+
residual: torch.Tensor | None,
8988
quant_dtype: torch.dtype,
9089
):
9190
out, _ = ops.rms_norm_dynamic_per_token_quant(

benchmarks/kernels/bench_per_token_quant_fp8.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import itertools
4-
from typing import Callable
4+
from collections.abc import Callable
55
from unittest.mock import patch
66

77
import pandas as pd

benchmarks/kernels/benchmark_device_communicators.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
import json
2323
import os
2424
import time
25+
from collections.abc import Callable
2526
from contextlib import nullcontext
26-
from typing import Callable, Optional
2727

2828
import torch
2929
import torch.distributed as dist
@@ -264,12 +264,12 @@ def benchmark_allreduce(
264264
def benchmark_allreduce_single(
265265
self,
266266
sequence_length: int,
267-
allreduce_fn: Callable[[torch.Tensor], Optional[torch.Tensor]],
267+
allreduce_fn: Callable[[torch.Tensor], torch.Tensor | None],
268268
should_use_fn: Callable[[torch.Tensor], bool],
269269
context,
270270
num_warmup: int,
271271
num_trials: int,
272-
) -> Optional[float]:
272+
) -> float | None:
273273
"""Benchmark method with CUDA graph optimization."""
274274
try:
275275
# Create test tensor (2D: sequence_length x hidden_size)

0 commit comments

Comments
 (0)