|
8 | 8 | from torchax.ops.mappings import j2t_dtype |
9 | 9 | from tpu_info import device |
10 | 10 | from vllm.inputs import ProcessorInputs, PromptType |
11 | | -from vllm.platforms.interface import Platform, PlatformEnum, _Backend |
| 11 | +from vllm.platforms.interface import Platform, PlatformEnum |
12 | 12 | from vllm.sampling_params import SamplingParams, SamplingType |
13 | 13 |
|
14 | 14 | from tpu_commons.logger import init_logger |
|
18 | 18 | if TYPE_CHECKING: |
19 | 19 | from vllm.config import BlockSize, ModelConfig, VllmConfig |
20 | 20 | from vllm.pooling_params import PoolingParams |
| 21 | + from vllm.attention.backends.registry import _Backend |
21 | 22 | else: |
22 | 23 | BlockSize = None |
23 | 24 | ModelConfig = None |
24 | 25 | VllmConfig = None |
25 | 26 | PoolingParams = None |
| 27 | + _Backend = None |
26 | 28 |
|
27 | 29 | logger = init_logger(__name__) |
28 | 30 |
|
@@ -51,10 +53,11 @@ class TpuPlatform(Platform): |
51 | 53 | ] |
52 | 54 |
|
53 | 55 | @classmethod |
54 | | - def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, |
| 56 | + def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, |
55 | 57 | dtype: jnp.dtype, kv_cache_dtype: Optional[str], |
56 | 58 | block_size: int, use_v1: bool, use_mla: bool, |
57 | 59 | has_sink: bool, use_sparse: bool) -> str: |
| 60 | + from vllm.attention.backends.registry import _Backend |
58 | 61 | if selected_backend != _Backend.PALLAS: |
59 | 62 | logger.info("Cannot use %s backend on TPU.", selected_backend) |
60 | 63 |
|
|
0 commit comments