|
8 | 8 | import sys |
9 | 9 | import time |
10 | 10 | import traceback |
| 11 | +from collections.abc import Awaitable |
11 | 12 | from dataclasses import dataclass, field |
12 | | -from typing import Optional, Union |
| 13 | +from typing import Optional, Protocol, Union |
13 | 14 |
|
14 | 15 | import aiohttp |
15 | 16 | from tqdm.asyncio import tqdm |
@@ -92,6 +93,16 @@ class RequestFuncOutput: |
92 | 93 | start_time: float = 0.0 |
93 | 94 |
|
94 | 95 |
|
| 96 | +class RequestFunc(Protocol): |
| 97 | + def __call__( |
| 98 | + self, |
| 99 | + request_func_input: RequestFuncInput, |
| 100 | + session: aiohttp.ClientSession, |
| 101 | + pbar: Optional[tqdm] = None, |
| 102 | + ) -> Awaitable[RequestFuncOutput]: |
| 103 | + ... |
| 104 | + |
| 105 | + |
95 | 106 | async def async_request_openai_completions( |
96 | 107 | request_func_input: RequestFuncInput, |
97 | 108 | session: aiohttp.ClientSession, |
@@ -507,7 +518,7 @@ async def async_request_openai_embeddings( |
507 | 518 |
|
508 | 519 |
|
509 | 520 | # TODO: Add more request functions for different API protocols. |
510 | | -ASYNC_REQUEST_FUNCS = { |
| 521 | +ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = { |
511 | 522 | "vllm": async_request_openai_completions, |
512 | 523 | "openai": async_request_openai_completions, |
513 | 524 | "openai-chat": async_request_openai_chat_completions, |
|
0 commit comments