From 44753cf640188faa41c2f9363dbd6ef971b5c25d Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Mon, 18 Aug 2025 05:55:03 +0000 Subject: [PATCH] [TPU] make ptxla not imported when using tpu_commons Signed-off-by: Chengji Yao Signed-off-by: Chengji Yao --- .../device_communicators/tpu_communicator.py | 27 +++--- .../layers/fused_moe/moe_pallas.py | 2 +- .../model_loader/default_loader.py | 21 ++-- vllm/platforms/tpu.py | 3 + vllm/v1/attention/backends/pallas.py | 97 ++++++++++--------- vllm/v1/worker/tpu_worker.py | 22 +++-- 6 files changed, 94 insertions(+), 78 deletions(-) diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index c60a7a7eb25c..942dd67f065d 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -10,6 +10,7 @@ from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.platforms.tpu import USE_TPU_COMMONS from .base_device_communicator import DeviceCommunicatorBase @@ -18,16 +19,17 @@ logger = init_logger(__name__) -if current_platform.is_tpu(): - import torch_xla - import torch_xla.core.xla_model as xm - import torch_xla.runtime as xr - from torch_xla._internal import pjrt - from torch_xla.distributed.xla_multiprocessing import ( - create_optimized_replica_groups) - - if USE_RAY: - from vllm.executor import ray_utils +if not USE_TPU_COMMONS: + logger.info("tpu_commons not found, using vLLM's TpuCommunicator") + if current_platform.is_tpu(): + import torch_xla + import torch_xla.core.xla_model as xm + import torch_xla.runtime as xr + from torch_xla._internal import pjrt + from torch_xla.distributed.xla_multiprocessing import ( + create_optimized_replica_groups) + if USE_RAY: + from vllm.executor import ray_utils class TpuCommunicator(DeviceCommunicatorBase): @@ -94,10 +96,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: return xm.all_gather(input_, dim=dim) -try: +if USE_TPU_COMMONS: from tpu_commons.distributed.device_communicators import ( TpuCommunicator as TpuCommonsCommunicator) TpuCommunicator = TpuCommonsCommunicator # type: ignore -except ImportError: - logger.info("tpu_commons not found, using vLLM's TpuCommunicator") - pass diff --git a/vllm/model_executor/layers/fused_moe/moe_pallas.py b/vllm/model_executor/layers/fused_moe/moe_pallas.py index d35bd0098b3c..582ae3e12c28 100644 --- a/vllm/model_executor/layers/fused_moe/moe_pallas.py +++ b/vllm/model_executor/layers/fused_moe/moe_pallas.py @@ -3,7 +3,6 @@ import torch import torch.nn.functional as F -import torch_xla.experimental.custom_kernel # noqa: F401 def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor: @@ -41,6 +40,7 @@ def fused_moe( gating_output: [*, num_experts] """ assert expert_map is None, "expert_map is not supported for pallas MoE." + import torch_xla.experimental.custom_kernel # noqa: F401 orig_shape = hidden_states.shape hidden_size = hidden_states.shape[-1] num_tokens = hidden_states.shape[:-1].numel() diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 2b8e4427591c..34b8d8e4ed62 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -207,16 +207,21 @@ def _get_weights_iterator( ) if current_platform.is_tpu(): - # In PyTorch XLA, we should call `xm.mark_step` frequently so that - # not too many ops are accumulated in the XLA program. - import torch_xla.core.xla_model as xm + from vllm.platforms.tpu import USE_TPU_COMMONS - def _xla_weights_iterator(iterator: Generator): - for weights in iterator: - yield weights - xm.mark_step() + if not USE_TPU_COMMONS: + # In PyTorch XLA, we should call `xm.mark_step` + # requently so that not too many ops are accumulated + # in the XLA program. import torch_xla.core.xla_model + # as xm + import torch_xla.core.xla_model as xm - weights_iterator = _xla_weights_iterator(weights_iterator) + def _xla_weights_iterator(iterator: Generator): + for weights in iterator: + yield weights + xm.mark_step() + + weights_iterator = _xla_weights_iterator(weights_iterator) if self.counter_before_loading_weights == 0.0: self.counter_before_loading_weights = time.perf_counter() diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index ba06abd07f08..dc2be5c25090 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -24,6 +24,8 @@ logger = init_logger(__name__) +USE_TPU_COMMONS = False + class TpuPlatform(Platform): _enum = PlatformEnum.TPU @@ -201,6 +203,7 @@ def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool: try: from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform TpuPlatform = TpuCommonsPlatform # type: ignore + USE_TPU_COMMONS = True except ImportError: logger.info("tpu_commons not found, using vLLM's TpuPlatform") pass diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 9b122136afb7..3eb4a0e7a574 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -5,12 +5,6 @@ from typing import Optional import torch -import torch_xla.core.xla_builder as xb -import torch_xla.experimental.custom_kernel # noqa: F401 -# Required to register custom ops. -from torch.library import impl -from torch_xla._internal.jax_workarounds import requires_jax -from torch_xla.experimental.custom_kernel import XLA_LIB from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) @@ -37,6 +31,57 @@ "uint8": torch.uint8, } +try: + import tpu_commons # noqa: F401 +except ImportError: + # Lazy import torch_xla + import torch_xla.core.xla_builder as xb + import torch_xla.experimental.custom_kernel # noqa: F401 + from torch.library import impl + from torch_xla._internal.jax_workarounds import requires_jax + from torch_xla.experimental.custom_kernel import XLA_LIB + + @requires_jax + def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, num_slices_per_block: int): + from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update + new_kv_cache = xb.call_jax( + kv_cache_update, + (kv, slot_mapping, kv_cache, num_kv_update_slices), { + "page_size": page_size, + "num_slices_per_block": num_slices_per_block + }) + return new_kv_cache + + + XLA_LIB.define( + "kv_cache_update_op(Tensor kv, Tensor slot_mapping," \ + "Tensor kv_cache, Tensor num_kv_update_slices, int page_size," \ + "int num_slices_per_block)" \ + "-> Tensor", ) + + @impl(XLA_LIB, "kv_cache_update_op", "XLA") + def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, + num_slices_per_block: int) -> torch.Tensor: + new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, + num_kv_update_slices, page_size, + num_slices_per_block) + return new_kv_cache + + @impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") + def kv_cache_update_op_non_xla(kv: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + num_kv_update_slices: torch.Tensor, + page_size: int, + num_slices_per_block: int) -> torch.Tensor: + return kv_cache + class PallasAttentionBackend(AttentionBackend): @@ -313,46 +358,6 @@ def write_to_kv_cache( kv_cache.copy_(new_kv_cache) -@requires_jax -def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, page_size: int, - num_slices_per_block: int): - from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update - new_kv_cache = xb.call_jax( - kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), { - "page_size": page_size, - "num_slices_per_block": num_slices_per_block - }) - return new_kv_cache - - -XLA_LIB.define( - "kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \ - "Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \ - "-> Tensor", ) - - -@impl(XLA_LIB, "kv_cache_update_op", "XLA") -def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, page_size: int, - num_slices_per_block: int) -> torch.Tensor: - new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache, - num_kv_update_slices, page_size, - num_slices_per_block) - return new_kv_cache - - -@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd") -def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor, - kv_cache: torch.Tensor, - num_kv_update_slices: torch.Tensor, - page_size: int, - num_slices_per_block: int) -> torch.Tensor: - return kv_cache - - # We can move this function to a common utils file if it's also useful for other # hardware. def dtype_bits(dtype: torch.dtype): diff --git a/vllm/v1/worker/tpu_worker.py b/vllm/v1/worker/tpu_worker.py index 72e0e4230a01..9adf8a14213f 100644 --- a/vllm/v1/worker/tpu_worker.py +++ b/vllm/v1/worker/tpu_worker.py @@ -1,15 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A TPU worker class.""" + import os from typing import Any, Optional import torch import torch.distributed import torch.nn as nn -import torch_xla.core.xla_model as xm -import torch_xla.debug.profiler as xp -import torch_xla.runtime as xr import vllm.envs as envs from vllm.config import VllmConfig @@ -21,19 +19,27 @@ from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.platforms import current_platform +from vllm.platforms.tpu import USE_TPU_COMMONS from vllm.tasks import SupportedTask from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv -from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig, KVCacheSpec) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.utils import report_usage_stats -from vllm.v1.worker.tpu_model_runner import TPUModelRunner from vllm.v1.worker.utils import bind_kv_cache logger = init_logger(__name__) +if not USE_TPU_COMMONS: + logger.info("tpu_commons not found, using vLLM's TPUWorker.") + import torch_xla.core.xla_model as xm + import torch_xla.debug.profiler as xp + import torch_xla.runtime as xr + + from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT + from vllm.v1.worker.tpu_model_runner import TPUModelRunner + class TPUWorker: @@ -325,9 +331,7 @@ def _init_tpu_worker_distributed_environment( ensure_kv_transfer_initialized(vllm_config) -try: +if USE_TPU_COMMONS: from tpu_commons.worker import TPUWorker as TPUCommonsWorker + TPUWorker = TPUCommonsWorker # type: ignore -except ImportError: - logger.info("tpu_commons not found, using vLLM's TPUWorker.") - pass