Skip to content

Commit

Permalink
[FIX] Make flash_attn optional (vllm-project#3269)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored and AdrianAbeyta committed Mar 8, 2024
1 parent 90c2cd4 commit 9e6144a
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 90 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,3 @@ hip_compat.h

# Benchmark dataset
*.json

# Third-party Python packages.
vllm/thirdparty_files/
58 changes: 1 addition & 57 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import re
import subprocess
import sys
import warnings
from pathlib import Path
from typing import List, Set
Expand All @@ -15,8 +14,6 @@
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME

ROOT_DIR = os.path.dirname(__file__)
# This is a temporary directory to store third-party packages.
THIRDPARTY_SUBDIR = "vllm/thirdparty_files"

# If you are developing the C++ backend of vLLM, consider building vLLM with
# `python setup.py develop` since it will give you incremental builds.
Expand Down Expand Up @@ -341,61 +338,9 @@ def get_torch_arch_list() -> Set[str]:
"nvcc": NVCC_FLAGS_PUNICA,
},
))
elif _is_hip():
amd_archs = os.getenv("GPU_ARCHS")
if amd_archs is None:
amd_archs = get_amdgpu_offload_arch()
for arch in amd_archs.split(";"):
if arch not in ROCM_SUPPORTED_ARCHS:
raise RuntimeError(
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
f"amdgpu_arch_found: {arch}")
NVCC_FLAGS += [f"--offload-arch={arch}"]
NVCC_FLAGS += ["-DENABLE_FP8_E4M3"]

elif _is_neuron():
neuronxcc_version = get_neuronxcc_version()

# Download the FlashAttention package.
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/setup.py#L518-L530
flash_attn_version = "2.5.6"
install_dir = os.path.join(ROOT_DIR, THIRDPARTY_SUBDIR)
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"-q",
f"--target={install_dir}",
"einops", # Dependency of flash-attn.
f"flash-attn=={flash_attn_version}",
"--no-dependencies", # Required to avoid re-installing torch.
],
env=dict(os.environ, CC="gcc"),
)

# Copy the FlashAttention package into the vLLM package after build.
class build_ext(BuildExtension):

def run(self):
super().run()
target_dir = os.path.join(self.build_lib, THIRDPARTY_SUBDIR)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
self.copy_tree(install_dir, target_dir)

class BinaryDistribution(setuptools.Distribution):

def has_ext_modules(self):
return True

else:
build_ext = BuildExtension
BinaryDistribution = setuptools.Distribution
if _is_neuron():
neuronxcc_version = get_neuronxcc_version()

vllm_extension_sources = [
"csrc/cache_kernels.cu",
"csrc/attention/attention_kernels.cu",
Expand Down Expand Up @@ -544,7 +489,6 @@ def get_requirements() -> List[str]:
python_requires=">=3.8",
install_requires=get_requirements(),
ext_modules=ext_modules,
cmdclass=cmdclass,
distclass=distclass,
cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {},
package_data=package_data,
)
30 changes: 7 additions & 23 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,12 @@
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""


# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/ray/__init__.py#L11
def _configure_system():
import os
import sys

# Importing flash-attn.
thirdparty_files = os.path.join(os.path.abspath(os.path.dirname(__file__)),
"thirdparty_files")
sys.path.insert(0, thirdparty_files)


_configure_system()
# Delete configuration function.
del _configure_system

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402
from vllm.engine.async_llm_engine import AsyncLLMEngine # noqa: E402
from vllm.engine.llm_engine import LLMEngine # noqa: E402
from vllm.engine.ray_utils import initialize_cluster # noqa: E402
from vllm.entrypoints.llm import LLM # noqa: E402
from vllm.outputs import CompletionOutput, RequestOutput # noqa: E402
from vllm.sampling_params import SamplingParams # noqa: E402
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster
from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams

__version__ = "0.3.3"

Expand Down
37 changes: 31 additions & 6 deletions vllm/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
"""Attention layer."""
from functools import lru_cache
from typing import List, Optional

import torch
import torch.nn as nn

from vllm.logger import init_logger
from vllm.model_executor.input_metadata import InputMetadata
from vllm.utils import is_hip

logger = init_logger(__name__)


class Attention(nn.Module):
"""Attention layer.
Expand All @@ -30,17 +34,12 @@ def __init__(
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and
torch.get_default_dtype() in (torch.float16, torch.bfloat16)):
# Ampere or later NVIDIA GPUs.
# NOTE(woosuk): FlashAttention does not support FP32.
if _use_flash_attn():
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend
self.backend = FlashAttentionBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
sliding_window)
else:
# Turing and Volta NVIDIA GPUs or AMD GPUs.
# Or FP32 on any GPU.
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend
self.backend = XFormersBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
Expand All @@ -57,3 +56,29 @@ def forward(
) -> torch.Tensor:
return self.backend.forward(query, key, value, key_cache, value_cache,
input_metadata)


@lru_cache(maxsize=1)
def _use_flash_attn() -> bool:
try:
import flash_attn # noqa: F401
except ImportError:
logger.info("flash_attn is not found. Using xformers backend.")
return False

if is_hip():
# AMD GPUs.
return False
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("flash_attn is not supported on Turing or older GPUs. "
"Using xformers backend.")
return False
if torch.get_default_dtype() not in (torch.float16, torch.bfloat16):
logger.info(
"flash_attn only supports torch.float16 or torch.bfloat16. "
"Using xformers backend.")
return False

logger.info("Using flash_attn backend.")
return True
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Attention layer with Flash and PagedAttention."""
from typing import List, Optional

# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/.
from flash_attn import flash_attn_func
import torch

Expand Down

0 comments on commit 9e6144a

Please sign in to comment.