Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Jul 11, 2024
1 parent 52b7fcb commit cf3b724
Show file tree
Hide file tree
Showing 12 changed files with 1,106 additions and 474 deletions.
6 changes: 5 additions & 1 deletion tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from vllm.attention import AttentionMetadata
from vllm.attention import AttentionMetadata, AttentionMetadataBuilder
from vllm.attention.backends.abstract import AttentionBackend
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.pooling_metadata import PoolingMetadata
Expand All @@ -26,6 +26,10 @@ def get_impl_cls():
def get_metadata_cls() -> Type["AttentionMetadata"]:
return AttentionMetadata

@staticmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise AttentionMetadataBuilder

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down
4 changes: 3 additions & 1 deletion vllm/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
AttentionMetadata,
AttentionMetadataBuilder)
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend

__all__ = [
"Attention",
"AttentionBackend",
"AttentionMetadata",
"AttentionMetadataBuilder",
"Attention",
"get_attn_backend",
]
27 changes: 27 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@ def get_metadata_cls() -> Type["AttentionMetadata"]:
def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
return cls.get_metadata_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError

@classmethod
def make_metadata_builder(cls, *args,
**kwargs) -> "AttentionMetadataBuilder":
return cls.get_builder_cls()(*args, **kwargs)

@staticmethod
@abstractmethod
def get_kv_cache_shape(
Expand Down Expand Up @@ -110,6 +120,23 @@ def asdict_zerocopy(self,
T = TypeVar("T", bound=AttentionMetadata)


class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""

@abstractmethod
def __init__(self, input_builder) -> None:
raise NotImplementedError

@abstractmethod
def add_seq_group(self, *args, **kwargs) -> None:
raise NotImplementedError

@abstractmethod
def build(self, runner, seq_lens, query_lens, use_captured_graph: bool,
cuda_graph_pad_size: int, batch_size: int) -> T:
raise NotImplementedError


class AttentionImpl(ABC, Generic[T]):

@abstractmethod
Expand Down
54 changes: 52 additions & 2 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import torch

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (
metadata_builder_add_seq_group_common, metadata_builder_build_common)
from vllm.attention.ops.blocksparse_attention.interface import (
LocalStridedBlockSparseAttn, get_head_sliding_step)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.sequence import SequenceGroupMetadata

if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder)


@dataclass
Expand Down Expand Up @@ -93,6 +102,10 @@ def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]:
def get_metadata_cls() -> Type["AttentionMetadata"]:
return BlocksparseFlashAttentionMetadata

@staticmethod
def get_builder_cls() -> Type["BlocksparseFlashAttentionMetadataBuilder"]:
return BlocksparseFlashAttentionMetadataBuilder

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down Expand Up @@ -244,6 +257,43 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
return self._cached_decode_metadata


class BlocksparseFlashAttentionMetadataBuilder(
AttentionMetadataBuilder[BlocksparseFlashAttentionMetadata]):

def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.decode_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
sliding_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int], prefix_cache_hit,
chunked_prefill_enabled):
metadata_builder_add_seq_group_common(
self, seq_group_metadata, token_lens, seq_lens, sliding_seq_lens,
query_lens, context_lens, curr_sliding_window_blocks,
prefix_cache_hit, chunked_prefill_enabled)

def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
use_captured_graph: bool, cuda_graph_pad_size: int,
batch_size: int):
return metadata_builder_build_common(
self, BlocksparseFlashAttentionMetadata, runner, seq_lens,
query_lens, use_captured_graph, cuda_graph_pad_size, batch_size)


class BlocksparseFlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
Expand Down
177 changes: 175 additions & 2 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
"""Attention layer with FlashAttention."""
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import torch
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.sequence import SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad

if TYPE_CHECKING:
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder)


class FlashAttentionBackend(AttentionBackend):
Expand All @@ -28,6 +39,10 @@ def get_impl_cls() -> Type["FlashAttentionImpl"]:
def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata

@staticmethod
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder

@staticmethod
def get_kv_cache_shape(
num_blocks: int,
Expand Down Expand Up @@ -184,6 +199,164 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
return self._cached_decode_metadata


class FlashAttentionMetadataBuilder(
AttentionMetadataBuilder[FlashAttentionMetadata]):

def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.decode_seq_lens: List[int] = []
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0

self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
self.use_v2_block_manager = (
input_builder.scheduler_config.use_v2_block_manager)

def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata,
token_lens: List[int], seq_lens: List[int],
sliding_seq_lens: List[int], query_lens: List[int],
context_lens: List[int],
curr_sliding_window_blocks: List[int], prefix_cache_hit,
chunked_prefill_enabled):
is_prompt = seq_group_metadata.is_prompt
block_tables = seq_group_metadata.block_tables

for (seq_id, token_len, seq_len, sliding_seq_len, query_len,
context_len, curr_sliding_window_block) in zip(
seq_group_metadata.seq_data.keys(), token_lens, seq_lens,
sliding_seq_lens, query_lens, context_lens,
curr_sliding_window_blocks):
self.context_lens.append(context_len)

if is_prompt:
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
assert query_len == 1, (
"seq_len: {}, context_len: {}, query_len: {}".format(
seq_len, context_len, query_len))
self.num_decode_tokens += query_len
self.decode_seq_lens.append(sliding_seq_len)

# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
block_table = block_tables[seq_id][-curr_sliding_window_block:]
self.block_tables.append(block_table)

# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(
is_prompt, query_len, context_len, self.sliding_window,
self.use_v2_block_manager)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size,
seq_group_metadata.block_tables)

def build(self, runner: "GPUModelRunnerBase", seq_lens, query_lens,
use_captured_graph: bool, cuda_graph_pad_size: int,
batch_size: int):
device = runner.device

logits_soft_cap = getattr(runner.model_config.hf_config,
"attn_logit_softcapping", None)
if logits_soft_cap is not None:
raise ValueError(
"Please use Flashinfer backend for models with logits_soft_cap"
" (i.e., Gemma-2). Otherwise, the output might be wrong."
" Set Flashinfer backend by "
"export VLLM_ATTENTION_BACKEND=FLASHINFER.")

max_query_len = max(query_lens)
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.decode_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens

if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size + cuda_graph_pad_size

# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = runner.graph_block_tables[:batch_size]
for i, block_table in enumerate(self.block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device=device)
else:
max_block_table_len = max(
len(block_table) for block_table in self.block_tables)
block_tables = make_tensor_with_pad(
self.block_tables,
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))

context_lens_tensor = torch.tensor(self.context_lens,
dtype=torch.int,
device=device)
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=device)
query_lens_tensor = torch.tensor(query_lens,
dtype=torch.long,
device=device)
query_start_loc = torch.zeros(query_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1,
dtype=torch.int32,
device=device)
torch.cumsum(seq_lens_tensor,
dim=0,
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])
torch.cumsum(query_lens_tensor,
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

slot_mapping_tensor = torch.tensor(self.slot_mapping,
dtype=torch.long,
device=device)

return FlashAttentionMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
)


class FlashAttentionImpl(AttentionImpl):
"""
If the input tensors contain prompt tokens, the layout is as follows:
Expand Down
Loading

0 comments on commit cf3b724

Please sign in to comment.