Skip to content

Commit

Permalink
Separate attention backends (vllm-project#3005)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored and dbogunowicz committed Mar 26, 2024
1 parent 8c862d8 commit 63e03d2
Show file tree
Hide file tree
Showing 35 changed files with 566 additions and 277 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,6 @@ _build/

# Benchmark dataset
*.json

# Third-party Python packages.
vllm/thirdparty_files/
48 changes: 45 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import re
import subprocess
import sys
import warnings
from pathlib import Path
from typing import List, Set
Expand All @@ -14,6 +15,8 @@
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME

ROOT_DIR = os.path.dirname(__file__)
# This is a temporary directory to store third-party packages.
THIRDPARTY_SUBDIR = "vllm/thirdparty_files"

# If you are developing the C++ backend of vLLM, consider building vLLM with
# `python setup.py develop` since it will give you incremental builds.
Expand Down Expand Up @@ -324,8 +327,46 @@ def get_torch_arch_list() -> Set[str]:
"nvcc": NVCC_FLAGS_PUNICA,
},
))
elif _is_neuron():
neuronxcc_version = get_neuronxcc_version()

# Download the FlashAttention package.
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/setup.py#L518-L530
flash_attn_version = "2.5.6"
install_dir = os.path.join(ROOT_DIR, THIRDPARTY_SUBDIR)
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"-q",
f"--target={install_dir}",
"einops", # Dependency of flash-attn.
f"flash-attn=={flash_attn_version}",
"--no-dependencies", # Required to avoid re-installing torch.
],
env=dict(os.environ, CC="gcc"),
)

# Copy the FlashAttention package into the vLLM package after build.
class build_ext(BuildExtension):

def run(self):
super().run()
target_dir = os.path.join(self.build_lib, THIRDPARTY_SUBDIR)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
self.copy_tree(install_dir, target_dir)

class BinaryDistribution(setuptools.Distribution):

def has_ext_modules(self):
return True

else:
build_ext = BuildExtension
BinaryDistribution = setuptools.Distribution
if _is_neuron():
neuronxcc_version = get_neuronxcc_version()

vllm_extension_sources = [
"csrc/cache_kernels.cu",
Expand Down Expand Up @@ -468,6 +509,7 @@ def get_requirements() -> List[str]:
python_requires=">=3.8",
install_requires=get_requirements(),
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {},
cmdclass={"build_ext": build_ext} if not _is_neuron() else {},
distclass=BinaryDistribution,
package_data=package_data,
)
2 changes: 1 addition & 1 deletion tests/kernels/test_prefix_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time

import torch
from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
from vllm.model_executor.layers.attention.ops.prefix_prefill import (
context_attention_fwd)
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
Expand Down
30 changes: 23 additions & 7 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster
from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams

# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/ray/__init__.py#L11
def _configure_system():
import os
import sys

# Importing flash-attn.
thirdparty_files = os.path.join(os.path.abspath(os.path.dirname(__file__)),
"thirdparty_files")
sys.path.insert(0, thirdparty_files)


_configure_system()
# Delete configuration function.
del _configure_system

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402
from vllm.engine.async_llm_engine import AsyncLLMEngine # noqa: E402
from vllm.engine.llm_engine import LLMEngine # noqa: E402
from vllm.engine.ray_utils import initialize_cluster # noqa: E402
from vllm.entrypoints.llm import LLM # noqa: E402
from vllm.outputs import CompletionOutput, RequestOutput # noqa: E402
from vllm.sampling_params import SamplingParams # noqa: E402

__version__ = "0.3.3"

Expand Down
5 changes: 5 additions & 0 deletions vllm/model_executor/layers/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from vllm.model_executor.layers.attention.attention import Attention

__all__ = [
"Attention",
]
59 changes: 59 additions & 0 deletions vllm/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Attention layer."""
from typing import List, Optional

import torch
import torch.nn as nn

from vllm.model_executor.input_metadata import InputMetadata
from vllm.utils import is_hip


class Attention(nn.Module):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and
torch.get_default_dtype() in (torch.float16, torch.bfloat16)):
# Ampere or later NVIDIA GPUs.
# NOTE(woosuk): FlashAttention does not support FP32.
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend
self.backend = FlashAttentionBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
sliding_window)
else:
# Turing and Volta NVIDIA GPUs or AMD GPUs.
# Or FP32 on any GPU.
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend
self.backend = XFormersBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
sliding_window)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
) -> torch.Tensor:
return self.backend.forward(query, key, value, key_cache, value_cache,
input_metadata)
124 changes: 124 additions & 0 deletions vllm/model_executor/layers/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Attention layer with Flash and PagedAttention."""
from typing import List, Optional

# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/.
from flash_attn import flash_attn_func
import torch

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention.ops.paged_attn import (
PagedAttentionImpl)


class FlashAttentionBackend:

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes()
if head_size not in suppored_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")

self.sliding_window = ((self.sliding_window, self.sliding_window) if
self.sliding_window is not None else (-1, -1))

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
block_size, x]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for the inputs.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
batch_size, seq_len, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

# Reshape the keys and values and store them in the cache.
# If key_cache and value_cache are not provided, the new key and value
# vectors will not be cached. This happens during the initial memory
# profiling run.
if key_cache is not None and value_cache is not None:
PagedAttentionImpl.reshape_and_cache(key, value, key_cache,
value_cache, input_metadata)

if input_metadata.is_prompt:
# Prompt run.
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
# normal attention
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
output = flash_attn_func(
query,
key,
value,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
# prefix-enabled attention
output = PagedAttentionImpl.forward_prefix(
query,
key,
value,
key_cache,
value_cache,
input_metadata,
self.num_heads,
self.num_kv_heads,
self.alibi_slopes,
)
else:
# Decoding run.
output = PagedAttentionImpl.forward_decode(
query,
key_cache,
value_cache,
input_metadata,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)

# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
Loading

0 comments on commit 63e03d2

Please sign in to comment.