diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index f54010c4231f..827649bfcf54 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -28,4 +28,5 @@ docker run \ sh -c ' VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2 + VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager ' diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index 681102b9d18b..466ba9833363 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -35,6 +35,7 @@ RUN --mount=type=bind,source=.git,target=.git \ if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi ENV VLLM_TARGET_DEVICE=xpu +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=.git,target=.git \ diff --git a/requirements/xpu.txt b/requirements/xpu.txt index 3cb6a4a8adda..0d95dc57152d 100644 --- a/requirements/xpu.txt +++ b/requirements/xpu.txt @@ -9,6 +9,7 @@ setuptools>=77.0.3,<80.0.0 wheel jinja2>=3.1.6 datasets # for benchmark scripts +numba == 0.60.0 # v0.61 doesn't support Python 3.9. Required for N-gram speculative decoding torch==2.7.0+xpu torchaudio diff --git a/vllm/_ipex_ops.py b/vllm/_ipex_ops.py index ae63e06030dd..2af138873624 100644 --- a/vllm/_ipex_ops.py +++ b/vllm/_ipex_ops.py @@ -228,6 +228,54 @@ def reshape_and_cache( ipex.llm.modules.PagedAttention.reshape_and_cache( key, value, key_cache, value_cache, slot_mapping) + @staticmethod + def reshape_and_cache_flash( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale_float: float, + v_scale_flaot: float, + ) -> None: + assert kv_cache_dtype == "auto" + # TODO: support FP8 kv cache. + ipex.llm.modules.PagedAttention.reshape_and_cache_flash( + key, value, key_cache, value_cache, slot_mapping) + + @staticmethod + def flash_attn_varlen_func( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_q: int, + max_seqlen_kv: int, + scale: float, + is_casual: bool, + block_table: torch.Tensor, + alibi_slopes: Optional[torch.Tensor], + ): + return ipex.llm.modules.PagedAttention.flash_attn_varlen_func( + output, + query.contiguous(), + key_cache, + value_cache, + cu_seqlens_q, + cu_seqlens_kv, + max_seqlen_q, + max_seqlen_kv, + scale, + is_casual, + block_table, + alibi_slopes, + k_scale=1.0, + v_scale=1.0, + ) + @staticmethod def copy_blocks(key_caches: list[torch.Tensor], value_caches: list[torch.Tensor], diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index 69cde06fd72e..da187f4ac585 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -11,6 +11,8 @@ def get_flash_attn_version(requires_alibi: bool = False) -> Optional[int]: # import here to avoid circular dependencies from vllm.platforms import current_platform + if current_platform.is_xpu(): + return 2 try: from vllm.vllm_flash_attn.flash_attn_interface import ( fa_version_unsupported_reason, is_fa_version_supported) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 38d567acfd8a..4abdd5fa6540 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1407,6 +1407,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "FLASHMLA", "FLASHINFER", "FLASHINFER_VLLM_V1", + "IPEX_V1", "ROCM_AITER_MLA", "TORCH_SDPA_VLLM_V1", "FLEX_ATTENTION", @@ -1440,10 +1441,11 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: _raise_or_fallback(feature_name=name, recommend_to_remove=False) return False - # Non-[CUDA, TPU, x86 CPU] may be supported on V1, + # Non-[CUDA, TPU, x86 CPU, XPU] may be supported on V1, # but off by default for now. v0_hardware = not any( (current_platform.is_cuda_alike(), current_platform.is_tpu(), + current_platform.is_xpu(), (current_platform.is_cpu() and current_platform.get_cpu_architecture() == CpuArchEnum.X86))) if v0_hardware and _warn_or_fallback( # noqa: SIM103 diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index bdc2b1f4c27c..7c568d099849 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -73,7 +73,7 @@ class RayDistributedExecutor(DistributedExecutorBase): def _init_executor(self) -> None: self.forward_dag: Optional[ray.dag.CompiledDAG] = None - if envs.VLLM_USE_V1: + if envs.VLLM_USE_V1 and not current_platform.is_xpu(): # V1 uses SPMD worker and compiled DAG os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1" os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f91f222b25e5..d60ebaca9475 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -57,6 +57,7 @@ class _Backend(enum.Enum): PALLAS = enum.auto() PALLAS_VLLM_V1 = enum.auto() IPEX = enum.auto() + IPEX_V1 = enum.auto() BLOCK_SPARSE_FLASH_ATTN = enum.auto() DUAL_CHUNK_FLASH_ATTN = enum.auto() NO_ATTENTION = enum.auto() diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 73f6f3d41767..be24e8cf6783 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -5,14 +5,16 @@ import torch +import vllm.envs as envs from vllm.logger import init_logger from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS from .interface import DeviceCapability, Platform, PlatformEnum, _Backend if TYPE_CHECKING: - from vllm.config import VllmConfig + from vllm.config import ModelConfig, VllmConfig else: + ModelConfig = None VllmConfig = None logger = init_logger(__name__) @@ -35,8 +37,13 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, use_mla: bool) -> str: if selected_backend != _Backend.IPEX: logger.info("Cannot use %s backend on XPU.", selected_backend) - logger.info("Using IPEX attention backend.") - return "vllm.attention.backends.ipex_attn.IpexAttnBackend" + use_v1 = envs.VLLM_USE_V1 + if use_v1: + logger.info("Using IPEX_V1 attention backend.") + return "vllm.v1.attention.backends.ipex_attn.IPEXAttentionBackend" + else: + logger.info("Using IPEX attention backend.") + return "vllm.attention.backends.ipex_attn.IpexAttnBackend" @classmethod def get_device_capability( @@ -67,25 +74,28 @@ def inference_mode(cls): @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config + # in V1(or with ipex chunked prefill) block_size is 64 + if cache_config and \ + cache_config.block_size is None and \ + envs.VLLM_USE_V1: + cache_config.block_size = 64 if cache_config and cache_config.block_size is None: cache_config.block_size = 16 - # check and update model config - model_config = vllm_config.model_config - if model_config.dtype == torch.bfloat16: - bf16_supported = cls.device_support_bf16() - if not bf16_supported: + # Instances created using VllmConfig() typically have model_config as + # None by default. The modification involves adding a check to prevent + # potential null exceptions check and update model config. + if vllm_config.model_config is not None: + model_config = vllm_config.model_config + if model_config.dtype == torch.bfloat16: + bf16_supported = cls.device_support_bf16() + if not bf16_supported: + model_config.dtype = torch.float16 + if not model_config.enforce_eager: logger.warning( - "bfloat16 is only supported on Intel Data Center GPU, " - "Intel Arc GPU is not supported yet. Your device is %s," - " which is not supported. will fallback to float16", - cls.get_device_name()) - model_config.dtype = torch.float16 - if not model_config.enforce_eager: - logger.warning( - "CUDA graph is not supported on XPU, fallback to the eager " - "mode.") - model_config.enforce_eager = True + "CUDA graph is not supported on XPU, fallback to the eager " + "mode.") + model_config.enforce_eager = True if vllm_config.speculative_config is not None: raise NotImplementedError( @@ -96,21 +106,26 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # check and update parallel config parallel_config = vllm_config.parallel_config - if parallel_config.worker_cls == "auto": + if envs.VLLM_USE_V1: + parallel_config.worker_cls =\ + "vllm.v1.worker.xpu_worker.XPUWorker" + else: parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker" if parallel_config.distributed_executor_backend is None: - parallel_config.distributed_executor_backend = "ray" + if parallel_config.world_size > 1: + parallel_config.distributed_executor_backend = "ray" + else: + parallel_config.distributed_executor_backend = "uni" elif parallel_config.distributed_executor_backend == "mp": # FIXME(kunshang): # spawn needs calling `if __name__ == '__main__':`` # fork is not supported for xpu start new process. - logger.error( - "Both start methods (spawn and fork) have issue " - "on XPU if you use mp backend, setting it to ray instead.") - parallel_config.distributed_executor_backend = "ray" - - elif parallel_config.distributed_executor_backend != "ray": + if envs.VLLM_WORKER_MULTIPROC_METHOD != "spawn": + logger.warning( + "Please use spawn as start method if you want to use mp.") + elif parallel_config.distributed_executor_backend != "ray" and \ + parallel_config.distributed_executor_backend != "uni": logger.warning( "%s is not supported on XPU, fallback to ray distributed" " executor backend.", @@ -142,15 +157,35 @@ def get_current_memory_usage(cls, @classmethod def device_support_bf16(cls) -> bool: device_name = cls.get_device_name().lower() - if device_name.count("arc") > 0: + if cls.is_client_gpu_a770(): + logger.warning("Intel Arc A770 have bfloat16 accuracy known issue," + " fallback to float16") return False - elif device_name.count("data center gpu") > 0: - return True else: - logger.warning("Unknown device name %s, always use float16", - device_name) - return False + logger.info( + "Device name %s supports bfloat16. Please file an issue " + "if you encounter any accuracy problems with bfloat16.", + device_name) + return True + + @classmethod + def is_data_center_gpu(cls) -> bool: + device_name = cls.get_device_name().lower() + return device_name.count("data center gpu") > 0 + + @classmethod + def is_client_gpu_a770(cls) -> bool: + device_name = cls.get_device_name().lower() + return device_name.count("a770") > 0 @classmethod def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator" # noqa + + @classmethod + def supports_v1(cls, model_config: ModelConfig) -> bool: + return True + + @classmethod + def device_count(cls) -> int: + return torch.xpu.device_count() diff --git a/vllm/v1/attention/backends/ipex_attn.py b/vllm/v1/attention/backends/ipex_attn.py new file mode 100644 index 000000000000..e689f13ace3d --- /dev/null +++ b/vllm/v1/attention/backends/ipex_attn.py @@ -0,0 +1,249 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional + +import torch + +from vllm._ipex_ops import ipex_ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, + AttentionLayer, + AttentionMetadata, AttentionType) +from vllm.attention.utils.fa_utils import get_flash_attn_version +from vllm.v1.attention.backends.flash_attn import ( + FlashAttentionMetadata, FlashAttentionMetadataBuilder) +from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.worker.block_table import BlockTable + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch + from vllm.v1.worker.xpu_model_runner import XPUModelRunner + + +@dataclass +class IPEXAttentionMetadata(FlashAttentionMetadata): + seq_start_loc: torch.Tensor = torch.tensor([0], dtype=torch.int32) + + def __init__(self, + flash_attn_metadata: FlashAttentionMetadata, + seq_start_loc: torch.Tensor = None, + **kwargs) -> None: + super().__init__(**flash_attn_metadata.__dict__, **kwargs) + if seq_start_loc is not None: + self.seq_start_loc = seq_start_loc + else: + self.seq_start_loc = torch.tensor([0], + dtype=torch.int32, + device=self.block_table.device) + + +class IPEXAttentionMetadataBuilder(FlashAttentionMetadataBuilder): + + def __init__(self, runner: "XPUModelRunner", kv_cache_spec: AttentionSpec, + block_table: BlockTable): + super().__init__(runner, kv_cache_spec, block_table) + # avoid “GPUModelerunner”, has no attribute + self.runner: XPUModelRunner = runner + self.aot_schedule = (get_flash_attn_version() == 3) + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + return False + + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): + attn_metadata = super().build(num_reqs, num_actual_tokens, + max_query_len, common_prefix_len, + common_attn_metadata) + seq_start_loc_cpu = self.runner.seq_start_loc_cpu[:num_reqs + 1] + seq_start_loc = seq_start_loc_cpu.to(self.runner.device, + non_blocking=True) + return IPEXAttentionMetadata(attn_metadata, + seq_start_loc=seq_start_loc) + + +class IPEXAttentionBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_supported_head_sizes() -> list[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "IPEX_V1" + + @staticmethod + def get_impl_cls() -> type["IPEXAttentionImpl"]: + return IPEXAttentionImpl + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return IPEXAttentionMetadata + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> tuple[int, ...]: + if block_size % 16 != 0: + raise ValueError("Block size must be a multiple of 16.") + return (2, num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_builder_cls() -> type["IPEXAttentionMetadataBuilder"]: + return IPEXAttentionMetadataBuilder + + +class IPEXAttentionImpl(AttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + blocksparse_params: Optional[dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, + kv_sharing_target_layer_name: Optional[str] = None, + use_irope: bool = False, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + if sliding_window is None: + self.sliding_window = (-1, -1) + else: + self.sliding_window = (sliding_window - 1, 0) + self.kv_cache_dtype = kv_cache_dtype + self.use_irope = use_irope + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + self.kv_sharing_target_layer_name = kv_sharing_target_layer_name + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + support_head_sizes = IPEXAttentionBackend.get_supported_head_sizes() + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "IpexAttnBackendImpl") + self.use_irope = use_irope + self.vllm_flash_attn_version = get_flash_attn_version() + + def forward( + self, + layer: AttentionLayer, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: IPEXAttentionMetadata, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass with IPEXAttention. + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + kv_cache = [2, num_blocks, block_size, num_kv_heads, head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + assert output is not None, "Output tensor must be provided." + if attn_metadata is None: + # Profiling run. + return output + + # NOTE(woosuk): IPEXAttention does not support FP8 KV cache. + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, ( + "key/v_scale is not supported in IPEXAttention.") + + num_actual_tokens = attn_metadata.num_actual_tokens + num_heads = self.num_heads + head_size = self.head_size + num_kv_heads = self.num_kv_heads + query = query.view(-1, num_heads, head_size) + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + + # Reshape the input keys and values and store them in the cache. + key_cache, value_cache = kv_cache.unbind(0) + + ipex_ops.reshape_and_cache_flash( + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, + attn_metadata.slot_mapping, + self.kv_cache_dtype, + layer._k_scale_float, + layer._v_scale_float, + ) + use_local_attn = \ + (self.use_irope and attn_metadata.local_attn_metadata is not None) + + if use_local_attn: + assert attn_metadata.local_attn_metadata is not None + local_metadata = attn_metadata.local_attn_metadata + cu_seqlens_q = local_metadata.local_query_start_loc + sequesd_k = local_metadata.local_seqused_k + max_seqlen_q = local_metadata.local_max_query_len + max_seqlen_k = local_metadata.local_max_seq_len + block_table = local_metadata.local_block_table + else: + cu_seqlens_q = attn_metadata.query_start_loc + sequesd_k = attn_metadata.seq_lens + max_seqlen_q = attn_metadata.max_query_len + max_seqlen_k = attn_metadata.max_seq_len + block_table = attn_metadata.block_table + + if not hasattr(attn_metadata, "seq_start_loc"): + cumsum = torch.cumsum(sequesd_k, dim=0) + cu_seqlens_k = torch.cat([ + torch.tensor([0], device=sequesd_k.device, dtype=torch.int32), + cumsum + ]).to(torch.int32) + else: + cu_seqlens_k = attn_metadata.seq_start_loc + + ipex_ops.flash_attn_varlen_func( + output[:num_actual_tokens], + query[:num_actual_tokens], + key_cache, + value_cache, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.scale, + is_casual=True, + block_table=block_table, + alibi_slopes=self.alibi_slopes, + ) + return output diff --git a/vllm/v1/worker/xpu_model_runner.py b/vllm/v1/worker/xpu_model_runner.py new file mode 100644 index 000000000000..252dcda61451 --- /dev/null +++ b/vllm/v1/worker/xpu_model_runner.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import TYPE_CHECKING, Any, Optional + +import numpy as np +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.spec_decode.metadata import SpecDecodeMetadata +from vllm.v1.worker.gpu_model_runner import GPUModelRunner + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) + + +class XPUModelRunner(GPUModelRunner): + """A model runner for XPU devices.""" + + def __init__( + self, + vllm_config: VllmConfig, + device: torch.device, + ): + super().__init__(vllm_config, device) + # FIXME: To be verified. + self.cascade_attn_enabled = False + # this is XPU specific + self.seq_start_loc_cpu = torch.zeros(self.max_num_reqs + 1, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + self.seq_start_loc_np = self.seq_start_loc_cpu.numpy() + + def _init_device_properties(self) -> None: + pass + + def _sync_device(self) -> None: + torch.xpu.synchronize() + + def _prepare_inputs( + self, scheduler_output: "SchedulerOutput" + ) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata]]: + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + # Get the number of scheduled tokens for each request. + req_ids = self.input_batch.req_ids + tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids] + num_scheduled_tokens = np.array(tokens, dtype=np.int32) + # ======== XPU start ========= + seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + self.seq_start_loc_np[0] = 0 + np.cumsum(seq_lens, out=self.seq_start_loc_np[1:num_reqs + 1]) + # ======== XPU end ========= + return super()._prepare_inputs(scheduler_output) diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py new file mode 100644 index 000000000000..d9ea03986566 --- /dev/null +++ b/vllm/v1/worker/xpu_worker.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +import os + +import torch +import torch.distributed + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform +from vllm.v1.worker.gpu_worker import (Worker, + init_worker_distributed_environment) +from vllm.v1.worker.xpu_model_runner import XPUModelRunner + +logger = init_logger(__name__) + + +class XPUWorker(Worker): + """A XPU worker class.""" + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ): + super().__init__(vllm_config, local_rank, rank, + distributed_init_method, is_driver_worker) + device_config = self.device_config + assert device_config.device_type == "xpu" + assert current_platform.is_xpu() + + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.XPU, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None + + # we provide this function due to `torch.xpu.mem_get_info()` doesn't + # return correct free_gpu_memory on intel client GPU. We need to + # calculate/estiamte it. + def xpu_get_mem_info(self): + if current_platform.is_data_center_gpu(): + return torch.xpu.mem_get_info() + else: + _, total_gpu_memory = torch.xpu.mem_get_info() + # FIXME: memory_allocated() doesn't count non-torch allocations, + # and we don't have any API to get it. so we mark it as 128MB. + used_memory = torch.xpu.memory_allocated() + non_torch_allocations = 128 * 1024 * 1024 + free_gpu_memory = total_gpu_memory - (used_memory + + non_torch_allocations) + return free_gpu_memory, total_gpu_memory + + @torch.inference_mode() + def determine_available_memory(self) -> int: + """Profiles the peak memory usage of the model to determine how many + KV blocks may be allocated without OOMs. + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.xpu.empty_cache() + torch.xpu.reset_peak_memory_stats() + + free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info() + current_allocated_bytes = torch.xpu.memory_allocated() + msg = ("Before memory profiling run, " + f"total GPU memory: {total_gpu_memory / 1024**2:.2f} MB, " + f"model load takes {current_allocated_bytes / 1024**2:.2f} MB, " + f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.") + logger.info(msg) + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + free_gpu_memory, _ = self.xpu_get_mem_info() + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + assert self.init_gpu_memory > free_gpu_memory, ( + "Error in memory profiling. " + f"Initial free memory {self.init_gpu_memory}, current free memory" + f" {free_gpu_memory}. This happens when the GPU memory was " + "not properly cleaned up before initializing the vLLM instance.") + + # Get the peak memory allocation recorded by torch + peak_memory = torch.xpu.memory_stats()["allocated_bytes.all.peak"] + + torch.xpu.empty_cache() + torch_allocated_bytes = torch.xpu.memory_stats( + )["allocated_bytes.all.current"] + total_allocated_bytes = self.xpu_get_mem_info( + )[1] - self.xpu_get_mem_info()[0] + + non_torch_allocations = total_allocated_bytes - torch_allocated_bytes + if non_torch_allocations > 0: + peak_memory += non_torch_allocations + available_kv_cache_memory = ( + total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) + + msg = ("After memory profiling run, " + f"peak memory usage is {peak_memory / 1024**2:.2f} MB," + f"torch mem is {torch_allocated_bytes / 1024**2:.2f} MB, " + f"non-torch mem is {non_torch_allocations / 1024**2:.2f} MB, " + f"free gpu memory is {free_gpu_memory / 1024**2:.2f} MB.") + logger.info(msg) + + return int(available_kv_cache_memory) + + def init_device(self): + if self.device_config.device.type == "xpu" and current_platform.is_xpu( + ): + self.device = torch.device(f"xpu:{self.local_rank}") + torch.xpu.set_device(self.device) + torch.xpu.empty_cache() + self.init_gpu_memory = torch.xpu.get_device_properties( + self.local_rank).total_memory + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + + ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "drmfd") + ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi") + ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", + str(self.parallel_config.world_size)) + os.environ["CCL_ZE_IPC_EXCHANGE"] = ENV_CCL_ZE_IPC_EXCHANGE + os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT + os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE + os.environ["LOCAL_RANK"] = str(self.local_rank) + dist_backend = "ccl" + + init_worker_distributed_environment(self.vllm_config, self.rank, + self.distributed_init_method, + self.local_rank, dist_backend) + + # global all_reduce needed for overall oneccl warm up + torch.distributed.all_reduce(torch.zeros(1).xpu()) + + # Set random seed. + set_random_seed(self.model_config.seed) + + # Construct the model runner + self.model_runner = XPUModelRunner( # type: ignore + self.vllm_config, self.device)