From 80aa7e91fcd547a7a1396f71b9bdce18e5c92245 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Fri, 14 Jun 2024 00:33:14 +0800 Subject: [PATCH] [Hardware][Intel] Optimize CPU backend and add more performance tips (#4971) Co-authored-by: Jianan Gu --- Dockerfile.cpu | 8 +- README.md | 2 +- .../getting_started/cpu-installation.rst | 23 +++- requirements-cpu.txt | 2 +- vllm/attention/backends/torch_sdpa.py | 23 +++- vllm/attention/ops/ipex_attn.py | 120 ++++++++++++++++++ 6 files changed, 165 insertions(+), 13 deletions(-) create mode 100644 vllm/attention/ops/ipex_attn.py diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 403a1cd0391b0..777bb08296ed9 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -3,9 +3,13 @@ FROM ubuntu:22.04 AS cpu-test-1 RUN apt-get update -y \ - && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip \ + && apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 \ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 +RUN echo 'export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD' >> ~/.bashrc + +RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl + RUN pip install --upgrade pip \ && pip install wheel packaging ninja "setuptools>=49.4.0" numpy @@ -21,6 +25,6 @@ RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install WORKDIR /workspace/ -RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks +RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks CMD ["/bin/bash"] diff --git a/README.md b/README.md index 57374d2791d15..8e4480ac2b5c9 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ vLLM is flexible and easy to use with: - Tensor parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support NVIDIA GPUs and AMD GPUs +- Support NVIDIA GPUs, AMD GPUs, and Intel CPUs - (Experimental) Prefix caching support - (Experimental) Multi-lora support diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index 5270253cae9ab..a9544e8a59a3d 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -10,6 +10,7 @@ Table of contents: #. :ref:`Requirements ` #. :ref:`Quick start using Dockerfile ` #. :ref:`Build from source ` +#. :ref:`Intel Extension for PyTorch ` #. :ref:`Performance tips ` .. _cpu_backend_requirements: @@ -18,7 +19,7 @@ Requirements ------------ * OS: Linux -* Compiler: gcc/g++>=12.3.0 (recommended) +* Compiler: gcc/g++>=12.3.0 (optional, recommended) * Instruction set architecture (ISA) requirement: AVX512 is required. .. _cpu_backend_quick_start_dockerfile: @@ -41,7 +42,7 @@ Quick start using Dockerfile Build from source ----------------- -- First, install required compiler. We recommend to use ``gcc/g++ >= 12.3.0`` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run: +- First, install recommended compiler. We recommend to use ``gcc/g++ >= 12.3.0`` as the default compiler to avoid potential problems. For example, on Ubuntu 22.4, you can run: .. code-block:: console @@ -70,6 +71,15 @@ Build from source - If you want to force enable AVX512_BF16 for the cross-compilation, please set environment variable VLLM_CPU_AVX512BF16=1 before the building. +.. _ipex_guidance: + +Intel Extension for PyTorch +--------------------------- + +- `Intel Extension for PyTorch (IPEX) `_ extends PyTorch with up-to-date features optimizations for an extra performance boost on Intel hardware. + +- IPEX after the ``2.3.0`` can be enabled in the CPU backend by default if it is installed. + .. _cpu_backend_performance_tips: Performance tips @@ -77,6 +87,15 @@ Performance tips - vLLM CPU backend uses environment variable ``VLLM_CPU_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_CPU_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. +- We highly recommend to use TCMalloc for high performance memory allocation and better cache locality. For example, on Ubuntu 22.4, you can run: + +.. code-block:: console + + $ sudo apt-get install libtcmalloc-minimal4 # install TCMalloc library + $ find / -name *libtcmalloc* # find the dynamic link library path + $ export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:$LD_PRELOAD # prepend the library to LD_PRELOAD + $ python examples/offline_inference.py # run vLLM + - vLLM CPU backend uses OpenMP for thread-parallel computation. If you want the best performance on CPU, it will be very critical to isolate CPU cores for OpenMP threads with other thread pools (like web-service event-loop), to avoid CPU oversubscription. - If using vLLM CPU backend on a bare-metal machine, it is recommended to disable the hyper-threading. diff --git a/requirements-cpu.txt b/requirements-cpu.txt index b739642d8d344..8b7d86e686217 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -2,5 +2,5 @@ -r requirements-common.txt # Dependencies for x86_64 CPUs -torch == 2.3.0+cpu +torch == 2.3.1+cpu triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 9b50adec5244d..4b08cce99afb0 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -8,8 +8,16 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata) -from vllm.attention.ops.paged_attn import (PagedAttention, - PagedAttentionMetadata) +from vllm.attention.ops.paged_attn import PagedAttentionMetadata +from vllm.utils import is_cpu + +if is_cpu(): + try: + from vllm.attention.ops.ipex_attn import PagedAttention + except ImportError: + from vllm.attention.ops.paged_attn import PagedAttention +else: + from vllm.attention.ops.paged_attn import PagedAttention class TorchSDPABackend(AttentionBackend): @@ -197,13 +205,14 @@ def forward( attn_metadata.attn_bias): end = start + seq_len sub_out = scaled_dot_product_attention( - query[:, start:end, :], - key[:, start:end, :], - value[:, start:end, :], + query[None, :, start:end, :], + key[None, :, start:end, :], + value[None, :, start:end, :], attn_mask=mask, dropout_p=0.0, is_causal=not self.need_mask, - scale=self.scale).movedim(query.dim() - 2, 0) + scale=self.scale).squeeze(0).movedim( + query.dim() - 2, 0) output[start:end, :, :] = sub_out start = end else: @@ -248,7 +257,7 @@ def _make_alibi_bias( num_heads = alibi_slopes.shape[0] bias = bias[None, :].repeat((num_heads, 1, 1)) - bias.mul_(alibi_slopes[:, None, None]) + bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0) inf_mask = torch.empty( (1, seq_len, seq_len), dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py new file mode 100644 index 0000000000000..5a5317b65004e --- /dev/null +++ b/vllm/attention/ops/ipex_attn.py @@ -0,0 +1,120 @@ +from typing import Dict, List, Optional, Tuple + +import intel_extension_for_pytorch.llm.modules as ipex_modules +import torch + +from vllm import _custom_ops as ops + + +class PagedAttention: + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 80, 96, 112, 128, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size * num_kv_heads * head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[torch.Tensor, torch.Tensor]: + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + kv_scale: float, + *args, + ) -> None: + ipex_modules.PagedAttention.reshape_and_cache( + key, value, key_cache, value_cache, + slot_mapping.flatten().int()) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + kv_scale: float, + *args, + ) -> torch.Tensor: + output = torch.empty_like(query) + block_size = value_cache.shape[2] + head_mapping = torch.arange( + 0, + num_kv_heads, + device="cpu", + dtype=torch.int32, + ).view(num_kv_heads, + 1).repeat_interleave(query.size(1) // num_kv_heads).flatten() + ipex_modules.PagedAttention.single_query_cached_kv_attention( + output, query.contiguous(), key_cache, value_cache, head_mapping, + scale, block_tables, context_lens, block_size, max_context_len, + alibi_slopes) + + return output + + @staticmethod + def forward_prefix( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + subquery_start_loc: torch.Tensor, + prompt_lens_tensor: torch.Tensor, + context_lens: torch.Tensor, + max_subquery_len: int, + alibi_slopes: Optional[torch.Tensor], + *args, + ) -> torch.Tensor: + raise NotImplementedError + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: Dict[int, int], + *args, + ) -> None: + raise NotImplementedError + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + *args, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists)