Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/scripts/hardware_ci/run-xpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ docker run \
sh -c '
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
'
1 change: 1 addition & 0 deletions docker/Dockerfile.xpu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi

ENV VLLM_TARGET_DEVICE=xpu
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn

RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=.git,target=.git \
Expand Down
1 change: 1 addition & 0 deletions requirements/xpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ setuptools>=77.0.3,<80.0.0
wheel
jinja2>=3.1.6
datasets # for benchmark scripts
numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding

torch==2.7.0+xpu
torchaudio
Expand Down
48 changes: 48 additions & 0 deletions vllm/_ipex_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,54 @@ def reshape_and_cache(
ipex.llm.modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache, slot_mapping)

@staticmethod
def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale_float: float,
v_scale_flaot: float,
) -> None:
assert kv_cache_dtype == "auto"
# TODO: support FP8 kv cache.
ipex.llm.modules.PagedAttention.reshape_and_cache_flash(
key, value, key_cache, value_cache, slot_mapping)

@staticmethod
def flash_attn_varlen_func(
output: torch.Tensor,
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_kv: torch.Tensor,
max_seqlen_q: int,
max_seqlen_kv: int,
scale: float,
is_casual: bool,
block_table: torch.Tensor,
alibi_slopes: Optional[torch.Tensor],
):
return ipex.llm.modules.PagedAttention.flash_attn_varlen_func(
output,
query.contiguous(),
key_cache,
value_cache,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
scale,
is_casual,
block_table,
alibi_slopes,
k_scale=1.0,
v_scale=1.0,
)

@staticmethod
def copy_blocks(key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
Expand Down
2 changes: 2 additions & 0 deletions vllm/attention/utils/fa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]:
# import here to avoid circular dependencies
from vllm.platforms import current_platform
if current_platform.is_xpu():
return 2
try:
from vllm.vllm_flash_attn.flash_attn_interface import (
fa_version_unsupported_reason, is_fa_version_supported)
Expand Down
4 changes: 3 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,6 +1407,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
"FLASHMLA",
"FLASHINFER",
"FLASHINFER_VLLM_V1",
"IPEX_V1",
"ROCM_AITER_MLA",
"TORCH_SDPA_VLLM_V1",
"FLEX_ATTENTION",
Expand Down Expand Up @@ -1440,10 +1441,11 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
_raise_or_fallback(feature_name=name, recommend_to_remove=False)
return False

# Non-[CUDA, TPU, x86 CPU] may be supported on V1,
# Non-[CUDA, TPU, x86 CPU, XPU] may be supported on V1,
# but off by default for now.
v0_hardware = not any(
(current_platform.is_cuda_alike(), current_platform.is_tpu(),
current_platform.is_xpu(),
(current_platform.is_cpu()
and current_platform.get_cpu_architecture() == CpuArchEnum.X86)))
if v0_hardware and _warn_or_fallback( # noqa: SIM103
Expand Down
2 changes: 1 addition & 1 deletion vllm/executor/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class RayDistributedExecutor(DistributedExecutorBase):

def _init_executor(self) -> None:
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
if envs.VLLM_USE_V1:
if envs.VLLM_USE_V1 and not current_platform.is_xpu():
# V1 uses SPMD worker and compiled DAG
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class _Backend(enum.Enum):
PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto()
IPEX = enum.auto()
IPEX_V1 = enum.auto()
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
DUAL_CHUNK_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto()
Expand Down
99 changes: 67 additions & 32 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@

import torch

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS

from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import ModelConfig, VllmConfig
else:
ModelConfig = None
VllmConfig = None

logger = init_logger(__name__)
Expand All @@ -35,8 +37,13 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
use_mla: bool) -> str:
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
logger.info("Using IPEX attention backend.")
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
use_v1 = envs.VLLM_USE_V1
if use_v1:
logger.info("Using IPEX_V1 attention backend.")
return "vllm.v1.attention.backends.ipex_attn.IPEXAttentionBackend"
else:
logger.info("Using IPEX attention backend.")
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"

