Skip to content

Commit e4c8629

Browse files
committed
pre-commit run -a
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
1 parent ceb7cc2 commit e4c8629

File tree

594 files changed

+33881
-28948
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

594 files changed

+33881
-28948
lines changed

tests/async_engine/api_server_async_engine.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""vllm.entrypoints.api_server with some extra logging for testing."""
4+
45
from collections.abc import Iterable
56
from typing import Any
67

@@ -17,7 +18,6 @@
1718

1819

1920
class AsyncLLMEngineWithStats(AsyncLLMEngine):
20-
2121
def __init__(self, *args, **kwargs):
2222
super().__init__(*args, **kwargs)
2323
self._num_aborts = 0
@@ -47,8 +47,10 @@ def stats() -> Response:
4747
engine_args = AsyncEngineArgs.from_cli_args(args)
4848
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args)
4949
vllm.entrypoints.api_server.engine = engine
50-
uvicorn.run(app,
51-
host=args.host,
52-
port=args.port,
53-
log_level="debug",
54-
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE)
50+
uvicorn.run(
51+
app,
52+
host=args.host,
53+
port=args.port,
54+
log_level="debug",
55+
timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
56+
)

tests/async_engine/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ def use_v0_only(monkeypatch):
99
Since this module is V0 only, set VLLM_USE_V1=0 for
1010
all tests in the module.
1111
"""
12-
monkeypatch.setenv('VLLM_USE_V1', '0')
12+
monkeypatch.setenv("VLLM_USE_V1", "0")

tests/async_engine/test_api_server.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313

1414

1515
def _query_server(prompt: str, max_tokens: int = 5) -> dict:
16-
response = requests.post("http://localhost:8000/generate",
17-
json={
18-
"prompt": prompt,
19-
"max_tokens": max_tokens,
20-
"temperature": 0,
21-
"ignore_eos": True
22-
})
16+
response = requests.post(
17+
"http://localhost:8000/generate",
18+
json={
19+
"prompt": prompt,
20+
"max_tokens": max_tokens,
21+
"temperature": 0,
22+
"ignore_eos": True,
23+
},
24+
)
2325
response.raise_for_status()
2426
return response.json()
2527

@@ -30,8 +32,9 @@ def _query_server_long(prompt: str) -> dict:
3032

3133
@pytest.fixture
3234
def api_server(distributed_executor_backend: str):
33-
script_path = Path(__file__).parent.joinpath(
34-
"api_server_async_engine.py").absolute()
35+
script_path = (
36+
Path(__file__).parent.joinpath("api_server_async_engine.py").absolute()
37+
)
3538
commands = [
3639
sys.executable,
3740
"-u",
@@ -80,8 +83,9 @@ def test_api_server(api_server, distributed_executor_backend: str):
8083
for result in pool.map(_query_server, prompts):
8184
assert result
8285

83-
num_aborted_requests = requests.get(
84-
"http://localhost:8000/stats").json()["num_aborted_requests"]
86+
num_aborted_requests = requests.get("http://localhost:8000/stats").json()[
87+
"num_aborted_requests"
88+
]
8589
assert num_aborted_requests == 0
8690

8791
# Try with 100 prompts
@@ -101,8 +105,9 @@ def test_api_server(api_server, distributed_executor_backend: str):
101105
# give it some times to update the stats
102106
time.sleep(1)
103107

104-
num_aborted_requests = requests.get(
105-
"http://localhost:8000/stats").json()["num_aborted_requests"]
108+
num_aborted_requests = requests.get("http://localhost:8000/stats").json()[
109+
"num_aborted_requests"
110+
]
106111
assert num_aborted_requests > 0
107112

108113
# check that server still runs after cancellations

tests/async_engine/test_async_llm_engine.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ class MockModelConfig:
3636

3737

3838
class MockEngine:
39-
4039
def __init__(self):
4140
self.step_calls = 0
4241
self.add_request_calls = 0
@@ -49,8 +48,7 @@ def __init__(self):
4948
async def step_async(self, virtual_engine):
5049
# PP size is 1, ignore virtual engine
5150
self.step_calls += 1
52-
return [RequestOutput(
53-
request_id=self.request_id)] if self.request_id else []
51+
return [RequestOutput(request_id=self.request_id)] if self.request_id else []
5452

5553
async def process_model_inputs_async(self, *args, **kwargs):
5654
pass
@@ -67,7 +65,7 @@ def stop_generating(self):
6765
def add_request(self, **kwargs):
6866
del kwargs # Unused
6967
self.add_request_calls += 1
70-
print(f'Request calls: {self.add_request_calls}')
68+
print(f"Request calls: {self.add_request_calls}")
7169

7270
async def add_request_async(self, **kwargs):
7371
self.add_request_calls += 1
@@ -142,9 +140,12 @@ def start_engine():
142140
print(f"Starting engine with num_scheduler_steps={num_scheduler_steps}")
143141

144142
return AsyncLLMEngine.from_engine_args(
145-
AsyncEngineArgs(model="facebook/opt-125m",
146-
enforce_eager=True,
147-
num_scheduler_steps=num_scheduler_steps))
143+
AsyncEngineArgs(
144+
model="facebook/opt-125m",
145+
enforce_eager=True,
146+
num_scheduler_steps=num_scheduler_steps,
147+
)
148+
)
148149

