44
55import torch
66
7+ import vllm .envs as envs
78from vllm .logger import init_logger
89
910from .interface import Platform , PlatformEnum , _Backend
@@ -33,22 +34,28 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
3334 dtype : torch .dtype , kv_cache_dtype : Optional [str ],
3435 block_size : int , use_v1 : bool ,
3536 use_mla : bool ) -> str :
36- if selected_backend != _Backend .PALLAS :
37+ if (selected_backend != _Backend .PALLAS
38+ and selected_backend != _Backend .PALLAS_VLLM_V1 ):
3739 logger .info ("Cannot use %s backend on TPU." , selected_backend )
38- logger .info ("Using Pallas backend." )
39- return "vllm.attention.backends.pallas.PallasAttentionBackend"
40+
41+ if use_v1 :
42+ logger .info ("Using Pallas V1 backend." )
43+ return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
44+ else :
45+ logger .info ("Using Pallas backend." )
46+ return "vllm.attention.backends.pallas.PallasAttentionBackend"
4047
4148 @classmethod
4249 def get_device_name (cls , device_id : int = 0 ) -> str :
43- raise NotImplementedError
50+ return "tpu"
4451
4552 @classmethod
4653 def get_device_total_memory (cls , device_id : int = 0 ) -> int :
4754 raise NotImplementedError
4855
4956 @classmethod
5057 def is_async_output_supported (cls , enforce_eager : Optional [bool ]) -> bool :
51- return True
58+ return not envs . VLLM_USE_V1
5259
5360 @classmethod
5461 def inference_mode (cls ):
@@ -63,22 +70,18 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
6370 cache_config .block_size = 16
6471
6572 compilation_config = vllm_config .compilation_config
66- if compilation_config .level == CompilationLevel .NO_COMPILATION :
67- # TPU does not support NO_COMPILATION
73+
74+ # TPU only supports DYNAMO_ONCE compilation level
75+ if compilation_config .level != CompilationLevel .DYNAMO_ONCE :
76+ logger .info ("[TPU] Forcing DYNAMO_ONCE compilation level" )
6877 compilation_config .level = CompilationLevel .DYNAMO_ONCE
69- assert compilation_config .level < CompilationLevel .PIECEWISE ,\
70- "TPU does not support Inductor."
7178
7279 if compilation_config .backend == "" :
7380 compilation_config .backend = "openxla"
7481
7582 assert vllm_config .speculative_config is None , \
7683 "TPU does not support speculative decoding"
7784
78- assert not vllm_config .scheduler_config .chunked_prefill_enabled , (
79- "Chunked prefill is not yet supported for TPU backend" )
80- assert not vllm_config .speculative_config , (
81- "Speculative decoding is not yet supported for TPU backend" )
8285 if vllm_config .model_config .dtype in (torch .float16 , torch .float32 ):
8386 logger .warning (
8487 "The TPU backend currently does not support %s. "
@@ -88,8 +91,27 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
8891 parallel_config = vllm_config .parallel_config
8992 scheduler_config = vllm_config .scheduler_config
9093 if parallel_config .worker_cls == "auto" :
91- if scheduler_config . is_multi_step :
94+ if envs . VLLM_USE_V1 :
9295 parallel_config .worker_cls = \
93- "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker "
96+ "vllm.v1. worker.tpu_worker.TPUWorker "
9497 else :
95- parallel_config .worker_cls = "vllm.worker.tpu_worker.TPUWorker"
98+ if scheduler_config .is_multi_step :
99+ parallel_config .worker_cls = \
100+ "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
101+ else :
102+ parallel_config .worker_cls = \
103+ "vllm.worker.tpu_worker.TPUWorker"
104+
105+ # Adjust scheduler config for V1
106+ # TODO: Add support for these
107+ if envs .VLLM_USE_V1 and vllm_config .cache_config .enable_prefix_caching :
108+ logger .warning ("[V1][TPU] Disable prefix caching" )
109+ vllm_config .cache_config .enable_prefix_caching = False
110+
111+ assert not vllm_config .speculative_config , (
112+ "Speculative decoding is not yet supported for TPU backend" )
113+
114+ @classmethod
115+ def is_pin_memory_available (cls ):
116+ logger .warning ("Pin memory is not supported on TPU." )
117+ return False
0 commit comments