@classmethod
def get_device_capability(
Expand Down Expand Up @@ -67,25 +74,28 @@ def inference_mode(cls):
@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
cache_config = vllm_config.cache_config
# in V1(or with ipex chunked prefill) block_size is 64
if cache_config and \
cache_config.block_size is None and \
envs.VLLM_USE_V1:
cache_config.block_size = 64
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16

# check and update model config
model_config = vllm_config.model_config
if model_config.dtype == torch.bfloat16:
bf16_supported = cls.device_support_bf16()
if not bf16_supported:
# Instances created using VllmConfig() typically have model_config as
# None by default. The modification involves adding a check to prevent
# potential null exceptions check and update model config.
if vllm_config.model_config is not None:
model_config = vllm_config.model_config
if model_config.dtype == torch.bfloat16:
bf16_supported = cls.device_support_bf16()
if not bf16_supported:
model_config.dtype = torch.float16
if not model_config.enforce_eager:
logger.warning(
"bfloat16 is only supported on Intel Data Center GPU, "
"Intel Arc GPU is not supported yet. Your device is %s,"
" which is not supported. will fallback to float16",
cls.get_device_name())
model_config.dtype = torch.float16
if not model_config.enforce_eager:
logger.warning(
"CUDA graph is not supported on XPU, fallback to the eager "
"mode.")
model_config.enforce_eager = True
"CUDA graph is not supported on XPU, fallback to the eager "
"mode.")
model_config.enforce_eager = True

if vllm_config.speculative_config is not None:
raise NotImplementedError(
Expand All @@ -96,21 +106,26 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:

# check and update parallel config
parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
if envs.VLLM_USE_V1:
parallel_config.worker_cls =\
"vllm.v1.worker.xpu_worker.XPUWorker"
else:
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"

if parallel_config.distributed_executor_backend is None:
parallel_config.distributed_executor_backend = "ray"
if parallel_config.world_size > 1:
parallel_config.distributed_executor_backend = "ray"
else:
parallel_config.distributed_executor_backend = "uni"
elif parallel_config.distributed_executor_backend == "mp":
# FIXME(kunshang):
# spawn needs calling `if __name__ == '__main__':``
# fork is not supported for xpu start new process.
logger.error(
"Both start methods (spawn and fork) have issue "
"on XPU if you use mp backend, setting it to ray instead.")
parallel_config.distributed_executor_backend = "ray"

elif parallel_config.distributed_executor_backend != "ray":
if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn":
logger.warning(
"Please use spawn as start method if you want to use mp.")
elif parallel_config.distributed_executor_backend != "ray" and \
parallel_config.distributed_executor_backend != "uni":
logger.warning(
"%s is not supported on XPU, fallback to ray distributed"
" executor backend.",
Expand Down Expand Up @@ -142,15 +157,35 @@ def get_current_memory_usage(cls,
@classmethod
def device_support_bf16(cls) -> bool:
device_name = cls.get_device_name().lower()
if device_name.count("arc") > 0:
if cls.is_client_gpu_a770():
logger.warning("Intel Arc A770 have bfloat16 accuracy known issue,"
" fallback to float16")
return False
elif device_name.count("data center gpu") > 0:
return True
else:
logger.warning("Unknown device name %s, always use float16",
device_name)
return False
logger.info(
"Device name %s supports bfloat16. Please file an issue "
"if you encounter any accuracy problems with bfloat16.",
device_name)
return True

@classmethod
def is_data_center_gpu(cls) -> bool:
device_name = cls.get_device_name().lower()
return device_name.count("data center gpu") > 0

@classmethod
def is_client_gpu_a770(cls) -> bool:
device_name = cls.get_device_name().lower()
return device_name.count("a770") > 0

@classmethod
def get_device_communicator_cls(cls) -> str:
return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa

@classmethod
def supports_v1(cls, model_config: ModelConfig) -> bool:
return True

@classmethod
def device_count(cls) -> int:
return torch.xpu.device_count()
Loading