diff --git a/csrc/attention/attention_dtypes.cuh b/csrc/attention/attention_dtypes.cuh new file mode 100644 index 000000000000..1d586ddf7522 --- /dev/null +++ b/csrc/attention/attention_dtypes.cuh @@ -0,0 +1,5 @@ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float16.cuh" +#include "dtype_float32.cuh" diff --git a/csrc/attention/attention_generic.cuh b/csrc/attention/attention_generic.cuh new file mode 100644 index 000000000000..799f873f462a --- /dev/null +++ b/csrc/attention/attention_generic.cuh @@ -0,0 +1,47 @@ +#pragma once + +#include + +namespace cacheflow { + +// A vector type to store Q, K, V elements. +template +struct Vec {}; + +// A vector type to store FP32 accumulators. +template +struct FloatVec {}; + +// Template vector operations. +template +inline __device__ Acc mul(A a, B b); + +template +inline __device__ float sum(T v); + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ void zero(T& dst) { + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; + +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +} // namespace cacheflow diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu new file mode 100644 index 000000000000..a4bd6aeb6867 --- /dev/null +++ b/csrc/attention/attention_kernels.cu @@ -0,0 +1,451 @@ +#include +#include + +#include "attention_dtypes.cuh" +#include "attention_utils.cuh" + +#include + +#define WARP_SIZE 32 +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +namespace cacheflow { + +// Utility function for attention softmax. +template +inline __device__ float block_sum(float* red_smem, float sum) { + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } + + // Parallel reduction inside the warp. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + +// Grid: (num_heads, num_seqs). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS> +__global__ void single_query_cached_kv_attention_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const int q_stride) { + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int seq_idx = blockIdx.y; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread group + // fetch or compute 16 bytes at a time. + // For example, if the size of a thread group is 4 and the data type is half, + // then the vector size is 16 / (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... + // th vectors of the query, and so on. + // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + Q_vec q_vecs[NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(scalar_t); + float qk_max = -FLT_MAX; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + const int context_len = context_lens[seq_idx]; + const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th + // vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE + + head_idx * HEAD_SIZE * BLOCK_SIZE + + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); + const bool mask = token_idx >= context_len; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + logits[token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx)); + + const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE + + head_idx * HEAD_SIZE * BLOCK_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec = *reinterpret_cast(v_ptr + offset); + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += __shfl_xor_sync(uint32_t(-1), acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for logits + // is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + from_float(*(out_ptr + row_idx), accs[i]); + } + } + } +} + +} // namespace cacheflow + +#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ + cacheflow::single_query_cached_kv_attention_kernel \ + <<>>( \ + out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + query_stride); + +// TODO(woosuk): Tune NUM_THREADS. +template< + typename T, + int BLOCK_SIZE, + int NUM_THREADS = 128> +void single_query_cached_kv_attention_launcher( + torch::Tensor& out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int max_context_len) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int query_stride = query.stride(0); + + int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_context_len * sizeof(T); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs); + dim3 block(NUM_THREADS); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + case 32: + LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); + break; + case 64: + LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); + break; + case 80: + LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); + break; + case 96: + LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); + break; + case 128: + LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); + break; + case 160: + LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); + break; + case 192: + LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); + break; + case 256: + LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + single_query_cached_kv_attention_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len); + +#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 1: \ + CALL_KERNEL_LAUNCHER(T, 1); \ + break; \ + case 2: \ + CALL_KERNEL_LAUNCHER(T, 2); \ + break; \ + case 4: \ + CALL_KERNEL_LAUNCHER(T, 4); \ + break; \ + case 8: \ + CALL_KERNEL_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_KERNEL_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_KERNEL_LAUNCHER(T, 32); \ + break; \ + case 64: \ + CALL_KERNEL_LAUNCHER(T, 64); \ + break; \ + case 128: \ + CALL_KERNEL_LAUNCHER(T, 128); \ + break; \ + case 256: \ + CALL_KERNEL_LAUNCHER(T, 256); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void single_query_cached_kv_attention( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int block_size, + int max_context_len) { + // TODO(woosuk): Support FP32 and BF16. + if (query.dtype() == at::ScalarType::Half) { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } +} + +#undef WARP_SIZE +#undef MAX +#undef MIN diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh new file mode 100644 index 000000000000..df529095d9c2 --- /dev/null +++ b/csrc/attention/attention_utils.cuh @@ -0,0 +1,38 @@ +#pragma once + +#include "attention_dtypes.cuh" + +#include +#include + +namespace cacheflow { + +// Q*K^T operation. +template +inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { + using A_vec = typename FloatVec::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + A_vec qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +template +struct Qk_dot { + template + static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { + return qk_dot_(q, k); + } +}; + +} // namespace cacheflow diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh new file mode 100644 index 000000000000..d2a60353e116 --- /dev/null +++ b/csrc/attention/dtype_float16.cuh @@ -0,0 +1,426 @@ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +#include + +namespace cacheflow { + +// FP16 vector types for Q, K, V. +template<> +struct Vec { + using Type = uint16_t; +}; +template<> +struct Vec { + using Type = uint32_t; +}; +template<> +struct Vec { + using Type = uint2; +}; +template<> +struct Vec { + using Type = uint4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template<> +struct FloatVec { + using Type = float; +}; +template<> +struct FloatVec { + using Type = float2; +}; +template<> +struct FloatVec { + using Type = Float4_; +}; +template<> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ uint32_t h0_h0(uint16_t a) { + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); + return b; +} + +inline __device__ float half_to_float(uint16_t h) { + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); + return f; +} + +inline __device__ float2 half2_to_float2(uint32_t v) { + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +} + +inline __device__ uint16_t float_to_half(float f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); + return tmp.u16[0]; +} + +inline __device__ uint32_t float2_to_half2(float2 f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); +#endif + return tmp.u32; +} + +// Vector addition. +inline __device__ uint16_t add(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +inline __device__ uint32_t add(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +inline __device__ uint2 add(uint2 a, uint2 b) { + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ uint4 add(uint4 a, uint4 b) { + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(uint32_t a, float2 fb) { + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(uint2 a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(uint4 a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template<> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +template<> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +template<> +inline __device__ uint32_t mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template<> +inline __device__ uint2 mul(uint2 a, uint2 b) { + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template<> +inline __device__ uint2 mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + uint2 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + return c; +} + +template<> +inline __device__ uint4 mul(uint4 a, uint4 b) { + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +template<> +inline __device__ uint4 mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + uint4 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + c.z = mul(s, b.z); + c.w = mul(s, b.w); + return c; +} + +template<> +inline __device__ float mul(uint16_t a, uint16_t b) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb; +} + +template<> +inline __device__ float2 mul(uint32_t a, uint32_t b) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return mul(fa, fb); +} + +template<> +inline __device__ float2 mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template<> +inline __device__ Float4_ mul(uint2 a, uint2 b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template<> +inline __device__ Float4_ mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template<> +inline __device__ Float8_ mul(uint4 a, uint4 b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template<> +inline __device__ Float8_ mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + return d; +} + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { + return fma(h0_h0(a), b, c); +} + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(uint16_t a, uint16_t b, float fc) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb + fc; +} + +inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { + return fma(h0_h0(a), b, fc); +} + +inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { + uint32_t s = h0_h0(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { + uint32_t s = h0_h0(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template<> +inline __device__ float sum(uint16_t v) { + return half_to_float(v); +} + +template<> +inline __device__ float sum(uint32_t v) { + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +template<> +inline __device__ float sum(uint2 v) { + uint32_t c = add(v.x, v.y); + return sum(c); +} + +template<> +inline __device__ float sum(uint4 v) { + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); + return sum(c); +} + +// Zero-out a vector. +inline __device__ void zero(uint16_t& dst) { + dst = uint16_t(0); +} + +// From float32 to float16. +inline __device__ void from_float(uint16_t& dst, float src) { + dst = float_to_half(src); +} + +inline __device__ void from_float(uint32_t& dst, float2 src) { + dst = float2_to_half2(src); +} + +inline __device__ void from_float(uint2& dst, Float4_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +inline __device__ void from_float(uint4& dst, Float8_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +// From float16 to float32. +inline __device__ float to_float(uint16_t u) { + return half_to_float(u); +} + +inline __device__ float2 to_float(uint32_t u) { + return half2_to_float2(u); +} + +inline __device__ Float4_ to_float(uint2 u) { + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +inline __device__ Float8_ to_float(uint4 u) { + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +} // namespace cacheflow diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh new file mode 100644 index 000000000000..517da64b3609 --- /dev/null +++ b/csrc/attention/dtype_float32.cuh @@ -0,0 +1,250 @@ +#pragma once + +#include "attention_generic.cuh" + +#include + +namespace cacheflow { + +// Define FP32 vector data types. +struct Float4_ { + float2 x; + float2 y; +}; + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +// FP32 vector types for Q, K, V. +template<> +struct Vec { + using Type = float; +}; +template<> +struct Vec { + using Type = float2; +}; +template<> +struct Vec { + using Type = float4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template<> +struct FloatVec { + using Type = float; +}; +template<> +struct FloatVec { + using Type = float2; +}; +template<> +struct FloatVec { + using Type = float4; +}; + +// Vector addition. +inline __device__ float add(float a, float b) { + return a + b; +} + +inline __device__ float2 add(float2 a, float2 b) { + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ float4 add(float4 a, float4 b) { + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +// Vector multiplication. +template<> +inline __device__ float mul(float a, float b) { + return a * b; +} + +template<> +inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template<> +inline __device__ float2 mul(float a, float2 b) { + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +template<> +inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +template<> +inline __device__ float4 mul(float a, float4 b) { + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +// Vector fused multiply-add. +inline __device__ float fma(float a, float b, float c) { + return a * b + c; +} + +inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ float2 fma(float a, float2 b, float2 c) { + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ float4 fma(float4 a, float4 b, float4 c) { + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ float4 fma(float a, float4 b, float4 c) { + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { + Float4_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +// Vector sum. +template<> +inline __device__ float sum(float v) { + return v; +} + +template<> +inline __device__ float sum(float2 v) { + return v.x + v.y; +} + +template<> +inline __device__ float sum(float4 v) { + return v.x + v.y + v.z + v.w; +} + +template<> +inline __device__ float sum(Float4_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y; +} + +template<> +inline __device__ float sum(Float8_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; +} + +// Vector dot product. +inline __device__ float dot(float a, float b) { + return a * b; +} + +inline __device__ float dot(float2 a, float2 b) { + float2 c = mul(a, b); + return c.x + c.y; +} + +inline __device__ float dot(Float4_ a, Float4_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + return acc.x + acc.y; +} + +inline __device__ float dot(Float8_ a, Float8_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + acc = fma(a.z, b.z, acc); + acc = fma(a.w, b.w, acc); + return acc.x + acc.y; +} + +// From float to float. +inline __device__ void from_float(float& dst, float src) { + dst = src; +} + +inline __device__ void from_float(float2& dst, float2 src) { + dst = src; +} + +inline __device__ void from_float(float4& dst, float4 src) { + dst = src; +} + +// From float to float. +inline __device__ float to_float(float u) { + return u; +} + +inline __device__ float2 to_float(float2 u) { + return u; +} + +inline __device__ float4 to_float(float4 u) { + return u; +} + +inline __device__ Float4_ to_float(Float4_ u) { + return u; +} + +inline __device__ Float8_ to_float(Float8_ u) { + return u; +} + +} // namespace cacheflow diff --git a/csrc/attention_kernels.cu b/csrc/attention_kernels.cu deleted file mode 100644 index c25acbb8be6f..000000000000 --- a/csrc/attention_kernels.cu +++ /dev/null @@ -1,896 +0,0 @@ -#include -#include - -#include "attention_utils.h" -#include "cuda_primitives.h" -#include "reduction_utils.h" - -#include - -#define WARP_SIZE 32 -#define MAX(a, b) ((a) > (b) ? (a) : (b)) -#define MIN(a, b) ((a) < (b) ? (a) : (b)) - -namespace cacheflow { - -// Grid: (num_heads, num_seqs). -template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS> -__global__ void single_query_cached_kv_attention_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const int q_stride) { - constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int thread_idx = threadIdx.x; - const int warp_idx = thread_idx / WARP_SIZE; - const int lane = thread_idx % WARP_SIZE; - - const int head_idx = blockIdx.x; - const int num_heads = gridDim.x; - const int seq_idx = blockIdx.y; - - // A vector type to store a part of a key or a query. - // The vector size is configured in such a way that the threads in a thread group - // fetch or compute 16 bytes at a time. - // For example, if the size of a thread group is 4 and the data type is half, - // then the vector size is 16 / (4 * sizeof(half)) == 2. - constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); - using K_vec = typename Vec::Type; - using Q_vec = typename Vec::Type; - - constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; - constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; - - const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; - const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; - - // Load the query to registers. - // Each thread in a thread group has a different part of the query. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... - // th vectors of the query, and so on. - // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. - const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; - Q_vec q_vecs[NUM_VECS_PER_THREAD]; -#pragma unroll - for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { - const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - q_vecs[i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); - } - - // Memory planning. - extern __shared__ char shared_mem[]; - // NOTE(woosuk): We use FP32 logits and accumulation. - float *logits = reinterpret_cast(shared_mem); - // Workspace for reduction. - __shared__ float red_smem[2 * NUM_WARPS]; - - // x == THREAD_GROUP_SIZE * VEC_SIZE - // Each thread group fetches x elements from the key at a time. - constexpr int x = 16 / sizeof(scalar_t); - float qk_max = -FLT_MAX; - - const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - const int context_len = context_lens[seq_idx]; - const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; - - // Iterate over the key blocks. - // Each warp fetches a block of keys for each iteration. - // Each thread group in a warp fetches a key from the block, and computes - // dot product with the query. - for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { - const int physical_block_number = block_table[block_idx]; - - // Load a key to registers. - // Each thread in a thread group has a different part of the key. - // For example, if the the thread group size is 4, then the first thread in the group - // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th - // vectors of the key, and so on. - for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - K_vec k_vecs[NUM_VECS_PER_THREAD]; - -#pragma unroll - for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { - const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE - + head_idx * HEAD_SIZE * BLOCK_SIZE - + physical_block_offset * x; - const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; - const int offset1 = (vec_idx * VEC_SIZE) / x; - const int offset2 = (vec_idx * VEC_SIZE) % x; - k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); - } - - // Compute dot product. - // This includes a reduction across the threads in the same thread group. - const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); - const bool mask = token_idx >= context_len; - - if (thread_group_offset == 0) { - // Store the partial reductions to shared memory. - // NOTE(woosuk): It is required to zero out the masked logits. - logits[token_idx] = mask ? 0.f : qk; - // Update the max value. - qk_max = mask ? qk_max : fmaxf(qk_max, qk); - } - } - } - - // Perform reduction across the threads in the same warp to get the - // max qk value for each "warp" (not across the thread block yet). - // The 0-th thread of each thread group already has its max qk value. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - if (lane == 0) { - red_smem[warp_idx] = qk_max; - } - __syncthreads(); - - // TODO(woosuk): Refactor this part. - // Get the max qk value for the sequence. - qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - // Broadcast the max qk value to all threads. - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - - // Get the sum of the exp values. - float exp_sum = 0.f; - for (int i = thread_idx; i < context_len; i += NUM_THREADS) { - float val = __expf(logits[i] - qk_max); - logits[i] = val; - exp_sum += val; - } - exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); - - // Compute softmax. - const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); - for (int i = thread_idx; i < context_len; i += NUM_THREADS) { - logits[i] *= inv_sum; - } - __syncthreads(); - - // Each thread will fetch 16 bytes from the value cache at a time. - constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); - using V_vec = typename Vec::Type; - using L_vec = typename FloatVec::Type; - - constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; - constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; - - float accs[NUM_ROWS_PER_THREAD]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - accs[i] = 0.f; - } - - for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { - const int physical_block_number = block_table[block_idx]; - const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - L_vec logits_vec = *reinterpret_cast(logits + token_idx); - - const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE - + head_idx * HEAD_SIZE * BLOCK_SIZE; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE) { - const int offset = row_idx * BLOCK_SIZE + physical_block_offset; - V_vec v_vec = *reinterpret_cast(v_ptr + offset); - accs[i] += dot(logits_vec, cast_to_float(v_vec)); - } - } - } - - // Perform reduction within each warp. -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - float acc = accs[i]; -#pragma unroll - for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { - acc += __shfl_xor_sync(uint32_t(-1), acc, mask); - } - accs[i] = acc; - } - - // NOTE(woosuk): A barrier is required because the shared memory space for logits - // is reused for the output. - __syncthreads(); - - // Perform reduction across warps. - float* out_smem = reinterpret_cast(shared_mem); -#pragma unroll - for (int i = NUM_WARPS; i > 1; i /= 2) { - int mid = i / 2; - // Upper warps write to shared memory. - if (warp_idx >= mid && warp_idx < i) { - float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - dst[row_idx] = accs[i]; - } - } - } - __syncthreads(); - - // Lower warps update the output. - if (warp_idx < mid) { - const float* src = &out_smem[warp_idx * HEAD_SIZE]; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - accs[i] += src[row_idx]; - } - } - } - __syncthreads(); - } - - // Write the final output. - if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; -#pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - convert_from_float(*(out_ptr + row_idx), accs[i]); - } - } - } -} - -} // namespace cacheflow - -#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ - cacheflow::single_query_cached_kv_attention_kernel \ - <<>>( \ - out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - scale, \ - block_tables_ptr, \ - context_lens_ptr, \ - max_num_blocks_per_seq, \ - query_stride); - -// TODO(woosuk): Tune NUM_THREADS. -template< - typename T, - int BLOCK_SIZE, - int NUM_THREADS = 128> -void single_query_cached_kv_attention_launcher( - torch::Tensor& out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int query_stride = query.stride(0); - - int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); - assert(head_size % thread_group_size == 0); - - T* out_ptr = reinterpret_cast(out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_context_len * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); - int shared_mem_size = std::max(logits_size, outputs_size); - - dim3 grid(num_heads, num_seqs); - dim3 block(NUM_THREADS); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - switch (head_size) { - case 32: - LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); - break; - case 64: - LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); - break; - case 80: - LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); - break; - case 96: - LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); - break; - case 128: - LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); - break; - case 160: - LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); - break; - case 192: - LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); - break; - case 256: - LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); - break; - default: - assert(false); - break; - } -} - -#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - single_query_cached_kv_attention_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len); - -void single_query_cached_kv_attention( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] - int block_size, - int max_context_len) { - // TODO(woosuk): Support BF16. - if (query.element_size() == 2) { - // Half. - if (block_size == 1) { - CALL_KERNEL_LAUNCHER(uint16_t, 1); - } else if (block_size == 2) { - CALL_KERNEL_LAUNCHER(uint16_t, 2); - } else if (block_size == 4) { - CALL_KERNEL_LAUNCHER(uint16_t, 4); - } else if (block_size == 8) { - CALL_KERNEL_LAUNCHER(uint16_t, 8); - } else if (block_size == 16) { - CALL_KERNEL_LAUNCHER(uint16_t, 16); - } else if (block_size == 32) { - CALL_KERNEL_LAUNCHER(uint16_t, 32); - } else if (block_size == 64) { - CALL_KERNEL_LAUNCHER(uint16_t, 64); - } else if (block_size == 128) { - CALL_KERNEL_LAUNCHER(uint16_t, 128); - } else if (block_size == 256) { - CALL_KERNEL_LAUNCHER(uint16_t, 256); - } else { - assert(false); - } - } else { - // Float. - assert(false); - } -} - -// namespace cacheflow { - -// // Grid: (num_heads, num_query_tokens). -// template< -// typename scalar_t, -// int HEAD_SIZE, -// int BLOCK_SIZE, -// int NUM_THREADS> -// __device__ void multi_query_cached_kv_attention_kernel_unoptimized_( -// scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] -// const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] -// const int seq_start_idx, -// const int seq_len, -// const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] -// const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] -// const float scale, -// const int* __restrict__ block_table, // [num_seqs, max_num_blocks_per_seq] -// const int context_len, -// const int max_num_blocks_per_seq, -// const int q_stride) { -// constexpr int THREAD_GROUP_SIZE = WARP_SIZE / BLOCK_SIZE; -// constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; -// const int thread_idx = threadIdx.x; -// const int warp_idx = thread_idx / WARP_SIZE; -// const int lane = thread_idx % WARP_SIZE; - -// const int head_idx = blockIdx.x; -// const int num_heads = gridDim.x; -// const int seq_idx = blockIdx.y; - -// // A vector type to store a part of a key or a query. -// // The vector size is configured in such a way that the threads in a thread group -// // fetch or comput 16 bytes at a time. -// // For example, if the size of a thread group is 4 and the data type is half, -// // then the vector size is 16 / (4 * sizeof(half)) == 2. -// constexpr int VEC_SIZE = 16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)); -// using K_vec = typename Vec::Type; -// using Q_vec = typename Vec::Type; - -// constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; -// constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; - -// const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; -// const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; - -// // Load the query to registers. -// // Each thread in a thread group has a different part of the query. -// // For example, if the the thread group size is 4, then the first thread in the group -// // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... -// // th vectors of the query, and so on. -// // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. -// const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; -// Q_vec q_vecs[NUM_VECS_PER_THREAD]; -// #pragma unroll -// for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { -// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; -// q_vecs[i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); -// } - -// // Memory planning. -// extern __shared__ char shared_mem[]; -// // NOTE(woosuk): We use FP32 logits and accumulation. -// float *logits = reinterpret_cast(shared_mem); -// // Workspace for reduction. -// __shared__ float red_smem[2 * NUM_WARPS]; - -// // x == THREAD_GROUP_SIZE * VEC_SIZE -// // Each thread group fetches x elements from the key at a time. -// constexpr int x = 16 / sizeof(scalar_t); -// float qk_max = -FLT_MAX; - -// const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; -// const int mask_boundary = context_len - seq_len + 1 + (seq_idx - seq_start_idx); - -// // Iterate over the key blocks. -// // Each warp fetches a block of keys for each iteration. -// // Each thread group in a warp fetches a key from the block, and computes -// // dot product with the query. -// for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { -// const int physical_block_number = block_table[block_idx]; -// const int physical_block_offset = thread_group_idx % BLOCK_SIZE; -// const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - -// // Load a key to registers. -// // Each thread in a thread group has a different part of the key. -// // For example, if the the thread group size is 4, then the first thread in the group -// // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th -// // vectors of the key, and so on. -// K_vec k_vecs[NUM_VECS_PER_THREAD]; -// #pragma unroll -// for (int i = 0; i < NUM_VECS_PER_THREAD; i++) { -// const scalar_t* k_ptr = k_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE -// + head_idx * HEAD_SIZE * BLOCK_SIZE -// + physical_block_offset * x; -// const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; -// const int offset1 = (vec_idx * VEC_SIZE) / x; -// const int offset2 = (vec_idx * VEC_SIZE) % x; -// k_vecs[i] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); -// } - -// // Compute dot product. -// // This includes a reduction across the threads in the same thread group. -// const float qk = scale * Qk_dot::dot(q_vecs, k_vecs); -// const bool mask = token_idx >= mask_boundary; - -// if (thread_group_offset == 0) { -// // Store the partial reductions to shared memory. -// // NOTE(woosuk): It is required to zero out the masked logits. -// logits[token_idx] = mask ? 0.f : qk; -// // Update the max value. -// qk_max = mask ? qk_max : fmaxf(qk_max, qk); -// } -// } - -// // Perform reduction across the threads in the same warp to get the -// // max qk value for each "warp" (not across the thread block yet). -// // The 0-th thread of each thread group already has its max qk value. -// #pragma unroll -// for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { -// qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); -// } -// if (lane == 0) { -// red_smem[warp_idx] = qk_max; -// } -// __syncthreads(); - -// // TODO(woosuk): Refactor this part. -// // Get the max qk value for the sequence. -// qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -// #pragma unroll -// for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { -// qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); -// } -// // Broadcast the max qk value to all threads. -// qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - -// // Get the sum of the exp values. -// float exp_sum = 0.f; -// for (int i = thread_idx; i < mask_boundary; i += NUM_THREADS) { -// float val = __expf(logits[i] - qk_max); -// logits[i] = val; -// exp_sum += val; -// } -// exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); - -// // Compute softmax. -// const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); -// for (int i = thread_idx; i < context_len; i += NUM_THREADS) { -// logits[i] *= inv_sum; -// } -// __syncthreads(); - -// // Each thread will fetch 16 bytes from the value cache at a time. -// constexpr int V_VEC_SIZE = 16 / sizeof(scalar_t); -// using V_vec = typename Vec::Type; -// using L_vec = typename FloatVec::Type; - -// constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; -// constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; -// constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; - -// float accs[NUM_ROWS_PER_THREAD]; -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// accs[i] = 0.f; -// } - -// for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { -// const int physical_block_number = block_table[block_idx]; -// const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; -// const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; -// L_vec logits_vec = *reinterpret_cast(logits + token_idx); - -// const scalar_t* v_ptr = v_cache + physical_block_number * num_heads * HEAD_SIZE * BLOCK_SIZE -// + head_idx * HEAD_SIZE * BLOCK_SIZE; -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -// if (row_idx < HEAD_SIZE) { -// const int offset = row_idx * BLOCK_SIZE + physical_block_offset; -// V_vec v_vec = *reinterpret_cast(v_ptr + offset); -// accs[i] += dot(logits_vec, cast_to_float(v_vec)); -// } -// } -// } - -// // Perform reduction within each warp. -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// float acc = accs[i]; -// #pragma unroll -// for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { -// acc += __shfl_xor_sync(uint32_t(-1), acc, mask); -// } -// accs[i] = acc; -// } - -// // NOTE(woosuk): A barrier is required because the shared memory space for logits -// // is reused for the output. -// __syncthreads(); - -// // Perform reduction across warps. -// float* out_smem = reinterpret_cast(shared_mem); -// #pragma unroll -// for (int i = NUM_WARPS; i > 1; i /= 2) { -// int mid = i / 2; -// // Upper warps write to shared memory. -// if (warp_idx >= mid && warp_idx < i) { -// float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { -// dst[row_idx] = accs[i]; -// } -// } -// } -// __syncthreads(); - -// // Lower warps update the output. -// if (warp_idx < mid) { -// const float* src = &out_smem[warp_idx * HEAD_SIZE]; -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { -// accs[i] += src[row_idx]; -// } -// } -// } -// __syncthreads(); -// } - -// // Write the final output. -// if (warp_idx == 0) { -// scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; -// #pragma unroll -// for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { -// const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; -// if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { -// convert_from_float(*(out_ptr + row_idx), accs[i]); -// } -// } -// } -// } - - -// // Grid: (num_heads, num_query_tokens). -// template< -// typename scalar_t, -// int HEAD_SIZE, -// int BLOCK_SIZE, -// int NUM_THREADS> -// __global__ void multi_query_cached_kv_attention_kernel( -// const int* cu_query_lens, // [num_prompts+1] -// const int* seq_prompt_mapping, // [num_seqs] mapping from seq_idx to prompt_idx -// scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] -// const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] -// const scalar_t* __restrict__ k_cache, // [num_blocks, num_heads, head_size/x, block_size, x] -// const scalar_t* __restrict__ v_cache, // [num_blocks, num_heads, head_size, block_size] -// const float scale, -// const int* __restrict__ block_tables, // [num_prompts, max_num_blocks_per_seq] -// const int* __restrict__ context_lens, // [num_prompts] -// const int max_num_blocks_per_seq, -// const int q_stride) { -// const int seq_idx = blockIdx.y; -// const int prompt_idx = seq_prompt_mapping[seq_idx]; -// const int seq_start_idx = cu_query_lens[prompt_idx]; -// const int seq_len = cu_query_lens[prompt_idx + 1] - seq_start_idx; -// const int* block_table = block_tables + prompt_idx * max_num_blocks_per_seq; -// const int context_len = context_lens[prompt_idx]; -// multi_query_cached_kv_attention_kernel_unoptimized_< -// scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>( -// out, -// q, -// seq_start_idx, -// seq_len, -// k_cache, -// v_cache, -// scale, -// block_table, -// context_len, -// max_num_blocks_per_seq, -// q_stride); -// } - -// } // namespace cacheflow - -// #define LAUNCH_MULTI_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ -// cacheflow::multi_query_cached_kv_attention_kernel \ -// <<>>( \ -// cu_query_lens_ptr, \ -// seq_prompt_mapping_ptr, \ -// out_ptr, \ -// query_ptr, \ -// key_cache_ptr, \ -// value_cache_ptr, \ -// scale, \ -// block_tables_ptr, \ -// context_lens_ptr, \ -// max_num_blocks_per_seq, \ -// query_stride); - - -// // TODO(woosuk): Tune NUM_THREADS. -// template< -// typename T, -// int BLOCK_SIZE, -// int NUM_THREADS = 128> -// void multi_query_cached_kv_attention_launcher( -// torch::Tensor& cu_query_lens, -// torch::Tensor& seq_prompt_mapping, -// torch::Tensor& out, -// torch::Tensor& query, -// torch::Tensor& key_cache, -// torch::Tensor& value_cache, -// float scale, -// torch::Tensor& block_tables, -// torch::Tensor& context_lens, -// int max_context_len) { -// int num_seqs = query.size(0); -// int num_heads = query.size(1); -// int head_size = query.size(2); -// int max_num_blocks_per_seq = block_tables.size(1); -// int query_stride = query.stride(0); - -// int* cu_query_lens_ptr = cu_query_lens.data_ptr(); -// int* seq_prompt_mapping_ptr = seq_prompt_mapping.data_ptr(); -// T* out_ptr = reinterpret_cast(out.data_ptr()); -// T* query_ptr = reinterpret_cast(query.data_ptr()); -// T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); -// T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); -// int* block_tables_ptr = block_tables.data_ptr(); -// int* context_lens_ptr = context_lens.data_ptr(); - -// constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; -// int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; -// int logits_size = padded_max_context_len * sizeof(float); -// int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); -// int shared_mem_size = std::max(logits_size, outputs_size); - -// dim3 grid(num_heads, num_seqs); -// dim3 block(NUM_THREADS); -// const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); -// switch (head_size) { -// case 32: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); -// break; -// case 64: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); -// break; -// case 80: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); -// break; -// case 96: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); -// break; -// case 128: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); -// break; -// case 160: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); -// break; -// case 192: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); -// break; -// case 256: -// LAUNCH_MULTI_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); -// break; -// default: -// assert(false); -// break; -// } -// } - -// void multi_query_cached_kv_attention( -// torch::Tensor& cu_query_lens, -// torch::Tensor& out, -// torch::Tensor& query, -// torch::Tensor& key_cache, -// torch::Tensor& value_cache, -// float scale, -// torch::Tensor& block_tables, -// torch::Tensor& context_lens, -// int block_size, -// int max_context_len) { - -// torch::Tensor query_lens = cu_query_lens.to(torch::kCPU); - -// int num_queries = query_lens.size(0) - 1; -// const int* query_lens_ptr = query_lens.data_ptr(); -// int num_seqs = query.size(0); - -// torch::Tensor cpu_tensor = torch::empty({num_seqs}, torch::dtype(torch::kInt32)); -// auto accessor = cpu_tensor.accessor(); -// for (int i = 0, query_cursor = 0; i < num_seqs; ++i) { -// if (i >= query_lens_ptr[query_cursor + 1]) { -// ++query_cursor; -// } -// accessor[i] = query_cursor; -// } - -// // TODO(suquark): This can be slow, as it to(torch::kCPU) and to(torch::kCUDA) -// // implicitly synchronizes the CPU and GPU. And we can avoid this issue by giving -// // the mapping as an input parameter. Let's do this optimization in a later PR. -// torch::Tensor seq_prompt_mapping = cpu_tensor.to(torch::kCUDA); - -// // TODO(woosuk): Support BF16. -// if (query.element_size() == 2) { -// // Half. -// if (block_size == 8) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else if (block_size == 16) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else if (block_size == 32) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else { -// assert(false); -// } -// } else if (query.element_size() == 4) { -// // Float. -// if (block_size == 8) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else if (block_size == 16) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else if (block_size == 32) { -// multi_query_cached_kv_attention_launcher( -// cu_query_lens, -// seq_prompt_mapping, -// out, -// query, -// key_cache, -// value_cache, -// scale, -// block_tables, -// context_lens, -// max_context_len); -// } else { -// assert(false); -// } -// } else { -// assert(false); -// } -// } - -#undef WARP_SIZE -#undef MAX -#undef MIN diff --git a/csrc/attention_utils.h b/csrc/attention_utils.h deleted file mode 100644 index 049555390715..000000000000 --- a/csrc/attention_utils.h +++ /dev/null @@ -1,165 +0,0 @@ -#pragma once - -#include "cuda_primitives.h" - -#include -#include - -#define MMHA_USE_FP32_ACUM_FOR_FMA -#define MMHA_USE_FP32_ACUM_FOR_OUT - -namespace cacheflow { - -// A vector type to store Q, K, V elements. -template -struct Vec {}; -template<> -struct Vec { - using Type = float; -}; -template<> -struct Vec { - using Type = float2; -}; -template<> -struct Vec { - using Type = float4; -}; -template<> -struct Vec { - using Type = uint16_t; -}; -template<> -struct Vec { - using Type = uint32_t; -}; -template<> -struct Vec { - using Type = uint2; -}; -template<> -struct Vec { - using Type = uint4; -}; - -template -struct FloatVec {}; -template<> -struct FloatVec { - using Type = float; -}; -template<> -struct FloatVec { - using Type = float2; -}; -template<> -struct FloatVec { - using Type = float4; -}; -template<> -struct FloatVec { - using Type = float; -}; -template<> -struct FloatVec { - using Type = float2; -}; -template<> -struct FloatVec { - using Type = Float4_; -}; -template<> -struct FloatVec { - using Type = Float8_; -}; - -template -inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) -{ - using K_vec_acum = typename FloatVec::Type; - // Compute the parallel products for Q*K^T (treat vector lanes separately). - K_vec_acum qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } - - // Finalize the reduction across lanes. - float qk = sum(qk_vec); -#pragma unroll - for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); - } - return qk; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Qk_dot { - template - static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) - { - return qk_dot_(q, k); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) -{ - float4 c; - float zero = 0.f; - asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" - " {%0, %1, %2, %3}, \n" - " {%4, %5}, \n" - " {%6}, \n" - " {%7, %7, %7, %7}; \n" - - : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) - : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 - using K_vec_acum = typename FloatVec::Type; - K_vec_acum qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - uint32_t qk_vec_ = float2_to_half2(qk_vec); - return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; -#else - return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; -#endif -#else - return 0.f; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Qk_dot { - template - static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) - { -#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) - return qk_hmma_dot_(q, k); -#else - return qk_dot_<4>(q, k); -#endif // defined MMHA_USE_HMMA_FOR_REDUCTION - } -}; - -} // namespace cacheflow - -#undef MMHA_USE_FP32_ACUM_FOR_FMA -#undef MMHA_USE_FP32_ACUM_FOR_OUT diff --git a/csrc/cuda_primitives.h b/csrc/cuda_primitives.h deleted file mode 100644 index 10e730fd7bda..000000000000 --- a/csrc/cuda_primitives.h +++ /dev/null @@ -1,1340 +0,0 @@ -#pragma once - -#include - -namespace cacheflow { -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Float8_ { - float2 x; - float2 y; - float2 z; - float2 w; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct Float4_ { - float2 x; - float2 y; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -struct bf16_4_t { - __nv_bfloat162 x; - __nv_bfloat162 y; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct bf16_8_t { - __nv_bfloat162 x; - __nv_bfloat162 y; - __nv_bfloat162 z; - __nv_bfloat162 w; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float add(float a, float b) -{ - return a + b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 add(float2 a, float2 b) -{ - float2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 add(float4 a, float4 b) -{ - float4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) -{ - return a + b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hadd2(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) -{ - bf16_4_t c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) -{ - bf16_8_t c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint16_t add(uint16_t a, uint16_t b) -{ - uint16_t c; - asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t add(uint32_t a, uint32_t b) -{ - uint32_t c; - asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 add(uint2 a, uint2 b) -{ - uint2 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 add(uint4 a, uint4 b) -{ - uint4 c; - c.x = add(a.x, b.x); - c.y = add(a.y, b.y); - c.z = add(a.z, b.z); - c.w = add(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint16_t float_to_half(float f) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; -#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better? - float zero = 0.f; - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f)); -#else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); -#endif - return tmp.u16[0]; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t float2_to_half2(float2 f) -{ - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); -#else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); -#endif - return tmp.u32; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float half_to_float(uint16_t h) -{ - float f; - asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); - return f; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 half2_to_float2(uint32_t v) -{ - uint16_t lo, hi; - asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); - return make_float2(half_to_float(lo), half_to_float(hi)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 add(uint32_t a, float2 fb) -{ - float2 fa = half2_to_float2(a); - return add(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ add(uint2 a, Float4_ fb) -{ - Float4_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ add(uint4 a, Float8_ fb) -{ - Float8_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - fc.z = add(a.z, fb.z); - fc.w = add(a.w, fb.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t h0_h0(uint16_t a) -{ - uint32_t b; - asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); - return b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(float a, float b, float c) -{ - return a * b + c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(float2 a, float2 b, float2 c) -{ - float2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(float a, float2 b, float2 c) -{ - float2 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 fma(float4 a, float4 b, float4 c) -{ - float4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 fma(float a, float4 b, float4 c) -{ - float4 d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) -{ - Float4_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) -{ - Float8_ d; - d.x = fma(a, b.x, c.x); - d.y = fma(a, b.y, c.y); - d.z = fma(a, b.z, c.z); - d.w = fma(a, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float2 add(__nv_bfloat162 a, float2 fb) -{ - float2 fa = bf1622float2(a); - return add(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) -{ - Float4_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) -{ - Float8_ fc; - fc.x = add(a.x, fb.x); - fc.y = add(a.y, fb.y); - fc.z = add(a.z, fb.z); - fc.w = add(a.w, fb.w); - return fc; -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) -{ - uint32_t d; - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) -{ - return fma(h0_h0(a), b, c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) -{ - uint2 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) -{ - uint32_t s = h0_h0(a); - uint2 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) -{ - uint4 d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) -{ - uint32_t s = h0_h0(a); - uint4 d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - d.z = fma(s, b.z, c.z); - d.w = fma(s, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(uint16_t a, uint16_t b, float fc) -{ - float fa = half_to_float(a); - float fb = half_to_float(b); - return fa * fb + fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) -{ - float2 fa = half2_to_float2(a); - float2 fb = half2_to_float2(b); - return fma(fa, fb, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) -{ - return fma(h0_h0(a), b, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) -{ - Float4_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) -{ - uint32_t s = h0_h0(a); - Float4_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) -{ - Float8_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - fd.z = fma(a.z, b.z, fc.z); - fd.w = fma(a.w, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) -{ - uint32_t s = h0_h0(a); - Float8_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - fd.z = fma(s, b.z, fc.z); - fd.w = fma(s, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef ENABLE_BF16 -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hfma2(a, b, c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hfma2(bf162bf162(a), b, c); -} -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) -{ - bf16_4_t d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_4_t d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) -{ - bf16_8_t d; - d.x = fma(a.x, b.x, c.x); - d.y = fma(a.y, b.y, c.y); - d.z = fma(a.z, b.z, c.z); - d.w = fma(a.w, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_8_t d; - d.x = fma(s, b.x, c.x); - d.y = fma(s, b.y, c.y); - d.z = fma(s, b.z, c.z); - d.w = fma(s, b.w, c.w); - return d; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) -{ - return __bfloat162float(a) * __bfloat162float(b) + fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) -{ - float2 fa = bf1622float2(a); - float2 fb = bf1622float2(b); - return fma(fa, fb, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) -{ - return fma(bf162bf162(a), b, fc); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) -{ - Float4_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) -{ - __nv_bfloat162 s = bf162bf162(a); - Float4_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) -{ - Float8_ fd; - fd.x = fma(a.x, b.x, fc.x); - fd.y = fma(a.y, b.y, fc.y); - fd.z = fma(a.z, b.z, fc.z); - fd.w = fma(a.w, b.w, fc.w); - return fd; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) -{ - __nv_bfloat162 s = bf162bf162(a); - Float8_ fd; - fd.x = fma(s, b.x, fc.x); - fd.y = fma(s, b.y, fc.y); - fd.z = fma(s, b.z, fc.z); - fd.w = fma(s, b.w, fc.w); - return fd; -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ Acc mul(A a, B b); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(float a, float b) -{ - return a * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(float2 a, float2 b) -{ - float2 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(float a, float2 b) -{ - float2 c; - c.x = a * b.x; - c.y = a * b.y; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float4 mul(float4 a, float4 b) -{ - float4 c; - c.x = a.x * b.x; - c.y = a.y * b.y; - c.z = a.z * b.z; - c.w = a.w * b.w; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float4 mul(float a, float4 b) -{ - float4 c; - c.x = a * b.x; - c.y = a * b.y; - c.z = a * b.z; - c.w = a * b.w; - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint16_t mul(uint16_t a, uint16_t b) -{ - uint16_t c; - asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint32_t mul(uint32_t a, uint32_t b) -{ - uint32_t c; - asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint32_t mul(uint16_t a, uint32_t b) -{ - return mul(h0_h0(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint2 mul(uint2 a, uint2 b) -{ - uint2 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint2 mul(uint16_t a, uint2 b) -{ - uint32_t s = h0_h0(a); - uint2 c; - c.x = mul(s, b.x); - c.y = mul(s, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint4 mul(uint4 a, uint4 b) -{ - uint4 c; - c.x = mul(a.x, b.x); - c.y = mul(a.y, b.y); - c.z = mul(a.z, b.z); - c.w = mul(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ uint4 mul(uint16_t a, uint4 b) -{ - uint32_t s = h0_h0(a); - uint4 c; - c.x = mul(s, b.x); - c.y = mul(s, b.y); - c.z = mul(s, b.z); - c.w = mul(s, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(uint16_t a, uint16_t b) -{ - float fa = half_to_float(a); - float fb = half_to_float(b); - return fa * fb; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(uint32_t a, uint32_t b) -{ - float2 fa = half2_to_float2(a); - float2 fb = half2_to_float2(b); - return mul(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(uint16_t a, uint32_t b) -{ - return mul(h0_h0(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(uint2 a, uint2 b) -{ - Float4_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(uint16_t a, uint2 b) -{ - uint32_t s = h0_h0(a); - Float4_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(uint4 a, uint4 b) -{ - Float8_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - fc.z = mul(a.z, b.z); - fc.w = mul(a.w, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(uint16_t a, uint4 b) -{ - uint32_t s = h0_h0(a); - Float8_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - fc.z = mul(s, b.z); - fc.w = mul(s, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -template<> -inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - return __hmul(a, b); -#else - return bf16hmul(a, b); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hmul2(a, b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) -{ - return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) -{ - bf16_4_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_4_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) -{ - bf16_8_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); - c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); - c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - bf16_8_t c; - c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); - c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); - c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); - c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) -{ - float fa = (float)a; - float fb = (float)b; - return fa * fb; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) -{ - float2 fa = bf1622float2(a); - float2 fb = bf1622float2(b); - return mul(fa, fb); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) -{ - return mul(bf162bf162(a), b); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) -{ - Float4_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - Float4_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) -{ - Float8_ fc; - fc.x = mul(a.x, b.x); - fc.y = mul(a.y, b.y); - fc.z = mul(a.z, b.z); - fc.w = mul(a.w, b.w); - return fc; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) -{ - __nv_bfloat162 s = bf162bf162(a); - Float8_ fc; - fc.x = mul(s, b.x); - fc.y = mul(s, b.y); - fc.z = mul(s, b.z); - fc.w = mul(s, b.w); - return fc; -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float v) -{ - return v; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float2 v) -{ - return v.x + v.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(float4 v) -{ - return v.x + v.y + v.z + v.w; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ float sum(__nv_bfloat162 v) -{ - float2 vf = bf1622float2(v); - return vf.x + vf.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(bf16_4_t v) -{ - return sum(v.x) + sum(v.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(bf16_8_t v) -{ - return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint16_t v) -{ - return half_to_float(v); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint32_t v) -{ - float2 tmp = half2_to_float2(v); - return tmp.x + tmp.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint2 v) -{ - uint32_t c = add(v.x, v.y); - return sum(c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(uint4 v) -{ -#if 1 - uint32_t c = add(v.x, v.y); - c = add(c, v.z); - c = add(c, v.w); -#else - uint32_t c = add(v.x, v.y); - uint32_t d = add(v.z, v.w); - c = add(c, d); -#endif - return sum(c); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(Float4_ v) -{ - return v.x.x + v.x.y + v.y.x + v.y.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float sum(Float8_ v) -{ - return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float dot(float a, float b) -{ - return a * b; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float dot(float2 a, float2 b) -{ - float2 c = mul(a, b); - return c.x + c.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float dot(Float4_ a, Float4_ b) -{ - float2 acc = mul(a.x, b.x); - acc = fma(a.y, b.y, acc); - return acc.x + acc.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float dot(Float8_ a, Float8_ b) -{ - float2 acc = mul(a.x, b.x); - acc = fma(a.y, b.y, acc); - acc = fma(a.z, b.z, acc); - acc = fma(a.w, b.w, acc); - return acc.x + acc.y; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float dot(T a, T b) -{ - return sum(mul(a, b)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float dot(T a, T b) -{ - return sum(mul(a, b)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void zero(uint16_t& dst) -{ - dst = uint16_t(0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void zero(T& dst) -{ - constexpr int WORDS = sizeof(T) / 4; - union { - T raw; - uint32_t words[WORDS]; - } tmp; -#pragma unroll - for (int ii = 0; ii < WORDS; ++ii) { - tmp.words[ii] = 0u; - } - dst = tmp.raw; -} - - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float& dst, float src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint16_t& dst, float src) -{ - dst = float_to_half(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint32_t& dst, float2 src) -{ - dst = float2_to_half2(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) -{ - dst = __float2bfloat16(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst = __float22bfloat162_rn(src); -#else - dst = __floats2bfloat162_rn(src.x, src.y); -#endif -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint2& dst, Float4_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint4& dst, Float8_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); - dst.z = float2_to_half2(src.z); - dst.w = float2_to_half2(src.w); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); - dst.z = __float22bfloat162_rn(src.z); - dst.w = __float22bfloat162_rn(src.w); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); - dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); - dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); -#endif -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float2& dst, float2 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float4& dst, float4 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(float4 u) -{ - return u.x; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(uint4 u) -{ - float2 tmp = half2_to_float2(u.x); - return tmp.x; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// inline __device__ float cast_to_float(float u) -// { -// return u; -// } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// inline __device__ float2 cast_to_float(float2 u) -// { -// return u; -// } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// inline __device__ float4 cast_to_float(float4 u) -// { -// return u; -// } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// inline __device__ Float4_ cast_to_float(Float4_ u) -// { -// return u; -// } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// inline __device__ Float8_ cast_to_float(Float8_ u) -// { -// return u; -// } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float cast_to_float(uint16_t u) -{ - return half_to_float(u); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 cast_to_float(uint32_t u) -{ - return half2_to_float2(u); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ cast_to_float(uint2 u) -{ - Float4_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - return tmp; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ cast_to_float(uint4 u) -{ - Float8_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - tmp.z = half2_to_float2(u.z); - tmp.w = half2_to_float2(u.w); - return tmp; -} - -} diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 84372ed2dd60..a9606b106721 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -1,7 +1,7 @@ #include #include -#include "reduction_utils.h" +#include "reduction_utils.cuh" namespace cacheflow { diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh new file mode 100644 index 000000000000..7f904c08698a --- /dev/null +++ b/csrc/reduction_utils.cuh @@ -0,0 +1,34 @@ +#pragma once + +namespace cacheflow { + +template +__inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val += __shfl_xor_sync(0xffffffff, val, mask, 32); + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + return val; +} + +} // namespace cacheflow diff --git a/csrc/reduction_utils.h b/csrc/reduction_utils.h deleted file mode 100644 index f977ab70f1fe..000000000000 --- a/csrc/reduction_utils.h +++ /dev/null @@ -1,76 +0,0 @@ -#pragma once - -namespace cacheflow { - -template -inline __device__ float block_sum(float* red_smem, float sum) -{ - - // Decompose the thread index into warp / lane. - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; - - // Compute the sum per warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Warp leaders store the data to shared memory. - if (lane == 0) { - red_smem[warp] = sum; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The warps compute the final sums. - if (lane < WARPS_PER_BLOCK) { - sum = red_smem[lane]; - } - -// Parallel reduction inside the warp. -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Broadcast to other threads. - return __shfl_sync(uint32_t(-1), sum, 0); -} - -#define FINAL_MASK 0xffffffff - -template -__inline__ __device__ T warpReduceSum(T val) -{ -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); - return val; -} - -/* Calculate the sum of all elements in a block */ -template -__inline__ __device__ T blockReduceSum(T val) -{ - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - val = warpReduceSum(val); - - if (lane == 0) - shared[wid] = val; - - __syncthreads(); - - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); - val = warpReduceSum(val); - - return val; -} - -} // namespace cacheflow diff --git a/setup.py b/setup.py index e96c73033379..bac0b0f18c74 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ # Attention kernels. attention_extension = cpp_extension.CUDAExtension( name='cacheflow.attention_ops', - sources=['csrc/attention.cpp', 'csrc/attention_kernels.cu'], + sources=['csrc/attention.cpp', 'csrc/attention/attention_kernels.cu'], extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS}, ) ext_modules.append(attention_extension) diff --git a/tests/kernels/attention.py b/tests/kernels/attention.py index 7c2f350f1140..4567315d2e7a 100644 --- a/tests/kernels/attention.py +++ b/tests/kernels/attention.py @@ -271,78 +271,6 @@ def test_multi_query_kv_attention( assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) -def test_multi_query_cached_kv_attention( - num_queries: int, - num_heads: int, - head_size: int, - block_size: int, - num_blocks: int, - dtype: torch.dtype, -) -> None: - query_lens = random.sample(range(1, MAX_SEQ_LEN), num_queries) - cu_query_lens = [0] - for query_len in query_lens: - cu_query_lens.append(cu_query_lens[-1] + query_len) - num_total_tokens = cu_query_lens[-1] - - qkv = torch.randn( - num_total_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda') - query, _, _ = qkv.unbind(dim=1) - x = 16 // torch.tensor([], dtype=dtype).element_size() - key_block_shape = (num_heads, head_size // x, block_size, x) - key_cache = torch.randn( - size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda') - value_block_shape = (num_heads, head_size, block_size) - value_cache = torch.randn( - size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda') - - cu_query_lens = torch.tensor(cu_query_lens, dtype=torch.int, device='cuda') - context_lens = [ - query_len + random.randint(0, MAX_SEQ_LEN - query_len) - for query_len in query_lens - ] - max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda') - - max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size - block_tables = [] - for _ in range(num_queries): - block_table = [ - random.randint(0, num_blocks - 1) - for _ in range(max_num_blocks_per_seq) - ] - block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda') - - scale = float(1.0 / (head_size ** 0.5)) - output = torch.empty( - num_total_tokens, num_heads, head_size, dtype=dtype, device='cuda') - - attention_ops.multi_query_cached_kv_attention( - cu_query_lens, - output, - query, - key_cache, - value_cache, - scale, - block_tables, - context_lens, - block_size, - max_context_len, - ) - - ref_output = ref_multi_query_cached_kv_attention( - cu_query_lens, - query, - key_cache, - value_cache, - block_tables, - context_lens, - dtype, - ) - assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) - - @torch.inference_mode() def test_attention(seed: int) -> None: # NOTE(woosuk): Even when the seed is fixed, there is a chance that @@ -364,24 +292,6 @@ def test_attention(seed: int) -> None: dtype=dtype, ) - # NOTE(siyuan): Same as above. Re-run the test if it fails. Also - # note that the test is also more likely to fail due to the much - # larger amount of tokens in the input may increase the variance. - for dtype in [torch.half, torch.float]: - for block_size in [8, 16, 32]: - for head_size in [32, 64, 80, 96, 128, 160, 192, 256]: - print(f'Testing multi_query_cached_kv_attention with ' - f'dtype={dtype}, block_size={block_size}, ' - f'head_size={head_size}') - test_multi_query_cached_kv_attention( - num_queries=11, - num_heads=3, - head_size=head_size, - block_size=block_size, - num_blocks=1024, - dtype=dtype, - ) - # NOTE(woosuk): FlashAttention does not support FP32. for dtype in [torch.half]: # NOTE(woosuk): FlashAttention does not support head_size > 128.