Skip to content

Commit

Permalink
Implement single_query_cached_kv_attention kernel (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Mar 1, 2023
1 parent cbf8779 commit 0deacbc
Show file tree
Hide file tree
Showing 12 changed files with 2,140 additions and 60 deletions.
4 changes: 3 additions & 1 deletion cacheflow/master/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def __init__(
block_size: int,
num_blocks: int,
) -> None:
assert block_size in [8, 16, 32]
if block_size not in [8, 16]:
raise ValueError(f'Unsupported block size: {block_size}'
'The block size must be either 8 or 16.')
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks
Expand Down
54 changes: 19 additions & 35 deletions cacheflow/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import torch
import torch.nn as nn

from cacheflow import ops
from cacheflow import attention_ops
from cacheflow import cache_ops
from cacheflow.models import InputMetadata


class OPTCacheFlowAttention(nn.Module):

def __init__(self, scale: float) -> None:
super().__init__()
self.scale = scale
self.scale = float(scale)

def _masked_attention(
self,
Expand Down Expand Up @@ -57,46 +58,29 @@ def single_query_cached_kv_attention(
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
) -> None:
num_heads = value_cache.shape[1]
head_size = value_cache.shape[3]
block_size = value_cache.shape[2]
block_tables = input_metadata.block_tables

# FIXME(woosuk): Replace the following with a custom op.
for i in range(input_metadata.num_generation_tokens):
q = query[i].unsqueeze(0)
block_table = block_tables[i]
context_len = int(input_metadata.context_lens[i])

keys = []
values = []
for j in range(context_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size

k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_heads, head_size)
keys.append(k)

v = value_cache[block_number, :, block_offset, :]
values.append(v)
keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)

out = self._masked_attention(q, keys, values)
out = out.view(num_heads, head_size)
output[i].copy_(out, non_blocking=True)
block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention(
output,
query,
key_cache,
value_cache,
self.scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
)

def forward(
self,
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
Expand All @@ -110,7 +94,7 @@ def forward(

# Reshape the input tensors.
num_heads = value_cache.shape[1]
head_size = value_cache.shape[3]
head_size = value_cache.shape[2]
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_heads, head_size)
value = value.view(-1, num_heads, head_size)
Expand All @@ -125,7 +109,7 @@ def forward(
cache_event.wait()

# Reshape the keys and values and store them in the cache.
ops.reshape_and_cache(
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, input_metadata.slot_mapping)

if input_metadata.num_generation_tokens > 0:
Expand Down
20 changes: 12 additions & 8 deletions cacheflow/worker/cache_engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, List, Tuple

import torch
from cacheflow import ops
from cacheflow import cache_ops

KVCache = Tuple[torch.Tensor, torch.Tensor]

Expand Down Expand Up @@ -57,20 +57,22 @@ def get_key_block_shape(self) -> Tuple[int, int, int, int]:
def get_value_block_shape(self) -> Tuple[int, int, int]:
return (
self.num_heads,
self.block_size,
self.head_size,
self.block_size,
)

def allocate_gpu_cache(self) -> List[KVCache]:
gpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape()
for _ in range(self.num_layers):
key_blocks = torch.empty(
size=(self.num_gpu_blocks, *self.get_key_block_shape()),
size=(self.num_gpu_blocks, *key_block_shape),
dtype=self.dtype,
device=self.gpu_id,
)
value_blocks = torch.empty(
size=(self.num_gpu_blocks, *self.get_value_block_shape()),
size=(self.num_gpu_blocks, *value_block_shape),
dtype=self.dtype,
device=self.gpu_id,
)
Expand All @@ -79,14 +81,16 @@ def allocate_gpu_cache(self) -> List[KVCache]:

def allocate_cpu_cache(self) -> List[KVCache]:
cpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape()
for _ in range(self.num_layers):
key_blocks = torch.empty(
size=(self.num_cpu_blocks, *self.get_key_block_shape()),
size=(self.num_cpu_blocks, *key_block_shape),
dtype=self.dtype,
pin_memory=True,
)
value_blocks = torch.empty(
size=(self.num_cpu_blocks, *self.get_value_block_shape()),
size=(self.num_cpu_blocks, *value_block_shape),
dtype=self.dtype,
pin_memory=True,
)
Expand All @@ -104,10 +108,10 @@ def _copy_blocks(
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
ops.copy_cache_blocks(
cache_ops.copy_cache_blocks(
src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks.
ops.copy_cache_blocks(
cache_ops.copy_cache_blocks(
src_value_cache, dst_value_cache, src_to_dst)
event = self.events[i]
event.record(stream=self.cache_stream)
Expand Down
2 changes: 1 addition & 1 deletion cacheflow/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def prepare_inputs(
_pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables]
block_tables_tensor = torch.tensor(
padded_block_tables, dtype=int, device=self.device)
padded_block_tables, dtype=torch.int, device=self.device)

input_metadata = InputMetadata(
seq_ids=prompt_seq_ids + generation_seq_ids,
Expand Down
19 changes: 19 additions & 0 deletions csrc/attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include <torch/extension.h>

void single_query_cached_kv_attention(
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"single_query_cached_kv_attention",
&single_query_cached_kv_attention,
"Compute the attention between an input query and the cached key/value tensors");
}
Loading

0 comments on commit 0deacbc

Please sign in to comment.