Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions vllm/distributed/device_communicators/tpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS

from .base_device_communicator import DeviceCommunicatorBase

Expand All @@ -18,16 +19,17 @@

logger = init_logger(__name__)

if current_platform.is_tpu():
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import (
create_optimized_replica_groups)

if USE_RAY:
from vllm.executor import ray_utils
if not USE_TPU_COMMONS:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
if current_platform.is_tpu():
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import (
create_optimized_replica_groups)
if USE_RAY:
from vllm.executor import ray_utils


class TpuCommunicator(DeviceCommunicatorBase):
Expand Down Expand Up @@ -94,10 +96,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
return xm.all_gather(input_, dim=dim)


try:
if USE_TPU_COMMONS:
from tpu_commons.distributed.device_communicators import (
TpuCommunicator as TpuCommonsCommunicator)
TpuCommunicator = TpuCommonsCommunicator # type: ignore
except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuCommunicator")
pass
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/fused_moe/moe_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
import torch.nn.functional as F
import torch_xla.experimental.custom_kernel # noqa: F401


def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
Expand Down Expand Up @@ -41,6 +40,7 @@ def fused_moe(
gating_output: [*, num_experts]
"""
assert expert_map is None, "expert_map is not supported for pallas MoE."
import torch_xla.experimental.custom_kernel # noqa: F401
orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
Expand Down
21 changes: 13 additions & 8 deletions vllm/model_executor/model_loader/default_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,16 +207,21 @@ def _get_weights_iterator(
)

if current_platform.is_tpu():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
import torch_xla.core.xla_model as xm
from vllm.platforms.tpu import USE_TPU_COMMONS

def _xla_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
xm.mark_step()
if not USE_TPU_COMMONS:
# In PyTorch XLA, we should call `xm.mark_step`
# requently so that not too many ops are accumulated
# in the XLA program. import torch_xla.core.xla_model
# as xm
import torch_xla.core.xla_model as xm

weights_iterator = _xla_weights_iterator(weights_iterator)
def _xla_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
xm.mark_step()

weights_iterator = _xla_weights_iterator(weights_iterator)

if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()
Expand Down
3 changes: 3 additions & 0 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

logger = init_logger(__name__)

USE_TPU_COMMONS = False


class TpuPlatform(Platform):
_enum = PlatformEnum.TPU
Expand Down Expand Up @@ -201,6 +203,7 @@ def is_kv_cache_dtype_supported(cls, kv_cache_dtype: str) -> bool:
try:
from tpu_commons.platforms import TpuPlatform as TpuCommonsPlatform
TpuPlatform = TpuCommonsPlatform # type: ignore
USE_TPU_COMMONS = True
except ImportError:
logger.info("tpu_commons not found, using vLLM's TpuPlatform")
pass
97 changes: 51 additions & 46 deletions vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@
from typing import Optional

import torch
import torch_xla.core.xla_builder as xb
import torch_xla.experimental.custom_kernel # noqa: F401
# Required to register custom ops.
from torch.library import impl
from torch_xla._internal.jax_workarounds import requires_jax
from torch_xla.experimental.custom_kernel import XLA_LIB

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
Expand All @@ -37,6 +31,57 @@
"uint8": torch.uint8,
}

try:
import tpu_commons # noqa: F401
except ImportError:
# Lazy import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.experimental.custom_kernel # noqa: F401
from torch.library import impl
from torch_xla._internal.jax_workarounds import requires_jax
from torch_xla.experimental.custom_kernel import XLA_LIB

@requires_jax
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int, num_slices_per_block: int):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(
kv_cache_update,
(kv, slot_mapping, kv_cache, num_kv_update_slices), {
"page_size": page_size,
"num_slices_per_block": num_slices_per_block
})
return new_kv_cache


XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping," \
"Tensor kv_cache, Tensor num_kv_update_slices, int page_size," \
"int num_slices_per_block)" \
"-> Tensor", )

@impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
num_kv_update_slices, page_size,
num_slices_per_block)
return new_kv_cache

@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(kv: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
return kv_cache


class PallasAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -313,46 +358,6 @@ def write_to_kv_cache(
kv_cache.copy_(new_kv_cache)


@requires_jax
def kv_cache_update_op_impl(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor, page_size: int,
num_slices_per_block: int):
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
new_kv_cache = xb.call_jax(
kv_cache_update, (kv, slot_mapping, kv_cache, num_kv_update_slices), {
"page_size": page_size,
"num_slices_per_block": num_slices_per_block
})
return new_kv_cache


XLA_LIB.define(
"kv_cache_update_op(Tensor kv, Tensor slot_mapping, Tensor kv_cache," \
"Tensor num_kv_update_slices, int page_size, int num_slices_per_block)" \
"-> Tensor", )


@impl(XLA_LIB, "kv_cache_update_op", "XLA")
def kv_cache_update_op_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor, page_size: int,
num_slices_per_block: int) -> torch.Tensor:
new_kv_cache = kv_cache_update_op_impl(kv, slot_mapping, kv_cache,
num_kv_update_slices, page_size,
num_slices_per_block)
return new_kv_cache


@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
def kv_cache_update_op_non_xla(kv: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache: torch.Tensor,
num_kv_update_slices: torch.Tensor,
page_size: int,
num_slices_per_block: int) -> torch.Tensor:
return kv_cache


# We can move this function to a common utils file if it's also useful for other
# hardware.
def dtype_bits(dtype: torch.dtype):
Expand Down
22 changes: 13 additions & 9 deletions vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A TPU worker class."""

import os
from typing import Any, Optional

import torch
import torch.distributed
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr

import vllm.envs as envs
from vllm.config import VllmConfig
Expand All @@ -21,19 +19,27 @@
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_COMMONS
from vllm.tasks import SupportedTask
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
from vllm.v1.worker.utils import bind_kv_cache

logger = init_logger(__name__)

if not USE_TPU_COMMONS:
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr

from vllm.v1.attention.backends.pallas import TPU_HEAD_SIZE_ALIGNMENT
from vllm.v1.worker.tpu_model_runner import TPUModelRunner


class TPUWorker:

Expand Down Expand Up @@ -325,9 +331,7 @@ def _init_tpu_worker_distributed_environment(
ensure_kv_transfer_initialized(vllm_config)


try:
if USE_TPU_COMMONS:
from tpu_commons.worker import TPUWorker as TPUCommonsWorker

TPUWorker = TPUCommonsWorker # type: ignore
except ImportError:
logger.info("tpu_commons not found, using vLLM's TPUWorker.")
pass