From 0deacbce6e96a1af5885babc4e470ce2a0cecf95 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 1 Mar 2023 15:02:19 -0800 Subject: [PATCH] Implement `single_query_cached_kv_attention` kernel (#3) --- cacheflow/master/block_manager.py | 4 +- cacheflow/models/attention.py | 54 +- cacheflow/worker/cache_engine.py | 20 +- cacheflow/worker/worker.py | 2 +- csrc/attention.cpp | 19 + csrc/attention_kernels.cu | 400 +++++++ csrc/attention_utils.h | 204 ++++ csrc/cache_kernels.cu | 10 +- csrc/cuda_primitives.h | 1318 ++++++++++++++++++++++++ setup.py | 11 +- tests/kernels/attention.py | 142 +++ tests/{kernels.py => kernels/cache.py} | 16 +- 12 files changed, 2140 insertions(+), 60 deletions(-) create mode 100644 csrc/attention.cpp create mode 100644 csrc/attention_kernels.cu create mode 100644 csrc/attention_utils.h create mode 100644 csrc/cuda_primitives.h create mode 100644 tests/kernels/attention.py rename tests/{kernels.py => kernels/cache.py} (79%) diff --git a/cacheflow/master/block_manager.py b/cacheflow/master/block_manager.py index cc3b3849422ef..6d749b1a44d94 100644 --- a/cacheflow/master/block_manager.py +++ b/cacheflow/master/block_manager.py @@ -15,7 +15,9 @@ def __init__( block_size: int, num_blocks: int, ) -> None: - assert block_size in [8, 16, 32] + if block_size not in [8, 16]: + raise ValueError(f'Unsupported block size: {block_size}' + 'The block size must be either 8 or 16.') self.device = device self.block_size = block_size self.num_blocks = num_blocks diff --git a/cacheflow/models/attention.py b/cacheflow/models/attention.py index 71218f7e8221a..0babcc8b16621 100644 --- a/cacheflow/models/attention.py +++ b/cacheflow/models/attention.py @@ -3,7 +3,8 @@ import torch import torch.nn as nn -from cacheflow import ops +from cacheflow import attention_ops +from cacheflow import cache_ops from cacheflow.models import InputMetadata @@ -11,7 +12,7 @@ class OPTCacheFlowAttention(nn.Module): def __init__(self, scale: float) -> None: super().__init__() - self.scale = scale + self.scale = float(scale) def _masked_attention( self, @@ -57,38 +58,21 @@ def single_query_cached_kv_attention( output: torch.Tensor, # [num_generation_tokens, num_heads, head_size] query: torch.Tensor, # [num_generation_tokens, num_heads, head_size] key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] - value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size] + value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] input_metadata: InputMetadata, ) -> None: - num_heads = value_cache.shape[1] - head_size = value_cache.shape[3] - block_size = value_cache.shape[2] - block_tables = input_metadata.block_tables - - # FIXME(woosuk): Replace the following with a custom op. - for i in range(input_metadata.num_generation_tokens): - q = query[i].unsqueeze(0) - block_table = block_tables[i] - context_len = int(input_metadata.context_lens[i]) - - keys = [] - values = [] - for j in range(context_len): - block_number = int(block_table[j // block_size]) - block_offset = j % block_size - - k = key_cache[block_number, :, :, block_offset, :] - k = k.reshape(num_heads, head_size) - keys.append(k) - - v = value_cache[block_number, :, block_offset, :] - values.append(v) - keys = torch.stack(keys, dim=0) - values = torch.stack(values, dim=0) - - out = self._masked_attention(q, keys, values) - out = out.view(num_heads, head_size) - output[i].copy_(out, non_blocking=True) + block_size = value_cache.shape[3] + attention_ops.single_query_cached_kv_attention( + output, + query, + key_cache, + value_cache, + self.scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + ) def forward( self, @@ -96,7 +80,7 @@ def forward( key: torch.Tensor, # [num_tokens, num_heads * head_size] value: torch.Tensor, # [num_tokens, num_heads * head_size] key_cache: torch.Tensor, # [num_blocks, num_heads, head_size/x, block_size, x] - value_cache: torch.Tensor, # [num_blocks, num_heads, block_size, head_size] + value_cache: torch.Tensor, # [num_blocks, num_heads, head_size, block_size] input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: # [num_tokens, num_heads * head_size] @@ -110,7 +94,7 @@ def forward( # Reshape the input tensors. num_heads = value_cache.shape[1] - head_size = value_cache.shape[3] + head_size = value_cache.shape[2] query = query.view(-1, num_heads, head_size) key = key.view(-1, num_heads, head_size) value = value.view(-1, num_heads, head_size) @@ -125,7 +109,7 @@ def forward( cache_event.wait() # Reshape the keys and values and store them in the cache. - ops.reshape_and_cache( + cache_ops.reshape_and_cache( key, value, key_cache, value_cache, input_metadata.slot_mapping) if input_metadata.num_generation_tokens > 0: diff --git a/cacheflow/worker/cache_engine.py b/cacheflow/worker/cache_engine.py index 7f4d291b90168..d8597a3b1cf1e 100644 --- a/cacheflow/worker/cache_engine.py +++ b/cacheflow/worker/cache_engine.py @@ -1,7 +1,7 @@ from typing import Dict, List, Tuple import torch -from cacheflow import ops +from cacheflow import cache_ops KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -57,20 +57,22 @@ def get_key_block_shape(self) -> Tuple[int, int, int, int]: def get_value_block_shape(self) -> Tuple[int, int, int]: return ( self.num_heads, - self.block_size, self.head_size, + self.block_size, ) def allocate_gpu_cache(self) -> List[KVCache]: gpu_cache: List[KVCache] = [] + key_block_shape = self.get_key_block_shape() + value_block_shape = self.get_value_block_shape() for _ in range(self.num_layers): key_blocks = torch.empty( - size=(self.num_gpu_blocks, *self.get_key_block_shape()), + size=(self.num_gpu_blocks, *key_block_shape), dtype=self.dtype, device=self.gpu_id, ) value_blocks = torch.empty( - size=(self.num_gpu_blocks, *self.get_value_block_shape()), + size=(self.num_gpu_blocks, *value_block_shape), dtype=self.dtype, device=self.gpu_id, ) @@ -79,14 +81,16 @@ def allocate_gpu_cache(self) -> List[KVCache]: def allocate_cpu_cache(self) -> List[KVCache]: cpu_cache: List[KVCache] = [] + key_block_shape = self.get_key_block_shape() + value_block_shape = self.get_value_block_shape() for _ in range(self.num_layers): key_blocks = torch.empty( - size=(self.num_cpu_blocks, *self.get_key_block_shape()), + size=(self.num_cpu_blocks, *key_block_shape), dtype=self.dtype, pin_memory=True, ) value_blocks = torch.empty( - size=(self.num_cpu_blocks, *self.get_value_block_shape()), + size=(self.num_cpu_blocks, *value_block_shape), dtype=self.dtype, pin_memory=True, ) @@ -104,10 +108,10 @@ def _copy_blocks( src_key_cache, src_value_cache = src[i] dst_key_cache, dst_value_cache = dst[i] # Copy the key blocks. - ops.copy_cache_blocks( + cache_ops.copy_cache_blocks( src_key_cache, dst_key_cache, src_to_dst) # Copy the value blocks. - ops.copy_cache_blocks( + cache_ops.copy_cache_blocks( src_value_cache, dst_value_cache, src_to_dst) event = self.events[i] event.record(stream=self.cache_stream) diff --git a/cacheflow/worker/worker.py b/cacheflow/worker/worker.py index 9b5f4661efb21..aabbb7e3f1e72 100644 --- a/cacheflow/worker/worker.py +++ b/cacheflow/worker/worker.py @@ -118,7 +118,7 @@ def prepare_inputs( _pad_to_max(block_table, max_num_blocks_per_seq) for block_table in generation_block_tables] block_tables_tensor = torch.tensor( - padded_block_tables, dtype=int, device=self.device) + padded_block_tables, dtype=torch.int, device=self.device) input_metadata = InputMetadata( seq_ids=prompt_seq_ids + generation_seq_ids, diff --git a/csrc/attention.cpp b/csrc/attention.cpp new file mode 100644 index 0000000000000..bb2766c1d6b67 --- /dev/null +++ b/csrc/attention.cpp @@ -0,0 +1,19 @@ +#include + +void single_query_cached_kv_attention( + torch::Tensor& out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int block_size, + int max_context_len); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "single_query_cached_kv_attention", + &single_query_cached_kv_attention, + "Compute the attention between an input query and the cached key/value tensors"); +} diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu new file mode 100644 index 0000000000000..23c2b2cda1e2f --- /dev/null +++ b/csrc/attention_kernels.cu @@ -0,0 +1,400 @@ +#include +#include + +#include "attention_utils.h" +#include "cuda_primitives.h" + +#include + +#define WARP_SIZE 32 + +namespace cacheflow { + +// Grid: (num_heads, num_seqs). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS> +__global__ void single_query_cached_kv_attention_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq) { + constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE; + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int seq_idx = blockIdx.y; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread group + // fetch or comput 16 bytes at a time. + // For example, if the size of a thread group is 4 and the data type is half, + // then the vector size is 16 / (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... + // th vectors of the query, and so on. + const scalar_t* q_ptr = q + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + Q_vec q_vecs[NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 logits and accumulation. + float *logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(scalar_t); + float qk_max = -FLT_MAX; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + const int context_len = context_lens[seq_idx]; + const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + const int physical_block_offset = thread_group_idx % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th + // vectors of the key, and so on. + K_vec k_vecs[NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { + const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE + + head_idx * HEAD_SIZE * BLOCK_SIZE + + physical_block_offset * x; + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + k_vecs[i] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); + const bool mask = token_idx >= context_len; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + logits[token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t); + using V_vec = typename Vec::Type; + using L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; + + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec = *reinterpret_cast(logits + token_idx); + + const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE + + head_idx * HEAD_SIZE * BLOCK_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec = *reinterpret_cast(v_ptr + offset); + accs[i] += dot(logits_vec, cast_to_float(v_vec)); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += __shfl_xor_sync(uint32_t(-1), acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for logits + // is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + convert_from_float(*(out_ptr + row_idx), accs[i]); + } + } + } +} + +} // namespace cacheflow + +#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ + cacheflow::single_query_cached_kv_attention_kernel \ + <<>>( \ + out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq); + +// TODO(woosuk): Tune NUM_THREADS. +template< + typename T, + int BLOCK_SIZE, + int NUM_THREADS = 128> +void single_query_cached_kv_attention_launcher( + torch::Tensor& out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int max_context_len) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_context_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs); + dim3 block(NUM_THREADS); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + case 32: + LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); + break; + case 64: + LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); + break; + case 80: + LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); + break; + case 96: + LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); + break; + case 128: + LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); + break; + case 160: + LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); + break; + case 192: + LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); + break; + case 256: + LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); + break; + default: + assert(false); + break; + } +} + +void single_query_cached_kv_attention( + torch::Tensor& out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int block_size, + int max_context_len) { + // TODO(woosuk): Support BF16. + if (query.element_size() == 2) { + // Half. + if (block_size == 8) { + single_query_cached_kv_attention_launcher( + out, + query, + key_cache, + value_cache, + scale, + block_tables, + context_lens, + max_context_len); + } else if (block_size == 16) { + single_query_cached_kv_attention_launcher( + out, + query, + key_cache, + value_cache, + scale, + block_tables, + context_lens, + max_context_len); + } else { + assert(false); + } + } else if (query.element_size() == 4) { + // Float. + if (block_size == 8) { + single_query_cached_kv_attention_launcher( + out, + query, + key_cache, + value_cache, + scale, + block_tables, + context_lens, + max_context_len); + } else if (block_size == 16) { + single_query_cached_kv_attention_launcher( + out, + query, + key_cache, + value_cache, + scale, + block_tables, + context_lens, + max_context_len); + } else { + assert(false); + } + } else { + assert(false); + } +} + +#undef WARP_SIZE diff --git a/csrc/attention_utils.h b/csrc/attention_utils.h new file mode 100644 index 0000000000000..ff59f43d79887 --- /dev/null +++ b/csrc/attention_utils.h @@ -0,0 +1,204 @@ +#pragma once + +#include "cuda_primitives.h" + +#include +#include + +#define MMHA_USE_FP32_ACUM_FOR_FMA +#define MMHA_USE_FP32_ACUM_FOR_OUT + +namespace cacheflow { + +// A vector type to store Q, K, V elements. +template +struct Vec {}; +template<> +struct Vec { + using Type = float; +}; +template<> +struct Vec { + using Type = float2; +}; +template<> +struct Vec { + using Type = float4; +}; +template<> +struct Vec { + using Type = uint16_t; +}; +template<> +struct Vec { + using Type = uint32_t; +}; +template<> +struct Vec { + using Type = uint2; +}; +template<> +struct Vec { + using Type = uint4; +}; + +template +struct FloatVec {}; +template<> +struct FloatVec { + using Type = float; +}; +template<> +struct FloatVec { + using Type = float2; +}; +template<> +struct FloatVec { + using Type = float4; +}; +template<> +struct FloatVec { + using Type = float; +}; +template<> +struct FloatVec { + using Type = float2; +}; +template<> +struct FloatVec { + using Type = Float4_; +}; +template<> +struct FloatVec { + using Type = Float8_; +}; + +template +inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) +{ + using K_vec_acum = typename FloatVec::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_dot { + template + static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) + { + return qk_dot_(q, k); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) +{ + float4 c; + float zero = 0.f; + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%7, %7, %7, %7}; \n" + + : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) + : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using K_vec_acum = typename FloatVec::Type; + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + uint32_t qk_vec_ = float2_to_half2(qk_vec); + return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; +#else + return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; +#endif +#else + return 0.f; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Qk_dot { + template + static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) + { +#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) + return qk_hmma_dot_(q, k); +#else + return qk_dot_<4>(q, k); +#endif // defined MMHA_USE_HMMA_FOR_REDUCTION + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float block_sum(float* red_smem, float sum) +{ + + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < WARPS_PER_BLOCK) { + sum = red_smem[lane]; + } + +// Parallel reduction inside the warp. +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + +} // namespace cacheflow + +#undef MMHA_USE_FP32_ACUM_FOR_FMA +#undef MMHA_USE_FP32_ACUM_FOR_OUT diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 1cce1bc76551c..2e9ca8f0df7ba 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -48,7 +48,7 @@ __global__ void reshape_and_cache_kernel( const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] scalar_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, block_size, head_size] + scalar_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] const int* __restrict__ slot_mapping, // [num_tokens] const int num_heads, const int head_size, @@ -73,10 +73,10 @@ __global__ void reshape_and_cache_kernel( + x_idx * block_size * x + block_offset * x + x_offset; - const int tgt_value_idx = block_idx * num_heads * block_size * head_size - + head_idx * block_size * head_size - + block_offset * head_size - + head_offset; + const int tgt_value_idx = block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + + head_offset * block_size + + block_offset; key_cache[tgt_key_idx] = __ldg(&key[src_idx]); value_cache[tgt_value_idx] = __ldg(&value[src_idx]); } diff --git a/csrc/cuda_primitives.h b/csrc/cuda_primitives.h new file mode 100644 index 0000000000000..f8f137a7eb56a --- /dev/null +++ b/csrc/cuda_primitives.h @@ -0,0 +1,1318 @@ +#pragma once + +#include + +namespace cacheflow { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Float4_ { + float2 x; + float2 y; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +struct bf16_4_t { + __nv_bfloat162 x; + __nv_bfloat162 y; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct bf16_8_t { + __nv_bfloat162 x; + __nv_bfloat162 y; + __nv_bfloat162 z; + __nv_bfloat162 w; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float add(float a, float b) +{ + return a + b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 add(float2 a, float2 b) +{ + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 add(float4 a, float4 b) +{ + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) +{ + return a + b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hadd2(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) +{ + bf16_4_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) +{ + bf16_8_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} +#endif // ENABLE_BF16 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint16_t add(uint16_t a, uint16_t b) +{ + uint16_t c; + asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t add(uint32_t a, uint32_t b) +{ + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint2 add(uint2 a, uint2 b) +{ + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint4 add(uint4 a, uint4 b) +{ + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint16_t float_to_half(float f) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better? + float zero = 0.f; + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); +#endif + return tmp.u16[0]; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t float2_to_half2(float2 f) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); +#endif + return tmp.u32; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float half_to_float(uint16_t h) +{ + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); + return f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 half2_to_float2(uint32_t v) +{ + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 add(uint32_t a, float2 fb) +{ + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ add(uint2 a, Float4_ fb) +{ + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ add(uint4 a, Float8_ fb) +{ + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t h0_h0(uint16_t a) +{ + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); + return b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float fma(float a, float b, float c) +{ + return a * b + c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(float2 a, float2 b, float2 c) +{ + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(float a, float2 b, float2 c) +{ + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 fma(float4 a, float4 b, float4 c) +{ + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 fma(float a, float4 b, float4 c) +{ + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) +{ + Float4_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) +{ + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ float2 add(__nv_bfloat162 a, float2 fb) +{ + float2 fa = bf1622float2(a); + return add(fa, fb); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) +{ + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) +{ + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} +#endif // ENABLE_BF16 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) +{ + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) +{ + return fma(h0_h0(a), b, c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) +{ + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) +{ + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) +{ + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) +{ + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float fma(uint16_t a, uint16_t b, float fc) +{ + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb + fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) +{ + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return fma(fa, fb, fc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) +{ + return fma(h0_h0(a), b, fc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) +{ + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) +{ + uint32_t s = h0_h0(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) +{ + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) +{ + uint32_t s = h0_h0(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hfma2(a, b, c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hfma2(bf162bf162(a), b, c); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) +{ + bf16_4_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) +{ + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) +{ + bf16_8_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) +{ + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) +{ + return __bfloat162float(a) * __bfloat162float(b) + fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) +{ + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return fma(fa, fb, fc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) +{ + return fma(bf162bf162(a), b, fc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) +{ + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) +{ + __nv_bfloat162 s = bf162bf162(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) +{ + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) +{ + __nv_bfloat162 s = bf162bf162(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ Acc mul(A a, B b); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float mul(float a, float b) +{ + return a * b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(float2 a, float2 b) +{ + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(float a, float2 b) +{ + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float4 mul(float4 a, float4 b) +{ + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float4 mul(float a, float4 b) +{ + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) +{ + uint16_t c; + asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) +{ + uint32_t c; + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint32_t mul(uint16_t a, uint32_t b) +{ + return mul(h0_h0(a), b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint2 mul(uint2 a, uint2 b) +{ + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint2 mul(uint16_t a, uint2 b) +{ + uint32_t s = h0_h0(a); + uint2 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint4 mul(uint4 a, uint4 b) +{ + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint4 mul(uint16_t a, uint4 b) +{ + uint32_t s = h0_h0(a); + uint4 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + c.z = mul(s, b.z); + c.w = mul(s, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float mul(uint16_t a, uint16_t b) +{ + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(uint32_t a, uint32_t b) +{ + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return mul(fa, fb); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(uint16_t a, uint32_t b) +{ + return mul(h0_h0(a), b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float4_ mul(uint2 a, uint2 b) +{ + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float4_ mul(uint16_t a, uint2 b) +{ + uint32_t s = h0_h0(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(uint4 a, uint4 b) +{ + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(uint16_t a, uint4 b) +{ + uint32_t s = h0_h0(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +template<> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __hmul(a, b); +#else + return bf16hmul(a, b); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hmul2(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) +{ + return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) +{ + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) +{ + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) +{ + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) +{ + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) +{ + float fa = (float)a; + float fb = (float)b; + return fa * fb; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) +{ + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return mul(fa, fb); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) +{ + return mul(bf162bf162(a), b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) +{ + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) +{ + __nv_bfloat162 s = bf162bf162(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) +{ + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) +{ + __nv_bfloat162 s = bf162bf162(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(float v) +{ + return v; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(float2 v) +{ + return v.x + v.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(float4 v) +{ + return v.x + v.y + v.z + v.w; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ float sum(__nv_bfloat162 v) +{ + float2 vf = bf1622float2(v); + return vf.x + vf.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(bf16_4_t v) +{ + return sum(v.x) + sum(v.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(bf16_8_t v) +{ + return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(uint16_t v) +{ + return half_to_float(v); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(uint32_t v) +{ + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(uint2 v) +{ + uint32_t c = add(v.x, v.y); + return sum(c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(uint4 v) +{ +#if 1 + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); +#else + uint32_t c = add(v.x, v.y); + uint32_t d = add(v.z, v.w); + c = add(c, d); +#endif + return sum(c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(Float4_ v) +{ + return v.x.x + v.x.y + v.y.x + v.y.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(Float8_ v) +{ + return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float dot(Float4_ a, Float4_ b) +{ + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + return acc.x + acc.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float dot(Float8_ a, Float8_ b) +{ + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + acc = fma(a.z, b.z, acc); + acc = fma(a.w, b.w, acc); + return acc.x + acc.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float dot(T a, T b) +{ + return sum(mul(a, b)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float dot(T a, T b) +{ + return sum(mul(a, b)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void zero(uint16_t& dst) +{ + dst = uint16_t(0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void zero(T& dst) +{ + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float& dst, float src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint16_t& dst, float src) +{ + dst = float_to_half(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint32_t& dst, float2 src) +{ + dst = float2_to_half2(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) +{ + dst = __float2bfloat16(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst = __float22bfloat162_rn(src); +#else + dst = __floats2bfloat162_rn(src.x, src.y); +#endif +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint2& dst, Float4_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint4& dst, Float8_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); + dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); + dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); +#endif +} +#endif // ENABLE_BF16 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float2& dst, float2 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float4& dst, float4 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(float4 u) +{ + return u.x; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(uint4 u) +{ + float2 tmp = half2_to_float2(u.x); + return tmp.x; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float cast_to_float(float u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(float2 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 cast_to_float(float4 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(Float4_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(Float8_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(uint32_t u) +{ + return half2_to_float2(u); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(uint2 u) +{ + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(uint4 u) +{ + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +} diff --git a/setup.py b/setup.py index 4bcb921777ec2..428088a8682ad 100644 --- a/setup.py +++ b/setup.py @@ -9,15 +9,22 @@ # Cache operations. cache_extension = cpp_extension.CUDAExtension( - name='cacheflow.ops', + name='cacheflow.cache_ops', sources=['csrc/cache.cpp', 'csrc/cache_kernels.cu'], extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS}, ) ext_modules.append(cache_extension) +# Attention kernels. +attention_extension = cpp_extension.CUDAExtension( + name='cacheflow.attention_ops', + sources=['csrc/attention.cpp', 'csrc/attention_kernels.cu'], + extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS}, +) +ext_modules.append(attention_extension) + setuptools.setup( name='cacheflow', - requires_python='>=3.9', ext_modules=ext_modules, cmdclass={'build_ext': cpp_extension.BuildExtension}, ) diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py new file mode 100644 index 0000000000000..550e2b28ad7da --- /dev/null +++ b/tests/kernels/attention.py @@ -0,0 +1,142 @@ +import random +from typing import Optional + +import torch + +from cacheflow import attention_ops + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + query = query * scale + attn = torch.einsum('qhd,khd->hqk', query, key) + if attn_mask is not None: + attn = attn + attn_mask + attn = torch.softmax(attn, dim=-1) + out = torch.einsum('hqk,khd->qhd', attn, value) + return out + + +def ref_single_query_cached_kv_attention( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, +) -> None: + num_heads = value_cache.shape[1] + head_size = value_cache.shape[2] + block_size = value_cache.shape[3] + + num_input_tokens = query.shape[0] + for i in range(num_input_tokens): + q = query[i].unsqueeze(0) + block_table = block_tables[i] + context_len = int(context_lens[i]) + + keys = [] + values = [] + for j in range(context_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_heads, head_size) + keys.append(k) + + v = value_cache[block_number, :, :, block_offset] + values.append(v) + keys = torch.stack(keys, dim=0) + values = torch.stack(values, dim=0) + + scale = 1.0 / (head_size ** 0.5) + out = ref_masked_attention(q, keys, values, scale) + out = out.view(num_heads, head_size) + output[i].copy_(out, non_blocking=True) + + +def test_single_query_cached_kv_attention( + num_tokens: int, + num_heads: int, + head_size: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, +) -> None: + query = torch.randn( + num_tokens, num_heads, head_size, dtype=dtype, device='cuda') + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_block_shape = (num_heads, head_size // x, block_size, x) + key_cache = torch.randn( + size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda') + value_block_shape = (num_heads, head_size, block_size) + value_cache = torch.randn( + size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda') + + context_lens = [random.randint(1, 4096) for _ in range(num_tokens)] + max_context_len = max(context_lens) + context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda') + + max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size + block_tables = [] + for _ in range(num_tokens): + block_table = [ + random.randint(0, num_blocks - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables.append(block_table) + block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda') + + scale = float(1.0 / (head_size ** 0.5)) + output = torch.empty_like(query) + attention_ops.single_query_cached_kv_attention( + output, + query, + key_cache, + value_cache, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + ) + + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention( + ref_output, + query, + key_cache, + value_cache, + block_tables, + context_lens, + ) + # NOTE(woosuk): Due to the difference in the data types the two + # implementations use for attention softmax logits and accumulation, + # there is a small difference in the final outputs. + # We should use a relaxed tolerance for the test. + assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) + + +@torch.inference_mode() +def test_attention() -> None: + for dtype in [torch.half, torch.float]: + for block_size in [8, 16]: + for head_size in [64, 80, 96, 128, 256]: + test_single_query_cached_kv_attention( + num_tokens=37, + num_heads=3, + head_size=head_size, + block_size=block_size, + num_blocks=1024, + dtype=dtype, + ) + + +if __name__ == '__main__': + test_attention() diff --git a/tests/kernels.py b/tests/kernels/cache.py similarity index 79% rename from tests/kernels.py rename to tests/kernels/cache.py index 8d2eb22e8a02e..9eebe437448f8 100644 --- a/tests/kernels.py +++ b/tests/kernels/cache.py @@ -2,7 +2,7 @@ import torch -from cacheflow.ops import reshape_and_cache +from cacheflow import cache_ops def test_reshape_and_cache( @@ -26,30 +26,30 @@ def test_reshape_and_cache( key_cache = torch.randn(size=key_cache_shape, dtype=dtype, device='cuda') cloned_key_cache = key_cache.clone() - value_cache_shape = (num_blocks, num_heads, block_size, head_size) + value_cache_shape = (num_blocks, num_heads, head_size, block_size) value_cache = torch.randn( size=value_cache_shape, dtype=dtype, device='cuda') cloned_value_cache = value_cache.clone() - reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) + cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping) for i in range(num_tokens): reshaped_key = key.reshape(num_tokens, num_heads, head_size // x, x) block_idx = slot_mapping[i] // block_size block_offset = slot_mapping[i] % block_size cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i] - cloned_value_cache[block_idx, :, block_offset, :] = value[i] + cloned_value_cache[block_idx, :, :, block_offset] = value[i] assert torch.allclose(key_cache, cloned_key_cache) assert torch.allclose(value_cache, cloned_value_cache) -@torch.no_grad() -def test_kernels(): +@torch.inference_mode() +def test_cache() -> None: test_reshape_and_cache( - num_tokens=3, num_heads=2, head_size=16, block_size=2, num_blocks=2, + num_tokens=3, num_heads=2, head_size=16, block_size=8, num_blocks=2, dtype=torch.half) if __name__ == '__main__': - test_kernels() + test_cache()