From 9e6144a6a8ee724e47626b8ffb022e85c9fbc84b Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Fri, 8 Mar 2024 10:52:20 -0800 Subject: [PATCH] [FIX] Make `flash_attn` optional (#3269) --- .gitignore | 3 - setup.py | 58 +------------------ vllm/__init__.py | 30 +++------- .../layers/attention/attention.py | 37 ++++++++++-- .../layers/attention/backends/flash_attn.py | 1 - 5 files changed, 39 insertions(+), 90 deletions(-) diff --git a/.gitignore b/.gitignore index 6ff62f1c75806..b1513ef0ddb0c 100644 --- a/.gitignore +++ b/.gitignore @@ -185,6 +185,3 @@ hip_compat.h # Benchmark dataset *.json - -# Third-party Python packages. -vllm/thirdparty_files/ diff --git a/setup.py b/setup.py index 286b90fdf6fbc..879ffaa3ae732 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,6 @@ import os import re import subprocess -import sys import warnings from pathlib import Path from typing import List, Set @@ -15,8 +14,6 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME ROOT_DIR = os.path.dirname(__file__) -# This is a temporary directory to store third-party packages. -THIRDPARTY_SUBDIR = "vllm/thirdparty_files" # If you are developing the C++ backend of vLLM, consider building vLLM with # `python setup.py develop` since it will give you incremental builds. @@ -341,61 +338,9 @@ def get_torch_arch_list() -> Set[str]: "nvcc": NVCC_FLAGS_PUNICA, }, )) -elif _is_hip(): - amd_archs = os.getenv("GPU_ARCHS") - if amd_archs is None: - amd_archs = get_amdgpu_offload_arch() - for arch in amd_archs.split(";"): - if arch not in ROCM_SUPPORTED_ARCHS: - raise RuntimeError( - f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}" - f"amdgpu_arch_found: {arch}") - NVCC_FLAGS += [f"--offload-arch={arch}"] - NVCC_FLAGS += ["-DENABLE_FP8_E4M3"] - elif _is_neuron(): neuronxcc_version = get_neuronxcc_version() - # Download the FlashAttention package. - # Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/setup.py#L518-L530 - flash_attn_version = "2.5.6" - install_dir = os.path.join(ROOT_DIR, THIRDPARTY_SUBDIR) - subprocess.check_call( - [ - sys.executable, - "-m", - "pip", - "install", - "-q", - f"--target={install_dir}", - "einops", # Dependency of flash-attn. - f"flash-attn=={flash_attn_version}", - "--no-dependencies", # Required to avoid re-installing torch. - ], - env=dict(os.environ, CC="gcc"), - ) - - # Copy the FlashAttention package into the vLLM package after build. - class build_ext(BuildExtension): - - def run(self): - super().run() - target_dir = os.path.join(self.build_lib, THIRDPARTY_SUBDIR) - if not os.path.exists(target_dir): - os.makedirs(target_dir) - self.copy_tree(install_dir, target_dir) - - class BinaryDistribution(setuptools.Distribution): - - def has_ext_modules(self): - return True - -else: - build_ext = BuildExtension - BinaryDistribution = setuptools.Distribution - if _is_neuron(): - neuronxcc_version = get_neuronxcc_version() - vllm_extension_sources = [ "csrc/cache_kernels.cu", "csrc/attention/attention_kernels.cu", @@ -544,7 +489,6 @@ def get_requirements() -> List[str]: python_requires=">=3.8", install_requires=get_requirements(), ext_modules=ext_modules, - cmdclass=cmdclass, - distclass=distclass, + cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {}, package_data=package_data, ) diff --git a/vllm/__init__.py b/vllm/__init__.py index 59f1345b58d42..f1e30f5eb6e6e 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -1,28 +1,12 @@ """vLLM: a high-throughput and memory-efficient inference engine for LLMs""" - -# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/ray/__init__.py#L11 -def _configure_system(): - import os - import sys - - # Importing flash-attn. - thirdparty_files = os.path.join(os.path.abspath(os.path.dirname(__file__)), - "thirdparty_files") - sys.path.insert(0, thirdparty_files) - - -_configure_system() -# Delete configuration function. -del _configure_system - -from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402 -from vllm.engine.async_llm_engine import AsyncLLMEngine # noqa: E402 -from vllm.engine.llm_engine import LLMEngine # noqa: E402 -from vllm.engine.ray_utils import initialize_cluster # noqa: E402 -from vllm.entrypoints.llm import LLM # noqa: E402 -from vllm.outputs import CompletionOutput, RequestOutput # noqa: E402 -from vllm.sampling_params import SamplingParams # noqa: E402 +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.llm_engine import LLMEngine +from vllm.engine.ray_utils import initialize_cluster +from vllm.entrypoints.llm import LLM +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.sampling_params import SamplingParams __version__ = "0.3.3" diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 830e82e10f7ad..724dd0511c5aa 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -1,12 +1,16 @@ """Attention layer.""" +from functools import lru_cache from typing import List, Optional import torch import torch.nn as nn +from vllm.logger import init_logger from vllm.model_executor.input_metadata import InputMetadata from vllm.utils import is_hip +logger = init_logger(__name__) + class Attention(nn.Module): """Attention layer. @@ -30,17 +34,12 @@ def __init__( sliding_window: Optional[int] = None, ) -> None: super().__init__() - if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and - torch.get_default_dtype() in (torch.float16, torch.bfloat16)): - # Ampere or later NVIDIA GPUs. - # NOTE(woosuk): FlashAttention does not support FP32. + if _use_flash_attn(): from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend self.backend = FlashAttentionBackend(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window) else: - # Turing and Volta NVIDIA GPUs or AMD GPUs. - # Or FP32 on any GPU. from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend self.backend = XFormersBackend(num_heads, head_size, scale, num_kv_heads, alibi_slopes, @@ -57,3 +56,29 @@ def forward( ) -> torch.Tensor: return self.backend.forward(query, key, value, key_cache, value_cache, input_metadata) + + +@lru_cache(maxsize=1) +def _use_flash_attn() -> bool: + try: + import flash_attn # noqa: F401 + except ImportError: + logger.info("flash_attn is not found. Using xformers backend.") + return False + + if is_hip(): + # AMD GPUs. + return False + if torch.cuda.get_device_capability()[0] < 8: + # Volta and Turing NVIDIA GPUs. + logger.info("flash_attn is not supported on Turing or older GPUs. " + "Using xformers backend.") + return False + if torch.get_default_dtype() not in (torch.float16, torch.bfloat16): + logger.info( + "flash_attn only supports torch.float16 or torch.bfloat16. " + "Using xformers backend.") + return False + + logger.info("Using flash_attn backend.") + return True diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index 512f4e49c7eb2..4abe195f274a7 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -1,7 +1,6 @@ """Attention layer with Flash and PagedAttention.""" from typing import List, Optional -# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/. from flash_attn import flash_attn_func import torch