forked from js8544/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Separate attention backends (vllm-project#3005)
- Loading branch information
1 parent
8c862d8
commit 63e03d2
Showing
35 changed files
with
566 additions
and
277 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -184,3 +184,6 @@ _build/ | |
|
||
# Benchmark dataset | ||
*.json | ||
|
||
# Third-party Python packages. | ||
vllm/thirdparty_files/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from vllm.model_executor.layers.attention.attention import Attention | ||
|
||
__all__ = [ | ||
"Attention", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
"""Attention layer.""" | ||
from typing import List, Optional | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from vllm.model_executor.input_metadata import InputMetadata | ||
from vllm.utils import is_hip | ||
|
||
|
||
class Attention(nn.Module): | ||
"""Attention layer. | ||
This class takes query, key, and value tensors as input. The input tensors | ||
can either contain prompt tokens or generation tokens. | ||
The class does the following: | ||
1. Store the input key and value tensors in the KV cache. | ||
2. Perform (multi-head/multi-query/grouped-query) attention. | ||
3. Return the output tensor. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
num_heads: int, | ||
head_size: int, | ||
scale: float, | ||
num_kv_heads: Optional[int] = None, | ||
alibi_slopes: Optional[List[float]] = None, | ||
sliding_window: Optional[int] = None, | ||
) -> None: | ||
super().__init__() | ||
if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and | ||
torch.get_default_dtype() in (torch.float16, torch.bfloat16)): | ||
# Ampere or later NVIDIA GPUs. | ||
# NOTE(woosuk): FlashAttention does not support FP32. | ||
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend | ||
self.backend = FlashAttentionBackend(num_heads, head_size, scale, | ||
num_kv_heads, alibi_slopes, | ||
sliding_window) | ||
else: | ||
# Turing and Volta NVIDIA GPUs or AMD GPUs. | ||
# Or FP32 on any GPU. | ||
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend | ||
self.backend = XFormersBackend(num_heads, head_size, scale, | ||
num_kv_heads, alibi_slopes, | ||
sliding_window) | ||
|
||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
key_cache: Optional[torch.Tensor], | ||
value_cache: Optional[torch.Tensor], | ||
input_metadata: InputMetadata, | ||
) -> torch.Tensor: | ||
return self.backend.forward(query, key, value, key_cache, value_cache, | ||
input_metadata) |
File renamed without changes.
124 changes: 124 additions & 0 deletions
124
vllm/model_executor/layers/attention/backends/flash_attn.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
"""Attention layer with Flash and PagedAttention.""" | ||
from typing import List, Optional | ||
|
||
# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/. | ||
from flash_attn import flash_attn_func | ||
import torch | ||
|
||
from vllm.model_executor.input_metadata import InputMetadata | ||
from vllm.model_executor.layers.attention.ops.paged_attn import ( | ||
PagedAttentionImpl) | ||
|
||
|
||
class FlashAttentionBackend: | ||
|
||
def __init__( | ||
self, | ||
num_heads: int, | ||
head_size: int, | ||
scale: float, | ||
num_kv_heads: Optional[int] = None, | ||
alibi_slopes: Optional[List[float]] = None, | ||
sliding_window: Optional[int] = None, | ||
) -> None: | ||
self.num_heads = num_heads | ||
self.head_size = head_size | ||
self.scale = float(scale) | ||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads | ||
self.sliding_window = sliding_window | ||
if alibi_slopes is not None: | ||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) | ||
self.alibi_slopes = alibi_slopes | ||
|
||
assert self.num_heads % self.num_kv_heads == 0 | ||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads | ||
suppored_head_sizes = PagedAttentionImpl.get_supported_head_sizes() | ||
if head_size not in suppored_head_sizes: | ||
raise ValueError( | ||
f"Head size {head_size} is not supported by PagedAttention. " | ||
f"Supported head sizes are: {suppored_head_sizes}.") | ||
|
||
self.sliding_window = ((self.sliding_window, self.sliding_window) if | ||
self.sliding_window is not None else (-1, -1)) | ||
|
||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
key_cache: Optional[torch.Tensor], | ||
value_cache: Optional[torch.Tensor], | ||
input_metadata: InputMetadata, | ||
) -> torch.Tensor: | ||
"""Forward pass with FlashAttention and PagedAttention. | ||
Args: | ||
query: shape = [batch_size, seq_len, num_heads * head_size] | ||
key: shape = [batch_size, seq_len, num_kv_heads * head_size] | ||
value: shape = [batch_size, seq_len, num_kv_heads * head_size] | ||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x, | ||
block_size, x] | ||
value_cache: shape = [num_blocks, num_kv_heads, head_size, | ||
block_size] | ||
input_metadata: metadata for the inputs. | ||
Returns: | ||
shape = [batch_size, seq_len, num_heads * head_size] | ||
""" | ||
batch_size, seq_len, hidden_size = query.shape | ||
# Reshape the query, key, and value tensors. | ||
query = query.view(-1, self.num_heads, self.head_size) | ||
key = key.view(-1, self.num_kv_heads, self.head_size) | ||
value = value.view(-1, self.num_kv_heads, self.head_size) | ||
|
||
# Reshape the keys and values and store them in the cache. | ||
# If key_cache and value_cache are not provided, the new key and value | ||
# vectors will not be cached. This happens during the initial memory | ||
# profiling run. | ||
if key_cache is not None and value_cache is not None: | ||
PagedAttentionImpl.reshape_and_cache(key, value, key_cache, | ||
value_cache, input_metadata) | ||
|
||
if input_metadata.is_prompt: | ||
# Prompt run. | ||
if (key_cache is None or value_cache is None | ||
or input_metadata.block_tables.numel() == 0): | ||
# normal attention | ||
query = query.unflatten(0, (batch_size, seq_len)) | ||
key = key.unflatten(0, (batch_size, seq_len)) | ||
value = value.unflatten(0, (batch_size, seq_len)) | ||
output = flash_attn_func( | ||
query, | ||
key, | ||
value, | ||
softmax_scale=self.scale, | ||
causal=True, | ||
window_size=self.sliding_window, | ||
alibi_slopes=self.alibi_slopes, | ||
) | ||
else: | ||
# prefix-enabled attention | ||
output = PagedAttentionImpl.forward_prefix( | ||
query, | ||
key, | ||
value, | ||
key_cache, | ||
value_cache, | ||
input_metadata, | ||
self.num_heads, | ||
self.num_kv_heads, | ||
self.alibi_slopes, | ||
) | ||
else: | ||
# Decoding run. | ||
output = PagedAttentionImpl.forward_decode( | ||
query, | ||
key_cache, | ||
value_cache, | ||
input_metadata, | ||
self.num_kv_heads, | ||
self.scale, | ||
self.alibi_slopes, | ||
) | ||
|
||
# Reshape the output tensor. | ||
return output.view(batch_size, seq_len, hidden_size) |
Oops, something went wrong.