diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 821be5cd..67a00bed 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -184,7 +184,8 @@ __device__ __forceinline__ void sync_state(state_t& st, float* smem, f * \tparam vec_size A template integer indicates the vector size * \tparam bdx A template integer indicates the block size in x dimension * \tparam bdy A template integer indicates the block size in y dimension - * \tparam DTypeIn A template type indicates the input data type + * \tparam DTypeQ A template type indicates the query data type + * \tparam DTypeKV A template type indicates the key-value data type * \tparam DTypeOut A template type indicates the output data type * \param q [num_qo_heads, head_dim] The query matrix * \param k [seq_len, num_kv_heads, head_dim] The key matrix in kv-cache @@ -203,9 +204,9 @@ __device__ __forceinline__ void sync_state(state_t& st, float* smem, f */ template -__global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, - DTypeIn* __restrict__ v, DTypeOut* __restrict__ o, + uint32_t bdy, uint32_t bdz, typename DTypeQ, typename DTypeKV, typename DTypeOut> +__global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* __restrict__ k, + DTypeKV* __restrict__ v, DTypeOut* __restrict__ o, DTypeOut* __restrict__ tmp, tensor_info_t info, float sm_scale, float rope_rcp_scale, @@ -224,11 +225,11 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* uint32_t seq_len = info.kv_len; extern __shared__ uint8_t smem[]; - DTypeIn* k_smem = (DTypeIn*)smem; - DTypeIn* v_smem = (DTypeIn*)(smem + num_stages_smem * bdy * tile_size_per_bdx * bdz * head_dim * - sizeof(DTypeIn)); + DTypeKV* k_smem = (DTypeKV*)smem; + DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * bdy * tile_size_per_bdx * bdz * head_dim * + sizeof(DTypeKV)); float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * tile_size_per_bdx * bdz * head_dim * - sizeof(DTypeIn)); + sizeof(DTypeKV)); uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; vec_t q_vec; @@ -260,7 +261,7 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* // preload k tiles and v tiles uint32_t producer_kv_idx_base = chunk_start; - constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; + constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size * 8; #pragma unroll for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { @@ -356,10 +357,10 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* } template + uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz, typename DTypeQ, + typename DTypeKV, typename DTypeOut> __global__ void BatchDecodeWithPaddedKVCacheKernel( - DTypeIn* __restrict__ q, DTypeIn* __restrict__ k, DTypeIn* __restrict__ v, + DTypeQ* __restrict__ q, DTypeKV* __restrict__ k, DTypeKV* __restrict__ v, DTypeOut* __restrict__ o, float* __restrict__ lse, tensor_info_t info, float sm_scale, float rope_rcp_scale, float rope_rcp_theta) { @@ -376,9 +377,9 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel( uint32_t seq_len = info.kv_len; extern __shared__ uint8_t smem[]; - DTypeIn* k_smem = (DTypeIn*)smem; - DTypeIn* v_smem = (DTypeIn*)(smem + num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn)); - float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeIn)); + DTypeKV* k_smem = (DTypeKV*)smem; + DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeKV)); + float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * bdz * head_dim * sizeof(DTypeKV)); uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; vec_t q_vec; @@ -407,7 +408,7 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel( // preload k tiles and v tiles uint32_t producer_kv_idx_base = 0; - constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; + constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size * 8; #pragma unroll for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { cp_async::pred_load( @@ -495,7 +496,8 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel( * \tparam bdy A template integer indicates the block size in y dimension * \tparam bdz A template integer indicates the block size in z dimension * \tparam page_storage Whether to store indices or pointers of each active page - * \tparam DTypeIn A template type indicates the input data type + * \tparam DTypeQ A template type indicates the query data type + * \tparam DTypeKV A template type indicates the key-value data type * \tparam DTypeOut A template type indicates the output data type * \tparam IdType A template type indicates the index data type * \param q [batch_size, num_qo_heads, head_dim] The query matrix @@ -512,11 +514,11 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel( */ template + PageStorage page_storage, QKVLayout kv_layout, typename DTypeQ, typename DTypeKV, + typename DTypeOut, typename IdType> __global__ void BatchDecodeWithPagedKVCacheKernel( - DTypeIn* __restrict__ q, IdType* __restrict__ q_offset, - paged_kv_t paged_kv, + DTypeQ* __restrict__ q, IdType* __restrict__ q_offset, + paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* __restrict__ o, DTypeOut* __restrict__ tmp_v, float* __restrict__ tmp_s, float* __restrict__ lse, bool* __restrict__ block_valid_mask, float sm_scale, float rope_rcp_scale, @@ -548,13 +550,13 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( partition_kv ? kv_partition_info.batch_idx_map[batch_idx] : batch_idx; extern __shared__ uint8_t smem[]; - DTypeIn* k_smem = (DTypeIn*)smem; - DTypeIn* v_smem = (DTypeIn*)(smem + num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim * - sizeof(DTypeIn)); - DTypeIn** k_ptrs_smem = (DTypeIn**)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * - head_dim * sizeof(DTypeIn)); + DTypeKV* k_smem = (DTypeKV*)smem; + DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim * + sizeof(DTypeKV)); + DTypeKV** k_ptrs_smem = (DTypeKV**)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * + head_dim * sizeof(DTypeKV)); float* smem_md = (float*)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim * - sizeof(DTypeIn)); + sizeof(DTypeKV)); const uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; vec_t q_vec; @@ -582,7 +584,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( // preload k/v tiles uint32_t stage_idx = 0; - constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; + constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size * 8; // NOTE(Zihao): when CUDAGraph is disabled, gridDim.x = batch_size, otherwise, // we guarantee that indptr array length is greater than or equal to batch_size + 1, // so we can safely access paged_kv.indptr[batch_idx + 1] @@ -597,7 +599,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( } block.sync(); - DTypeIn* k_ptrs[tile_size_per_bdx]; + DTypeKV* k_ptrs[tile_size_per_bdx]; #pragma unroll for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { #pragma unroll @@ -615,7 +617,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( cp_async::commit_group(); #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - DTypeIn* v_ptr = k_ptrs[j] + paged_kv.kv_offset_delta(); + DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_offset_delta(); cp_async::pred_load( v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, @@ -684,7 +686,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( // load v tiles #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - DTypeIn* v_ptr = k_ptrs[j] + paged_kv.kv_offset_delta(); + DTypeKV* v_ptr = k_ptrs[j] + paged_kv.kv_offset_delta(); cp_async::pred_load( v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, @@ -737,7 +739,8 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo /*! * \brief FlashAttention decoding with kv-cache for a single request - * \tparam DTypeIn A template type indicates the input data type + * \tparam DTypeQ A template type indicates the query data type + * \tparam DTypeKV A template type indicates the key-value data type * \tparam DTypeOut A template type indicates the output data type * \param q The query matrix, shape: [num_qo_heads, head_dim] * \param k The key matrix in kv-cache, shape: [seq_len, num_kv_heads, head_dim] @@ -758,33 +761,34 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeo * \return status Indicates whether CUDA calls are successful */ template -cudaError_t SingleDecodeWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, + typename DTypeOut> +cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, uint32_t num_kv_heads, uint32_t seq_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { const float rope_rcp_scale = 1.f / rope_scale; const float rope_rcp_theta = 1.f / rope_theta; const uint32_t num_qo_heads = num_kv_heads * GROUP_SIZE; - constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); + constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; static_assert(bdx <= 32U); constexpr uint32_t bdy = GROUP_SIZE; constexpr uint32_t num_threads = - std::max(get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeIn)), bdx * bdy); + std::max(get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeKV)), bdx * bdy); constexpr uint32_t bdz = num_threads / (bdx * bdy); tensor_info_t info(1, seq_len, num_kv_heads); - constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeIn) == 1 ? 2U : 8U) : 1U; + constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 8U) : 1U; const uint32_t smem_size = - 2U * num_stages_smem * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeIn) + + 2U * num_stages_smem * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) + 2U * bdy * bdz * sizeof(float); if (seq_len <= 256 || tmp == nullptr) { // no need to use partition-kv kernel auto kernel = SingleDecodeWithKVCacheKernel; + DTypeQ, DTypeKV, DTypeOut>; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -805,7 +809,7 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v // use partition-kv kernel auto kernel = SingleDecodeWithKVCacheKernel; + bdy, bdz, DTypeQ, DTypeKV, DTypeOut>; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -845,9 +849,9 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v } template + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType> cudaError_t BatchDecodeWithPagedKVCacheDispatched( - DTypeIn* q, IdType* q_offset, paged_kv_t paged_kv, + DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, float* lse, bool* block_valid_mask, uint32_t padded_batch_size, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { @@ -856,17 +860,17 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t num_qo_heads = num_kv_heads * GROUP_SIZE; - constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); + constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; static_assert(bdx <= 32); constexpr uint32_t bdy = GROUP_SIZE; constexpr uint32_t num_threads = std::max(128U, bdx * bdy); constexpr uint32_t bdz = num_threads / (bdx * bdy); - constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeIn) == 1 ? 2U : 4U) : 1U; + constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; const uint32_t smem_size = - 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeIn) + - std::max(tile_size_per_bdx * num_threads * sizeof(DTypeIn*), 2 * bdy * bdz * sizeof(float)); + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); if (tmp_v == nullptr) { // do not use partition-kv kernel @@ -875,7 +879,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( auto kernel = BatchDecodeWithPagedKVCacheKernel; + bdz, page_storage, kv_layout, DTypeQ, DTypeKV, DTypeOut, IdType>; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); void* args[] = {(void*)&q, @@ -896,7 +900,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( auto partition_kv_kernel = BatchDecodeWithPagedKVCacheKernel; + kv_layout, DTypeQ, DTypeKV, DTypeOut, IdType>; FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( partition_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); void* args[] = {(void*)&q, @@ -926,7 +930,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( /*! * \brief FlashAttention decoding cuda kernel with paged kv-cache for batched requests * \tparam page_storage Whether to store indices or pointers of each active page - * \tparam DTypeIn A template type indicates the input data type + * \tparam DTypeQ A template type indicates the query data type + * \tparam DTypeKV A template type indicates the key-value data type * \tparam DTypeOut A template type indicates the output data type * \tparam IdType A template type indicates the index data type used in paged kv-cache * \param q [batch_size, num_qo_heads, head_dim] The query matrix @@ -942,8 +947,8 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( * \return status Indicates whether CUDA calls are successful */ template -cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut> +cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, float sm_scale, float rope_scale, @@ -952,7 +957,7 @@ cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DType const float rope_rcp_theta = 1.f / rope_theta; const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; - constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); + constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; static_assert(bdx <= 32); @@ -961,12 +966,12 @@ cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DType constexpr uint32_t bdz = num_threads / (bdx * bdy); const uint32_t smem_size = - 2 * num_stages_smem * bdy * bdz * HEAD_DIM * sizeof(DTypeIn) + 2 * bdy * bdz * sizeof(float); + 2 * num_stages_smem * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + 2 * bdy * bdz * sizeof(float); dim3 nblks(batch_size, num_kv_heads); dim3 nthrs(bdx, bdy, bdz); auto kernel = BatchDecodeWithPaddedKVCacheKernel; + vec_size, bdx, bdy, bdz, DTypeQ, DTypeKV, DTypeOut>; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); tensor_info_t info(1, padded_kv_len, num_kv_heads); diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index 41724e1d..6d8bf3e0 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -31,11 +31,11 @@ namespace flashinfer { template + PageStorage page_storage, QKVLayout kv_layout, typename DTypeQ, typename DTypeKV, + typename DTypeOut, typename IdType> __global__ void BatchDecodeWithPagedKVCacheKernel( - DTypeIn* __restrict__ q, IdType* __restrict__ q_offset, - paged_kv_t paged_kv, + DTypeQ* __restrict__ q, IdType* __restrict__ q_offset, + paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* __restrict__ o, DTypeOut* __restrict__ tmp_v, float* __restrict__ tmp_s, float* __restrict__ lse, bool* __restrict__ block_valid_mask, float sm_scale, float rope_rcp_scale, @@ -86,7 +86,7 @@ std::pair PartitionPagedKVCacheBinarySearchMinNumPagePerBatc * \brief Estimate the temporary buffer size and the maximum grid size for the * partition-kv BatchDecodeWithPagedKVCache kernel * \tparam page_storage Whether to store indices or pointers of each active page - * \tparam DTypeIn A template type indicates the input data type + * \tparam DTypeKV A template type indicates the key-value data type * \tparam DTypeOut A template type indicates the output data type * \tparam IdType A template type indicates the index data type * \param tmp_size The estimated temporary buffer size, return 0 if not use partition-kv kernel @@ -100,27 +100,29 @@ std::pair PartitionPagedKVCacheBinarySearchMinNumPagePerBatc * \return status Indicates whether CUDA calls are successful */ template + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType> cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( uint32_t& tmp_size, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, uint32_t& new_batch_size, uint32_t batch_size, IdType* kv_indptr, const uint32_t num_qo_heads, const uint32_t page_size, bool enable_cuda_graph, cudaStream_t stream) { - constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeIn), HEAD_DIM / 32UL); + constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); constexpr uint32_t num_stages_smem = 2U; constexpr uint32_t bdx = HEAD_DIM / vec_size; static_assert(bdx <= 32); constexpr uint32_t bdy = GROUP_SIZE; constexpr uint32_t num_threads = std::max(128U, bdx * bdy); constexpr uint32_t bdz = num_threads / (bdx * bdy); - constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeIn) == 1 ? 2U : 4U) : 1U; + constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; const uint32_t smem_size = - 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeIn) + - std::max(tile_size_per_bdx * num_threads * sizeof(DTypeIn*), 2 * bdy * bdz * sizeof(float)); + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); + // Note that the dtype of Q should not impact the cudaOccupancyMaxActiveBlocksPerMultiprocessor + // return, which is why we just use DTypeKV as it simplifies the API. auto partition_kv_kernel = BatchDecodeWithPagedKVCacheKernel< /*partition_kv=*/true, POS_ENCODING_MODE, num_stages_smem, tile_size_per_bdx, vec_size, bdx, - bdy, bdz, page_storage, kv_layout, DTypeIn, DTypeOut, IdType>; + bdy, bdz, page_storage, kv_layout, DTypeQ, DTypeKV, DTypeOut, IdType>; int num_blocks_per_sm = 0; int num_sm = 0; int dev_id = 0; @@ -294,7 +296,7 @@ class BatchDecodeHandler { bool* GetBlockValidMask() const { return block_valid_mask_; } template + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType> cudaError_t BeginForwardDispatched(void* buffer, size_t workspace_size_in_bytes, IdType* indptr, IdType* last_page_len, uint32_t batch_size, uint32_t num_qo_heads, uint32_t page_size) { @@ -303,8 +305,8 @@ class BatchDecodeHandler { uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size; auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched; + kv_layout, POS_ENCODING_MODE, + DTypeQ, DTypeKV, DTypeOut, IdType>; FLASHINFER_CUDA_CALL(work_estimation_func(tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr, num_qo_heads, page_size, diff --git a/include/flashinfer/decode_attention_decl.cuh b/include/flashinfer/decode_attention_decl.cuh index b28aefa2..40998418 100644 --- a/include/flashinfer/decode_attention_decl.cuh +++ b/include/flashinfer/decode_attention_decl.cuh @@ -29,35 +29,35 @@ namespace flashinfer { template -cudaError_t SingleDecodeWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, + PosEncodingMode pos_encoding_mode, typename DTypeQ, typename DTypeKV, typename DTypeOut> +cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, uint32_t num_kv_heads, uint32_t seq_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); template + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType> cudaError_t BatchDecodeWithPagedKVCacheDispatched( - DTypeIn* q, IdType* q_offset, paged_kv_t paged_kv, + DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s, float* lse, bool* block_valid_mask, uint32_t padded_batch_size, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); template -cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut> +cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); template + PosEncodingMode POS_ENCODING_MODE, typename DTypeQ, typename DTypeKV, typename DTypeOut, typename IdType> cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched( - BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* lse, + BatchDecodeHandler* handler, DTypeQ* q, IdType* q_offset, + paged_kv_t paged_kv, DTypeOut* o, float* lse, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { - paged_kv_t new_paged_kv = paged_kv; + paged_kv_t new_paged_kv = paged_kv; kv_partition_info_t kv_partition_info; DTypeOut* tmp_v = handler->GetTempV(); float* tmp_s = handler->GetTempS(); @@ -82,7 +82,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapperDispatched( } return BatchDecodeWithPagedKVCacheDispatched( + POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( q, q_offset, new_paged_kv, kv_partition_info, o, tmp_v, tmp_s, lse, handler->GetBlockValidMask(), handler->GetPaddedBatchSize(), sm_scale, rope_scale, rope_theta, stream); diff --git a/python/csrc/batch_decode.cu b/python/csrc/batch_decode.cu index b58b5efb..1af97f01 100644 --- a/python/csrc/batch_decode.cu +++ b/python/csrc/batch_decode.cu @@ -33,8 +33,7 @@ std::vector batch_decode_with_padded_kv_cache( CHECK_SHAPE(k_padded, v_padded); CHECK_EQ(q.size(0), k_padded.size(0)); CHECK_EQ(q.size(2), k_padded.size(3)); - CHECK_EQ(q.scalar_type(), k_padded.scalar_type()); - CHECK_EQ(q.scalar_type(), v_padded.scalar_type()); + CHECK_EQ(v_padded.scalar_type(), k_padded.scalar_type()); unsigned int batch_size = q.size(0); unsigned int num_qo_heads = q.size(1); unsigned int head_dim = q.size(2); @@ -58,53 +57,57 @@ std::vector batch_decode_with_padded_kv_cache( } if (is_float8_tensor(q)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), c_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - nv_half* tmp = nullptr; - cudaError_t status = - BatchDecodeWithPaddedKVCacheDispatched( - static_cast(q.data_ptr()), - static_cast(k_padded.data_ptr()), - static_cast(v_padded.data_ptr()), - static_cast(o.data_ptr()), - /*tmp=*/tmp, - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - batch_size, padded_kv_len, num_qo_heads, sm_scale, rope_scale, rope_theta, - torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPaddedKVCache failed with error code ", status); - return true; + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k_padded.scalar_type(), kv_type, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + nv_half* tmp = nullptr; + cudaError_t status = + BatchDecodeWithPaddedKVCacheDispatched( + static_cast(q.data_ptr()), + static_cast(k_padded.data_ptr()), + static_cast(v_padded.data_ptr()), + static_cast(o.data_ptr()), + /*tmp=*/tmp, + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + batch_size, padded_kv_len, num_qo_heads, sm_scale, rope_scale, rope_theta, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPaddedKVCache failed with error code ", status); + return true; + }); }); - }); + }); }); }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { - c_type* tmp = nullptr; - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - cudaError_t status = BatchDecodeWithPaddedKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, c_type, c_type>( - static_cast(q.data_ptr()), static_cast(k_padded.data_ptr()), - static_cast(v_padded.data_ptr()), static_cast(o.data_ptr()), - /*tmp=*/tmp, - /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, - batch_size, padded_kv_len, num_qo_heads, sm_scale, rope_scale, rope_theta, - torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPaddedKVCache failed with error code ", status); - return true; + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k_padded.scalar_type(), kv_type, [&] { + q_type* tmp = nullptr; + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + cudaError_t status = BatchDecodeWithPaddedKVCacheDispatched< + GROUP_SIZE, HEAD_DIM, KV_LAYOUT, POS_ENCODING_MODE, q_type, kv_type, q_type>( + static_cast(q.data_ptr()), static_cast(k_padded.data_ptr()), + static_cast(v_padded.data_ptr()), static_cast(o.data_ptr()), + /*tmp=*/tmp, + /*lse=*/return_lse ? static_cast(lse.data_ptr()) : nullptr, + batch_size, padded_kv_len, num_qo_heads, sm_scale, rope_scale, rope_theta, + torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPaddedKVCache failed with error code ", status); + return true; + }); }); - }); + }); }); }); }); @@ -121,7 +124,7 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, unsigned int pos_encoding_mode, - torch::Tensor empty_data) { + torch::Tensor empty_q_data, torch::Tensor empty_kv_data) { // NOTE(zihao): not necessary to be CUDA tensor CHECK_CONTIGUOUS(indptr); CHECK_CONTIGUOUS(last_page_len); @@ -136,50 +139,54 @@ void BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward( cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(); handler_->SetCUDAStream(torch_current_stream); - if (is_float8_tensor(empty_data)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(empty_data.scalar_type(), c_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - handler_->BeginForwardDispatched( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, - static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, - page_size); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); + if (is_float8_tensor(empty_q_data)) { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(empty_q_data.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(empty_kv_data.scalar_type(), kv_type, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = + handler_->BeginForwardDispatched( + static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, + static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, + page_size); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); + }); }); }); }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_data.scalar_type(), c_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - handler_->BeginForwardDispatched( - static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, - static_cast(indptr.data_ptr()), - static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, - page_size); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(empty_q_data.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(empty_kv_data.scalar_type(), kv_type, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = + handler_->BeginForwardDispatched( + static_cast(workspace_buffer.data_ptr()), workspace_size_in_bytes, + static_cast(indptr.data_ptr()), + static_cast(last_page_len.data_ptr()), batch_size, num_qo_heads, + page_size); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); + }); }); }); }); @@ -208,7 +215,6 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( CHECK_DIM(1, paged_kv_last_page_len); // (B,) CHECK_DIM(1, paged_kv_indptr); // (B+1,) CHECK_DIM(1, paged_kv_indices); // (nnz,) - CHECK_EQ(q.scalar_type(), paged_kv_data.scalar_type()); // (num_max_pages, 2, H_kv, page_size, head_dim) for HND // (num_max_pages, 2, page_size, H_kv, head_dim) for NHD CHECK_DIM(5, paged_kv_data); @@ -242,61 +248,65 @@ std::vector BatchDecodeWithPagedKVCachePyTorchWrapper::Forward( } if (is_float8_tensor(q)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), c_type, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, - static_cast(paged_kv_data.data_ptr()), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, - c_type, nv_half, int32_t>( - handler_.get(), static_cast(q.data_ptr()), /*q_offset=*/nullptr, - paged_kv, static_cast(o.data_ptr()), - /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), - sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(paged_kv_data.scalar_type(), kv_type, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, + static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, + q_type, kv_type, nv_half, int32_t>( + handler_.get(), static_cast(q.data_ptr()), /*q_offset=*/nullptr, + paged_kv, static_cast(o.data_ptr()), + /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), + sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); + }); }); }); }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { - return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, - static_cast(paged_kv_data.data_ptr()), - static_cast(paged_kv_indices.data_ptr()), - static_cast(paged_kv_indptr.data_ptr()), - static_cast(paged_kv_last_page_len.data_ptr())); - cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< - PageStorage::kIndices, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, - c_type, c_type, int32_t>( - handler_.get(), static_cast(q.data_ptr()), /*q_offset=*/nullptr, - paged_kv, static_cast(o.data_ptr()), - /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), - sm_scale, rope_scale, rope_theta, - /*stream=*/torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "BatchDecodeWithPagedKVCache failed with error ", - cudaGetErrorString(status)); - return true; - }); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(paged_kv_data.scalar_type(), kv_type, [&] { + return DISPATCH_kv_layout(kv_layout_, KV_LAYOUT, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + paged_kv_t paged_kv( + num_kv_heads, page_size, head_dim, batch_size, + static_cast(paged_kv_data.data_ptr()), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + cudaError_t status = BatchDecodeWithPagedKVCacheWrapperDispatched< + PageStorage::kIndices, KV_LAYOUT, GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, + q_type, kv_type, q_type, int32_t>( + handler_.get(), static_cast(q.data_ptr()), /*q_offset=*/nullptr, + paged_kv, static_cast(o.data_ptr()), + /*lse=*/(return_lse ? static_cast(lse.data_ptr()) : nullptr), + sm_scale, rope_scale, rope_theta, + /*stream=*/torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchDecodeWithPagedKVCache failed with error ", + cudaGetErrorString(status)); + return true; + }); + }); }); }); }); diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 06ff2dc2..e7ab8a07 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -77,7 +77,7 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper { void BeginForward(torch::Tensor workspace_buffer, torch::Tensor indptr, torch::Tensor last_page_len, unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads, unsigned int head_dim, unsigned int page_size, - unsigned int pos_encoding_mode, torch::Tensor empty_data); + unsigned int pos_encoding_mode, torch::Tensor empty_q_data, torch::Tensor empty_kv_data); void EndForward(); void UpdatePageLockedBufferSize(uint32_t max_workspace_size_in_bytes); bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } diff --git a/python/csrc/pytorch_extension_utils.h b/python/csrc/pytorch_extension_utils.h index 8d5a6952..f67fa369 100644 --- a/python/csrc/pytorch_extension_utils.h +++ b/python/csrc/pytorch_extension_utils.h @@ -31,6 +31,7 @@ using namespace flashinfer; + #ifdef FLASHINFER_ENABLE_BF16 #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ [&]() -> bool { \ @@ -96,6 +97,88 @@ using namespace flashinfer; }() #endif +#if defined (FLASHINFER_ENABLE_BF16) && defined (FLASHINFER_ENABLE_FP8) +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::BFloat16: { \ + using c_type = nv_bfloat16; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Float8_e4m3fn: { \ + using c_type = __nv_fp8_e4m3; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Float8_e5m2: { \ + using c_type = __nv_fp8_e5m2; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() +#elif defined (FLASHINFER_ENABLE_BF16) +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::BFloat16: { \ + using c_type = nv_bfloat16; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() +#elif defined (FLASHINFER_ENABLE_FP8) +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Float8_e4m3fn: { \ + using c_type = __nv_fp8_e4m3; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Float8_e5m2: { \ + using c_type = __nv_fp8_e5m2; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() +#else +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() +#endif + #define _DISPATCH_SWITCH(var_name, cond, ...) \ [&]() -> bool { \ switch (cond) { \ diff --git a/python/csrc/single_decode.cu b/python/csrc/single_decode.cu index 9172fd8b..6f591bad 100644 --- a/python/csrc/single_decode.cu +++ b/python/csrc/single_decode.cu @@ -32,8 +32,7 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc CHECK_DIM(3, v); CHECK_SHAPE(k, v); CHECK_EQ(q.size(1), k.size(2)); - CHECK_EQ(q.scalar_type(), k.scalar_type()); - CHECK_EQ(q.scalar_type(), v.scalar_type()); + CHECK_EQ(v.scalar_type(), k.scalar_type()); unsigned int num_qo_heads = q.size(0); unsigned int head_dim = q.size(1); unsigned int kv_len, num_kv_heads; @@ -51,47 +50,51 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc q, q.options().dtype(is_float8_tensor(q) ? torch::kFloat16 : q.scalar_type())); if (is_float8_tensor(q)) { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), c_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - SingleDecodeWithKVCacheDispatched( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), num_kv_heads, kv_len, sm_scale, - rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SingleDecodeWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(q.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k.scalar_type(), kv_type, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = + SingleDecodeWithKVCacheDispatched( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), static_cast(o.data_ptr()), + static_cast(tmp.data_ptr()), num_kv_heads, kv_len, sm_scale, + rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SingleDecodeWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + }); }); }); }); }); } else { - DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), c_type, [&] { - return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { - return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { - return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { - return DISPATCH_pos_encoding_mode( - PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { - cudaError_t status = - SingleDecodeWithKVCacheDispatched( - static_cast(q.data_ptr()), static_cast(k.data_ptr()), - static_cast(v.data_ptr()), static_cast(o.data_ptr()), - static_cast(tmp.data_ptr()), num_kv_heads, kv_len, sm_scale, - rope_scale, rope_theta, torch_current_stream); - TORCH_CHECK(status == cudaSuccess, - "SingleDecodeWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - return true; - }); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE(q.scalar_type(), q_type, [&] { + return DISPATCH_PYTORCH_DTYPE_TO_CTYPE_COMBINED_FP8(k.scalar_type(), kv_type, [&] { + return DISPATCH_group_size(num_qo_heads / num_kv_heads, GROUP_SIZE, [&] { + return DISPATCH_head_dim(head_dim, HEAD_DIM, [&] { + return DISPATCH_kv_layout(kv_layout, KV_LAYOUT, [&] { + return DISPATCH_pos_encoding_mode( + PosEncodingMode(pos_encoding_mode), POS_ENCODING_MODE, [&] { + cudaError_t status = + SingleDecodeWithKVCacheDispatched( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(v.data_ptr()), static_cast(o.data_ptr()), + static_cast(tmp.data_ptr()), num_kv_heads, kv_len, sm_scale, + rope_scale, rope_theta, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "SingleDecodeWithKVCache kernel launch failed, error: " + + std::string(cudaGetErrorString(status))); + return true; + }); + }); }); }); }); diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index 20b53c0d..3e291335 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -540,6 +540,7 @@ def begin_forward( page_size: int, pos_encoding_mode: str = "NONE", data_type: Union[str, torch.dtype] = "float16", + q_data_type: Optional[Union[str, torch.dtype]] = None, ): r"""Create auxiliary data structures for batch decode for multiple forward calls within the same decode step. @@ -566,6 +567,9 @@ def begin_forward( ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. data_type : Union[str, torch.dtype] The data type of the paged kv cache + q_data_type : Optional[Union[str, torch.dtype]] + The data type of the query tensor. If None, will be set to + ``data_type``. Note ---- @@ -599,8 +603,16 @@ def begin_forward( self._paged_kv_indices_buf = indices self._paged_kv_last_page_len_buf = last_page_len - # NOTE(Zihao): the following tensor acts as placeholder to pass dtype info - empty_data = torch.empty( + # NOTE(Zihao): the following tensors acts as placeholder to pass dtype info + if not q_data_type: + q_data_type = data_type + empty_q_data = torch.empty( + 0, + dtype=( + getattr(torch, q_data_type) if isinstance(q_data_type, str) else q_data_type + ), + ) + empty_kv_data = torch.empty( 0, dtype=( getattr(torch, data_type) if isinstance(data_type, str) else data_type @@ -616,7 +628,8 @@ def begin_forward( head_dim, page_size, PosEncodingMode[pos_encoding_mode].value, - empty_data, + empty_q_data, + empty_kv_data, ) def end_forward(self): diff --git a/python/generate_batch_padded_decode_inst.py b/python/generate_batch_padded_decode_inst.py index 1ef596d4..fa5fb973 100644 --- a/python/generate_batch_padded_decode_inst.py +++ b/python/generate_batch_padded_decode_inst.py @@ -29,15 +29,16 @@ def get_cu_file_str( head_dim, kv_layout, pos_encoding_mode, - dtype_in, + dtype_q, + dtype_kv, dtype_out, ): content = """#include namespace flashinfer {{ -template cudaError_t BatchDecodeWithPaddedKVCacheDispatched<{group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {dtype_in}, {dtype_out}>( - {dtype_in}* q, {dtype_in}* k, {dtype_in}* v, +template cudaError_t BatchDecodeWithPaddedKVCacheDispatched<{group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}>( + {dtype_q}* q, {dtype_kv}* k, {dtype_kv}* v, {dtype_out}* o, {dtype_out}* tmp, float* lse, uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, float sm_scale, float rope_scale, @@ -49,7 +50,8 @@ def get_cu_file_str( group_size=group_size, head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], - dtype_in=dtype_literal[dtype_in], + dtype_q=dtype_literal[dtype_q], + dtype_kv=dtype_literal[dtype_kv], dtype_out=dtype_literal[dtype_out], ) return content @@ -58,7 +60,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( r"batch_padded_decode_group_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" - r"dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" + r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) diff --git a/python/generate_batch_paged_decode_inst.py b/python/generate_batch_paged_decode_inst.py index 1e72a9a8..bd7d2652 100644 --- a/python/generate_batch_paged_decode_inst.py +++ b/python/generate_batch_paged_decode_inst.py @@ -26,7 +26,7 @@ def get_cu_file_str( - group_size, head_dim, kv_layout, pos_encoding_mode, dtype_in, dtype_out, idtype + group_size, head_dim, kv_layout, pos_encoding_mode, dtype_q, dtype_kv, dtype_out, idtype ): content = """#include @@ -34,9 +34,9 @@ def get_cu_file_str( constexpr PageStorage page_storage = PageStorage::kIndices; -template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{group_size}, {head_dim}, page_storage, {kv_layout}, {pos_encoding_mode}, {dtype_in}, {dtype_out}, {idtype}>( - {dtype_in}* q, {idtype}* q_offset, - paged_kv_t paged_kv, +template cudaError_t BatchDecodeWithPagedKVCacheDispatched<{group_size}, {head_dim}, page_storage, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}, {idtype}>( + {dtype_q}* q, {idtype}* q_offset, + paged_kv_t paged_kv, kv_partition_info_t<{idtype}> kv_partition_info, {dtype_out}* o, {dtype_out}* tmp_v, float* tmp_s, float* lse, bool* block_valid_mask, uint32_t padded_batch_size, @@ -49,7 +49,8 @@ def get_cu_file_str( group_size=group_size, head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], - dtype_in=dtype_literal[dtype_in], + dtype_q=dtype_literal[dtype_q], + dtype_kv=dtype_literal[dtype_kv], dtype_out=dtype_literal[dtype_out], idtype=idtype_literal[idtype], ) @@ -59,7 +60,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( r"batch_paged_decode_group_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" - r"dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" + r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)_idtype_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) diff --git a/python/generate_single_decode_inst.py b/python/generate_single_decode_inst.py index 67d417c0..8fc36218 100644 --- a/python/generate_single_decode_inst.py +++ b/python/generate_single_decode_inst.py @@ -21,14 +21,14 @@ def get_cu_file_str( - group_size, head_dim, kv_layout, pos_encoding_mode, dtype_in, dtype_out + group_size, head_dim, kv_layout, pos_encoding_mode, dtype_q, dtype_kv, dtype_out ): content = """#include namespace flashinfer {{ -template cudaError_t SingleDecodeWithKVCacheDispatched<{group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {dtype_in}, {dtype_out}>( - {dtype_in}* q, {dtype_in}* k, {dtype_in}* v, {dtype_out}* o, +template cudaError_t SingleDecodeWithKVCacheDispatched<{group_size}, {head_dim}, {kv_layout}, {pos_encoding_mode}, {dtype_q}, {dtype_kv}, {dtype_out}>( + {dtype_q}* q, {dtype_kv}* k, {dtype_kv}* v, {dtype_out}* o, {dtype_out}* tmp, uint32_t num_kv_heads, uint32_t seq_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); @@ -39,7 +39,8 @@ def get_cu_file_str( group_size=group_size, head_dim=head_dim, pos_encoding_mode=pos_encoding_mode_literal[int(pos_encoding_mode)], - dtype_in=dtype_literal[dtype_in], + dtype_q=dtype_literal[dtype_q], + dtype_kv=dtype_literal[dtype_kv], dtype_out=dtype_literal[dtype_out], ) return content @@ -48,7 +49,7 @@ def get_cu_file_str( if __name__ == "__main__": pattern = ( r"single_decode_group_([0-9]+)_head_([0-9]+)_layout_([0-9]+)_posenc_([0-9]+)_" - r"dtypein_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" + r"dtypeq_([a-z0-9]+)_dtypekv_([a-z0-9]+)_dtypeout_([a-z0-9]+)\.cu" ) compiled_pattern = re.compile(pattern) diff --git a/python/setup.py b/python/setup.py index fff016b7..c758f394 100644 --- a/python/setup.py +++ b/python/setup.py @@ -94,12 +94,12 @@ def get_instantiation_cu() -> List[str]: idtypes = ["i32"] prefill_dtypes = ["f16"] decode_dtypes = ["f16"] + fp8_dtypes = ["e4m3", "e5m2"] if enable_bf16: prefill_dtypes.append("bf16") decode_dtypes.append("bf16") - fp8_dtypes = [] if enable_fp8: - fp8_dtypes = ["e4m3", "e5m2"] + decode_dtypes.extend(fp8_dtypes) files = [] # single decode files @@ -114,29 +114,17 @@ def get_instantiation_cu() -> List[str]: kv_layouts, pos_encoding_modes, ): - for dtype in decode_dtypes: - fname = f"single_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypein_{dtype}_dtypeout_{dtype}.cu" + for dtype_q, dtype_kv in itertools.product(decode_dtypes, decode_dtypes): + dtype_out = dtype_q if dtype_q not in fp8_dtypes else "f16" + fname = f"single_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" files.append(prefix + "/" + fname) content = generate_single_decode_inst.get_cu_file_str( group_size, head_dim, kv_layout, pos_encoding_mode, - dtype, - dtype, - ) - write_if_different(root / prefix / fname, content) - - for dtype_in in fp8_dtypes: - dtype_out = "f16" - fname = f"single_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypein_{dtype_in}_dtypeout_{dtype_out}.cu" - files.append(prefix + "/" + fname) - content = generate_single_decode_inst.get_cu_file_str( - group_size, - head_dim, - kv_layout, - pos_encoding_mode, - dtype_in, + dtype_q, + dtype_kv, dtype_out, ) write_if_different(root / prefix / fname, content) @@ -154,58 +142,33 @@ def get_instantiation_cu() -> List[str]: pos_encoding_modes, ): for idtype in idtypes: - for dtype in decode_dtypes: - fname = f"batch_paged_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypein_{dtype}_dtypeout_{dtype}_idtype_{idtype}.cu" - files.append(prefix + "/" + fname) - content = generate_batch_paged_decode_inst.get_cu_file_str( - group_size, - head_dim, - kv_layout, - pos_encoding_mode, - dtype, - dtype, - idtype, - ) - write_if_different(root / prefix / fname, content) - - for dtype_in in fp8_dtypes: - dtype_out = "f16" - fname = f"batch_paged_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypein_{dtype_in}_dtypeout_{dtype_out}_idtype_{idtype}.cu" + for dtype_q, dtype_kv in itertools.product(decode_dtypes, decode_dtypes): + dtype_out = dtype_q if dtype_q not in fp8_dtypes else "f16" + fname = f"batch_paged_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}_idtype_{idtype}.cu" files.append(prefix + "/" + fname) content = generate_batch_paged_decode_inst.get_cu_file_str( group_size, head_dim, kv_layout, pos_encoding_mode, - dtype_in, + dtype_q, + dtype_kv, dtype_out, idtype, ) write_if_different(root / prefix / fname, content) - for dtype in decode_dtypes: - fname = f"batch_padded_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypein_{dtype}_dtypeout_{dtype}.cu" - files.append(prefix + "/" + fname) - content = generate_batch_padded_decode_inst.get_cu_file_str( - group_size, - head_dim, - kv_layout, - pos_encoding_mode, - dtype, - dtype, - ) - write_if_different(root / prefix / fname, content) - - for dtype_in in fp8_dtypes: - dtype_out = "f16" - fname = f"batch_padded_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypein_{dtype_in}_dtypeout_{dtype_out}.cu" + for dtype_q, dtype_kv in itertools.product(decode_dtypes, decode_dtypes): + dtype_out = dtype_q if dtype_q not in fp8_dtypes else "f16" + fname = f"batch_padded_decode_group_{group_size}_head_{head_dim}_layout_{kv_layout}_posenc_{pos_encoding_mode}_dtypeq_{dtype_q}_dtypekv_{dtype_kv}_dtypeout_{dtype_out}.cu" files.append(prefix + "/" + fname) content = generate_batch_padded_decode_inst.get_cu_file_str( group_size, head_dim, kv_layout, pos_encoding_mode, - dtype_in, + dtype_q, + dtype_kv, dtype_out, ) write_if_different(root / prefix / fname, content) diff --git a/python/tests/test_batch_decode_kernels.py b/python/tests/test_batch_decode_kernels.py index dbdee27a..d7dc92a0 100644 --- a/python/tests/test_batch_decode_kernels.py +++ b/python/tests/test_batch_decode_kernels.py @@ -30,7 +30,10 @@ @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) @pytest.mark.parametrize( - "dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] + "q_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] +) +@pytest.mark.parametrize( + "kv_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] ) def test_batch_decode_with_paged_kv_cache( batch_size, @@ -41,9 +44,10 @@ def test_batch_decode_with_paged_kv_cache( head_dim, kv_layout, pos_encoding_mode, - dtype, + q_dtype, + kv_dtype, ): - q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).to(dtype) + q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).to(q_dtype) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size kv_data = ( @@ -68,9 +72,10 @@ def test_batch_decode_with_paged_kv_cache( head_dim, page_size, "NONE", - dtype, + kv_dtype, + q_dtype, ) - o = wrapper.forward(q, kv_data.to(dtype), pos_encoding_mode=pos_encoding_mode) + o = wrapper.forward(q, kv_data.to(kv_dtype), pos_encoding_mode=pos_encoding_mode) for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] @@ -90,7 +95,7 @@ def test_batch_decode_with_paged_kv_cache( .reshape(-1, num_kv_heads, head_dim), ], dim=0, - ).to(dtype) + ).to(kv_dtype) vi = torch.cat( [ kv_data[kv_indptr[i] : kv_indptr[i + 1] - 1, 1] @@ -105,7 +110,7 @@ def test_batch_decode_with_paged_kv_cache( .reshape(-1, num_kv_heads, head_dim), ], dim=0, - ).to(dtype) + ).to(kv_dtype) o_ref_i = flashinfer.single_decode_with_kv_cache( qi, ki, vi, pos_encoding_mode=pos_encoding_mode ) @@ -123,7 +128,10 @@ def test_batch_decode_with_paged_kv_cache( @pytest.mark.parametrize("kv_layout", ["HND", "NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) @pytest.mark.parametrize( - "dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] + "q_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] +) +@pytest.mark.parametrize( + "kv_dtype", [torch.float16, torch.float8_e4m3fn, torch.float8_e5m2] ) def test_cuda_graph_batch_decode_with_paged_kv_cache( batch_size, @@ -134,9 +142,10 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( head_dim, kv_layout, pos_encoding_mode, - dtype, + q_dtype, + kv_dtype, ): - q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).to(dtype) + q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).to(q_dtype) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size kv_data = ( @@ -144,7 +153,7 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( if kv_layout == "HND" else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim).to(0) ) - kv_data_dtype = kv_data.to(dtype) + kv_data_dtype = kv_data.to(kv_dtype) kv_indptr_host_warmup = torch.arange(0, batch_size + 1).int() kv_indices_host_warmup = torch.arange(0, batch_size).int() kv_last_page_len_host_warmup = torch.full( @@ -173,7 +182,8 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( head_dim, page_size, "NONE", - dtype, + kv_dtype, + q_dtype, ) # warmup s = torch.cuda.Stream() @@ -204,7 +214,8 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( head_dim, page_size, "NONE", - dtype, + kv_dtype, + q_dtype, ) g.replay() @@ -224,7 +235,8 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( head_dim, page_size, "NONE", - dtype, + kv_dtype, + q_dtype, ) g.replay() @@ -250,7 +262,7 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( .reshape(-1, num_kv_heads, head_dim), ], dim=0, - ).to(dtype) + ).to(kv_dtype) vi = torch.cat( [ kv_data[kv_indptr[i] : kv_indptr[i + 1] - 1, 1] @@ -265,7 +277,7 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( .reshape(-1, num_kv_heads, head_dim), ], dim=0, - ).to(dtype) + ).to(kv_dtype) o_ref_i = flashinfer.single_decode_with_kv_cache( qi, ki, vi, pos_encoding_mode=pos_encoding_mode ) @@ -276,17 +288,23 @@ def test_cuda_graph_batch_decode_with_paged_kv_cache( if __name__ == "__main__": test_batch_decode_with_paged_kv_cache( - 256, 54, 8, 8, 8, 128, "NHD", "NONE", torch.float16 + 256, 54, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16 ) test_batch_decode_with_paged_kv_cache( - 12, 2048, 8, 8, 8, 128, "NHD", "NONE", torch.float16 + 12, 2048, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16 ) test_batch_decode_with_paged_kv_cache( - 12, 54, 1, 8, 8, 128, "HND", "NONE", torch.float8_e5m2 + 12, 54, 1, 8, 8, 128, "HND", "NONE", torch.float8_e5m2, torch.float8_e5m2 ) test_cuda_graph_batch_decode_with_paged_kv_cache( - 12, 2048, 8, 8, 8, 128, "NHD", "NONE", torch.float16 + 12, 2048, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16 ) test_cuda_graph_batch_decode_with_paged_kv_cache( - 128, 54, 8, 8, 8, 128, "NHD", "NONE", torch.float16 + 128, 54, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16 ) + test_batch_decode_with_paged_kv_cache( + 12, 54, 1, 8, 8, 128, "HND", "NONE", torch.float8_e5m2, torch.float16 + ) + test_cuda_graph_batch_decode_with_paged_kv_cache( + 12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float8_e5m2, torch.float16 + ) \ No newline at end of file diff --git a/src/bench_batch_decode.cu b/src/bench_batch_decode.cu index e0b793ff..9aa1b919 100644 --- a/src/bench_batch_decode.cu +++ b/src/bench_batch_decode.cu @@ -73,7 +73,7 @@ void bench_flashinfer_batch_decode(nvbench::state& state) { size_t workspace_size_in_bytes = 32 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); // begin forward - BatchDecodeHandlerBeginForward( + BatchDecodeHandlerBeginForward( &handler, (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, kv_indptr_host.data(), kv_last_page_len_host.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, pos_encoding_mode); diff --git a/src/bench_cascade.cu b/src/bench_cascade.cu index 34011d3e..ec09cdb5 100644 --- a/src/bench_cascade.cu +++ b/src/bench_cascade.cu @@ -109,7 +109,7 @@ void bench_two_level_single_prefix_cascade_decode(nvbench::state& state) { BatchDecodeHandler cascade_handler; size_t workspace_size_in_bytes = 32 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); - BatchDecodeHandlerBeginForward( + BatchDecodeHandlerBeginForward( &cascade_handler, (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); diff --git a/src/flashinfer_ops.cuh b/src/flashinfer_ops.cuh index cb3e1f15..84294e00 100644 --- a/src/flashinfer_ops.cuh +++ b/src/flashinfer_ops.cuh @@ -145,8 +145,8 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( return cudaSuccess; } -template -cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, DTypeOut* tmp, +template +cudaError_t SingleDecodeWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t seq_len, uint32_t head_dim, QKVLayout kv_layout = QKVLayout::kNHD, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, @@ -175,8 +175,8 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut return cudaSuccess; } -template -cudaError_t BatchDecodeWithPaddedKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, +template +cudaError_t BatchDecodeWithPaddedKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o, DTypeOut* tmp, float* lse, uint32_t batch_size, uint32_t padded_kv_len, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, @@ -200,17 +200,17 @@ cudaError_t BatchDecodeWithPaddedKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTy {DISPATCH_pos_encoding_mode( pos_encoding_mode, POS_ENCODING_MODE, {DISPATCH_kv_layout(kv_layout, KV_LAYOUT, { return BatchDecodeWithPaddedKVCacheDispatched( + POS_ENCODING_MODE, DTypeQ, DtypeKV, DTypeOut>( q, k, v, o, tmp, lse, batch_size, padded_kv_len, num_qo_heads, sm_scale, rope_scale, rope_theta, stream); })})})}); return cudaSuccess; } -template +template cudaError_t BatchDecodeWithPagedKVCacheNoSplitKV( - DTypeIn* q, IdType* q_offset, paged_kv_t paged_kv, + DTypeQ* q, IdType* q_offset, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, float* lse, uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, @@ -231,8 +231,8 @@ cudaError_t BatchDecodeWithPagedKVCacheNoSplitKV( {DISPATCH_head_dim( head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { return BatchDecodeWithPagedKVCacheDispatched( + kv_layout, POS_ENCODING_MODE, DTypeQ, + DTypeKV, DTypeOut, IdType>( q, q_offset, paged_kv, kv_partition_info, o, /*tmp_v=*/nullptr, /*tmp_s=*/nullptr, lse, /*block_valid_mask=*/nullptr, /*padded_batch_size=*/paged_kv.batch_size, sm_scale, @@ -247,7 +247,8 @@ cudaError_t BatchDecodeWithPagedKVCacheNoSplitKV( * for cooperative kernels. * \tparam page_storage Whether to store indices or pointers of each active page * \tparam kv_layout The layout of last 3 dimensions in KV-Cache - * \tparam DTypeIn The data type of input tensor. + * \tparam DTypeQ The data type of query tensor. + * \tparam DTypeKV The data type of key-value tensor. * \tparam DTypeOut The data type of output tensor. * \tparam IdType The data type of index tensor. * \param handler The handler for the batch decode forward request. @@ -263,11 +264,11 @@ cudaError_t BatchDecodeWithPagedKVCacheNoSplitKV( * \note This wrapper function should be only called after we call BeginForward function in the * BatchDecodeHandler. */ -template +template cudaError_t BatchDecodeWithPagedKVCacheWrapper( - BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset, - paged_kv_t paged_kv, DTypeOut* o, float* lse, + BatchDecodeHandler* handler, DTypeQ* q, IdType* q_offset, + paged_kv_t paged_kv, DTypeOut* o, float* lse, uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { @@ -287,14 +288,14 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper( {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { return BatchDecodeWithPagedKVCacheWrapperDispatched( + DTypeQ, DTypeKV, DTypeOut, IdType>( handler, q, q_offset, paged_kv, o, lse, sm_scale, rope_scale, rope_theta, stream); })})}); return cudaSuccess; } -template +template cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* buffer, size_t workspace_size_in_bytes, IdType* indptr, IdType* last_page_len, uint32_t batch_size, @@ -311,7 +312,7 @@ cudaError_t BatchDecodeHandlerBeginForward(BatchDecodeHandler* handler, void* bu DISPATCH_head_dim(head_dim, HEAD_DIM, { DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { return handler->BeginForwardDispatched( + POS_ENCODING_MODE, DTypeQ, DTypeKV, DTypeOut, IdType>( buffer, workspace_size_in_bytes, indptr, last_page_len, batch_size, num_qo_heads, page_size); }); diff --git a/src/test_batch_decode.cu b/src/test_batch_decode.cu index 6b14372d..b8f77a97 100644 --- a/src/test_batch_decode.cu +++ b/src/test_batch_decode.cu @@ -100,7 +100,7 @@ void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, si flashinfer::BatchDecodeHandler handler; size_t workspace_size_in_bytes = 32 * 1024 * 1024; thrust::device_vector buffer(workspace_size_in_bytes); - BatchDecodeHandlerBeginForward( + BatchDecodeHandlerBeginForward( &handler, (void*)thrust::raw_pointer_cast(buffer.data()), workspace_size_in_bytes, kv_indptr.data(), kv_last_page_len.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, pos_encoding_mode); diff --git a/src/test_cascade.cu b/src/test_cascade.cu index e760804a..0b1e6d18 100644 --- a/src/test_cascade.cu +++ b/src/test_cascade.cu @@ -283,12 +283,12 @@ void _TestTwoLevelSinglePrefixCascadeDecodeCorrectness(size_t batch_size, thrust::device_vector buffer_baseline(workspace_size_in_bytes), buffer_cascade(workspace_size_in_bytes); - BatchDecodeHandlerBeginForward( + BatchDecodeHandlerBeginForward( &baseline_handler, (void*)thrust::raw_pointer_cast(buffer_baseline.data()), workspace_size_in_bytes, kv_indptr_combined_h.data(), kv_last_page_len_combined_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); - BatchDecodeHandlerBeginForward( + BatchDecodeHandlerBeginForward( &cascade_handler, (void*)thrust::raw_pointer_cast(buffer_cascade.data()), workspace_size_in_bytes, kv_indptr_unique_h.data(), kv_last_page_len_unique_h.data(), batch_size, num_qo_heads, num_kv_heads, head_dim, page_size, PosEncodingMode::kNone); diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index f3e6c6ec..b7682972 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -424,7 +424,7 @@ void _FlashInferAttentionDecodeWithPagedKVCacheBeginForward( batch_decode_handlers[handler_idx].SetCUDAStream(static_cast(copy_stream)); DISPATCH_TVM_CUDA_IDTYPE(page_table_indptr->dtype, dtype_idx, { cudaError_t status = - BatchDecodeHandlerBeginForward( + BatchDecodeHandlerBeginForward( batch_decode_handlers + handler_idx, static_cast(workspace_buffer->data), workspace_size_in_bytes, static_cast(page_table_indptr->data) +