Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/scripts/hardware_ci/run-xpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ docker run \
sh -c '
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
'
1 change: 1 addition & 0 deletions docker/Dockerfile.xpu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi

ENV VLLM_TARGET_DEVICE=xpu
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn

RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=.git,target=.git \
Expand Down
1 change: 1 addition & 0 deletions requirements/xpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ setuptools>=77.0.3,<80.0.0
wheel
jinja2>=3.1.6
datasets # for benchmark scripts
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding

torch==2.7.0+xpu
torchaudio
Expand Down
105 changes: 105 additions & 0 deletions vllm/_ipex_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,111 @@ def reshape_and_cache(
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slot_mapping)

@staticmethod
def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: Optional[torch.Tensor] = None,
v_scale: Optional[torch.Tensor] = None,
k_scale_float: float = 1.0,
v_scale_float: float = 1.0,
) -> None:
assert kv_cache_dtype == "auto"
# TODO: support FP8 kv cache.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The TODO comment indicates that FP8 KV cache support is pending. While this is acceptable for now, it's good practice to create a tracking issue (e.g., on GitHub) for this TODO and reference it in the comment. This helps ensure the task isn't forgotten and provides visibility into planned enhancements.

ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
key, value, key_cache, value_cache, slot_mapping)

@staticmethod
def flash_attn_varlen_func(
out: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
seqused_k: torch.Tensor, # we don't support this in ipex kernel
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
causal: bool,
block_table: torch.Tensor,
alibi_slopes: Optional[torch.Tensor],
window_size: Optional[list[int]] = None,
softcap: Optional[float] = 0.0,
cu_seqlens_k: Optional[torch.Tensor] = None,
# The following parameters are not used in ipex kernel currently,
# we keep API compatible to CUDA's.
scheduler_metadata=None,
fa_version: int = 2,
q_descale=None,
k_descale=None,
v_descale=None,
):
if cu_seqlens_k is None:
# cu_seqlens_k is not used in ipex kernel.
cu_seqlens_k = torch.cumsum(seqused_k, dim=0)
cu_seqlens_k = torch.cat([
torch.tensor([0], device=seqused_k.device, dtype=torch.int32),
cu_seqlens_k
]).to(torch.int32)

real_window_size: tuple[int, int]
if window_size is None:
real_window_size = (-1, -1)
else:
assert len(window_size) == 2
real_window_size = (window_size[0], window_size[1])
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
out,
q.contiguous(),
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
causal,
block_table,
alibi_slopes,
softcap=softcap,
window_size_left=real_window_size[0],
window_size_right=real_window_size[1],
k_scale=1.0,
v_scale=1.0,
)

@staticmethod
def get_scheduler_metadata(
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads_q,
num_heads_kv,
headdim,
cache_seqlens: torch.Tensor,
qkv_dtype=torch.bfloat16,
headdim_v=None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_size: Optional[int] = None,
max_seqlen_k_new=0,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
has_softcap=False,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
) -> None:
logger.warning_once(
"get_scheduler_metadata is not implemented for ipex_ops, "
"returning None.")
return None

@staticmethod
def copy_blocks(key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
Expand Down
15 changes: 14 additions & 1 deletion vllm/attention/utils/fa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,27 @@

from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)

if current_platform.is_cuda():
from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
get_scheduler_metadata)
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
flash_attn_varlen_func = ops.flash_attn_varlen_func
get_scheduler_metadata = ops.get_scheduler_metadata


def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
if current_platform.is_xpu():
return 2
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)
Expand Down Expand Up @@ -50,6 +64,5 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:


def flash_attn_supports_fp8() -> bool:
from vllm.platforms import current_platform
return get_flash_attn_version() == 3 and \
current_platform.get_device_capability().major == 9
2 changes: 1 addition & 1 deletion vllm/executor/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class RayDistributedExecutor(DistributedExecutorBase):

def _init_executor(self) -> None:
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
if envs.VLLM_USE_V1:
if envs.VLLM_USE_V1 and not current_platform.is_xpu():
# V1 uses SPMD worker and compiled DAG
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
Expand Down
104 changes: 70 additions & 34 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import os
from typing import TYPE_CHECKING, Optional

import torch

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS

from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
else:
ModelConfig = None
VllmConfig = None

