Skip to content

Commit 9e4a464

Browse files
committed
ruff check --fix --unsafe-fixes (I checked them all)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
1 parent e4c8629 commit 9e4a464

File tree

13 files changed

+101
-112
lines changed

13 files changed

+101
-112
lines changed

tests/compile/test_full_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from __future__ import annotations
55

66
import tempfile
7-
from typing import Any, Optional, Union
7+
from typing import Any, Union
88

99
import pytest
1010
import torch
@@ -17,7 +17,7 @@
1717
from ..utils import create_new_process_for_each_test
1818

1919

20-
def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
20+
def models_list(*, all: bool = True, keywords: list[str] | None = None):
2121
TEST_MODELS: list[tuple[str, dict[str, Any]]] = [
2222
("facebook/opt-125m", {}),
2323
(

tests/kernels/attention/test_lightning_attn.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,7 @@ def reference_lightning_attention(q, k, v, ed, block_size, kv_history):
3333

3434
# More efficient implementation
3535
# Convert decay factors to matrix form
36-
if ed.dim() == 1:
37-
decay = torch.exp(-ed).view(1, -1, 1, 1)
38-
else:
39-
decay = torch.exp(-ed)
36+
decay = torch.exp(-ed).view(1, -1, 1, 1) if ed.dim() == 1 else torch.exp(-ed)
4037

4138
for b in range(B):
4239
for step in range(S):

tests/kernels/quantization/test_cutlass_scaled_mm.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,7 @@ def cutlass_fp8_gemm_helper(
8888
# make scales K-major for blockwise quant, doesn't affect 1D scales
8989
scale_b = scale_b.t().contiguous().t()
9090

91-
if use_bias:
92-
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
93-
else:
94-
bias = None
91+
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
9592

9693
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
9794
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
@@ -122,10 +119,7 @@ def cutlass_int8_gemm_helper(
122119
scale_a = torch.randn(a_scales_shape, device=device, dtype=torch.float32)
123120
scale_b = torch.randn(b_scales_shape, device=device, dtype=torch.float32)
124121

125-
if use_bias:
126-
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10
127-
else:
128-
bias = None
122+
bias = torch.rand((n,), device=device, dtype=out_dtype) * 10 if use_bias else None
129123

130124
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
131125
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)

tests/v1/engine/test_llm_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import random
6-
from typing import TYPE_CHECKING, Optional
6+
from typing import TYPE_CHECKING
77

88
import pytest
99

@@ -78,7 +78,7 @@ def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch):
7878

7979
def _get_test_sampling_params(
8080
prompt_list: list[str],
81-
seed: Optional[int] = 42,
81+
seed: int | None = 42,
8282
structured_outputs: bool = False,
8383
) -> tuple[list[SamplingParams], list[int]]:
8484
"""Generate random sampling params for a batch."""

vllm/model_executor/guided_decoding/outlines_logits_processors.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import importlib.metadata
99
import json
1010
import os
11-
from typing import Optional, Union
11+
from typing import Union
1212

1313
import regex as re
1414
import torch
@@ -36,11 +36,11 @@
3636
class BaseLogitsProcessor:
3737

3838
def __init__(self, guide: Guide, eos_token_id: int,
39-
reasoner: Optional[ReasoningParser]) -> None:
39+
reasoner: ReasoningParser | None) -> None:
4040
self._guide: Guide = guide
4141
self._eos_token_id: int = eos_token_id
42-
self._reasoner: Optional[ReasoningParser] = reasoner
43-
self._mask: Optional[torch.Tensor] = None
42+
self._reasoner: ReasoningParser | None = reasoner
43+
self._mask: torch.Tensor | None = None
4444

4545
def __call__(self, input_ids: list[int],
4646
scores: torch.Tensor) -> torch.Tensor:
@@ -114,7 +114,7 @@ def _get_guide(cls, regex_string: str,
114114
return Guide(index)
115115

116116
def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase,
117-
reasoner: Optional[ReasoningParser]) -> None:
117+
reasoner: ReasoningParser | None) -> None:
118118
super().__init__(
119119
guide=RegexLogitsProcessor._get_guide(regex_string, tokenizer),
120120
eos_token_id=tokenizer.eos_token_id, # type: ignore
@@ -126,7 +126,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
126126
def __init__(self, schema: Union[str, dict, BaseModel],
127127
tokenizer: PreTrainedTokenizerBase,
128128
whitespace_pattern: Union[str, None],
129-
reasoner: Optional[ReasoningParser]) -> None:
129+
reasoner: ReasoningParser | None) -> None:
130130

131131
if isinstance(schema, type(BaseModel)):
132132
schema_str = json.dumps(schema.model_json_schema())

vllm/reasoning/abs_reasoning_parsers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from abc import abstractmethod
88
from collections.abc import Sequence
99
from functools import cached_property
10-
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
10+
from typing import TYPE_CHECKING, Any, Callable, Union
1111

1212
from vllm.logger import init_logger
1313
from vllm.utils import import_from_path, is_list_of
@@ -77,7 +77,7 @@ def extract_reasoning_content(
7777
self,
7878
model_output: str,
7979
request: Union[ChatCompletionRequest, ResponsesRequest],
80-
) -> tuple[Optional[str], Optional[str]]:
80+
) -> tuple[str | None, str | None]:
8181
"""
8282
Extract reasoning content from a complete model-generated string.
8383
@@ -135,7 +135,7 @@ def get_reasoning_parser(cls, name: str | None) -> type[ReasoningParser]:
135135
def _register_module(
136136
cls,
137137
module: type,
138-
module_name: Optional[Union[str, list[str]]] = None,
138+
module_name: Union[str, list[str]] | None = None,
139139
force: bool = True,
140140
) -> None:
141141
if not issubclass(module, ReasoningParser):
@@ -155,7 +155,7 @@ def _register_module(
155155
@classmethod
156156
def register_module(
157157
cls,
158-
name: Optional[Union[str, list[str]]] = None,
158+
name: Union[str, list[str]] | None = None,
159159
force: bool = True,
160160
module: Union[type, None] = None,
161161
) -> Union[type, Callable]:

0 commit comments

Comments
 (0)