-
-
Notifications
You must be signed in to change notification settings - Fork 11.2k
[Hardware][Intel GPU] Add v1 Intel GPU support with Flash attention backend. #19560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,18 +1,21 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
|
|
||
| import os | ||
| from typing import TYPE_CHECKING, Optional | ||
|
|
||
| import torch | ||
|
|
||
| import vllm.envs as envs | ||
| from vllm.logger import init_logger | ||
| from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS | ||
|
|
||
| from .interface import DeviceCapability, Platform, PlatformEnum, _Backend | ||
|
|
||
| if TYPE_CHECKING: | ||
| from vllm.config import VllmConfig | ||
| from vllm.config import ModelConfig, VllmConfig | ||
| else: | ||
| ModelConfig = None | ||
| VllmConfig = None | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
@@ -35,8 +38,13 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, | |
| use_mla: bool) -> str: | ||
| if selected_backend != _Backend.IPEX: | ||
| logger.info("Cannot use %s backend on XPU.", selected_backend) | ||
| logger.info("Using IPEX attention backend.") | ||
| return "vllm.attention.backends.ipex_attn.IpexAttnBackend" | ||
| use_v1 = envs.VLLM_USE_V1 | ||
| if use_v1: | ||
| logger.info("Using Flash Attention backend on V1 engine.") | ||
| return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" | ||
| else: | ||
| logger.info("Using IPEX attention backend.") | ||
| return "vllm.attention.backends.ipex_attn.IpexAttnBackend" | ||
|
|
||
| @classmethod | ||
| def get_device_capability( | ||
|
|
@@ -67,25 +75,27 @@ def inference_mode(cls): | |
| @classmethod | ||
| def check_and_update_config(cls, vllm_config: VllmConfig) -> None: | ||
| cache_config = vllm_config.cache_config | ||
| # in V1(or with ipex chunked prefill) block_size is 64 | ||
| if cache_config and cache_config.block_size is None: | ||
| cache_config.block_size = 16 | ||
|
|
||
| # check and update model config | ||
| model_config = vllm_config.model_config | ||
| if model_config.dtype == torch.bfloat16: | ||
| bf16_supported = cls.device_support_bf16() | ||
| if not bf16_supported: | ||
| if envs.VLLM_USE_V1: | ||
| cache_config.block_size = 64 | ||
| else: | ||
| cache_config.block_size = 16 | ||
|
|
||
| # Instances created using VllmConfig() typically have model_config as | ||
| # None by default. The modification involves adding a check to prevent | ||
| # potential null exceptions check and update model config. | ||
| if vllm_config.model_config is not None: | ||
| model_config = vllm_config.model_config | ||
| if model_config.dtype == torch.bfloat16: | ||
| bf16_supported = cls.device_support_bf16() | ||
| if not bf16_supported: | ||
| model_config.dtype = torch.float16 | ||
| if not model_config.enforce_eager: | ||
| logger.warning( | ||
| "bfloat16 is only supported on Intel Data Center GPU, " | ||
| "Intel Arc GPU is not supported yet. Your device is %s," | ||
| " which is not supported. will fallback to float16", | ||
| cls.get_device_name()) | ||
| model_config.dtype = torch.float16 | ||
| if not model_config.enforce_eager: | ||
| logger.warning( | ||
| "CUDA graph is not supported on XPU, fallback to the eager " | ||
| "mode.") | ||
| model_config.enforce_eager = True | ||
| "CUDA graph is not supported on XPU, fallback to the eager " | ||
| "mode.") | ||
| model_config.enforce_eager = True | ||
|
|
||
| if vllm_config.speculative_config is not None: | ||
| raise NotImplementedError( | ||
|
|
@@ -96,21 +106,27 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: | |
|
|
||
| # check and update parallel config | ||
| parallel_config = vllm_config.parallel_config | ||
| if parallel_config.worker_cls == "auto": | ||
| if envs.VLLM_USE_V1: | ||
| parallel_config.worker_cls =\ | ||
| "vllm.v1.worker.xpu_worker.XPUWorker" | ||
| else: | ||
| parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker" | ||
|
|
||
| if parallel_config.distributed_executor_backend is None: | ||
| parallel_config.distributed_executor_backend = "ray" | ||
| if parallel_config.world_size > 1: | ||
| parallel_config.distributed_executor_backend = "ray" | ||
| else: | ||
| parallel_config.distributed_executor_backend = "uni" | ||
| elif parallel_config.distributed_executor_backend == "mp": | ||
| # FIXME(kunshang): | ||
| # spawn needs calling `if __name__ == '__main__':`` | ||
| # fork is not supported for xpu start new process. | ||
| logger.error( | ||
| "Both start methods (spawn and fork) have issue " | ||
| "on XPU if you use mp backend, setting it to ray instead.") | ||
| parallel_config.distributed_executor_backend = "ray" | ||
|
|
||
| elif parallel_config.distributed_executor_backend != "ray": | ||
| if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn": | ||
| os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" | ||
| logger.warning( | ||
| "Please use spawn as start method if you want to use mp.") | ||
| elif parallel_config.distributed_executor_backend != "ray" and \ | ||
| parallel_config.distributed_executor_backend != "uni": | ||
| logger.warning( | ||
| "%s is not supported on XPU, fallback to ray distributed" | ||
| " executor backend.", | ||
|
|
@@ -142,15 +158,35 @@ def get_current_memory_usage(cls, | |
| @classmethod | ||
| def device_support_bf16(cls) -> bool: | ||
| device_name = cls.get_device_name().lower() | ||
| if device_name.count("arc") > 0: | ||
| if cls.is_client_gpu_a770(): | ||
| logger.warning("Intel Arc A770 have bfloat16 accuracy known issue," | ||
| " fallback to float16") | ||
| return False | ||
|
||
| elif device_name.count("data center gpu") > 0: | ||
| return True | ||
| else: | ||
| logger.warning("Unknown device name %s, always use float16", | ||
| device_name) | ||
| return False | ||
| logger.info( | ||
| "Device name %s supports bfloat16. Please file an issue " | ||
| "if you encounter any accuracy problems with bfloat16.", | ||
| device_name) | ||
| return True | ||
|
|
||
| @classmethod | ||
| def is_data_center_gpu(cls) -> bool: | ||
| device_name = cls.get_device_name().lower() | ||
| return device_name.count("data center gpu") > 0 | ||
|
|
||
| @classmethod | ||
| def is_client_gpu_a770(cls) -> bool: | ||
| device_name = cls.get_device_name().lower() | ||
| return device_name.count("a770") > 0 | ||
|
|
||
| @classmethod | ||
| def get_device_communicator_cls(cls) -> str: | ||
| return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa | ||
|
|
||
| @classmethod | ||
| def supports_v1(cls, model_config: ModelConfig) -> bool: | ||
| return True | ||
|
|
||
| @classmethod | ||
| def device_count(cls) -> int: | ||
| return torch.xpu.device_count() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,10 +14,12 @@ | |
| from vllm.attention.layer import Attention | ||
| from vllm.attention.ops.merge_attn_states import merge_attn_states | ||
| from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8, | ||
| get_flash_attn_version) | ||
| flash_attn_varlen_func, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has to be under some guard rather than unconditional, as it doesn't exist on ROCm
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| get_flash_attn_version, | ||
| get_scheduler_metadata, | ||
| reshape_and_cache_flash) | ||
| from vllm.config import VllmConfig, get_layers_from_vllm_config | ||
| from vllm.logger import init_logger | ||
| from vllm.platforms import current_platform | ||
| from vllm.utils import cdiv | ||
| from vllm.v1.attention.backends.utils import ( | ||
| AttentionMetadataBuilder, CommonAttentionMetadata, get_kv_cache_layout, | ||
|
|
@@ -28,10 +30,6 @@ | |
| if TYPE_CHECKING: | ||
| from vllm.v1.worker.gpu_model_runner import GPUModelRunner | ||
|
|
||
| if current_platform.is_cuda(): | ||
| from vllm.vllm_flash_attn import (flash_attn_varlen_func, | ||
| get_scheduler_metadata) | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
|
|
@@ -443,7 +441,7 @@ def forward( | |
| # and value[:num_actual_tokens] because the reshape_and_cache_flash | ||
| # op uses the slot_mapping's shape to determine the number of | ||
| # actual tokens. | ||
| torch.ops._C_cache_ops.reshape_and_cache_flash( | ||
| reshape_and_cache_flash( | ||
| key, | ||
| value, | ||
| key_cache, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
TODOcomment indicates that FP8 KV cache support is pending. While this is acceptable for now, it's good practice to create a tracking issue (e.g., on GitHub) for thisTODOand reference it in the comment. This helps ensure the task isn't forgotten and provides visibility into planned enhancements.