|
7 | 7 | import random |
8 | 8 | import sys |
9 | 9 | from datetime import timedelta |
10 | | -from platform import uname |
11 | 10 | from typing import TYPE_CHECKING, Any, NamedTuple |
12 | 11 |
|
13 | 12 | import numpy as np |
14 | 13 | import torch |
15 | | -from torch.distributed import PrefixStore, ProcessGroup |
16 | 14 |
|
17 | | -from vllm.inputs import ProcessorInputs, PromptType |
18 | 15 | from vllm.logger import init_logger |
19 | 16 |
|
20 | 17 | if TYPE_CHECKING: |
| 18 | + from torch.distributed import PrefixStore, ProcessGroup |
| 19 | + |
21 | 20 | from vllm.attention.backends.registry import _Backend |
22 | | - from vllm.config import ModelConfig, VllmConfig |
| 21 | + from vllm.config import VllmConfig |
| 22 | + from vllm.inputs import ProcessorInputs, PromptType |
23 | 23 | from vllm.pooling_params import PoolingParams |
24 | 24 | from vllm.sampling_params import SamplingParams |
25 | 25 | from vllm.utils import FlexibleArgumentParser |
26 | 26 | else: |
27 | | - _Backend = object |
28 | | - ModelConfig = object |
29 | | - VllmConfig = object |
30 | | - PoolingParams = object |
31 | | - SamplingParams = object |
32 | 27 | FlexibleArgumentParser = object |
33 | 28 |
|
34 | 29 | logger = init_logger(__name__) |
35 | 30 |
|
36 | 31 |
|
37 | 32 | def in_wsl() -> bool: |
38 | 33 | # Reference: https://github.com/microsoft/WSL/issues/4071 |
39 | | - return "microsoft" in " ".join(uname()).lower() |
| 34 | + return "microsoft" in " ".join(platform.uname()).lower() |
40 | 35 |
|
41 | 36 |
|
42 | 37 | class PlatformEnum(enum.Enum): |
@@ -178,15 +173,16 @@ def import_kernels(cls) -> None: |
178 | 173 | import vllm._moe_C # noqa: F401 |
179 | 174 |
|
180 | 175 | @classmethod |
181 | | - def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend: |
| 176 | + def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend": |
| 177 | + # Import _Backend here to avoid circular import. |
182 | 178 | from vllm.attention.backends.registry import _Backend |
183 | 179 |
|
184 | 180 | return _Backend.TORCH_SDPA |
185 | 181 |
|
186 | 182 | @classmethod |
187 | 183 | def get_attn_backend_cls( |
188 | 184 | cls, |
189 | | - selected_backend: _Backend, |
| 185 | + selected_backend: "_Backend", |
190 | 186 | head_size: int, |
191 | 187 | dtype: torch.dtype, |
192 | 188 | kv_cache_dtype: str | None, |
@@ -317,7 +313,7 @@ def pre_register_and_update( |
317 | 313 | pass |
318 | 314 |
|
319 | 315 | @classmethod |
320 | | - def check_and_update_config(cls, vllm_config: VllmConfig) -> None: |
| 316 | + def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: |
321 | 317 | """ |
322 | 318 | Check and update the configuration for the current platform. |
323 | 319 |
|
@@ -498,9 +494,9 @@ def opaque_attention_op(cls) -> bool: |
498 | 494 | @classmethod |
499 | 495 | def validate_request( |
500 | 496 | cls, |
501 | | - prompt: PromptType, |
502 | | - params: SamplingParams | PoolingParams, |
503 | | - processed_inputs: ProcessorInputs, |
| 497 | + prompt: "PromptType", |
| 498 | + params: "SamplingParams | PoolingParams", |
| 499 | + processed_inputs: "ProcessorInputs", |
504 | 500 | ) -> None: |
505 | 501 | """Raises if this request is unsupported on this platform""" |
506 | 502 |
|
@@ -543,25 +539,16 @@ def get_static_graph_wrapper_cls(cls) -> str: |
543 | 539 | def stateless_init_device_torch_dist_pg( |
544 | 540 | cls, |
545 | 541 | backend: str, |
546 | | - prefix_store: PrefixStore, |
| 542 | + prefix_store: "PrefixStore", |
547 | 543 | group_rank: int, |
548 | 544 | group_size: int, |
549 | 545 | timeout: timedelta, |
550 | | - ) -> ProcessGroup: |
| 546 | + ) -> "ProcessGroup": |
551 | 547 | """ |
552 | 548 | Init platform-specific torch distributed process group. |
553 | 549 | """ |
554 | 550 | raise NotImplementedError |
555 | 551 |
|
556 | | - @classmethod |
557 | | - def is_kv_cache_dtype_supported( |
558 | | - cls, kv_cache_dtype: str, model_config: ModelConfig |
559 | | - ) -> bool: |
560 | | - """ |
561 | | - Returns if the kv_cache_dtype is supported by the current platform. |
562 | | - """ |
563 | | - return False |
564 | | - |
565 | 552 | @classmethod |
566 | 553 | def check_if_supports_dtype(cls, dtype: torch.dtype): |
567 | 554 | """ |
|
0 commit comments