diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index 7b149ae05..79b403626 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -7,6 +7,7 @@ import vllm.envs as vllm_envs from torchax.ops.mappings import j2t_dtype from tpu_info import device +from vllm.attention.backends.abstract import AttentionType from vllm.inputs import ProcessorInputs, PromptType from vllm.platforms.interface import Platform, PlatformEnum from vllm.sampling_params import SamplingParams, SamplingType @@ -57,7 +58,8 @@ class TpuPlatform(Platform): def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int, dtype: jnp.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, - has_sink: bool, use_sparse: bool) -> str: + has_sink: bool, use_sparse: bool, + attn_type: AttentionType) -> str: from vllm.attention.backends.registry import _Backend if selected_backend != _Backend.PALLAS: logger.info("Cannot use %s backend on TPU.", selected_backend)