-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Separate attention backends #3005
Changes from all commits
a40b2c9
6b6f7c7
f2b888c
7f4422c
1d9dc99
534d0f8
404022a
194df2f
a6910ea
346b1b7
05579fa
da115dd
19ecd4d
5b8e8c7
ef8ace1
6490fb4
3baebac
6a81692
963a2c7
38baed7
f97fc52
89069b8
9ba068a
677ad69
f5c7b07
4a80dd0
1319fc9
281c5d5
2f32381
c68fe7e
f65f65d
45d02a1
8e5ca7e
8333223
ed1ab56
73aedbd
0214afd
12ea60d
b460c21
4ffa89f
6ba0e70
974db99
f72560c
0b8ac9e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -184,3 +184,6 @@ _build/ | |
|
||
# Benchmark dataset | ||
*.json | ||
|
||
# Third-party Python packages. | ||
vllm/thirdparty_files/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from vllm.model_executor.layers.attention.attention import Attention | ||
|
||
__all__ = [ | ||
"Attention", | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
"""Attention layer.""" | ||
from typing import List, Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm.model_executor.input_metadata import InputMetadata | ||
from vllm.utils import is_hip | ||
|
||
|
||
class Attention(nn.Module): | ||
"""Attention layer. | ||
|
||
This class takes query, key, and value tensors as input. The input tensors | ||
can either contain prompt tokens or generation tokens. | ||
The class does the following: | ||
|
||
1. Store the input key and value tensors in the KV cache. | ||
2. Perform (multi-head/multi-query/grouped-query) attention. | ||
3. Return the output tensor. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_heads: int, | ||
head_size: int, | ||
scale: float, | ||
num_kv_heads: Optional[int] = None, | ||
alibi_slopes: Optional[List[float]] = None, | ||
sliding_window: Optional[int] = None, | ||
) -> None: | ||
super().__init__() | ||
if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and | ||
torch.get_default_dtype() in (torch.float16, torch.bfloat16)): | ||
# Ampere or later NVIDIA GPUs. | ||
# NOTE(woosuk): FlashAttention does not support FP32. | ||
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend | ||
self.backend = FlashAttentionBackend(num_heads, head_size, scale, | ||
num_kv_heads, alibi_slopes, | ||
sliding_window) | ||
else: | ||
# Turing and Volta NVIDIA GPUs or AMD GPUs. | ||
# Or FP32 on any GPU. | ||
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend | ||
self.backend = XFormersBackend(num_heads, head_size, scale, | ||
num_kv_heads, alibi_slopes, | ||
sliding_window) | ||
|
||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
key_cache: Optional[torch.Tensor], | ||
value_cache: Optional[torch.Tensor], | ||
input_metadata: InputMetadata, | ||
) -> torch.Tensor: | ||
return self.backend.forward(query, key, value, key_cache, value_cache, | ||
input_metadata) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
"""Attention layer with Flash and PagedAttention.""" | ||
from typing import List, Optional | ||
|
||
# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/. | ||
from flash_attn import flash_attn_func | ||
import torch | ||
|
||
from vllm.model_executor.input_metadata import InputMetadata | ||
from vllm.model_executor.layers.attention.ops.paged_attn import ( | ||
PagedAttentionImpl) | ||
|
||
|
||
class FlashAttentionBackend: | ||
|
||
def __init__( | ||
self, | ||
num_heads: int, | ||
head_size: int, | ||
scale: float, | ||
num_kv_heads: Optional[int] = None, | ||
alibi_slopes: Optional[List[float]] = None, | ||
sliding_window: Optional[int] = None, | ||
) -> None: | ||
self.num_heads = num_heads | ||
self.head_size = head_size | ||
self.scale = float(scale) | ||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads | ||
self.sliding_window = sliding_window | ||
if alibi_slopes is not None: | ||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) | ||
self.alibi_slopes = alibi_slopes | ||
|
||
assert self.num_heads % self.num_kv_heads == 0 | ||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads | ||
suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes() | ||
if head_size not in suppored_head_sizes: | ||
raise ValueError( | ||
f"Head size {head_size} is not supported by PagedAttention. " | ||
f"Supported head sizes are: {suppored_head_sizes}.") | ||
|
||
self.sliding_window = ((self.sliding_window, self.sliding_window) if | ||
self.sliding_window is not None else (-1, -1)) | ||
|
||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
key_cache: Optional[torch.Tensor], | ||
value_cache: Optional[torch.Tensor], | ||
input_metadata: InputMetadata, | ||
) -> torch.Tensor: | ||
"""Forward pass with FlashAttention and PagedAttention. | ||
|
||
Args: | ||
query: shape = [batch_size, seq_len, num_heads * head_size] | ||
key: shape = [batch_size, seq_len, num_kv_heads * head_size] | ||
value: shape = [batch_size, seq_len, num_kv_heads * head_size] | ||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x, | ||
block_size, x] | ||
value_cache: shape = [num_blocks, num_kv_heads, head_size, | ||
block_size] | ||
input_metadata: metadata for the inputs. | ||
Returns: | ||
shape = [batch_size, seq_len, num_heads * head_size] | ||
""" | ||
batch_size, seq_len, hidden_size = query.shape | ||
# Reshape the query, key, and value tensors. | ||
query = query.view(-1, self.num_heads, self.head_size) | ||
key = key.view(-1, self.num_kv_heads, self.head_size) | ||
value = value.view(-1, self.num_kv_heads, self.head_size) | ||
|
||
# Reshape the keys and values and store them in the cache. | ||
# If key_cache and value_cache are not provided, the new key and value | ||
# vectors will not be cached. This happens during the initial memory | ||
# profiling run. | ||
if key_cache is not None and value_cache is not None: | ||
PagedAttentionImpl.reshape_and_cache(key, value, key_cache, | ||
value_cache, input_metadata) | ||
|
||
if input_metadata.is_prompt: | ||
# Prompt run. | ||
if (key_cache is None or value_cache is None | ||
or input_metadata.block_tables.numel() == 0): | ||
# normal attention | ||
query = query.unflatten(0, (batch_size, seq_len)) | ||
key = key.unflatten(0, (batch_size, seq_len)) | ||
value = value.unflatten(0, (batch_size, seq_len)) | ||
output = flash_attn_func( | ||
query, | ||
key, | ||
value, | ||
softmax_scale=self.scale, | ||
causal=True, | ||
window_size=self.sliding_window, | ||
alibi_slopes=self.alibi_slopes, | ||
) | ||
else: | ||
# prefix-enabled attention | ||
output = PagedAttentionImpl.forward_prefix( | ||
query, | ||
key, | ||
value, | ||
key_cache, | ||
value_cache, | ||
input_metadata, | ||
self.num_heads, | ||
self.num_kv_heads, | ||
self.alibi_slopes, | ||
) | ||
else: | ||
# Decoding run. | ||
output = PagedAttentionImpl.forward_decode( | ||
query, | ||
key_cache, | ||
value_cache, | ||
input_metadata, | ||
self.num_kv_heads, | ||
self.scale, | ||
self.alibi_slopes, | ||
) | ||
|
||
# Reshape the output tensor. | ||
return output.view(batch_size, seq_len, hidden_size) | ||
Comment on lines
+81
to
+124
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would still suggest separating this out into private methods ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@Yard1 What do you think about this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree we should not make them part of public API, but they can be done as private APIs for the backends that do have that distinction. Basically we should try to modularize the forward method if possible as it makes it easier to read and test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it. First, I believe the current implementation is easy to read; That being said, I do agree that modularizing the backends will make it easy to test them. However, since this PR has already been delayed quite a bit, let's merge the PR and do modularization in the next PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am just curious that why not use
flash_attn_with_kvcache
? The kernel is faster than paged_attention_kernel. More benchmark details can be found in #2744