logger = init_logger(__name__)
Expand All @@ -35,8 +38,13 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
use_mla: bool) -> str:
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
logger.info("Using IPEX attention backend.")
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
use_v1 = envs.VLLM_USE_V1
if use_v1:
logger.info("Using Flash Attention backend on V1 engine.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
else:
logger.info("Using IPEX attention backend.")
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"

@classmethod
def get_device_capability(
Expand Down Expand Up @@ -67,25 +75,27 @@ def inference_mode(cls):
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cache_config = vllm_config.cache_config
# in V1(or with ipex chunked prefill) block_size is 64
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16

# check and update model config
model_config = vllm_config.model_config
if model_config.dtype == torch.bfloat16:
bf16_supported = cls.device_support_bf16()
if not bf16_supported:
if envs.VLLM_USE_V1:
cache_config.block_size = 64
else:
cache_config.block_size = 16

# Instances created using VllmConfig() typically have model_config as
# None by default. The modification involves adding a check to prevent
# potential null exceptions check and update model config.
if vllm_config.model_config is not None:
model_config = vllm_config.model_config
if model_config.dtype == torch.bfloat16:
bf16_supported = cls.device_support_bf16()
if not bf16_supported:
model_config.dtype = torch.float16
if not model_config.enforce_eager:
logger.warning(
"bfloat16 is only supported on Intel Data Center GPU, "
"Intel Arc GPU is not supported yet. Your device is %s,"
" which is not supported. will fallback to float16",
cls.get_device_name())
model_config.dtype = torch.float16
if not model_config.enforce_eager:
logger.warning(
"CUDA graph is not supported on XPU, fallback to the eager "
"mode.")
model_config.enforce_eager = True
"CUDA graph is not supported on XPU, fallback to the eager "
"mode.")
model_config.enforce_eager = True

if vllm_config.speculative_config is not None:
raise NotImplementedError(
Expand All @@ -96,21 +106,27 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:

# check and update parallel config
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
if envs.VLLM_USE_V1:
parallel_config.worker_cls =\
"vllm.v1.worker.xpu_worker.XPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"

if parallel_config.distributed_executor_backend is None:
parallel_config.distributed_executor_backend = "ray"
if parallel_config.world_size > 1:
parallel_config.distributed_executor_backend = "ray"
else:
parallel_config.distributed_executor_backend = "uni"
elif parallel_config.distributed_executor_backend == "mp":
# FIXME(kunshang):
# spawn needs calling `if __name__ == '__main__':``
# fork is not supported for xpu start new process.
logger.error(
"Both start methods (spawn and fork) have issue "
"on XPU if you use mp backend, setting it to ray instead.")
parallel_config.distributed_executor_backend = "ray"

elif parallel_config.distributed_executor_backend != "ray":
if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn":
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
logger.warning(
"Please use spawn as start method if you want to use mp.")
elif parallel_config.distributed_executor_backend != "ray" and \
parallel_config.distributed_executor_backend != "uni":
logger.warning(
"%s is not supported on XPU, fallback to ray distributed"
" executor backend.",
Expand Down Expand Up @@ -142,15 +158,35 @@ def get_current_memory_usage(cls,
@classmethod
def device_support_bf16(cls) -> bool:
device_name = cls.get_device_name().lower()
if device_name.count("arc") > 0:
if cls.is_client_gpu_a770():
logger.warning("Intel Arc A770 have bfloat16 accuracy known issue,"
" fallback to float16")
return False
Comment on lines +161 to 164
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for device_support_bf16 has changed. Previously, it explicitly checked for "data center gpu" to return True and warned for unknown devices. Now, it defaults to True for devices not matching is_client_gpu_a770().

This implies a broader assumption of bf16 support. Ensure this is intended and well-tested across various Intel GPUs. The log message now encourages users to file an issue, which is good, but the previous more conservative approach might have prevented issues on unsupported/unknown hardware.

elif device_name.count("data center gpu") > 0:
return True
else:
logger.warning("Unknown device name %s, always use float16",
device_name)
return False
logger.info(
"Device name %s supports bfloat16. Please file an issue "
"if you encounter any accuracy problems with bfloat16.",
device_name)
return True

@classmethod
def is_data_center_gpu(cls) -> bool:
device_name = cls.get_device_name().lower()
return device_name.count("data center gpu") > 0

@classmethod
def is_client_gpu_a770(cls) -> bool:
device_name = cls.get_device_name().lower()
return device_name.count("a770") > 0

@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa

@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
return True

@classmethod
def device_count(cls) -> int:
return torch.xpu.device_count()
12 changes: 5 additions & 7 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
from vllm.attention.layer import Attention
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
flash_attn_varlen_func,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has to be under some guard rather than unconditional, as it doesn't exist on ROCm
@jikunshang

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @gshtras , thanks for pointing the issue, added a quick fix in #20143

get_flash_attn_version,
get_scheduler_metadata,
reshape_and_cache_flash)
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (
AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout,
Expand All @@ -28,10 +30,6 @@
if TYPE_CHECKING:
from vllm.v1.worker.gpu_model_runner import GPUModelRunner

if current_platform.is_cuda():
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
get_scheduler_metadata)

logger = init_logger(__name__)


Expand Down Expand Up @@ -443,7 +441,7 @@ def forward(
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
reshape_and_cache_flash(
key,
value,
key_cache,
Expand Down
Loading