Skip to content

Commit 90f70e6

Browse files
committed
fix CI error about attn backend cls
Signed-off-by: Chenyaaang <chenyangli@google.com>
1 parent 96f523d commit 90f70e6

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

tpu_commons/platforms/tpu_jax.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torchax.ops.mappings import j2t_dtype
99
from tpu_info import device
1010
from vllm.inputs import ProcessorInputs, PromptType
11-
from vllm.platforms.interface import Platform, PlatformEnum, _Backend
11+
from vllm.platforms.interface import Platform, PlatformEnum
1212
from vllm.sampling_params import SamplingParams, SamplingType
1313

1414
from tpu_commons.logger import init_logger
@@ -18,11 +18,13 @@
1818
if TYPE_CHECKING:
1919
from vllm.config import BlockSize, ModelConfig, VllmConfig
2020
from vllm.pooling_params import PoolingParams
21+
from vllm.attention.backends.registry import _Backend
2122
else:
2223
BlockSize = None
2324
ModelConfig = None
2425
VllmConfig = None
2526
PoolingParams = None
27+
_Backend = None
2628

2729
logger = init_logger(__name__)
2830

@@ -51,10 +53,11 @@ class TpuPlatform(Platform):
5153
]
5254

5355
@classmethod
54-
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
56+
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
5557
dtype: jnp.dtype, kv_cache_dtype: Optional[str],
5658
block_size: int, use_v1: bool, use_mla: bool,
5759
has_sink: bool, use_sparse: bool) -> str:
60+
from vllm.attention.backends.registry import _Backend
5861
if selected_backend != _Backend.PALLAS:
5962
logger.info("Cannot use %s backend on TPU.", selected_backend)
6063

0 commit comments

Comments
 (0)