diff --git a/tpu_commons/platforms/tpu_jax.py b/tpu_commons/platforms/tpu_jax.py index 08dbbc38a..fe9c2d528 100644 --- a/tpu_commons/platforms/tpu_jax.py +++ b/tpu_commons/platforms/tpu_jax.py @@ -8,7 +8,7 @@ from torchax.ops.mappings import j2t_dtype from tpu_info import device from vllm.inputs import ProcessorInputs, PromptType -from vllm.platforms.interface import Platform, PlatformEnum, _Backend +from vllm.platforms.interface import Platform, PlatformEnum from vllm.sampling_params import SamplingParams, SamplingType from tpu_commons.logger import init_logger @@ -16,6 +16,7 @@ update_vllm_config_for_qwix_quantization if TYPE_CHECKING: + from vllm.attention.backends.registry import _Backend from vllm.config import BlockSize, ModelConfig, VllmConfig from vllm.pooling_params import PoolingParams else: @@ -23,6 +24,7 @@ ModelConfig = None VllmConfig = None PoolingParams = None + _Backend = None logger = init_logger(__name__) @@ -51,10 +53,11 @@ class TpuPlatform(Platform): ] @classmethod - def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, + def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, dtype: jnp.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, has_sink: bool, use_sparse: bool) -> str: + from vllm.attention.backends.registry import _Backend if selected_backend != _Backend.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend)