Skip to content

Commit 280b389

Browse files
committed
fix tpu torch compile error
Signed-off-by: Chenyaaang <chenyangli@google.com>
1 parent 2a03f93 commit 280b389

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

vllm/config/vllm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,9 +330,12 @@ def __post_init__(self):
330330
+ self.compilation_config.custom_ops.count("all")
331331
== 0
332332
):
333+
from vllm.platforms import current_platform
334+
333335
if (
334336
self.compilation_config.level > 0
335337
and self.compilation_config.backend != "eager"
338+
and not current_platform.is_tpu()
336339
):
337340
self.compilation_config.custom_ops.append("none")
338341
else:

vllm/platforms/tpu.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
139139
)
140140
compilation_config.cudagraph_mode = CUDAGraphMode.NONE
141141

142-
if compilation_config.backend == "":
143-
compilation_config.backend = "openxla"
142+
# Note: the default backend is set to inductor now
143+
# we want to overwrite to openxla to execute the ops properly on TPU.
144+
compilation_config.backend = "openxla"
144145

145146
assert vllm_config.speculative_config is None, (
146147
"TPU does not support speculative decoding"

0 commit comments

Comments
 (0)