Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement single_query_cached_kv_attention kernel #3

Merged
merged 14 commits into from
Mar 1, 2023
4 changes: 3 additions & 1 deletion cacheflow/master/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def __init__(
block_size: int,
num_blocks: int,
) -> None:
assert block_size in [8, 16, 32]
if block_size not in [8, 16]:
raise ValueError(f'Unsupported block size: {block_size}'
'The block size must be either 8 or 16.')
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks
Expand Down
54 changes: 19 additions & 35 deletions cacheflow/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
import torch
import torch.nn as nn

from cacheflow import ops
from cacheflow import attention_ops
from cacheflow import cache_ops
from cacheflow.models import InputMetadata


class OPTCacheFlowAttention(nn.Module):

def __init__(self, scale: float) -> None:
super().__init__()
self.scale = scale
self.scale = float(scale)

def _masked_attention(
self,
Expand Down Expand Up @@ -57,46 +58,29 @@ def single_query_cached_kv_attention(
output: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
query: torch.Tensor, # [num_generation_tokens, num_heads, head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
) -> None:
num_heads = value_cache.shape[1]
head_size = value_cache.shape[3]
block_size = value_cache.shape[2]
block_tables = input_metadata.block_tables

# FIXME(woosuk): Replace the following with a custom op.
for i in range(input_metadata.num_generation_tokens):
q = query[i].unsqueeze(0)
block_table = block_tables[i]
context_len = int(input_metadata.context_lens[i])

keys = []
values = []
for j in range(context_len):
block_number = int(block_table[j // block_size])
block_offset = j % block_size

k = key_cache[block_number, :, :, block_offset, :]
k = k.reshape(num_heads, head_size)
keys.append(k)

v = value_cache[block_number, :, block_offset, :]
values.append(v)
keys = torch.stack(keys, dim=0)
values = torch.stack(values, dim=0)

out = self._masked_attention(q, keys, values)
out = out.view(num_heads, head_size)
output[i].copy_(out, non_blocking=True)
block_size = value_cache.shape[3]
attention_ops.single_query_cached_kv_attention(
output,
query,
key_cache,
value_cache,
self.scale,
input_metadata.block_tables,
input_metadata.context_lens,
block_size,
input_metadata.max_context_len,
)

def forward(
self,
query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size]
key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x]
value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size]
value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size]
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor: # [num_tokens, num_heads * head_size]
Expand All @@ -110,7 +94,7 @@ def forward(

# Reshape the input tensors.
num_heads = value_cache.shape[1]
head_size = value_cache.shape[3]
head_size = value_cache.shape[2]
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_heads, head_size)
value = value.view(-1, num_heads, head_size)
Expand All @@ -125,7 +109,7 @@ def forward(
cache_event.wait()

# Reshape the keys and values and store them in the cache.
ops.reshape_and_cache(
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, input_metadata.slot_mapping)

if input_metadata.num_generation_tokens > 0:
Expand Down
20 changes: 12 additions & 8 deletions cacheflow/worker/cache_engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, List, Tuple

import torch
from cacheflow import ops
from cacheflow import cache_ops

KVCache = Tuple[torch.Tensor, torch.Tensor]

Expand Down Expand Up @@ -57,20 +57,22 @@ def get_key_block_shape(self) -> Tuple[int, int, int, int]:
def get_value_block_shape(self) -> Tuple[int, int, int]:
return (
self.num_heads,
self.block_size,
self.head_size,
self.block_size,
)

def allocate_gpu_cache(self) -> List[KVCache]:
gpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape()
for _ in range(self.num_layers):
key_blocks = torch.empty(
size=(self.num_gpu_blocks, *self.get_key_block_shape()),
size=(self.num_gpu_blocks, *key_block_shape),
dtype=self.dtype,
device=self.gpu_id,
)
value_blocks = torch.empty(
size=(self.num_gpu_blocks, *self.get_value_block_shape()),
size=(self.num_gpu_blocks, *value_block_shape),
dtype=self.dtype,
device=self.gpu_id,
)
Expand All @@ -79,14 +81,16 @@ def allocate_gpu_cache(self) -> List[KVCache]:

def allocate_cpu_cache(self) -> List[KVCache]:
cpu_cache: List[KVCache] = []
key_block_shape = self.get_key_block_shape()
value_block_shape = self.get_value_block_shape()
for _ in range(self.num_layers):
key_blocks = torch.empty(
size=(self.num_cpu_blocks, *self.get_key_block_shape()),
size=(self.num_cpu_blocks, *key_block_shape),
dtype=self.dtype,
pin_memory=True,
)
value_blocks = torch.empty(
size=(self.num_cpu_blocks, *self.get_value_block_shape()),
size=(self.num_cpu_blocks, *value_block_shape),
dtype=self.dtype,
pin_memory=True,
)
Expand All @@ -104,10 +108,10 @@ def _copy_blocks(
src_key_cache, src_value_cache = src[i]
dst_key_cache, dst_value_cache = dst[i]
# Copy the key blocks.
ops.copy_cache_blocks(
cache_ops.copy_cache_blocks(
src_key_cache, dst_key_cache, src_to_dst)
# Copy the value blocks.
ops.copy_cache_blocks(
cache_ops.copy_cache_blocks(
src_value_cache, dst_value_cache, src_to_dst)
event = self.events[i]
event.record(stream=self.cache_stream)
Expand Down
2 changes: 1 addition & 1 deletion cacheflow/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def prepare_inputs(
_pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables]
block_tables_tensor = torch.tensor(
padded_block_tables, dtype=int, device=self.device)
padded_block_tables, dtype=torch.int, device=self.device)

input_metadata = InputMetadata(
seq_ids=prompt_seq_ids + generation_seq_ids,
Expand Down
19 changes: 19 additions & 0 deletions csrc/attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include <torch/extension.h>

void single_query_cached_kv_attention(
torch::Tensor& out,
torch::Tensor& query,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
float scale,
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int block_size,
int max_context_len);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"single_query_cached_kv_attention",
&single_query_cached_kv_attention,
"Compute the attention between an input query and the cached key/value tensors");
}
Loading