Skip to content

Commit e434176

Browse files
samzongcharlifu
authored andcommitted
refactor(benchmarks): add type annotations to wait_for_endpoint parameters (vllm-project#25218)
Signed-off-by: samzong <samzong.lu@gmail.com> Signed-off-by: charlifu <charlifu@amd.com>
1 parent a68facf commit e434176

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

vllm/benchmarks/lib/endpoint_request_func.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
import sys
99
import time
1010
import traceback
11+
from collections.abc import Awaitable
1112
from dataclasses import dataclass, field
12-
from typing import Optional, Union
13+
from typing import Optional, Protocol, Union
1314

1415
import aiohttp
1516
from tqdm.asyncio import tqdm
@@ -92,6 +93,16 @@ class RequestFuncOutput:
9293
start_time: float = 0.0
9394

9495

96+
class RequestFunc(Protocol):
97+
def __call__(
98+
self,
99+
request_func_input: RequestFuncInput,
100+
session: aiohttp.ClientSession,
101+
pbar: Optional[tqdm] = None,
102+
) -> Awaitable[RequestFuncOutput]:
103+
...
104+
105+
95106
async def async_request_openai_completions(
96107
request_func_input: RequestFuncInput,
97108
session: aiohttp.ClientSession,
@@ -507,7 +518,7 @@ async def async_request_openai_embeddings(
507518

508519

509520
# TODO: Add more request functions for different API protocols.
510-
ASYNC_REQUEST_FUNCS = {
521+
ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
511522
"vllm": async_request_openai_completions,
512523
"openai": async_request_openai_completions,
513524
"openai-chat": async_request_openai_chat_completions,

vllm/benchmarks/lib/ready_checker.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
import aiohttp
99
from tqdm.asyncio import tqdm
1010

11-
from .endpoint_request_func import RequestFuncInput, RequestFuncOutput
11+
from .endpoint_request_func import (RequestFunc, RequestFuncInput,
12+
RequestFuncOutput)
1213

1314

1415
async def wait_for_endpoint(
15-
request_func,
16+
request_func: RequestFunc,
1617
test_input: RequestFuncInput,
1718
session: aiohttp.ClientSession,
1819
timeout_seconds: int = 600,

0 commit comments

Comments
 (0)