149150

150151
def uid() -> str:
@@ -157,8 +158,9 @@ async def async_engine():
157158
# scoped fixture and monkeypatch is function scoped.
158159
previous_value = os.getenv("VLLM_USE_V1", None)
159160
os.environ["VLLM_USE_V1"] = "0"
160-
engine = await asyncio.get_event_loop().run_in_executor(executor=None,
161-
func=start_engine)
161+
engine = await asyncio.get_event_loop().run_in_executor(
162+
executor=None, func=start_engine
163+
)
162164
try:
163165
yield engine
164166
finally:
@@ -182,7 +184,6 @@ def should_do_global_cleanup_after_test(request) -> bool:
182184
@pytest.mark.asyncio(scope="module")
183185
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
184186
async def test_asyncio_run(async_engine, stop):
185-
186187
scheduler_config = await async_engine.get_scheduler_config()
187188
num_scheduler_steps = scheduler_config.num_scheduler_steps
188189

@@ -196,9 +197,9 @@ async def run(prompt: str):
196197

197198
output_count = 0
198199
final_output = None
199-
async for output in async_engine.generate(prompt,
200-
sampling_params,
201-
request_id=uid()):
200+
async for output in async_engine.generate(
201+
prompt, sampling_params, request_id=uid()
202+
):
202203
output_count += 1
203204
final_output = output
204205
return final_output, output_count
@@ -247,18 +248,19 @@ async def run(prompt: str, kind: RequestOutputKind):
247248

248249
output_count = 0
249250
final_output = None
250-
async for output in async_engine.generate(prompt,
251-
params,
252-
request_id=uid()):
251+
async for output in async_engine.generate(prompt, params, request_id=uid()):
253252
output_count += 1
254253
final_output = output
255254

256255
assert final_output is not None
257256
assert final_output.finished
258257

259-
return (final_output.prompt_token_ids,
260-
final_output.outputs[0].token_ids,
261-
final_output.outputs[0].text, output_count)
258+
return (
259+
final_output.prompt_token_ids,
260+
final_output.outputs[0].token_ids,
261+
final_output.outputs[0].text,
262+
output_count,
263+
)
262264

263265
async def run_deltas(prompt: str):
264266
params = copy(sampling_params)
@@ -269,9 +271,7 @@ async def run_deltas(prompt: str):
269271
output_text = ""
270272
output_count = 0
271273
final_output = None
272-
async for output in async_engine.generate(prompt,
273-
params,
274-
request_id=uid()):
274+
async for output in async_engine.generate(prompt, params, request_id=uid()):
275275
token_ids = output.outputs[0].token_ids
276276
text = output.outputs[0].text
277277
final_output = output
@@ -298,7 +298,8 @@ async def run_deltas(prompt: str):
298298
results = await asyncio.gather(
299299
run("common input prompt", RequestOutputKind.CUMULATIVE),
300300
run("common input prompt", RequestOutputKind.FINAL_ONLY),
301-
run_deltas("common input prompt"))
301+
run_deltas("common input prompt"),
302+
)
302303

303304
# Make sure outputs are the same
304305
prompt_set = set(tuple(prompt_ids) for prompt_ids, _, _, _ in results)
@@ -342,9 +343,9 @@ async def test_cancellation(async_engine, stop):
342343

343344
i = 0
344345
with pytest.raises(CancelledError):
345-
async for output in async_engine.generate("test2",
346-
sampling_params,
347-
request_id=request_id):
346+
async for output in async_engine.generate(
347+
"test2", sampling_params, request_id=request_id
348+
):
348349
assert not output.finished
349350
i += 1
350351
if i == stop_at:
@@ -402,8 +403,7 @@ async def test_invalid_argument(async_engine):
402403

403404
# Targeting specific DP rank only supported in v1 multi-instance DP
404405
with pytest.raises(ValueError):
405-
async for _ in async_engine.generate("test",
406-
sampling_params,
407-
request_id=uid(),
408-
data_parallel_rank=0):
406+
async for _ in async_engine.generate(
407+
"test", sampling_params, request_id=uid(), data_parallel_rank=0
408+
):
409409
pass

tests/async_engine/test_request_tracker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ async def test_request_tracker():
6060
stream_5 = tracker.add_request("5")
6161
assert tracker.new_requests_event.is_set()
6262
tracker.process_request_output(
63-
RequestOutput("2", "output", [], [], [], finished=True))
63+
RequestOutput("2", "output", [], [], [], finished=True)
64+
)
6465
await tracker.wait_for_new_requests()
6566
new, aborted = tracker.get_new_and_aborted_requests()
6667
assert not tracker.new_requests_event.is_set()

0 commit comments

Comments
 (0)