diff --git a/csrc/batch_prefill.cu b/csrc/batch_prefill.cu index bf8dc5bde9..4d17cdd215 100644 --- a/csrc/batch_prefill.cu +++ b/csrc/batch_prefill.cu @@ -257,6 +257,7 @@ void BatchPrefillWithPagedKVCacheRun( static_cast(paged_kv_last_page_len.data_ptr())); params.paged_kv = paged_kv; params.q_indptr = static_cast(qo_indptr.data_ptr()); + params.q_lenptr = nullptr; // disable incontinous qo params.o = static_cast(o.data_ptr()); params.lse = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; diff --git a/csrc/batch_prefill_customize_config.jinja b/csrc/batch_prefill_customize_config.jinja index 77490d71b2..8a04d5dfcb 100644 --- a/csrc/batch_prefill_customize_config.jinja +++ b/csrc/batch_prefill_customize_config.jinja @@ -86,7 +86,10 @@ struct PagedParams { DTypeQ* q; paged_kv_t paged_kv; + IdType* q_indptr; + uint32_t* q_lenptr; + DTypeO* o; float* lse; uint_fastdiv group_size; @@ -110,6 +113,9 @@ struct PagedParams { bool partition_kv; __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + if(q_lenptr){ + return q_lenptr[batch_idx]; + } return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; } diff --git a/csrc/pod.cu b/csrc/pod.cu index 2b9a8d6f6c..2da7610d1a 100644 --- a/csrc/pod.cu +++ b/csrc/pod.cu @@ -13,194 +13,124 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include +#include +#include #include #include -#include "aot_extension_utils.h" #include "pod_config.inc" #include "pytorch_conversion_utils.h" #include "pytorch_extension_utils.h" namespace flashinfer { -template -cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, - typename PrefillParams::DTypeO* tmp, - DecodeParams decode_params, - typename DecodeParams::DTypeO* tmp_v, float* tmp_s, - cudaStream_t stream); +template +cudaError_t PODWithPagedKVCacheDispatched(PrefillParams prefill_params, DecodeParams decode_params, + typename DecodeParams::DTypeO* tmp_v, float* tmp_s, + cudaStream_t stream); } // namespace flashinfer using namespace flashinfer; -void pod_with_kv_cache_tensor( - // Prefill params - at::Tensor q_p, at::Tensor k_p, at::Tensor v_p, at::Tensor tmp_p, at::Tensor o_p, - std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, - int64_t window_left_p, std::optional maybe_custom_mask_p, - std::optional maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p, - double rope_rcp_scale_p, double rope_rcp_theta_p, - // Decode params - at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, - at::Tensor plan_info_vec, at::Tensor q_d, at::Tensor paged_k_cache_d, - at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, at::Tensor paged_kv_indptr_d, - at::Tensor paged_kv_indices_d, at::Tensor paged_kv_last_page_len_d, at::Tensor o_d, - std::optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, - int64_t window_left_d, std::optional maybe_custom_mask_d, - std::optional maybe_mask_indptr_d, std::optional maybe_alibi_slopes_d, - double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, - // Shared params - int64_t cuda_stream) { - // Prefill setup - unsigned int head_dim_qk = q_p.size(2); - unsigned int kv_len_p, qo_len_p, num_kv_heads, num_qo_heads; - QKVLayout kv_layout_p = static_cast(layout_p); - qo_len_p = q_p.size(0); - num_qo_heads = q_p.size(1); - uint32_t q_stride_n_p = q_p.stride(0), q_stride_h_p = q_p.stride(1), k_stride_n_p, k_stride_h_p, - v_stride_n_p, v_stride_h_p; - if (kv_layout_p == QKVLayout::kNHD) { - kv_len_p = k_p.size(0); - num_kv_heads = k_p.size(1); - k_stride_n_p = k_p.stride(0); - k_stride_h_p = k_p.stride(1); - v_stride_n_p = v_p.stride(0); - v_stride_h_p = v_p.stride(1); - } else { - kv_len_p = k_p.size(1); - num_kv_heads = k_p.size(0); - k_stride_h_p = k_p.stride(0); - k_stride_n_p = k_p.stride(1); - v_stride_h_p = v_p.stride(0); - v_stride_n_p = v_p.stride(1); - } - if (maybe_lse_p) { - const auto& lse = *maybe_lse_p; - TORCH_CHECK(lse.size(0) == qo_len_p, lse.size(0), q_p.size(0)); - TORCH_CHECK(lse.size(1) == num_qo_heads, lse.size(1), q_p.size(1)); - } - - const MaskMode mask_mode_p = static_cast(mask_mode_code_p); +at::Tensor PODWithPagedKVCachePlan(at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, + at::Tensor qo_indptr, at::Tensor kv_indptr, + at::Tensor kv_last_page_len, int64_t total_num_rows, + int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, + int64_t head_dim_vo, bool causal, int64_t cuda_stream) { + size_t float_workspace_size_in_bytes = + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); + size_t int_workspace_size_in_bytes = + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); + + PoDPlanInfo plan_info; - auto q_scalar_type = q_p.scalar_type(); - auto kv_scalar_type = k_p.scalar_type(); - - // Decode setup (Tensor decode = batched prefill) - PrefillPlanInfo plan_info; - plan_info.FromVector(tensor_to_vec(plan_info_vec)); - QKVLayout kv_layout_d = static_cast(layout_d); - auto device = q_d.device(); - int64_t batch_size = paged_kv_indptr_d.size(0) - 1; - int64_t num_qo_heads_d = q_d.size(1); - - TORCH_CHECK(num_qo_heads == num_qo_heads_d, - "POD currently requires same # Query heads for prefill and decode"); + cudaStream_t stream = reinterpret_cast(cuda_stream); + cudaError_t status = + PoDPlan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr(), + kv_indptr.data_ptr(), kv_last_page_len.data_ptr(), + total_num_rows, batch_size, num_qo_heads, num_kv_heads, head_dim_qk, + head_dim_vo, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); + + TORCH_CHECK(status == cudaSuccess, + "Failed to plan PoD Attention with error: ", cudaGetErrorString(status)); + + return vec_to_tensor(plan_info.ToVector()); +} - int64_t num_kv_heads_d, page_size_d; - uint32_t head_dim_qk_d = q_d.size(2); - if (kv_layout_d == QKVLayout::kHND) { - num_kv_heads_d = paged_k_cache_d.size(1); - page_size_d = paged_k_cache_d.size(2); +void PODWithPagedKVCacheRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor paged_v_cache, at::Tensor paged_kv_indices, at::Tensor o, + std::optional maybe_lse, int64_t mask_mode_code, + int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS, + int64_t cuda_stream) { + PoDPlanInfo pod_plan_info; + pod_plan_info.FromVector(tensor_to_vec(plan_info_vec)); + + QKVLayout kv_layout = static_cast(layout); + const MaskMode mask_mode = static_cast(mask_mode_code); + auto device = q.device(); + + int64_t num_qo_heads = q.size(1); + int64_t num_kv_heads, page_size; + uint32_t head_dim_qk = q.size(2); + if (kv_layout == QKVLayout::kHND) { + num_kv_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); } else { - page_size_d = paged_k_cache_d.size(1); - num_kv_heads_d = paged_k_cache_d.size(2); + page_size = paged_k_cache.size(1); + num_kv_heads = paged_k_cache.size(2); } - TORCH_CHECK(num_kv_heads == num_kv_heads_d, - "POD currently requires same # KV heads for prefill and decode; Prefill: ", - num_kv_heads, ", Decode: ", num_kv_heads_d); - if (maybe_lse_d) { - const auto& lse = *maybe_lse_d; - TORCH_CHECK(lse.size(0) == q_d.size(0), lse.size(0), q_d.size(0)); - TORCH_CHECK(lse.size(1) == q_d.size(1), lse.size(1), q_d.size(1)); + if (maybe_lse) { + const auto& lse = *maybe_lse; + TORCH_CHECK(lse.size(0) == q.size(0), lse.size(0), q.size(0)); + TORCH_CHECK(lse.size(1) == q.size(1), lse.size(1), q.size(1)); } - void* float_buffer_ptr = static_cast(float_workspace_buffer_d.data_ptr()); - void* int_buffer_ptr = static_cast(int_workspace_buffer_d.data_ptr()); - - const MaskMode mask_mode_d = static_cast(mask_mode_code_d); - auto q_scalar_type_d = q_d.scalar_type(); - auto kv_scalar_type_d = paged_k_cache_d.scalar_type(); + void* float_buffer_ptr = static_cast(float_workspace_buffer.data_ptr()); + void* int_buffer_ptr = static_cast(int_workspace_buffer.data_ptr()); + auto q_scalar_type = q.scalar_type(); + auto kv_scalar_type = paged_k_cache.scalar_type(); // get q_stride_n and q_stride_h - const auto q_stride_n_d = q_d.stride(0); - const auto q_stride_h_d = q_d.stride(1); + const auto q_stride_n = q.stride(0); + const auto q_stride_h = q.stride(1); // get kv_cache_strides - const int64_t* kv_cache_strides_d = nullptr; - auto k_strides_d = paged_k_cache_d.strides(); - auto v_strides_d = paged_v_cache_d.strides(); - TORCH_CHECK(k_strides_d == v_strides_d, "k/v strides must be identical"); - kv_cache_strides_d = k_strides_d.data(); + const int64_t* kv_cache_strides = nullptr; + auto k_strides = paged_k_cache.strides(); + auto v_strides = paged_v_cache.strides(); + TORCH_CHECK(k_strides == v_strides, "k/v strides must be identical"); + kv_cache_strides = k_strides.data(); cudaStream_t stream = reinterpret_cast(cuda_stream); - DISPATCH_context( - MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, USE_SLIDING_WINDOW_P, - USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, [&] { - PrefillParams prefill_params; - { - // Make params a reference to prefill_params to set values - PrefillParams& params = prefill_params; - params.q = static_cast(q_p.data_ptr()); - params.k = static_cast(k_p.data_ptr()); - params.v = static_cast(v_p.data_ptr()); - params.o = static_cast(o_p.data_ptr()); - params.lse = maybe_lse_p ? static_cast(maybe_lse_p->data_ptr()) : nullptr; - params.num_qo_heads = num_qo_heads; - params.num_kv_heads = num_kv_heads; - params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); - params.qo_len = qo_len_p; - params.kv_len = kv_len_p; - params.q_stride_n = q_stride_n_p; - params.q_stride_h = q_stride_h_p; - params.k_stride_n = k_stride_n_p; - params.k_stride_h = k_stride_h_p; - params.v_stride_n = v_stride_n_p; - params.v_stride_h = v_stride_h_p; - - params.window_left = window_left_p; - params.partition_kv = false; - - params.maybe_custom_mask = maybe_custom_mask_p - ? static_cast(maybe_custom_mask_p->data_ptr()) - : nullptr; - params.maybe_alibi_slopes = maybe_alibi_slopes_p - ? static_cast(maybe_alibi_slopes_p->data_ptr()) - : nullptr; - params.logits_soft_cap = logits_soft_cap_p; - params.sm_scale = sm_scale_p; - params.rope_rcp_scale = rope_rcp_scale_p; - params.rope_rcp_theta = rope_rcp_theta_p; - } - - DecodeParams decode_params; + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, + USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, + RaggedParams, PagedParams, [&] { + PagedParams params_p, params_d; DTypeO* tmp_v = nullptr; float* tmp_s = nullptr; - { - DecodeParams& params = decode_params; - params.q = static_cast(q_d.data_ptr()); - paged_kv_t paged_kv( - num_kv_heads, page_size_d, HEAD_DIM_VO, batch_size, kv_layout_d, - static_cast(paged_k_cache_d.data_ptr()), - static_cast(paged_v_cache_d.data_ptr()), kv_cache_strides_d, - static_cast(paged_kv_indices_d.data_ptr()), - static_cast(paged_kv_indptr_d.data_ptr()), - static_cast(paged_kv_last_page_len_d.data_ptr())); - params.paged_kv = paged_kv; - params.q_indptr = static_cast(qo_indptr_d.data_ptr()); - params.o = static_cast(o_d.data_ptr()); - params.lse = maybe_lse_d ? static_cast(maybe_lse_d->data_ptr()) : nullptr; + auto _configureParams = [&](PagedParams& params, const PrefillPlanInfo& plan_info, + int64_t batch_size) { + params.q = static_cast(q.data_ptr()); + params.o = static_cast(o.data_ptr()); + params.q_indptr = nullptr; + params.q_lenptr = nullptr; + params.lse = maybe_lse ? static_cast(maybe_lse->data_ptr()) : nullptr; params.num_qo_heads = num_qo_heads; - params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); - params.q_stride_n = q_stride_n_d; - params.q_stride_h = q_stride_h_d; - params.window_left = window_left_d; + params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); + params.q_stride_n = q_stride_n; + params.q_stride_h = q_stride_h; + params.window_left = window_left; params.request_indices = nullptr; params.qo_tile_indices = nullptr; @@ -214,16 +144,13 @@ void pod_with_kv_cache_tensor( params.padded_batch_size = 0; params.partition_kv = false; - params.maybe_mask_indptr = maybe_mask_indptr_d - ? static_cast(maybe_mask_indptr_d->data_ptr()) - : nullptr; - params.maybe_alibi_slopes = maybe_alibi_slopes_d - ? static_cast(maybe_alibi_slopes_d->data_ptr()) - : nullptr; - params.logits_soft_cap = logits_soft_cap_d; - params.sm_scale = sm_scale_d; - params.rope_rcp_scale = rope_rcp_scale_d; - params.rope_rcp_theta = rope_rcp_theta_d; + ADDITIONAL_PARAMS_SETTER + + paged_kv_t paged_kv( + num_kv_heads, page_size, HEAD_DIM_VO, batch_size, kv_layout, + static_cast(paged_k_cache.data_ptr()), + static_cast(paged_v_cache.data_ptr()), kv_cache_strides, + static_cast(paged_kv_indices.data_ptr()), nullptr, nullptr); params.request_indices = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.request_indices_offset); @@ -234,9 +161,20 @@ void pod_with_kv_cache_tensor( params.o_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.o_indptr_offset); params.kv_chunk_size_ptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_chunk_size_ptr_offset); + params.q_indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.q_start_ptr_offset); + params.q_lenptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.q_len_ptr_offset); + paged_kv.indptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_start_ptr_offset); + paged_kv.len_ptr = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_ptr_offset); + paged_kv.last_page_len = + GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_last_page_offset); + params.paged_kv = paged_kv; + if (plan_info.split_kv) { - params.merge_indptr = - GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_indptr_offset); + params.partition_kv = true; // used in prefill kernel tmp_v = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.v_offset); tmp_s = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.s_offset); if (plan_info.enable_cuda_graph) { @@ -250,25 +188,31 @@ void pod_with_kv_cache_tensor( params.total_num_rows = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.total_num_rows_offset); } + }; + + _configureParams(params_p, pod_plan_info.plan_info_p, pod_plan_info.batch_size_vec_p); + _configureParams(params_d, pod_plan_info.plan_info_d, pod_plan_info.batch_size_vec_d); + + if (pod_plan_info.plan_info_p.split_kv || pod_plan_info.plan_info_d.split_kv) { + params_p.merge_indptr = GetPtrFromBaseOffset( + int_buffer_ptr, pod_plan_info.plan_info_p.merge_indptr_offset); + params_d.merge_indptr = GetPtrFromBaseOffset( + int_buffer_ptr, pod_plan_info.plan_info_d.merge_indptr_offset); } - constexpr bool use_custom_mask_p = MASK_MODE_P == MaskMode::kCustom; - using PrefillAttentionVariant = - DefaultAttention; - constexpr bool use_custom_mask_d = MASK_MODE_D == MaskMode::kCustom; - using DecodeAttentionVariant = - DefaultAttention; - // DISPATCH_CTA_TILE_Q(plan_info.cta_tile_q, CTA_TILE_Q, { - constexpr size_t CTA_TILE_Q = 16; - cudaError_t status = flashinfer::PODWithKVCacheTensorDispatched< - HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, MASK_MODE_P, - CTA_TILE_Q, MASK_MODE_D, PrefillAttentionVariant, DecodeAttentionVariant>( - prefill_params, static_cast(tmp_p.data_ptr()), decode_params, tmp_v, tmp_s, - stream); - TORCH_CHECK(status == cudaSuccess, "PODWithKVCache kernel launch failed, error: " + - std::string(cudaGetErrorString(status))); - //}); + cudaError_t status = cudaSuccess; + + DISPATCH_CTA_TILE_Q(pod_plan_info.plan_info_p.cta_tile_q, CTA_TILE_Q_P, { + DISPATCH_CTA_TILE_Q(pod_plan_info.plan_info_d.cta_tile_q, CTA_TILE_Q_D, { + status = flashinfer::PODWithPagedKVCacheDispatched< + CTA_TILE_Q_P, CTA_TILE_Q_D, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, + USE_FP16_QK_REDUCTION, MASK_MODE, AttentionVariant, PagedParams, PagedParams>( + params_p, params_d, tmp_v, tmp_s, stream); + }); + }); + + TORCH_CHECK(status == cudaSuccess, "PODWithPagedKVCacheDispatched failed with error ", + cudaGetErrorString(status)); + return true; }); } diff --git a/csrc/pod_config.inc b/csrc/pod_config.inc index 16306651a0..9bc74391af 100644 --- a/csrc/pod_config.inc +++ b/csrc/pod_config.inc @@ -1,8 +1,6 @@ #pragma once #include -#include #include -#include #include #include #include @@ -12,34 +10,36 @@ #include "aot_default_additional_params.h" #include "aot_extension_utils.h" -using namespace flashinfer; +using IdType = int32_t; -#define DISPATCH_context(MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, \ - USE_SLIDING_WINDOW_P, USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, ...) \ -{ \ - DISPATCH_mask_mode(mask_mode_p, MASK_MODE_P, [&] { \ - return DISPATCH_mask_mode(mask_mode_d, MASK_MODE_D, [&] { \ - return DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE( \ - q_scalar_type, kv_scalar_type, DTypeQ, DTypeKV, [&] { \ - using DTypeO = DTypeQ; \ - constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; \ - constexpr bool USE_FP16_QK_REDUCTION = false; \ - return DISPATCH_head_dim(head_dim_qk, HEAD_DIM_QK, [&] { \ - [[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \ - return DISPATCH_BOOL(window_left_p > -1, USE_SLIDING_WINDOW_P, [&] { \ - return DISPATCH_BOOL(window_left_d > -1, USE_SLIDING_WINDOW_D, [&] { \ - return DISPATCH_BOOL(false, USE_LOGITS_SOFT_CAP, [&] { \ - using IdType = int32_t; \ - using PrefillParams = SinglePrefillParams;\ - using DecodeParams = BatchPrefillPagedParams; \ - __VA_ARGS__(); \ - return true; \ - }); \ - }); \ - }); \ - }); \ - }); \ - }); \ - }); \ -} +#define ADDITIONAL_FUNC_PARAMS BATCH_PREFILL_ADDITIONAL_FUNC_PARAMS +#define ADDITIONAL_PARAMS_SETTER BATCH_PREFILL_ADDITIONAL_PARAMS_SETTER + +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, \ + POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, \ + USE_FP16_QK_REDUCTION, AttentionVariant, RaggedParams, PagedParams, ...) \ + { \ + DISPATCH_mask_mode(mask_mode, MASK_MODE, [&] { \ + return DISPATCH_PYTORCH_QKV_DTYPE_TO_CTYPE( \ + q_scalar_type, kv_scalar_type, DTypeQ, DTypeKV, [&] { \ + using DTypeO = DTypeQ; \ + using RaggedParams = BatchPrefillRaggedParams; \ + using PagedParams = BatchPrefillPagedParams; \ + constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; \ + constexpr bool USE_FP16_QK_REDUCTION = false; \ + constexpr bool use_custom_mask = MASK_MODE == MaskMode::kCustom; \ + return DISPATCH_head_dim(head_dim_qk, HEAD_DIM_QK, [&] { \ + [[maybe_unused]] constexpr int HEAD_DIM_VO = HEAD_DIM_QK; \ + return DISPATCH_BOOL(window_left > -1, USE_SLIDING_WINDOW, [&] { \ + return DISPATCH_BOOL(logits_soft_cap > 0.f, USE_LOGITS_SOFT_CAP, [&] { \ + using AttentionVariant = \ + DefaultAttention; \ + __VA_ARGS__(); \ + return true; \ + }); \ + }); \ + }); \ + }); \ + }); \ + } diff --git a/csrc/pod_customize_config.jinja b/csrc/pod_customize_config.jinja index 0e3c3f51dd..8a04d5dfcb 100644 --- a/csrc/pod_customize_config.jinja +++ b/csrc/pod_customize_config.jinja @@ -5,10 +5,17 @@ #include #include #include -#include -#include #include -#include + +#define ADDITIONAL_FUNC_PARAMS {{ additional_func_params }} +#define ADDITIONAL_PARAMS_SETTER {{ additional_params_setter }} + +#define DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, USE_FP16_QK_REDUCTION, AttentionVariant, RaggedParams, PagedParams, ...) \ + DISPATCH_MASK_MODE(mask_mode, MASK_MODE, { \ + constexpr auto use_custom_mask = MASK_MODE == MaskMode::kCustom; \ + using AttentionVariant = {{ variant_name }}; \ + __VA_ARGS__(); \ + }) using namespace flashinfer; @@ -19,25 +26,102 @@ using IdType = {{ idtype }}; constexpr int HEAD_DIM_QK = {{ head_dim_qk }}; constexpr int HEAD_DIM_VO = {{ head_dim_vo }}; constexpr bool USE_FP16_QK_REDUCTION = {{ use_fp16_qk_reduction }}; -constexpr auto USE_LOGITS_SOFT_CAP_P = {{ use_logits_soft_cap_p }}; -constexpr auto POS_ENCODING_MODE_P = {{ pos_encoding_mode_p }}; -constexpr auto USE_SLIDING_WINDOW_P = {{ use_sliding_window_p }}; - -constexpr auto USE_LOGITS_SOFT_CAP_D = {{ use_logits_soft_cap_d }}; -constexpr auto POS_ENCODING_MODE_D = {{ pos_encoding_mode_d }}; -constexpr auto USE_SLIDING_WINDOW_D = {{ use_sliding_window_d }}; - -constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; -constexpr bool USE_LOGITS_SOFT_CAP = false; - -using PrefillParams = SinglePrefillParams; -using DecodeParams = BatchPrefillPagedParams; - -#define DISPATCH_context(MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, \ - USE_SLIDING_WINDOW_P, USE_SLIDING_WINDOW_D, USE_LOGITS_SOFT_CAP, ...) \ - DISPATCH_mask_mode(mask_mode_p, MASK_MODE_P, [&] { \ - return DISPATCH_mask_mode(mask_mode_d, MASK_MODE_D, [&] { \ - __VA_ARGS__(); \ - return true; \ - }); \ -}); +constexpr auto USE_LOGITS_SOFT_CAP = {{ use_logits_soft_cap }}; +constexpr auto POS_ENCODING_MODE = {{ pos_encoding_mode }}; +constexpr auto USE_SLIDING_WINDOW = {{ use_sliding_window }}; + + +struct RaggedParams { + using DTypeQ = DTypeQ; + using DTypeKV = DTypeKV; + using DTypeO = DTypeO; + using IdType = IdType; + + DTypeQ* q; + DTypeKV* k; + DTypeKV* v; + IdType* q_indptr; + IdType* kv_indptr; + DTypeO* o; + float* lse; + uint_fastdiv group_size; + + {{ additional_params_decl }} + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t q_stride_n; + uint32_t q_stride_h; + uint32_t k_stride_n; + uint32_t k_stride_h; + uint32_t v_stride_n; + uint32_t v_stride_h; + int32_t window_left; + + IdType* request_indices; + IdType* qo_tile_indices; + IdType* kv_tile_indices; + IdType* merge_indptr; + IdType* o_indptr; + IdType* kv_chunk_size_ptr; + bool* block_valid_mask; + uint32_t max_total_num_rows; + uint32_t* total_num_rows; + uint32_t padded_batch_size; + bool partition_kv; + + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; + } + + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return kv_indptr[batch_idx + 1] - kv_indptr[batch_idx]; + } +}; + +struct PagedParams { + using DTypeQ = DTypeQ; + using DTypeKV = DTypeKV; + using DTypeO = DTypeO; + using IdType = IdType; + + DTypeQ* q; + paged_kv_t paged_kv; + + IdType* q_indptr; + uint32_t* q_lenptr; + + DTypeO* o; + float* lse; + uint_fastdiv group_size; + + {{ additional_params_decl }} + uint32_t num_qo_heads; + IdType q_stride_n; + IdType q_stride_h; + int32_t window_left; + + IdType* request_indices; + IdType* qo_tile_indices; + IdType* kv_tile_indices; + IdType* merge_indptr; + IdType* o_indptr; + bool* block_valid_mask; + IdType* kv_chunk_size_ptr; + uint32_t max_total_num_rows; + uint32_t* total_num_rows; + uint32_t padded_batch_size; + bool partition_kv; + + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + if(q_lenptr){ + return q_lenptr[batch_idx]; + } + return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; + } + + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return paged_kv.get_length(batch_idx); + } +}; + +{{ variant_decl }} diff --git a/csrc/pod_jit_pybind.cu b/csrc/pod_jit_pybind.cu index f67f50708d..fb0be84a40 100644 --- a/csrc/pod_jit_pybind.cu +++ b/csrc/pod_jit_pybind.cu @@ -16,26 +16,23 @@ #include "pod_config.inc" #include "pytorch_extension_utils.h" -void pod_with_kv_cache_tensor( - // Prefill params - at::Tensor q_p, at::Tensor k_p, at::Tensor v_p, at::Tensor tmp_p, at::Tensor o_p, - std::optional maybe_lse_p, int64_t mask_mode_code_p, int64_t layout_p, - int64_t window_left_p, std::optional maybe_custom_mask_p, - std::optional maybe_alibi_slopes_p, double logits_soft_cap_p, double sm_scale_p, - double rope_rcp_scale_p, double rope_rcp_theta_p, - // Decode params - at::Tensor float_workspace_buffer_d, at::Tensor int_workspace_buffer_d, - at::Tensor plan_info_vec, at::Tensor q_d, at::Tensor paged_k_cache_d, - at::Tensor paged_v_cache_d, at::Tensor qo_indptr_d, at::Tensor paged_kv_indptr_d, - at::Tensor paged_kv_indices_d, at::Tensor paged_kv_last_page_len_d, at::Tensor o_d, - std::optional maybe_lse_d, int64_t mask_mode_code_d, int64_t layout_d, - int64_t window_left_d, std::optional maybe_custom_mask_d, - std::optional maybe_mask_indptr_d, std::optional maybe_alibi_slopes_d, - double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, - // Shared params - int64_t cuda_stream); +at::Tensor PODWithPagedKVCachePlan(at::Tensor float_workspace_buffer, + at::Tensor int_workspace_buffer, + at::Tensor page_locked_int_workspace_buffer, + at::Tensor qo_indptr, at::Tensor kv_indptr, + at::Tensor kv_last_page_len, int64_t total_num_rows, + int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, + int64_t page_size, bool enable_cuda_graph, int64_t head_dim_qk, + int64_t head_dim_vo, bool causal, int64_t cuda_stream); + +void PODWithPagedKVCacheRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, + at::Tensor plan_info_vec, at::Tensor q, at::Tensor paged_k_cache, + at::Tensor paged_v_cache, at::Tensor paged_kv_indices, at::Tensor o, + std::optional maybe_lse, int64_t mask_mode_code, + int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS, + int64_t cuda_stream); TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { - // Batch-request prefill attention with KV-Cache operator - m.def("pod_with_kv_cache_tensor", pod_with_kv_cache_tensor); + m.def("plan", PODWithPagedKVCachePlan); + m.def("paged_run", PODWithPagedKVCacheRun); } diff --git a/csrc/pod_kernel_inst.jinja b/csrc/pod_kernel_inst.jinja index 1584437763..cc11bbffd8 100644 --- a/csrc/pod_kernel_inst.jinja +++ b/csrc/pod_kernel_inst.jinja @@ -1,34 +1,16 @@ -#include -#include -#include -#include -#include #include -#include -#include -#include - -#include "pytorch_conversion_utils.h" -#include "pytorch_extension_utils.h" -#include "aot_default_additional_params.h" -#include "aot_extension_utils.h" - #include "pod_config.inc" -using namespace flashinfer; - namespace flashinfer { -constexpr auto use_custom_mask_p = {{ mask_mode_p }} == MaskMode::kCustom; -constexpr auto use_custom_mask_d = {{ mask_mode_d }} == MaskMode::kCustom; -// Not sure about the below declaration -constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone; -template cudaError_t PODWithKVCacheTensorDispatched< - {{ head_dim_qk }}, {{ head_dim_vo }}, POS_ENCODING_MODE, - {{ use_fp16_qk_reduction }}, {{ mask_mode_p }}, 16, - {{ mask_mode_d }}, {{ variant_name_p }}, - {{ variant_name_d }}, PrefillParams, DecodeParams>( - PrefillParams prefill_params, {{ dtype_o }}* tmp, - DecodeParams decode_params, {{ dtype_o }}* tmp_v, - float *tmp_s, cudaStream_t stream); -}; +constexpr auto use_custom_mask = {{ mask_mode }} == MaskMode::kCustom; + +{% for cta_tile_q_p in [16, 64, 128] %} +{% for cta_tile_q_d in [16, 64, 128] %} +template cudaError_t PODWithPagedKVCacheDispatched< + /*CTA_TILE_Q_P=*/{{cta_tile_q_p}}, /*CTA_TILE_Q_D=*/{{cta_tile_q_d}}, {{head_dim_qk}}, {{head_dim_vo}}, {{pos_encoding_mode}}, {{use_fp16_qk_reduction}}, {{mask_mode}}, + {{ variant_name }}, PagedParams, PagedParams>(PagedParams prefill_params, PagedParams decode_params, {{ dtype_o }}* tmp_v, float* tmp_s, cudaStream_t stream); +{% endfor %} +{% endfor %} + +}; // namespace flashinfer diff --git a/flashinfer/jit/attention/pytorch.py b/flashinfer/jit/attention/pytorch.py index 660f346493..079d1477c8 100644 --- a/flashinfer/jit/attention/pytorch.py +++ b/flashinfer/jit/attention/pytorch.py @@ -328,32 +328,29 @@ def get_single_prefill_uri( def get_pod_uri( + backend: str, dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, - head_dim: int, - pos_encoding_mode_p: int, - use_sliding_window_p: bool, - use_logits_soft_cap_p: bool, - use_fp16_qk_reduction: bool, dtype_idx: torch.dtype, - pos_encoding_mode_d: int, - use_sliding_window_d: bool, - use_logits_soft_cap_d: bool, + head_dim_qk: int, + head_dim_vo: int, + pos_encoding_mode: int, + use_sliding_window: bool, + use_logits_soft_cap: bool, + use_fp16_qk_reduction: bool, ) -> str: return ( f"pod_with_kv_cache_dtype_q_{filename_safe_dtype_map[dtype_q]}_" f"dtype_kv_{filename_safe_dtype_map[dtype_kv]}_" f"dtype_o_{filename_safe_dtype_map[dtype_o]}_" - f"head_dim_{head_dim}_" - f"posenc_p_{pos_encoding_mode_p}_" - f"use_swa_p_{use_sliding_window_p}_" - f"use_logits_cap_p_{use_logits_soft_cap_p}_" - f"posenc_d_{pos_encoding_mode_d}_" - f"use_swa_d_{use_sliding_window_d}_" - f"use_logits_cap_d_{use_logits_soft_cap_d}_" f"dtype_idx_{filename_safe_dtype_map[dtype_idx]}_" - f"f16qk_{use_fp16_qk_reduction}" + f"head_dim_qk_{head_dim_qk}_" + f"head_dim_vo_{head_dim_vo}_" + f"posenc_{pos_encoding_mode}_" + f"use_swa_{use_sliding_window}_" + f"use_logits_cap_{use_logits_soft_cap}_" + f"f16qk_{use_fp16_qk_reduction}" + ("_sm90" if backend == "fa3" else "") ) @@ -494,168 +491,167 @@ def gen_single_prefill_module( def gen_pod_module( + backend: str, dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, - head_dim: int, - pos_encoding_mode_p: int, - use_sliding_window_p: bool, - use_logits_soft_cap_p: bool, - use_fp16_qk_reduction: bool, dtype_idx: torch.dtype, - pos_encoding_mode_d: int, - use_sliding_window_d: bool, - use_logits_soft_cap_d: bool, + head_dim_qk: int, + head_dim_vo: int, + pos_encoding_mode: int, + use_sliding_window: bool, + use_logits_soft_cap: bool, + use_fp16_qk_reduction: bool, ): + assert backend in ["fa2"], "Only fa2 backend is supported for pod module" uri = get_pod_uri( + backend, dtype_q, dtype_kv, dtype_o, - head_dim, - pos_encoding_mode_p, - use_sliding_window_p, - use_logits_soft_cap_p, - use_fp16_qk_reduction, dtype_idx, - pos_encoding_mode_d, - use_sliding_window_d, - use_logits_soft_cap_d, + head_dim_qk, + head_dim_vo, + pos_encoding_mode, + use_sliding_window, + use_logits_soft_cap, + use_fp16_qk_reduction, ) - additional_tensor_names = ["maybe_custom_mask", "maybe_alibi_slopes"] - additional_tensor_dtypes = ["uint8_t", "float"] - additional_scalar_names = [ - "logits_soft_cap", - "sm_scale", - "rope_rcp_scale", - "rope_rcp_theta", - ] - additional_scalar_dtypes = ["float", "float", "float", "float"] - variant_name_p = f"DefaultAttention" - variant_name_d = f"DefaultAttention" - variant_decl = f"#include" + if backend == "fa2": + additional_tensor_names = [ + "maybe_custom_mask", + "maybe_mask_indptr", + "maybe_alibi_slopes", + ] + additional_tensor_dtypes = [ + "uint8_t", + "int32_t", + "float", + ] # NOTE(Zihao): int32_t should follow dtype_idx + additional_scalar_names = [ + "logits_soft_cap", + "sm_scale", + "rope_rcp_scale", + "rope_rcp_theta", + ] + additional_scalar_dtypes = ["double", "double", "double", "double"] + variant_name = f"DefaultAttention" + variant_decl = f"#include" + else: + assert False, f"Unsupported backend: {backend}" return gen_customize_pod_module( + backend, uri, dtype_q, dtype_kv, dtype_o, dtype_idx, - head_dim, + head_dim_qk, + head_dim_vo, additional_tensor_names, additional_tensor_dtypes, additional_scalar_names, additional_scalar_dtypes, - variant_name_p, - variant_name_d, + variant_name, variant_decl, - pos_encoding_mode_p=pos_encoding_mode_p, - use_sliding_window_p=use_sliding_window_p, - use_logits_soft_cap_p=use_logits_soft_cap_p, - pos_encoding_mode_d=pos_encoding_mode_d, - use_sliding_window_d=use_sliding_window_d, - use_logits_soft_cap_d=use_logits_soft_cap_d, + pos_encoding_mode=pos_encoding_mode, + use_sliding_window=use_sliding_window, + use_logits_soft_cap=use_logits_soft_cap, use_fp16_qk_reduction=use_fp16_qk_reduction, ) def gen_customize_pod_module( + backend: str, uri: str, dtype_q: torch.dtype, dtype_kv: torch.dtype, dtype_o: torch.dtype, - dtype_idx: torch.dtype, - head_dim: int, + idtype: torch.dtype, + head_dim_qk: int, + head_dim_vo: int, additional_tensor_names: List[str], additional_tensor_dtypes: List[str], additional_scalar_names: List[str], additional_scalar_dtypes: List[str], - variant_name_p: str, - variant_name_d: str, + variant_name: str, variant_decl: str, - pos_encoding_mode_p: int = 0, - use_sliding_window_p: bool = False, - use_logits_soft_cap_p: bool = False, - pos_encoding_mode_d: int = 0, - use_sliding_window_d: bool = False, - use_logits_soft_cap_d: bool = False, + pos_encoding_mode: int = 0, + use_sliding_window: bool = False, + use_logits_soft_cap: bool = False, use_fp16_qk_reduction: bool = False, ): - gen_directory = FLASHINFER_GEN_SRC_DIR / uri - - ( - additional_params_decl, - additional_func_params, - additional_params_setter, - ) = generate_additional_params( - additional_tensor_names, - additional_tensor_dtypes, - additional_scalar_names, - additional_scalar_dtypes, - ) - - with open(FLASHINFER_CSRC_DIR / "pod_customize_config.jinja") as f: - config_templ = jinja2.Template(f.read()) - - with open(FLASHINFER_CSRC_DIR / "pod_kernel_inst.jinja") as f: - kernel_inst_templ = jinja2.Template(f.read()) - kwargs = { - "additional_func_params": additional_func_params, - "additional_params_decl": additional_params_decl, - "additional_params_setter": additional_params_setter, "variant_decl": variant_decl, - "variant_name_p": variant_name_p, - "variant_name_d": variant_name_d, + "variant_name": variant_name, "dtype_q": dtype_map[dtype_q], "dtype_kv": dtype_map[dtype_kv], "dtype_o": dtype_map[dtype_o], - "idtype": dtype_map[dtype_idx], - "head_dim_qk": head_dim, - "head_dim_vo": head_dim, - "pos_encoding_mode_p": pos_encoding_mode_literal[pos_encoding_mode_p], - "pos_encoding_mode_d": pos_encoding_mode_literal[pos_encoding_mode_d], - "use_sliding_window_p": str(use_sliding_window_p).lower(), - "use_logits_soft_cap_p": str(use_logits_soft_cap_p).lower(), - "use_sliding_window_d": str(use_sliding_window_d).lower(), - "use_logits_soft_cap_d": str(use_logits_soft_cap_d).lower(), + "idtype": dtype_map[idtype], + "head_dim_qk": head_dim_qk, + "head_dim_vo": head_dim_vo, + "pos_encoding_mode": pos_encoding_mode_literal[pos_encoding_mode], + "use_sliding_window": str(use_sliding_window).lower(), + "use_logits_soft_cap": str(use_logits_soft_cap).lower(), "use_fp16_qk_reduction": str(use_fp16_qk_reduction).lower(), } + if backend == "fa2": + gen_directory = FLASHINFER_GEN_SRC_DIR / uri - generated_inc_str = config_templ.render( - **kwargs, - ) + (additional_params_decl, additional_func_params, additional_params_setter) = ( + generate_additional_params( + additional_tensor_names, + additional_tensor_dtypes, + additional_scalar_names, + additional_scalar_dtypes, + ) + ) + kwargs |= { + "additional_params_decl": additional_params_decl, + "additional_func_params": additional_func_params, + "additional_params_setter": additional_params_setter, + } - os.makedirs(gen_directory, exist_ok=True) + with open(FLASHINFER_CSRC_DIR / "pod_customize_config.jinja") as f: + config_templ = jinja2.Template(f.read()) - source_paths = [] + with open(FLASHINFER_CSRC_DIR / "pod_kernel_inst.jinja") as f: + kernel_inst_templ = jinja2.Template(f.read()) - for mask_mode_p in [0, 1, 2]: - for mask_mode_d in [0, 1, 2]: - kwargs["mask_mode_p"] = mask_mode_literal[mask_mode_p] - kwargs["mask_mode_d"] = mask_mode_literal[mask_mode_d] + generated_inc_str = config_templ.render( + **kwargs, + ) - filename = f"pod_kernel_mask_{mask_mode_p}p_{mask_mode_d}d.cu" - dest_path = gen_directory / filename + os.makedirs(gen_directory, exist_ok=True) + + source_paths = [] + + for mask_mode in [0, 1, 2]: + dest_path = gen_directory / f"pod_paged_kernel_mask_{mask_mode}.cu" source_paths.append(dest_path) source = kernel_inst_templ.render( + mask_mode=mask_mode_literal[mask_mode], **kwargs, ) write_if_different(dest_path, source) - for filename in [ - "pod.cu", - "pod_jit_pybind.cu", - ]: - src_path = FLASHINFER_CSRC_DIR / filename - dest_path = gen_directory / filename - source_paths.append(dest_path) - with open(src_path, "r") as f: - source = f.read() - write_if_different(dest_path, source) + for filename in [ + "pod.cu", + "pod_jit_pybind.cu", + ]: + src_path = FLASHINFER_CSRC_DIR / filename + dest_path = gen_directory / filename + source_paths.append(dest_path) + with open(src_path, "r") as f: + source = f.read() + write_if_different(dest_path, source) - generated_config_path = gen_directory / "pod_config.inc" - write_if_different(generated_config_path, generated_inc_str) - return load_cuda_ops(uri, source_paths) + generated_config_path = gen_directory / "pod_config.inc" + write_if_different(generated_config_path, generated_inc_str) + return load_cuda_ops(uri, source_paths) + else: + assert False, f"Unsupported backend: {backend}" def gen_batch_decode_module( diff --git a/flashinfer/pod.py b/flashinfer/pod.py index 34fdcf52fc..4c47b652b9 100644 --- a/flashinfer/pod.py +++ b/flashinfer/pod.py @@ -22,19 +22,8 @@ import torch -from .decode import get_batch_decode_module -from .jit import ( - gen_batch_decode_module, - gen_batch_prefill_module, - gen_customize_batch_prefill_module, - gen_pod_module, - gen_single_prefill_module, - get_pod_uri, - has_prebuilt_ops, - prebuilt_ops_uri, -) +from .jit import gen_pod_module, get_pod_uri, has_prebuilt_ops, prebuilt_ops_uri from .page import block_sparse_indices_to_vector_sparse_offsets, get_seq_lens -from .prefill import get_batch_prefill_module from .quantization import packbits, segment_packbits from .utils import ( MaskMode, @@ -42,13 +31,10 @@ TensorLayout, _check_cached_qkv_data_type, _check_kv_layout, - _check_pos_encoding_mode, + _check_shape_dtype_device, _get_cache_alibi_slopes_buf, - _get_cache_buf, - _get_range_buf, _unpack_paged_kv_cache, canonicalize_torch_dtype, - determine_attention_backend, get_cuda_stream, is_float8, register_custom_op, @@ -58,21 +44,26 @@ _pod_modules = {} -def get_pod_module(*args): - global _pod_modules - if args not in _pod_modules: - uri = get_pod_uri(*args) +def get_pod_module(backend): + def backend_module(*args): + global _pod_modules + if args not in _pod_modules: + uri = get_pod_uri(backend, *args) - if has_prebuilt_ops and uri in prebuilt_ops_uri: - _kernels = torch.ops.flashinfer_kernels - # torch library for pod_with_kv_cache - # No tensor deprecated due to poor performance. Just use tensor cores for both. - run_tensor = _kernels.pod_with_kv_cache_tensor - else: - run_tensor = gen_pod_module(*args).pod_with_kv_cache_tensor - # Register the module - _pod_modules[args] = SimpleNamespace(run_tensor=run_tensor) - return _pod_modules[args] + if has_prebuilt_ops and uri in prebuilt_ops_uri: + assert False, "Prebuilt ops are not supported for POD module" + else: + module = gen_pod_module(backend, *args) + plan_func = module.plan + paged_run_func = module.paged_run + + _pod_modules[args] = SimpleNamespace( + plan=plan_func, + paged_run=paged_run_func, + ) + return _pod_modules[args] + + return backend_module class PODWithPagedKVCacheWrapper: @@ -146,9 +137,13 @@ def __init__( float_workspace_buffer: torch.Tensor, kv_layout: str = "NHD", use_cuda_graph: bool = False, - paged_kv_indptr_buffer: Optional[torch.Tensor] = None, - paged_kv_indices_buffer: Optional[torch.Tensor] = None, - paged_kv_last_page_len_buffer: Optional[torch.Tensor] = None, + qo_indptr_buf: Optional[torch.Tensor] = None, + paged_kv_indptr_buf: Optional[torch.Tensor] = None, + paged_kv_indices_buf: Optional[torch.Tensor] = None, + paged_kv_last_page_len_buf: Optional[torch.Tensor] = None, + custom_mask_buf: Optional[torch.Tensor] = None, + mask_indptr_buf: Optional[torch.Tensor] = None, + backend: str = "auto", jit_args: Optional[List[Any]] = None, ) -> None: r"""Constructor of :class:`PODWithPagedKVCacheWrapper`. @@ -189,69 +184,36 @@ def __init__( otherwise, the wrapper will use default attention implementation. """ _check_kv_layout(kv_layout) - """ - if jit_args is not None: - if use_tensor_cores: - self._jit_module = get_batch_prefill_jit_module( - jit_args[0], gen_customize_batch_prefill_module("fa2", *jit_args) - ) - else: - self._jit_module = get_batch_decode_jit_module( - jit_args[0], gen_customize_batch_decode_module(*jit_args) - ) - else: - """ - # Override options. Only tensor core version is performant. - use_tensor_cores = True self._jit_module = None self._kv_layout = kv_layout self._float_workspace_buffer = float_workspace_buffer self.device = float_workspace_buffer.device + self._int_workspace_buffer = torch.empty( (8 * 1024 * 1024,), dtype=torch.uint8, device=self.device ) self._pin_memory_int_workspace_buffer = torch.empty( - (8 * 1024 * 1024,), - dtype=torch.uint8, - pin_memory=True, + self._int_workspace_buffer.shape, + dtype=self._int_workspace_buffer.dtype, device="cpu", + pin_memory=True, ) + self._use_cuda_graph = use_cuda_graph if use_cuda_graph: - if not torch.is_tensor(paged_kv_indptr_buffer): - raise ValueError( - "paged_kv_indptr_buffer should be a torch.Tensor in cudagraph mode" - ) - if not torch.is_tensor(paged_kv_indices_buffer): - raise ValueError( - "paged_kv_indices_buffer should be a torch.Tensor in cudagraph mode" - ) - if not torch.is_tensor(paged_kv_last_page_len_buffer): - raise ValueError( - "paged_kv_last_page_len_buffer should be a torch.Tensor in cudagraph mode" - ) - self._fixed_batch_size = len(paged_kv_last_page_len_buffer) - if len(paged_kv_indptr_buffer) != self._fixed_batch_size + 1: - raise ValueError( - "The size of paged_kv_indptr_buffer should be batch_size + 1" - ) + assert False, "CudaGraph is not supported for PODWithPagedKVCacheWrapper" else: self._fixed_batch_size = 0 - self._paged_kv_indptr_buf = paged_kv_indptr_buffer - self._paged_kv_indices_buf = paged_kv_indices_buffer - self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buffer - self._use_tensor_cores = use_tensor_cores - self._use_cuda_graph = use_cuda_graph - - if use_cuda_graph: - # NOTE(Zihao): if once created, no need to update it in plan/run - self._qo_indptr_buf = torch.arange( - self._fixed_batch_size + 1, - dtype=torch.int32, - device=float_workspace_buffer.device, - ) + self._qo_indptr_buf = qo_indptr_buf + self._paged_kv_indptr_buf = paged_kv_indptr_buf + self._paged_kv_indices_buf = paged_kv_indices_buf + self._paged_kv_last_page_len_buf = paged_kv_last_page_len_buf + self._custom_mask_buf = custom_mask_buf + self._mask_indptr_buf = mask_indptr_buf + self._max_total_num_rows = None + self._backend = backend @property def is_cuda_graph_enabled(self) -> bool: @@ -283,67 +245,109 @@ def reset_workspace_buffer( def plan( self, - indptr: torch.Tensor, - indices: torch.Tensor, - last_page_len: torch.Tensor, + qo_indptr: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, num_qo_heads: int, num_kv_heads: int, - head_dim: int, + head_dim_qk: int, page_size: int, + head_dim_vo: Optional[int] = None, + custom_mask: Optional[torch.Tensor] = None, + packed_custom_mask: Optional[torch.Tensor] = None, + causal: bool = False, pos_encoding_mode: str = "NONE", - window_left: int = -1, - q_data_type: Optional[Union[str, torch.dtype]] = "float16", - kv_data_type: Optional[Union[str, torch.dtype]] = None, - data_type: Optional[Union[str, torch.dtype]] = None, + use_fp16_qk_reduction: bool = False, sm_scale: Optional[float] = None, + window_left: int = -1, + logits_soft_cap: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, + q_data_type: Union[str, torch.dtype] = "float16", + kv_data_type: Optional[Union[str, torch.dtype]] = None, non_blocking: bool = False, ) -> None: - r"""Plan POD's batch decode for given problem specification. + r"""Plan batch prefill/append attention on Paged KV-Cache for given problem specification. Parameters ---------- - indptr : torch.Tensor - The indptr of the paged kv cache, shape: ``[batch_size + 1]`` - indices : torch.Tensor - The page indices of the paged kv cache, shape: ``[qo_indptr[-1]]`` - last_page_len : torch.Tensor - The number of entries in the last page of each request in the paged kv - cache, shape: ``[batch_size]`` + qo_indptr : torch.Tensor + The indptr of the query/output tensor, shape: ``[batch_size + 1]``. + paged_kv_indptr : torch.Tensor + The indptr of the paged kv-cache, shape: ``[batch_size + 1]``. + paged_kv_indices : torch.Tensor + The page indices of the paged kv-cache, shape: ``[qo_indptr[-1]]``. + paged_kv_last_page_len : torch.Tensor + The number of entries in the last page of each request in the paged + kv-cache, shape: ``[batch_size]``. num_qo_heads : int - The number of query/output heads + The number of query/output heads. num_kv_heads : int - The number of key/value heads - head_dim : int - The dimension of the heads + The number of key/value heads. + head_dim_qk : int + The dimension of the query/key heads. page_size : int - The page size of the paged kv cache + The size of each page in the paged kv-cache. + head_dim_vo : Optional[int] + The dimension of the value/output heads, if not provided, will be set to + ``head_dim_qk``. + custom_mask : Optional[torch.Tensor] + The flattened boolean mask tensor, shape: ``(sum(q_len[i] * k_len[i] for i in range(batch_size))``. + The elements in the mask tensor should be either ``True`` or ``False``, + where ``False`` means the corresponding element in the attention matrix will be + masked out. + + Please refer to the :ref:`mask layout ` for more details about flattened + layout of mask tensor. + + When :attr:`custom_mask` is provided, and :attr:`packed_custom_mask` is not, the + function will pack the custom mask tensor into a 1D packed mask tensor, which introduces + additional overhead. + packed_custom_mask : Optional[torch.Tensor] + The 1D packed uint8 mask tensor, if provided, the :attr:`custom_mask` will be ignored. + The packed mask tensor is generated by :func:`flashinfer.quantization.packbits`. + causal : bool + Whether to apply causal mask to the attention matrix. + This is only effective when :attr:`custom_mask` is not provided in + :meth:`plan`. pos_encoding_mode : str The position encoding applied inside attention kernels, could be ``NONE``/``ROPE_LLAMA`` (LLAMA style rotary embedding) /``ALIBI``. - Defaults to ``NONE``. + Default is ``NONE``. + use_fp16_qk_reduction : bool + Whether to use f16 for qk reduction (faster at the cost of slight precision + loss). window_left : int The left (inclusive) window size for the attention window, when set to ``-1``, the window size will be set to the full length of the sequence. Defaults to ``-1``. - q_data_type : Optional[Union[str, torch.dtype]] + logits_soft_cap : Optional[float] + The attention logits soft capping value (used in Gemini, Grok and Gemma-2, etc.), if not + provided, will be set to ``0``. If greater than 0, the logits will be capped according to + formula: + :math:`\texttt{logits_soft_cap} \times \mathrm{tanh}(x / \texttt{logits_soft_cap})`, + where :math:`x` is the input logits. + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to + ``1.0 / sqrt(head_dim)``. + rope_scale : Optional[float] + The scale used in RoPE interpolation, if not provided, will be set to + ``1.0``. + rope_theta : Optional[float] + The theta used in RoPE, if not provided, will be set to ``1e4``. + q_data_type : Union[str, torch.dtype] The data type of the query tensor, defaults torch.float16. kv_data_type : Optional[Union[str, torch.dtype]] - The data type of the key/value tensor. If None, will be set to - ``q_data_type``. Defaults to ``None``. - data_type: Optional[Union[str, torch.dtype]] - The data type of both the query and key/value tensors. Defaults to torch.float16. - data_type is deprecated, please use q_data_type and kv_data_type instead. + The data type of the key/value tensor. If None, will be set to :attr:`q_data_type`. non_blocking : bool Whether to copy the input tensors to the device asynchronously, defaults to ``False``. If ``True``, user should synchronize before calling :meth:`run` or cuda graph replay. - Note ---- The :meth:`plan` method should be called before any :meth:`run` or :meth:`run_return_lse` calls, auxiliary data structures will be created - during this call and cached for multiple run calls. + during this call and cached for multiple kernel runs. The ``num_qo_heads`` must be a multiple of ``num_kv_heads``. If ``num_qo_heads`` is not equal to ``num_kv_heads``, the function will use @@ -351,100 +355,91 @@ def plan( The :meth:`plan` method cannot be used in Cuda Graph or in ``torch.compile``. """ - # Logits soft cap is not supported currently - logits_soft_cap = False - batch_size = len(last_page_len) + q_data_type = canonicalize_torch_dtype(q_data_type) + if kv_data_type is None: + kv_data_type = q_data_type + kv_data_type = canonicalize_torch_dtype(kv_data_type) + if logits_soft_cap is None: logits_soft_cap = 0.0 + if head_dim_vo is None: + head_dim_vo = head_dim_qk + + batch_size = len(qo_indptr) - 1 + if custom_mask is not None or packed_custom_mask is not None: + assert False, "Not supported" + if packed_custom_mask is None and custom_mask is not None: + assert ( + False + ), "custom_mask is not supported, please use packed_custom_mask instead" + + qo_indptr_host = qo_indptr.to("cpu") + paged_kv_indptr_host = paged_kv_indptr.to("cpu") + paged_kv_last_page_len_host = paged_kv_last_page_len.to("cpu") + total_num_rows = qo_indptr_host[-1] - qo_indptr_host = _get_range_buf(batch_size + 1, "cpu") if self.is_cuda_graph_enabled: - if batch_size != self._fixed_batch_size: - raise ValueError( - "The batch size should be fixed in cudagraph mode, the runtime batch size {} " - " mismatches the batch size set during initialization {}".format( - batch_size, self._fixed_batch_size - ) - ) - if len(indices) > len(self._paged_kv_indices_buf): - raise ValueError( - "The size of indices should be less than or equal to the allocated buffer" - ) - self._paged_kv_indptr_buf.copy_(indptr, non_blocking=non_blocking) - self._paged_kv_indices_buf[: len(indices)].copy_( - indices, non_blocking=non_blocking - ) - self._paged_kv_last_page_len_buf.copy_( - last_page_len, non_blocking=non_blocking - ) + assert False, "CudaGraph is not supported for PODWithPagedKVCacheWrapper" else: - self._paged_kv_indptr_buf = indptr.to( + self._qo_indptr_buf = qo_indptr.to(self.device, non_blocking=non_blocking) + self._paged_kv_indptr_buf = paged_kv_indptr.to( self.device, non_blocking=non_blocking ) - self._paged_kv_indices_buf = indices.to( + self._paged_kv_indices_buf = paged_kv_indices.to( self.device, non_blocking=non_blocking ) - self._paged_kv_last_page_len_buf = last_page_len.to( + self._paged_kv_last_page_len_buf = paged_kv_last_page_len.to( self.device, non_blocking=non_blocking ) - self._qo_indptr_buf = qo_indptr_host.to( - self.device, non_blocking=non_blocking - ) - - indptr_host = indptr.to("cpu") - last_page_len_host = last_page_len.to("cpu") - - if data_type is not None: - if q_data_type is None: - q_data_type = data_type - if kv_data_type is None: - kv_data_type = data_type - - q_data_type = canonicalize_torch_dtype(q_data_type) - if kv_data_type is None: - kv_data_type = q_data_type - kv_data_type = canonicalize_torch_dtype(kv_data_type) self._cached_q_data_type = q_data_type self._cached_kv_data_type = kv_data_type - kv_lens_arr_host = get_seq_lens(indptr_host, last_page_len_host, page_size) + if self._jit_module is not None: self._cached_module = self._jit_module else: - self._cached_module = get_batch_prefill_module("fa2")( + if self._backend == "auto": + # only support fa2 + self._backend = "fa2" + + get_module_args = ( q_data_type, kv_data_type, q_data_type, - indptr.dtype, - head_dim, # head_dim_qk - head_dim, # head_dim_vo + paged_kv_indptr.dtype, + head_dim_qk, + head_dim_vo, PosEncodingMode[pos_encoding_mode].value, - window_left != -1, # use_sliding_window + window_left >= 0, # use_sliding_window logits_soft_cap > 0, # use_logits_soft_cap - False, # use_fp16_qk_reduction + use_fp16_qk_reduction, ) + + self._cached_module = get_pod_module(self._backend)(*get_module_args) + with self.device as device: self._plan_info = self._cached_module.plan( self._float_workspace_buffer, self._int_workspace_buffer, self._pin_memory_int_workspace_buffer, qo_indptr_host, - indptr_host, - kv_lens_arr_host, - batch_size, # total_num_rows + paged_kv_indptr_host, + paged_kv_last_page_len_host, + self._max_total_num_rows or total_num_rows, batch_size, num_qo_heads, num_kv_heads, page_size, self.is_cuda_graph_enabled, - head_dim, - head_dim, - False, # causal + head_dim_qk, + head_dim_vo, + causal, get_cuda_stream(device), ) - self._indptr_type = indptr.dtype + self._causal = causal self._pos_encoding_mode = pos_encoding_mode + self._use_fp16_qk_reduction = use_fp16_qk_reduction self._window_left = window_left self._logits_soft_cap = logits_soft_cap self._sm_scale = sm_scale @@ -453,184 +448,244 @@ def plan( begin_forward = plan + def forward( + self, + q: torch.Tensor, + paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + causal: bool = False, + pos_encoding_mode: str = "NONE", + use_fp16_qk_reduction: bool = False, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + window_left: int = -1, + logits_soft_cap: Optional[float] = None, + sm_scale: Optional[float] = None, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, + ) -> torch.Tensor: + r"""Warning: This function is deprecated, please use :meth:`run` instead.""" + self._causal = causal + self._pos_encoding_mode = pos_encoding_mode + self._use_fp16_qk_reduction = use_fp16_qk_reduction + self._window_left = window_left + self._logits_soft_cap = logits_soft_cap + self._sm_scale = sm_scale + self._rope_scale = rope_scale + self._rope_theta = rope_theta + return self.run(q, paged_kv_cache, k_scale=k_scale, v_scale=v_scale) + + @overload def run( self, - # Main params (prefill and decode) - q_p: torch.Tensor, - k_p: torch.Tensor, - v_p: torch.Tensor, - q_d: torch.Tensor, - paged_kv_cache_d: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], - # Prefill options - custom_mask_p: Optional[torch.Tensor] = None, - packed_custom_mask_p: Optional[torch.Tensor] = None, - causal_p: bool = False, - kv_layout_p: str = "NHD", - pos_encoding_mode_p: str = "NONE", - sm_scale_p: Optional[float] = None, - window_left_p: int = -1, - rope_scale_p: Optional[float] = None, - rope_theta_p: Optional[float] = None, - return_lse_p: bool = False, - # Decode options - custom_mask_d: Optional[torch.Tensor] = None, - packed_custom_mask_d: Optional[torch.Tensor] = None, - causal_d: bool = False, - kv_layout_d: str = "NHD", - pos_encoding_mode_d: str = "NONE", - sm_scale_d: Optional[float] = None, - window_left_d: int = -1, - rope_scale_d: Optional[float] = None, - rope_theta_d: Optional[float] = None, - q_scale: Optional[float] = None, + q: torch.Tensor, + paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + *args, k_scale: Optional[float] = None, v_scale: Optional[float] = None, - return_lse_d: bool = False, - use_fp16_qk_reduction: bool = False, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + return_lse: Literal[False] = False, + ) -> torch.Tensor: ... + + @overload + def run( + self, + q: torch.Tensor, + paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + *args, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + return_lse: Literal[True] = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: ... + + def run( + self, + q: torch.Tensor, + paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], *args, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - r"""Compute POD-attention for a batch of requests.""" - # Currently unsupported - logits_soft_cap_p = None - logits_soft_cap_d = None - # Prefill setup - _check_pos_encoding_mode(pos_encoding_mode_p) - _check_kv_layout(kv_layout_p) - tmp_p = _get_cache_buf("pod_with_kv_cache_tmp", 32 * 1024 * 1024, q_p.device) - if logits_soft_cap_p is None: - logits_soft_cap_p = 0.0 - if sm_scale_p is None: - sm_scale_p = 1.0 / math.sqrt(q_p.size(-1)) - if rope_scale_p is None: - rope_scale_p = 1.0 - if rope_theta_p is None: - rope_theta_p = 1e4 - if custom_mask_p is not None and packed_custom_mask_p is None: - # create packed custom mask from custom mask - packed_custom_mask_p = packbits( - custom_mask_p.contiguous().view(-1), bitorder="little" - ) + r"""Compute batch prefill/append attention between query and paged kv-cache. - if packed_custom_mask_p is not None: - mask_mode_p = MaskMode.CUSTOM.value + Parameters + ---------- + q : torch.Tensor + The query tensor, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]`` + paged_kv_cache : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + The paged KV-Cache stored as a tuple of tensors or a single tensor: + + * a tuple ``(k_cache, v_cache)`` of 4-D tensors, each with shape: + ``[max_num_pages, page_size, num_kv_heads, head_dim]`` if :attr:`kv_layout` is ``NHD``, + and ``[max_num_pages, num_kv_heads, page_size, head_dim]`` if :attr:`kv_layout` is ``HND``. + + * a single 5-D tensor with shape: + ``[max_num_pages, 2, page_size, num_kv_heads, head_dim]`` if + :attr:`kv_layout` is ``NHD``, and + ``[max_num_pages, 2, num_kv_heads, page_size, head_dim]`` if + :attr:`kv_layout` is ``HND``. Where ``paged_kv_cache[:, 0]`` is the key-cache and + ``paged_kv_cache[:, 1]`` is the value-cache. + + *args + Additional arguments for custom kernels. + k_scale : Optional[float] + The calibration scale of key for fp8 input, if not provided, will be set to ``1.0``. + v_scale : Optional[float] + The calibration scale of value for fp8 input, if not provided, will be set to ``1.0``. + out : Optional[torch.Tensor] + The output tensor, if not provided, will be allocated internally. + lse : Optional[torch.Tensor] + The log-sum-exp of attention logits, if not provided, will be allocated internally. + return_lse : bool + Whether to return the logsumexp of attention output + + Returns + ------- + Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + If :attr:`return_lse` is ``False``, the attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. + If :attr:`return_lse` is ``True``, a tuple of two tensors: + + * The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. + * The logsumexp of attention output, shape: ``[qo_indptr[-1], num_qo_heads]``. + """ + k_cache, v_cache = _unpack_paged_kv_cache(paged_kv_cache, self._kv_layout) + _check_cached_qkv_data_type( + q, k_cache, self._cached_q_data_type, self._cached_kv_data_type + ) + stride_block = k_cache.stride(0) + if self._kv_layout == "NHD": + page_size = k_cache.shape[1] + stride_n = k_cache.stride(1) else: - if causal_p: - mask_mode_p = MaskMode.CAUSAL.value + page_size = k_cache.shape[2] + stride_n = k_cache.stride(2) + window_left = self._window_left + logits_soft_cap = self._logits_soft_cap + sm_scale = self._sm_scale + rope_scale = self._rope_scale + rope_theta = self._rope_theta + if logits_soft_cap is None: + logits_soft_cap = 0.0 + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(q.size(-1)) + if k_scale is not None: + sm_scale *= k_scale + if rope_scale is None: + rope_scale = 1.0 + if rope_theta is None: + rope_theta = 1e4 + if return_lse: + if lse is None: + lse = torch.empty( + (q.size(0), q.size(1)), dtype=torch.float32, device=q.device + ) else: - mask_mode_p = MaskMode.NON_CAUSAL.value + _check_shape_dtype_device( + lse, (q.size(0), q.size(1)), torch.float32, q.device, "lse" + ) - lse_p = None - if return_lse_p: - lse_p = torch.empty( - (q_p.size(0), q_p.size(1)), dtype=torch.float32, device=q_p.device + if out is None: + out = torch.empty( + q.shape[:-1] + v_cache.shape[-1:], dtype=q.dtype, device=q.device + ) + else: + _check_shape_dtype_device( + out, q.shape[:-1] + v_cache.shape[-1:], q.dtype, q.device, "out" ) - out_p = torch.empty_like(q_p) + if self._custom_mask_buf is not None: + mask_mode = MaskMode.CUSTOM.value + else: + if self._causal: + mask_mode = MaskMode.CAUSAL.value + else: + mask_mode = MaskMode.NON_CAUSAL.value - # Decode setup - k_cache_d, v_cache_d = _unpack_paged_kv_cache(paged_kv_cache_d, self._kv_layout) - _check_cached_qkv_data_type( - q_d, k_cache_d, self._cached_q_data_type, self._cached_kv_data_type - ) - # TODO_AK: Where are these coming from? - pos_encoding_mode_d = self._pos_encoding_mode - window_left_d = self._window_left - logits_soft_cap_d = self._logits_soft_cap - sm_scale_d = self._sm_scale - rope_scale_d = self._rope_scale - rope_theta_d = self._rope_theta - _check_pos_encoding_mode(pos_encoding_mode_d) - # What are the above for and what are the below? - if logits_soft_cap_d is None: - logits_soft_cap_d = 0.0 - if sm_scale_d is None: - head_dim = q_d.shape[-1] - sm_scale_d = 1.0 / math.sqrt(head_dim) - if q_scale is not None: - sm_scale_d *= q_scale - if k_scale is not None: - sm_scale_d *= k_scale - if rope_scale_d is None: - rope_scale_d = 1.0 - if rope_theta_d is None: - rope_theta_d = 1e4 - - lse_d = None - if return_lse_d: - lse_d = torch.empty( - (q_d.size(0), q_d.size(1)), dtype=torch.float32, device=q_d.device - ) - out_d = torch.empty_like(q_d) - - with q_p.device as device: # device guard - module_getter = get_pod_module( - # Prefill params - q_p.dtype, - k_p.dtype, - q_p.dtype, - q_p.shape[-1], - PosEncodingMode[pos_encoding_mode_p].value, - window_left_p >= 0, # use_sliding_window - logits_soft_cap_p > 0, # use_logits_soft_cap - use_fp16_qk_reduction, - # Decode params - # q_d.dtype, - # self._cached_kv_data_type, - # self._cached_q_data_type, - self._indptr_type, - # head_dim, # head_dim_qk - # head_dim, # head_dim_vo - PosEncodingMode[pos_encoding_mode_d].value, - window_left_d != -1, # use_sliding_window - logits_soft_cap_d > 0, # use_logits_soft_cap - ) - module_getter.run_tensor( - # Prefill params - q_p, - k_p, - v_p, - tmp_p, - out_p, - lse_p, - mask_mode_p, - TensorLayout[kv_layout_p].value, - window_left_p, - packed_custom_mask_p, - _get_cache_alibi_slopes_buf(q_p.shape[1], q_p.device), - logits_soft_cap_p, - sm_scale_p, - 1.0 / rope_scale_p, - 1.0 / rope_theta_p, - # Decode params - self._float_workspace_buffer, - self._int_workspace_buffer, - self._plan_info, - q_d, - k_cache_d, - v_cache_d, - self._qo_indptr_buf, - self._paged_kv_indptr_buf, + if self._backend == "fa3": + # NOTE(Zihao): we divide both stride_block and stride_n by stride_n + # because we will multiply stride_n back in the kernel + sparse_indices = block_sparse_indices_to_vector_sparse_offsets( self._paged_kv_indices_buf, - self._paged_kv_last_page_len_buf, - out_d, - lse_d, - MaskMode.NON_CAUSAL.value, - TensorLayout[self._kv_layout].value, - window_left_d, - None, # packed_custom_mask - None, # mask_indptr_buf - _get_cache_alibi_slopes_buf(q_d.shape[1], q_d.device), - logits_soft_cap_d, - sm_scale_d, - 1.0 / rope_scale_d, - 1.0 / rope_theta_d, - get_cuda_stream(device), + self._paged_kv_indptr_buf, + self._vector_sparse_indices_buffer, # output + self._vector_sparse_indptr_buffer, + self._kv_lens_buffer, + stride_block // stride_n, + 1, # stride_n // stride_n + page_size, ) + sparse_indptr = self._vector_sparse_indptr_buffer + else: + sparse_indices = self._paged_kv_indices_buf + sparse_indptr = self._paged_kv_indptr_buf + + run_args = [ + self._float_workspace_buffer, + self._int_workspace_buffer, + self._plan_info, + q, + k_cache, + v_cache, + sparse_indices, + out, + lse, + mask_mode, + TensorLayout[self._kv_layout].value, + window_left, + ] + if self._jit_module is not None: + run_args.extend(list(args)) + else: + run_args += [ + self._custom_mask_buf, + self._mask_indptr_buf, + _get_cache_alibi_slopes_buf(q.shape[1], q.device), + logits_soft_cap, + sm_scale, + rope_scale, + rope_theta, + get_cuda_stream(q.device), + ] + + self._cached_module.paged_run(*run_args) if v_scale is not None: - out_d *= v_scale + out *= v_scale + + return (out, lse) if return_lse else out - return (out_p, out_d) + run_return_lse = functools.partialmethod(run, return_lse=True) + + def forward_return_lse( + self, + q: torch.Tensor, + paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + causal: bool = False, + pos_encoding_mode: str = "NONE", + use_fp16_qk_reduction: bool = False, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + window_left: int = -1, + logits_soft_cap: Optional[float] = None, + sm_scale: Optional[float] = None, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Warning: This function is deprecated, please use :meth:`run_return_lse` instead.""" + self._causal = causal + self._pos_encoding_mode = pos_encoding_mode + self._use_fp16_qk_reduction = use_fp16_qk_reduction + self._window_left = window_left + self._logits_soft_cap = logits_soft_cap + self._sm_scale = sm_scale + self._rope_scale = rope_scale + self._rope_theta = rope_theta + return self.run_return_lse(q, paged_kv_cache, k_scale, v_scale) def end_forward(self) -> None: r"""Warning: this function is deprecated and has no effect.""" diff --git a/include/flashinfer/attention/default_prefill_params.cuh b/include/flashinfer/attention/default_prefill_params.cuh index b10ebec038..e836ab53dd 100644 --- a/include/flashinfer/attention/default_prefill_params.cuh +++ b/include/flashinfer/attention/default_prefill_params.cuh @@ -278,7 +278,12 @@ struct BatchPrefillPagedParams { DTypeQ* q; paged_kv_t paged_kv; uint8_t* maybe_custom_mask; + + // Add q_lenptr for non-continous qo layout + // which is used in PoD Attention IdType* q_indptr; + uint32_t* q_lenptr; + IdType* maybe_mask_indptr; IdType* maybe_q_rope_offset; // maybe_q_rope_offset is only used for fused-rope attention DTypeO* o; @@ -311,6 +316,7 @@ struct BatchPrefillPagedParams { paged_kv(), maybe_custom_mask(nullptr), q_indptr(nullptr), + q_lenptr(nullptr), maybe_mask_indptr(nullptr), maybe_q_rope_offset(nullptr), o(nullptr), @@ -348,6 +354,7 @@ struct BatchPrefillPagedParams { paged_kv(paged_kv), maybe_custom_mask(maybe_custom_mask), q_indptr(q_indptr), + q_lenptr(nullptr), maybe_mask_indptr(maybe_mask_indptr), maybe_q_rope_offset(maybe_q_rope_offset), o(o), @@ -375,6 +382,9 @@ struct BatchPrefillPagedParams { partition_kv(false) {} __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + if (q_lenptr) { + return q_lenptr[batch_idx]; + } return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; } diff --git a/include/flashinfer/attention/pod.cuh b/include/flashinfer/attention/pod.cuh index d84870b7f5..d9af88ddec 100644 --- a/include/flashinfer/attention/pod.cuh +++ b/include/flashinfer/attention/pod.cuh @@ -25,426 +25,227 @@ namespace flashinfer { -namespace cg = cooperative_groups; -using cp_async::SharedMemFillMode; -using mma::MMAMode; +namespace PoDOp { +constexpr int PREFILL = 0; +constexpr int DECODE = 1; +constexpr int NUM_OPS = 2; +} // namespace PoDOp -enum Operation { - PREFILL = 0, - DECODE = 1, -}; - -template +template __global__ __launch_bounds__(std::max( KTraits_P::NUM_THREADS, - KTraits_D::NUM_THREADS)) void PODWithKVCacheTensorKernel(const uint32_t xsize, - const __grid_constant__ PrefillParams - prefill_params, - const __grid_constant__ DecodeParams - decode_params, - int* tbAssign) { + KTraits_D::NUM_THREADS)) void PODWithPagedKVCacheKernel(const __grid_constant__ PrefillParams + prefill_params, + const __grid_constant__ DecodeParams + decode_params) { extern __shared__ uint8_t smem[]; - // PREFILL VARS - const uint32_t num_kv_heads_p = prefill_params.num_kv_heads; - const uint32_t num_chunks = prefill_params.partition_kv; - const uint32_t qo_len = prefill_params.qo_len; - // DECODE VARS - const uint32_t padded_bsize = decode_params.padded_batch_size; + const uint32_t padded_bsz_p = prefill_params.padded_batch_size; + const uint32_t num_kv_heads_p = prefill_params.paged_kv.num_heads; + const uint32_t padded_bsz_d = decode_params.padded_batch_size; const uint32_t num_kv_heads_d = decode_params.paged_kv.num_heads; - // THREADBLOCKS - const uint32_t prefill_blocks = num_kv_heads_p * xsize * (PartitionKV_P ? num_chunks : 1); - const uint32_t decode_blocks = padded_bsize * num_kv_heads_d; - - int op; - int linear_bid; - // SM-aware CTA scheduler - if (threadIdx.x == 0) { - // TODO_AK: If num_threads dont match, use virtual sub-CTAs. - // Requires changing block-level sync in main prefill/decode kernels. - constexpr int blk_factor_p = 1; - constexpr int blk_factor_d = 1; - - // SM-aware threadblock scheduler code - // Find out which SM this threadblock is scheduled on - int num_SMs; - // WARNING: nsmid has only been tested on A100/H100, and matches SM count - // No guarantee this will work on other GPUs - asm volatile("mov.u32 %0, %nsmid;" : "=r"(num_SMs)); - asm volatile("mov.u32 %0, %smid;" : "=r"(linear_bid)); - const int prefill_slots = (prefill_blocks + blk_factor_p - 1) / blk_factor_p; - const int decode_slots = (decode_blocks + blk_factor_d - 1) / blk_factor_d; - - if (prefill_slots <= decode_slots) { - // Total tags = (decode + prefill) / min(decode, prefill) - // = 1 + decode / prefill; when prefill < decode - const int total_tags = decode_slots / prefill_slots + 1; - // For this SM, what's the next operation we want to run? - op = (atomicAdd(&tbAssign[linear_bid], 1) % total_tags); - if (op > 0) { - op = 1; - } - } else { - // Total tags = (decode + prefill) / min(decode, prefill) - // = 1 + prefill / decode; when decode < prefill - const int pref_tags = prefill_slots / decode_slots; - - // For this SM, what's the next operation we want to run? - op = (atomicAdd(&tbAssign[linear_bid], 1) % (pref_tags + 1)); - if (op < pref_tags) { - op = 0; - } else { - op = 1; - } - } - - // Get the next blockId for that operation - linear_bid = atomicAdd(&tbAssign[num_SMs + op], 1); - // If the blockId obtained exceeds the max blockIds for that op, switch to the other op - if (op == 0 && linear_bid >= prefill_slots) { - linear_bid = atomicAdd(&tbAssign[num_SMs + 1], 1); - op = !op; - } else if (op == 1 && linear_bid >= decode_slots) { - op = !op; - linear_bid = atomicAdd(&tbAssign[num_SMs + 0], 1); - } - // Write the blockId and operation to shared memory - ((int*)smem)[0] = linear_bid; - ((int*)smem)[1] = op; - } - // Sync to wait for dynamic scheduler to finish - __syncthreads(); - // Fetch from shared memory the assigned blockId and operation. - linear_bid = ((int*)smem)[0]; - op = ((int*)smem)[1]; - // Sync to force all threads to wait - __syncthreads(); - - if (op == PREFILL) { - const uint32_t linear_tid = threadIdx.x; + const uint32_t num_blk_p = padded_bsz_p * num_kv_heads_p; + const uint32_t num_blk_d = padded_bsz_d * num_kv_heads_d; + const uint32_t physical_bid = blockIdx.x; + + if (physical_bid < num_blk_p) { + /* OP == PREFILL */ + auto& smem_storage = reinterpret_cast(smem); + + const uint32_t logical_bid = physical_bid % padded_bsz_p; + const uint32_t kv_head_idx = physical_bid / padded_bsz_p; + const uint32_t physical_tid = threadIdx.x; + // Return if threadId exceeds number of threads for this op - if (linear_tid >= 32 * KTraits_P::NUM_WARPS_Q * KTraits_P::NUM_WARPS_KV) return; + if (physical_tid >= WARP_SIZE * KTraits_P::NUM_WARPS_Q * KTraits_P::NUM_WARPS_KV) return; - const dim3 tid = dim3(linear_tid % 32, (linear_tid / 32) % KTraits_P::NUM_WARPS_Q, - (linear_tid / 32) / KTraits_P::NUM_WARPS_Q); - // dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), 1, num_kv_heads); - // dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), num_chunks, num_kv_heads); - // BlockID exceeds limit - if (linear_bid >= prefill_blocks) return; + const dim3 tid = + dim3(physical_tid % WARP_SIZE, (physical_tid / WARP_SIZE) % KTraits_P::NUM_WARPS_Q, + (physical_tid / WARP_SIZE) / KTraits_P::NUM_WARPS_Q); - const uint32_t bx = linear_bid % xsize; - auto& smem_storage = reinterpret_cast(smem); - // Not partition_kv - if constexpr (!PartitionKV_P) { - const uint32_t chunk_idx = 0; - const uint32_t kv_head_idx = linear_bid / xsize; - SinglePrefillWithKVCacheDevice(prefill_params, smem_storage, tid, bx, chunk_idx, - kv_head_idx, 1, num_kv_heads_p); - } else { - const uint32_t chunk_idx = (linear_bid / xsize) % num_chunks; - const uint32_t kv_head_idx = linear_bid / (xsize * num_chunks); - SinglePrefillWithKVCacheDevice(prefill_params, smem_storage, tid, bx, chunk_idx, - kv_head_idx, num_chunks, num_kv_heads_p); - } + BatchPrefillWithPagedKVCacheDevice(prefill_params, smem_storage, tid, logical_bid, + kv_head_idx, num_kv_heads_p); } else /* OP == DECODE */ { auto& smem_storage = reinterpret_cast(smem); - // dim3 nblks_d(padded_batch_size_d, 1, num_kv_heads); - if (linear_bid >= decode_blocks) return; - const uint32_t bx = linear_bid % padded_bsize; - const uint32_t kv_head_idx = linear_bid / padded_bsize; + const uint32_t logical_bid = (physical_bid - num_blk_p) % padded_bsz_d; + const uint32_t kv_head_idx = (physical_bid - num_blk_p) / padded_bsz_d; + const uint32_t physical_tid = threadIdx.x; - // dim3 nthrs_d(32, NUM_WARPS_Q_D, NUM_WARPS_KV_D); - const uint32_t linear_tid = threadIdx.x; // Return if threadId exceeds number of threads for this op - if (linear_tid >= 32 * KTraits_D::NUM_WARPS_Q * KTraits_D::NUM_WARPS_KV) return; + if (physical_tid >= WARP_SIZE * KTraits_D::NUM_WARPS_Q * KTraits_D::NUM_WARPS_KV) return; - const dim3 tid = dim3(linear_tid % 32, (linear_tid / 32) % KTraits_D::NUM_WARPS_Q, - (linear_tid / 32) / KTraits_D::NUM_WARPS_Q); + const dim3 tid = + dim3(physical_tid % WARP_SIZE, (physical_tid / WARP_SIZE) % KTraits_D::NUM_WARPS_Q, + (physical_tid / WARP_SIZE) / KTraits_D::NUM_WARPS_Q); - BatchPrefillWithPagedKVCacheDevice(decode_params, smem_storage, tid, bx, kv_head_idx, - num_kv_heads_d); + BatchPrefillWithPagedKVCacheDevice(decode_params, smem_storage, tid, logical_bid, + kv_head_idx, num_kv_heads_d); } } -template -cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params, - typename PrefillParams::DTypeO* tmp_p, - DecodeParams decode_params, - typename DecodeParams::DTypeO* tmp_v, float* tmp_s, - cudaStream_t stream) { - static_assert(std::is_same::value); - static_assert( - std::is_same::value); - static_assert(std::is_same::value); - // Ensure heads match - assert(prefill_params.num_kv_heads == decode_params.paged_kv.num_heads); - assert(prefill_params.num_qo_heads == decode_params.num_qo_heads); - // Prefill variable setup - using DTypeQ_P = typename PrefillParams::DTypeQ; - using DTypeKV_P = typename PrefillParams::DTypeKV; - using DTypeO_P = typename PrefillParams::DTypeO; +template +cudaError_t PODWithPagedKVCacheDispatched(PrefillParams prefill_params, DecodeParams decode_params, + typename DecodeParams::DTypeO* tmp_v, float* tmp_s, + cudaStream_t stream) { + using DTypeQ = typename PrefillParams::DTypeQ; + using DTypeKV = typename PrefillParams::DTypeKV; + using DTypeO = typename PrefillParams::DTypeO; const uint32_t num_qo_heads = prefill_params.num_qo_heads; - const uint32_t num_kv_heads = prefill_params.num_kv_heads; - const uint32_t qo_len = prefill_params.qo_len; - const uint32_t kv_len = prefill_params.kv_len; - if (kv_len < qo_len && MASK_MODE_P == MaskMode::kCausal) { - std::ostringstream err_msg; - err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be greater than or equal " - "to qo_len, got kv_len" - << kv_len << " and qo_len " << qo_len; - FLASHINFER_ERROR(err_msg.str()); + const uint32_t num_kv_heads = prefill_params.paged_kv.num_heads; + + static_assert(std::is_same::value); + static_assert(std::is_same::value); + static_assert(std::is_same::value); + assert(num_qo_heads == decode_params.num_qo_heads); + assert(num_kv_heads == decode_params.paged_kv.num_heads); + + const uint32_t padded_bsz_p = prefill_params.padded_batch_size; + const uint32_t padded_bsz_d = decode_params.padded_batch_size; + + if (padded_bsz_p == 0 && padded_bsz_d == 0) { + // No request, skip + return cudaSuccess; } - const uint32_t group_size = num_qo_heads / num_kv_heads; - const uint_fastdiv group_size_fastdiv(group_size); constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; - uint32_t cta_tile_q_p = 0; - int64_t unpacked_qo_len = qo_len * group_size; - if (unpacked_qo_len > 64 && HEAD_DIM_VO < 256) { - cta_tile_q_p = 128; - } else { - auto compute_capacity = GetCudaComputeCapability(); - if (compute_capacity.first >= 8) { - // Ampere or newer - if (unpacked_qo_len > 16) { - // avg_packed_qo_len <= 64 - cta_tile_q_p = 64; - } else { - // avg_packed_qo_len <= 16 - cta_tile_q_p = 16; - } - } else { - // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout - cta_tile_q_p = 64; - } - } + using DTypeQKAccum = + typename std::conditional, half, + float>::type; + + // Prefill metadata setups + constexpr uint32_t NUM_MMA_Q_P = get_num_mma_q(CTA_TILE_Q_P); + constexpr uint32_t NUM_WARPS_Q_P = get_num_warps_q(CTA_TILE_Q_P); + constexpr uint32_t NUM_WARPS_KV_P = get_num_warps_kv(CTA_TILE_Q_P); - // Decode vars setup - using DTypeQ_D = typename DecodeParams::DTypeQ; - using DTypeKV_D = typename DecodeParams::DTypeKV; - using DTypeO_D = typename DecodeParams::DTypeO; - const uint32_t padded_batch_size_d = decode_params.padded_batch_size; + // Decode metadata setups constexpr uint32_t NUM_MMA_Q_D = get_num_mma_q(CTA_TILE_Q_D); constexpr uint32_t NUM_WARPS_Q_D = get_num_warps_q(CTA_TILE_Q_D); constexpr uint32_t NUM_WARPS_KV_D = get_num_warps_kv(CTA_TILE_Q_D); - if (padded_batch_size_d == 0) { - // No request, skip - // this won't happen in CUDAGraph mode because we fixed the padded_batch_size - return cudaSuccess; - } + int nblks_p(padded_bsz_p * 1 * num_kv_heads); + int nthrs_p(32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P); + int nblks_d(padded_bsz_d * 1 * num_kv_heads); + int nthrs_d(32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D); - // constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; - // constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; - using DTypeQKAccum_D = - typename std::conditional, half, - float>::type; + int nblks = nblks_p + nblks_d; + int nthrs = max(nthrs_p, nthrs_d); + // Calculate occupancy + // we expect each sm execute two threadblocks int dev_id = 0; FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); int max_smem_per_sm = 0; FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); - // we expect each sm execute two threadblocks - // TODO(Zihao): fix the following computation - const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ_D) * 16) ? 2 : 1; - const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; + int num_sm = 0; + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - constexpr uint32_t max_num_mma_kv_reg_d = - (HEAD_DIM_VO >= 128 && NUM_MMA_Q_D == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && - !USE_FP16_QK_REDUCTION) - ? 2 - : (8 / NUM_MMA_Q_D); - // TODO(Zihao): fix the following computation - const uint32_t max_num_mma_kv_smem_d = - (max_smem_per_threadblock / (16 * HEAD_DIM_QK * sizeof(DTypeQ_D)) - - NUM_MMA_Q_D * NUM_WARPS_Q_D) / - (2 * NUM_WARPS_KV_D); - - // DISPATCH_CTA_TILE_Q(cta_tile_q_p, CTA_TILE_Q_P, { - constexpr size_t CTA_TILE_Q_P = 128; - constexpr uint32_t NUM_WARPS_Q_P = get_num_warps_q(CTA_TILE_Q_P); - constexpr uint32_t NUM_WARPS_KV_P = get_num_warps_kv(CTA_TILE_Q_P); - constexpr uint32_t NUM_MMA_Q_P = get_num_mma_q(CTA_TILE_Q_P); - - using DTypeQKAccum_P = - typename std::conditional, half, - float>::type; - - // we expect each sm execute two threadblocks - // TODO(Zihao): fix the following computation + // Prefill occupancy const int num_ctas_per_sm_p = - max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ_P) * 16) ? 2 : 1; + max_smem_per_sm >= 2 * (CTA_TILE_Q_P * HEAD_DIM_QK * sizeof(DTypeQ) + + (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV_P * sizeof(DTypeKV)) + ? 2 + : 1; const int max_smem_per_threadblock_p = max_smem_per_sm / num_ctas_per_sm_p; - - constexpr uint32_t max_num_mma_kv_reg_p = + const uint32_t max_num_mma_kv_smem_p = + (max_smem_per_threadblock_p - CTA_TILE_Q_P * HEAD_DIM_QK * sizeof(DTypeQ)) / + ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV_P * sizeof(DTypeKV)); + const uint32_t max_num_mma_kv_reg_p = (HEAD_DIM_VO >= 128 && NUM_MMA_Q_P == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !USE_FP16_QK_REDUCTION) ? 2 : (8 / NUM_MMA_Q_P); - // TODO(Zihao): fix the following computation - const uint32_t max_num_mma_kv_smem_p = - (max_smem_per_threadblock_p / (16 * HEAD_DIM_QK * sizeof(DTypeQ_P)) - - NUM_MMA_Q_P * NUM_WARPS_Q_P) / - (2 * NUM_WARPS_KV_P); - // control NUM_MMA_KV for maximum warp occupancy - DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_p, max_num_mma_kv_reg_p), NUM_MMA_KV_P, { - using KTraits_P = KernelTraits; - - if constexpr (KTraits_P::IsInvalid()) { - // Invalid configuration, skip - std::ostringstream err_msg; - err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_P - << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO - << " NUM_MMA_KV=" << NUM_MMA_KV_P << " NUM_WARPS_Q=" << NUM_WARPS_Q_P - << " NUM_WARPS_KV=" << NUM_WARPS_KV_P - << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" - " and report the issue to the developers."; - FLASHINFER_ERROR(err_msg.str()); - } else { - // Decode stuff - // TODO: Is there a way to avoid this nested dispatch? - DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_d, max_num_mma_kv_reg_d), NUM_MMA_KV_D, { - using KTraits_D = - KernelTraits; - if constexpr (KTraits_D::IsInvalid()) { - // Invalid configuration, skip - std::ostringstream err_msg; - err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_D - << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO - << " NUM_MMA_KV=" << NUM_MMA_KV_D << " NUM_WARPS_Q=" << NUM_WARPS_Q_D - << " NUM_WARPS_KV=" << NUM_WARPS_KV_D - << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" - " and report the issue to the developers."; - FLASHINFER_ERROR(err_msg.str()); - } else { - // End decode stuff - constexpr uint32_t num_threads_p = (NUM_WARPS_Q_P * NUM_WARPS_KV_P) * WARP_SIZE; - size_t smem_size_p = sizeof(typename KTraits_P::SharedStorage); - size_t smem_size_d = sizeof(typename KTraits_D::SharedStorage); - - auto kernel = - PODWithKVCacheTensorKernel; - // Prefill: decide num_splits for split-kv - int num_blocks_per_sm = 0; - int num_sm = 0; - FLASHINFER_CUDA_CALL( - cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); - FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, kernel, num_threads_p, smem_size_p)); - uint32_t max_num_kv_chunks = - (num_blocks_per_sm * num_sm) / - (num_kv_heads * ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q)); - uint32_t num_chunks; - if (max_num_kv_chunks > 0) { - uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256); - num_chunks = ceil_div(kv_len, chunk_size); - } else { - num_chunks = 0; - } + // Decode occupancy + const int num_ctas_per_sm_d = + max_smem_per_sm >= 2 * (CTA_TILE_Q_D * HEAD_DIM_QK * sizeof(DTypeQ) + + (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV_D * sizeof(DTypeKV)) + ? 2 + : 1; + const int max_smem_per_threadblock_d = max_smem_per_sm / num_ctas_per_sm_d; + const uint32_t max_num_mma_kv_smem_d = + (max_smem_per_threadblock_d - CTA_TILE_Q_D * HEAD_DIM_QK * sizeof(DTypeQ)) / + ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV_D * sizeof(DTypeKV)); + const uint32_t max_num_mma_kv_reg_d = + (HEAD_DIM_VO >= 128 && NUM_MMA_Q_D == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !USE_FP16_QK_REDUCTION) + ? 2 + : (8 / NUM_MMA_Q_D); - // Setup new prefill params if (not) split - auto o_p = prefill_params.o; - auto lse_p = prefill_params.lse; - float* tmp_lse = (float*)(tmp_p + num_chunks * qo_len * num_qo_heads * HEAD_DIM_VO); - if (num_chunks <= 1 || tmp_p == nullptr) { - // Enough parallelism, do not split-kv - prefill_params.partition_kv = 0; - kernel = PODWithKVCacheTensorKernel; - } else { - // Use cooperative groups to increase occupancy - prefill_params.partition_kv = num_chunks; - prefill_params.o = tmp_p; - prefill_params.lse = tmp_lse; - kernel = - PODWithKVCacheTensorKernel; - } + DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_p, max_num_mma_kv_reg_p), NUM_MMA_KV_P, { + DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_d, max_num_mma_kv_reg_d), NUM_MMA_KV_D, { + using KTraits_P = KernelTraits; + using KTraits_D = KernelTraits; + if constexpr (KTraits_D::IsInvalid() || KTraits_P::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q_P=" << NUM_MMA_Q_P + << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV_P=" << NUM_MMA_KV_P << " NUM_WARPS_Q_P=" << NUM_WARPS_Q_P + << " NUM_WARPS_KV_P=" << NUM_WARPS_KV_P << std::endl; + err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q_D=" << NUM_MMA_Q_D + << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV_D=" << NUM_MMA_KV_D << " NUM_WARPS_Q_D=" << NUM_WARPS_Q_D + << " NUM_WARPS_KV_D=" << NUM_WARPS_KV_D + << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } else { + size_t smem_size_p = sizeof(typename KTraits_P::SharedStorage); + size_t smem_size_d = sizeof(typename KTraits_D::SharedStorage); + size_t smem_size = max(smem_size_p, smem_size_d); + auto kernel = PODWithPagedKVCacheKernel; + FLASHINFER_CUDA_CALL( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + // setup partition_kv metadata + auto o = prefill_params.o; + auto lse = prefill_params.lse; + assert(o == decode_params.o); + assert(lse == decode_params.lse); + if (prefill_params.partition_kv || decode_params.partition_kv) { + assert(tmp_v != nullptr && tmp_s != nullptr); + // either is partitioned will lead to additional merge kernel + prefill_params.o = tmp_v; + prefill_params.lse = tmp_s; + decode_params.o = tmp_v; + decode_params.lse = tmp_s; + } - // Setup new decode params if (not) split - auto o_d = decode_params.o; - auto lse_d = decode_params.lse; - if (tmp_v == nullptr) { - // do not partition kv - decode_params.partition_kv = false; + // Launch kernel + void* args[] = {(void*)&prefill_params, (void*)&decode_params}; + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + + // Post-kernel stuff for split-kv + if (prefill_params.partition_kv || decode_params.partition_kv) { + assert(prefill_params.merge_indptr == decode_params.merge_indptr); + if constexpr (AttentionVariant::use_softmax) { + FLASHINFER_CUDA_CALL(VariableLengthMergeStates( + tmp_v, tmp_s, prefill_params.merge_indptr, o, lse, + (prefill_params.max_total_num_rows + decode_params.max_total_num_rows), nullptr, + num_qo_heads, HEAD_DIM_VO, stream)); } else { - decode_params.partition_kv = true; - decode_params.o = tmp_v; - decode_params.lse = tmp_s; - } - uint32_t xsize = ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q); - int nblks_p(xsize * (prefill_params.partition_kv ? prefill_params.partition_kv : 1) * - num_kv_heads); - int nthrs_p(32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P); - - int nblks_d(padded_batch_size_d * 1 * num_kv_heads); - int nthrs_d(32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D); - - // ******* Select final combined sizes here ******* / - size_t smem_size = max(smem_size_p, smem_size_d); - int nblks = nblks_p + nblks_d; - int nthrs = max(nthrs_p, nthrs_d); - - // printf("Smem: prefill %zu, decode %zu, total %zu\n", smem_size_p, smem_size_d, - // smem_size); printf("Blocks: prefill %d, decode %d, total %d\n", nblks_p, nblks_d, - // nblks); printf("Threads: prefill %d, decode %d, total %d\n", nthrs_p, nthrs_d, nthrs); - // ************************************************ / - - static int* tbAssign = nullptr; - if (tbAssign == nullptr) cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2)); - cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2)); - - // Setup kernel arguments - void* args[] = {(void*)&xsize, (void*)&prefill_params, (void*)&decode_params, - (void*)&tbAssign}; - FLASHINFER_CUDA_CALL( - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - // Launch kernel - FLASHINFER_CUDA_CALL( - cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - - // Post-kernel stuff for split-kv prefill - if (!(num_chunks <= 1 || tmp_p == nullptr)) { - if constexpr (PrefillAttentionVariant::use_softmax) { - FLASHINFER_CUDA_CALL(MergeStates(tmp_p, tmp_lse, o_p, lse_p, num_chunks, qo_len, - num_qo_heads, HEAD_DIM_VO, stream)); - } else { - FLASHINFER_CUDA_CALL( - AttentionSum(tmp_p, o_p, num_chunks, qo_len, num_qo_heads, HEAD_DIM_VO, stream)); - } - } - // Post-kernel stuff for split-kv decode - if (tmp_v != nullptr) { - if constexpr (DecodeAttentionVariant::use_softmax) { - FLASHINFER_CUDA_CALL(VariableLengthMergeStates( - tmp_v, tmp_s, decode_params.merge_indptr, o_d, lse_d, - decode_params.max_total_num_rows, decode_params.total_num_rows, num_qo_heads, - HEAD_DIM_VO, stream)); - } else { - FLASHINFER_CUDA_CALL(VariableLengthAttentionSum( - tmp_v, decode_params.merge_indptr, o_d, decode_params.max_total_num_rows, - decode_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, stream)); - } + FLASHINFER_CUDA_CALL(VariableLengthAttentionSum( + tmp_v, prefill_params.merge_indptr, o, + (prefill_params.max_total_num_rows + decode_params.max_total_num_rows), nullptr, + num_qo_heads, HEAD_DIM_VO, stream)); } } - }); - } + } + }); }); - //}); return cudaSuccess; } diff --git a/include/flashinfer/attention/scheduler.cuh b/include/flashinfer/attention/scheduler.cuh index ae943a7198..5da62d2f07 100644 --- a/include/flashinfer/attention/scheduler.cuh +++ b/include/flashinfer/attention/scheduler.cuh @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -486,7 +487,8 @@ inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in } template -inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, +inline auto PrefillSplitQOKVIndptr(IdType* qo_start_ptr_h, uint32_t* qo_len_ptr_h, + IdType* kv_indptr_h, uint32_t* kv_len_ptr_h, uint32_t total_num_rows, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size, uint32_t max_batch_size_if_split, @@ -500,18 +502,16 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, // step 1: determine packed_qo_len_arr and verify qo_indptr contents. std::vector packed_qo_len_arr(batch_size), kv_len_arr(batch_size); for (uint32_t i = 0; i < batch_size; ++i) { - packed_qo_len_arr[i] = int64_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * int64_t(gqa_group_size); + packed_qo_len_arr[i] = int64_t(qo_len_ptr_h[i]) * int64_t(gqa_group_size); if (packed_qo_len_arr[i] < 0) { std::ostringstream err_msg; - err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] << " - qo_indptr[" << i << "]" - << qo_indptr_h[i] << " should be non-negative"; + err_msg << "qo_len_ptr_h[" << i << "]: " << qo_len_ptr_h[i] << " should be non-negative"; FLASHINFER_ERROR(err_msg.str()); } - kv_len_arr[i] = int64_t(kv_indptr_h[i + 1] - kv_indptr_h[i]); + kv_len_arr[i] = int64_t(kv_len_ptr_h[i]); if (kv_len_arr[i] < 0) { std::ostringstream err_msg; - err_msg << "kv_indptr[" << i + 1 << "]" << kv_indptr_h[i + 1] << " - kv_indptr[" << i << "]" - << kv_indptr_h[i] << " should be non-negative"; + err_msg << "kv_len_ptr_h[" << i << "]: " << kv_len_ptr_h[i] << " should be positive"; FLASHINFER_ERROR(err_msg.str()); } } @@ -605,6 +605,13 @@ struct PrefillPlanInfo { bool enable_cuda_graph; bool split_kv; + // used in PoD attention + int64_t kv_start_ptr_offset; + int64_t kv_len_ptr_offset; + int64_t q_start_ptr_offset; + int64_t q_len_ptr_offset; + int64_t kv_last_page_offset; + PrefillPlanInfo() : padded_batch_size(0), total_num_rows(0), @@ -620,7 +627,12 @@ struct PrefillPlanInfo { s_offset(0), block_valid_mask_offset(0), enable_cuda_graph(false), - split_kv(false) {} + split_kv(false), + kv_start_ptr_offset(0), + kv_len_ptr_offset(0), + q_start_ptr_offset(0), + q_len_ptr_offset(0), + kv_last_page_offset(0) {} // convert PrefillPlanInfo to std::vector std::vector ToVector() const { @@ -638,12 +650,17 @@ struct PrefillPlanInfo { s_offset, block_valid_mask_offset, enable_cuda_graph, - split_kv}; + split_kv, + kv_start_ptr_offset, + kv_len_ptr_offset, + q_start_ptr_offset, + q_len_ptr_offset, + kv_last_page_offset}; } // From std::vector to PrefillPlanInfo void FromVector(const std::vector& vec) { - if (vec.size() != 15) { + if (vec.size() != 20) { std::ostringstream err_msg; err_msg << "PrefillPlanInfo::FromVector: vec.size() should be 15, but got " << vec.size(); FLASHINFER_ERROR(err_msg.str()); @@ -663,6 +680,11 @@ struct PrefillPlanInfo { block_valid_mask_offset = vec[12]; enable_cuda_graph = vec[13]; split_kv = vec[14]; + kv_start_ptr_offset = vec[15]; + kv_len_ptr_offset = vec[16]; + q_start_ptr_offset = vec[17]; + q_len_ptr_offset = vec[18]; + kv_last_page_offset = vec[19]; } }; @@ -692,11 +714,17 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads; // step 2: determine kv_chunk_size + // get qo_len_ptr, which is torch.diff(qo_indptr_h) in batch prefill + std::vector qo_len_ptr, kv_len_ptr; + for (uint32_t i = 0; i < batch_size; ++i) { + qo_len_ptr.push_back(qo_indptr_h[i + 1] - qo_indptr_h[i]); + kv_len_ptr.push_back(kv_indptr_h[i + 1] - kv_indptr_h[i]); + } auto [split_kv, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec, qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] = - PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, batch_size, num_qo_heads, - num_kv_heads, head_dim_vo, page_size, max_batch_size_if_split, - enable_cuda_graph); + PrefillSplitQOKVIndptr(qo_indptr_h, qo_len_ptr.data(), kv_indptr_h, kv_len_ptr.data(), + total_num_rows, batch_size, num_qo_heads, num_kv_heads, head_dim_vo, + page_size, max_batch_size_if_split, enable_cuda_graph); plan_info.cta_tile_q = cta_tile_q; plan_info.total_num_rows = total_num_rows; @@ -1319,5 +1347,401 @@ inline cudaError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_by return cudaSuccess; } +struct PoDPlanInfo { + PrefillPlanInfo plan_info_p; + PrefillPlanInfo plan_info_d; + + // Need to keep the partition info + int64_t batch_size_vec_p; + int64_t batch_size_vec_d; + + PoDPlanInfo() : plan_info_p(), plan_info_d() {} + + // convert PoDPlanInfo to std::vector + std::vector ToVector() const { + std::vector vec_concat = plan_info_p.ToVector(); + std::vector vec_plan_info_d = plan_info_d.ToVector(); + + vec_concat.insert(vec_concat.end(), vec_plan_info_d.begin(), vec_plan_info_d.end()); + vec_concat.push_back(batch_size_vec_p); + vec_concat.push_back(batch_size_vec_d); + return vec_concat; + } + + // From std::vector to PoDPlanInfo + void FromVector(const std::vector& vec) { + if (vec.size() != 42) { + std::ostringstream err_msg; + err_msg << "PoDPlanInfo::FromVector: vec.size() should be 30, but got " << vec.size(); + FLASHINFER_ERROR(err_msg.str()); + } + std::vector vec_p(vec.begin(), vec.begin() + 20); + std::vector vec_d(vec.begin() + 20, vec.end() - 2); + plan_info_p.FromVector(vec_p); + plan_info_d.FromVector(vec_d); + + batch_size_vec_p = vec[40]; + batch_size_vec_d = vec[41]; + } + + /* + * Partition the entire working set into two distinct group. + * PoDAttention: Prefill / Decode group + * @param qo_indptr_h: akin to standard flashinfer API. [bsz+1,] + * @param kv_indptr_h: akin to standard flashinfer API. [bsz+1,] + * @param kv_last_page_len_h: akin to standard flashinfer API. [bsz,] + * @param batch_size: batch size of on-the-fly requests + * @return: partition all the metadata into two groups + */ + template + static inline auto PartitionWorkloads(const IdType* qo_indptr_h, const IdType* kv_indptr_h, + const IdType* kv_last_page_len_h, uint32_t batch_size) { + // classify each requests + // Modify this to change PoD workload partition + std::vector partition_bitmask; + std::vector qo_start_ptr_h_p, qo_start_ptr_h_d, kv_start_ptr_h_p, kv_start_ptr_h_d, + kv_last_page_len_h_p, kv_last_page_len_h_d; + std::vector qo_len_ptr_h_p, qo_len_ptr_h_d, kv_len_ptr_h_p, kv_len_ptr_h_d; + for (uint32_t i = 0; i < batch_size; ++i) { + uint32_t qo_len = qo_indptr_h[i + 1] - qo_indptr_h[i]; + uint32_t kv_page_len = kv_indptr_h[i + 1] - kv_indptr_h[i]; + if (qo_len > 8) { + // Prefill / Verify phase -> PoDOp: PREFILL=0 + partition_bitmask.push_back(false); + qo_start_ptr_h_p.push_back(qo_indptr_h[i]); + qo_len_ptr_h_p.push_back(qo_len); + kv_start_ptr_h_p.push_back(kv_indptr_h[i]); + kv_len_ptr_h_p.push_back(kv_page_len); + kv_last_page_len_h_p.push_back(kv_last_page_len_h[i]); + } else { + // Decode phase -> PoDOp: DECODE=1 + partition_bitmask.push_back(true); + qo_start_ptr_h_d.push_back(qo_indptr_h[i]); + qo_len_ptr_h_d.push_back(qo_len); + kv_start_ptr_h_d.push_back(kv_indptr_h[i]); + kv_len_ptr_h_d.push_back(kv_page_len); + kv_last_page_len_h_d.push_back(kv_last_page_len_h[i]); + } + } + // Append boundary element + // used in produce_kv_page + if (!kv_start_ptr_h_p.empty()) { + kv_start_ptr_h_p.push_back(kv_start_ptr_h_p.back() + kv_len_ptr_h_p.back()); + } + if (!kv_start_ptr_h_d.empty()) { + kv_start_ptr_h_d.push_back(kv_start_ptr_h_d.back() + kv_len_ptr_h_d.back()); + } + + return std::make_tuple(partition_bitmask, qo_start_ptr_h_p, qo_len_ptr_h_p, kv_start_ptr_h_p, + kv_len_ptr_h_p, kv_last_page_len_h_p, qo_start_ptr_h_d, qo_len_ptr_h_d, + kv_start_ptr_h_d, kv_len_ptr_h_d, kv_last_page_len_h_d); + } + + /* + * Partition the SMs into two distinct group. + * Assign SMs according to the memory_loading size + * @return: [num_sm_prefill, num_sm_decode] + */ + static inline auto PartitionSMs(const std::vector& qo_len_ptr_h_p, + const std::vector& kv_len_ptr_h_p, + const std::vector& qo_len_ptr_h_d, + const std::vector& kv_len_ptr_h_d, uint32_t page_size, + int num_sm) { + uint32_t total_len_p = + std::accumulate(qo_len_ptr_h_p.begin(), qo_len_ptr_h_p.end(), 0) + + 2 * page_size * std::accumulate(kv_len_ptr_h_p.begin(), kv_len_ptr_h_p.end(), 0); + uint32_t total_len_d = + std::accumulate(qo_len_ptr_h_d.begin(), qo_len_ptr_h_d.end(), 0) + + 2 * page_size * std::accumulate(kv_len_ptr_h_d.begin(), kv_len_ptr_h_d.end(), 0); + int num_sm_prefill = num_sm * total_len_p / (total_len_p + total_len_d); + num_sm_prefill = std::min(std::max(num_sm_prefill, 1), num_sm - 1); + int num_sm_decode = num_sm - num_sm_prefill; + return std::make_tuple(num_sm_prefill, num_sm_decode); + } +}; + +template +inline cudaError_t PoDPlan(void* float_buffer, size_t float_workspace_size_in_bytes, + void* int_buffer, void* page_locked_int_buffer, + size_t int_workspace_size_in_bytes, PoDPlanInfo& plan_info, + IdType* qo_indptr_h, IdType* kv_indptr_h, IdType* kv_last_page_len_h, + uint32_t total_num_rows, uint32_t batch_size, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t head_dim_qk, uint32_t head_dim_vo, + uint32_t page_size, bool enable_cuda_graph, uint32_t sizeof_dtype_o, + cudaStream_t stream) { + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " + << num_kv_heads; + FLASHINFER_ERROR(err_msg.str()); + } + + // Do not support CUDA graph for now + assert(enable_cuda_graph == false); + + // step 0: partition the workloads + auto [partition_bitmask, qo_start_ptr_vec_p, qo_len_ptr_vec_p, kv_start_ptr_vec_p, + kv_len_ptr_vec_p, kv_last_page_len_vec_p, qo_start_ptr_vec_d, qo_len_ptr_vec_d, + kv_start_ptr_vec_d, kv_len_ptr_vec_d, kv_last_page_len_vec_d] = + PoDPlanInfo::PartitionWorkloads(qo_indptr_h, kv_indptr_h, kv_last_page_len_h, batch_size); + + uint32_t batch_size_p = qo_start_ptr_vec_p.size(); + uint32_t batch_size_d = qo_start_ptr_vec_d.size(); + assert(batch_size_p + batch_size_d == batch_size); + assert(partition_bitmask.size() == batch_size); + plan_info.batch_size_vec_p = batch_size_p; + plan_info.batch_size_vec_d = batch_size_d; + + uint32_t total_num_rows_p = std::accumulate(qo_len_ptr_vec_p.begin(), qo_len_ptr_vec_p.end(), 0); + uint32_t total_num_rows_d = std::accumulate(qo_len_ptr_vec_d.begin(), qo_len_ptr_vec_d.end(), 0); + assert(total_num_rows_p + total_num_rows_d == total_num_rows); + + // step 1: get the number of SMs + int num_sm = 0; + int dev_id = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + int num_blocks_per_sm = 2; + auto [num_sm_p, num_sm_d] = PoDPlanInfo::PartitionSMs( + qo_len_ptr_vec_p, kv_len_ptr_vec_p, qo_len_ptr_vec_d, kv_len_ptr_vec_d, page_size, num_sm); + uint32_t max_batch_size_if_split_p = num_blocks_per_sm * num_sm_p / num_kv_heads; + uint32_t max_batch_size_if_split_d = num_blocks_per_sm * num_sm_d / num_kv_heads; + + // step 2: determine kv_chunk_size + // get qo_len_ptr, which is torch.diff(qo_indptr_h) in batch prefill + auto [split_kv_p, new_batch_size_p, padded_batch_size_p, cta_tile_q_p, kv_chunk_size_p, + request_indices_vec_p, qo_tile_indices_vec_p, kv_tile_indices_vec_p, merge_indptr_vec_p, + o_indptr_vec_p] = + PrefillSplitQOKVIndptr(qo_start_ptr_vec_p.data(), qo_len_ptr_vec_p.data(), + kv_start_ptr_vec_p.data(), kv_len_ptr_vec_p.data(), total_num_rows_p, + batch_size_p, num_qo_heads, num_kv_heads, head_dim_vo, page_size, + max_batch_size_if_split_p, enable_cuda_graph); + auto [split_kv_d, new_batch_size_d, padded_batch_size_d, cta_tile_q_d, kv_chunk_size_d, + request_indices_vec_d, qo_tile_indices_vec_d, kv_tile_indices_vec_d, merge_indptr_vec_d, + o_indptr_vec_d] = + PrefillSplitQOKVIndptr(qo_start_ptr_vec_d.data(), qo_len_ptr_vec_d.data(), + kv_start_ptr_vec_d.data(), kv_len_ptr_vec_d.data(), total_num_rows_d, + batch_size_d, num_qo_heads, num_kv_heads, head_dim_vo, page_size, + max_batch_size_if_split_d, enable_cuda_graph); + + // step 3: update o_indptr and merge_indptr + // NOTE(Yilong): only call the merge kernel once. merge_indtpr is shared + std::vector o_indptr_vec_p_tmp, o_indptr_vec_d_tmp, merge_indptr_vec; + IdType cur_o_ptr = 0; + merge_indptr_vec.push_back(cur_o_ptr); + uint32_t idx_p = 0, idx_d = 0; + for (uint32_t i = 0; i < batch_size; ++i) { + uint32_t qo_len = 0; + uint32_t num_tiles_kv = 0; + if (partition_bitmask[i]) { + // Decode + o_indptr_vec_d_tmp.push_back(cur_o_ptr); + qo_len = qo_len_ptr_vec_d[idx_d]; + num_tiles_kv = ceil_div(o_indptr_vec_d[idx_d + 1] - o_indptr_vec_d[idx_d], qo_len); + idx_d++; + } else { + // Prefill + o_indptr_vec_p_tmp.push_back(cur_o_ptr); + qo_len = qo_len_ptr_vec_p[idx_p]; + num_tiles_kv = ceil_div(o_indptr_vec_p[idx_p + 1] - o_indptr_vec_p[idx_p], qo_len); + idx_p++; + } + cur_o_ptr += num_tiles_kv * qo_len; + for (uint32_t row = 0; row < qo_len; ++row) { + merge_indptr_vec.push_back(merge_indptr_vec.back() + num_tiles_kv); + } + } + o_indptr_vec_p = std::move(o_indptr_vec_p_tmp); + o_indptr_vec_d = std::move(o_indptr_vec_d_tmp); + + // step 4: instantiate the plan_info + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + + // Prefill Plan Info + plan_info.plan_info_p.cta_tile_q = cta_tile_q_p; + plan_info.plan_info_p.total_num_rows = total_num_rows_p; + plan_info.plan_info_p.enable_cuda_graph = enable_cuda_graph; + plan_info.plan_info_p.padded_batch_size = padded_batch_size_p; + plan_info.plan_info_p.split_kv = split_kv_p; + + plan_info.plan_info_p.request_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size_p, 16, "batch_prefill_request_indices_p"); + plan_info.plan_info_p.qo_tile_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size_p, 16, "batch_prefill_qo_tile_indices_p"); + plan_info.plan_info_p.kv_tile_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size_p, 16, "batch_prefill_kv_tile_indices_p"); + plan_info.plan_info_p.o_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * batch_size_p, 16, "batch_prefill_o_indptr_p"); + plan_info.plan_info_p.kv_chunk_size_ptr_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr_p"); + plan_info.plan_info_p.kv_start_ptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * (batch_size_p + 1), 16, "batch_prefill_kv_start_ptr_p"); + plan_info.plan_info_p.kv_len_ptr_offset = int_allocator.aligned_alloc_offset( + sizeof(uint32_t) * batch_size_p, 16, "batch_prefill_kv_len_ptr_p"); + plan_info.plan_info_p.q_start_ptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * batch_size_p, 16, "batch_prefill_q_start_ptr_p"); + plan_info.plan_info_p.q_len_ptr_offset = int_allocator.aligned_alloc_offset( + sizeof(uint32_t) * batch_size_p, 16, "batch_prefill_q_len_ptr_p"); + plan_info.plan_info_p.kv_last_page_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * batch_size_p, 16, "batch_prefill_kv_last_page_len_ptr_p"); + + if (plan_info.plan_info_p.enable_cuda_graph) { + plan_info.plan_info_p.total_num_rows_offset = + int_allocator.aligned_alloc_offset(sizeof(uint32_t), 16, "batch_prefill_total_num_rows_p"); + uint32_t* total_num_rows_h = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_p.total_num_rows_offset); + *total_num_rows_h = total_num_rows_p; + } + + IdType* request_indices_h_p = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_p.request_indices_offset); + IdType* qo_tile_indices_h_p = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_p.qo_tile_indices_offset); + IdType* kv_tile_indices_h_p = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_p.kv_tile_indices_offset); + IdType* o_indptr_h_p = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.plan_info_p.o_indptr_offset); + IdType* kv_chunk_size_ptr_h_p = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_p.kv_chunk_size_ptr_offset); + IdType* kv_start_ptr_h_p = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_p.kv_start_ptr_offset); + uint32_t* kv_len_ptr_h_p = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_p.kv_len_ptr_offset); + IdType* q_start_ptr_h_p = GetPtrFromBaseOffset(page_locked_int_buffer, + plan_info.plan_info_p.q_start_ptr_offset); + uint32_t* q_len_ptr_h_p = GetPtrFromBaseOffset(page_locked_int_buffer, + plan_info.plan_info_p.q_len_ptr_offset); + IdType* kv_last_page_len_ptr_h_p = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_p.kv_last_page_offset); + + std::copy(request_indices_vec_p.begin(), request_indices_vec_p.end(), request_indices_h_p); + std::copy(qo_tile_indices_vec_p.begin(), qo_tile_indices_vec_p.end(), qo_tile_indices_h_p); + std::copy(kv_tile_indices_vec_p.begin(), kv_tile_indices_vec_p.end(), kv_tile_indices_h_p); + std::copy(o_indptr_vec_p.begin(), o_indptr_vec_p.end(), o_indptr_h_p); + kv_chunk_size_ptr_h_p[0] = kv_chunk_size_p; + std::copy(kv_start_ptr_vec_p.begin(), kv_start_ptr_vec_p.end(), kv_start_ptr_h_p); + std::copy(kv_len_ptr_vec_p.begin(), kv_len_ptr_vec_p.end(), kv_len_ptr_h_p); + std::copy(qo_start_ptr_vec_p.begin(), qo_start_ptr_vec_p.end(), q_start_ptr_h_p); + std::copy(qo_len_ptr_vec_p.begin(), qo_len_ptr_vec_p.end(), q_len_ptr_h_p); + std::copy(kv_last_page_len_vec_p.begin(), kv_last_page_len_vec_p.end(), kv_last_page_len_ptr_h_p); + + // Decode Plan Info + plan_info.plan_info_d.cta_tile_q = cta_tile_q_d; + plan_info.plan_info_d.total_num_rows = total_num_rows_d; + plan_info.plan_info_d.enable_cuda_graph = enable_cuda_graph; + plan_info.plan_info_d.padded_batch_size = padded_batch_size_d; + plan_info.plan_info_d.split_kv = split_kv_d; + + plan_info.plan_info_d.request_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size_d, 16, "batch_prefill_request_indices_d"); + plan_info.plan_info_d.qo_tile_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size_d, 16, "batch_prefill_qo_tile_indices_d"); + plan_info.plan_info_d.kv_tile_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size_d, 16, "batch_prefill_kv_tile_indices_d"); + plan_info.plan_info_d.o_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * batch_size_d, 16, "batch_prefill_o_indptr_d"); + plan_info.plan_info_d.kv_chunk_size_ptr_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr_d"); + plan_info.plan_info_d.kv_start_ptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * (batch_size_d + 1), 16, "batch_prefill_kv_start_ptr_d"); + plan_info.plan_info_d.kv_len_ptr_offset = int_allocator.aligned_alloc_offset( + sizeof(uint32_t) * batch_size_d, 16, "batch_prefill_kv_len_ptr_d"); + plan_info.plan_info_d.q_start_ptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * batch_size_d, 16, "batch_prefill_q_start_ptr_d"); + plan_info.plan_info_d.q_len_ptr_offset = int_allocator.aligned_alloc_offset( + sizeof(uint32_t) * batch_size_d, 16, "batch_prefill_q_len_ptr_d"); + plan_info.plan_info_d.kv_last_page_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * batch_size_d, 16, "batch_prefill_kv_last_page_len_ptr_d"); + + if (plan_info.plan_info_d.enable_cuda_graph) { + plan_info.plan_info_d.total_num_rows_offset = + int_allocator.aligned_alloc_offset(sizeof(uint32_t), 16, "batch_prefill_total_num_rows_d"); + uint32_t* total_num_rows_h = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_d.total_num_rows_offset); + *total_num_rows_h = total_num_rows_d; + } + + IdType* request_indices_h_d = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_d.request_indices_offset); + IdType* qo_tile_indices_h_d = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_d.qo_tile_indices_offset); + IdType* kv_tile_indices_h_d = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_d.kv_tile_indices_offset); + IdType* o_indptr_h_d = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.plan_info_d.o_indptr_offset); + IdType* kv_chunk_size_ptr_h_d = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_d.kv_chunk_size_ptr_offset); + IdType* kv_start_ptr_h_d = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_d.kv_start_ptr_offset); + uint32_t* kv_len_ptr_h_d = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_d.kv_len_ptr_offset); + IdType* q_start_ptr_h_d = GetPtrFromBaseOffset(page_locked_int_buffer, + plan_info.plan_info_d.q_start_ptr_offset); + uint32_t* q_len_ptr_h_d = GetPtrFromBaseOffset(page_locked_int_buffer, + plan_info.plan_info_d.q_len_ptr_offset); + IdType* kv_last_page_len_ptr_h_d = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_d.kv_last_page_offset); + + std::copy(request_indices_vec_d.begin(), request_indices_vec_d.end(), request_indices_h_d); + std::copy(qo_tile_indices_vec_d.begin(), qo_tile_indices_vec_d.end(), qo_tile_indices_h_d); + std::copy(kv_tile_indices_vec_d.begin(), kv_tile_indices_vec_d.end(), kv_tile_indices_h_d); + std::copy(o_indptr_vec_d.begin(), o_indptr_vec_d.end(), o_indptr_h_d); + kv_chunk_size_ptr_h_d[0] = kv_chunk_size_d; + std::copy(kv_start_ptr_vec_d.begin(), kv_start_ptr_vec_d.end(), kv_start_ptr_h_d); + std::copy(kv_len_ptr_vec_d.begin(), kv_len_ptr_vec_d.end(), kv_len_ptr_h_d); + std::copy(qo_start_ptr_vec_d.begin(), qo_start_ptr_vec_d.end(), q_start_ptr_h_d); + std::copy(qo_len_ptr_vec_d.begin(), qo_len_ptr_vec_d.end(), q_len_ptr_h_d); + std::copy(kv_last_page_len_vec_d.begin(), kv_last_page_len_vec_d.end(), kv_last_page_len_ptr_h_d); + + if (split_kv_p || split_kv_d) { + // One of the partition has split kv will incur merge kernel + // only one metata are used for both partition + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + + // shared + plan_info.plan_info_p.v_offset = float_allocator.aligned_alloc_offset( + num_qo_heads * (padded_batch_size_p * cta_tile_q_p + padded_batch_size_d * cta_tile_q_d) * + head_dim_vo * sizeof(float), + 16, "batch_prefill_tmp_v"); + plan_info.plan_info_d.v_offset = plan_info.plan_info_p.v_offset; + plan_info.plan_info_p.s_offset = float_allocator.aligned_alloc_offset( + num_qo_heads * (padded_batch_size_p * cta_tile_q_p + padded_batch_size_d * cta_tile_q_d) * + sizeof(float), + 16, "batch_prefill_tmp_s"); + plan_info.plan_info_d.s_offset = plan_info.plan_info_p.s_offset; + plan_info.plan_info_p.merge_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * (total_num_rows_p + total_num_rows_d + 1), 16, + "batch_prefill_merge_indptr"); + plan_info.plan_info_d.merge_indptr_offset = plan_info.plan_info_p.merge_indptr_offset; + + IdType* merge_indptr_h = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_p.merge_indptr_offset); + std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), merge_indptr_h); + + // Exclusive + plan_info.plan_info_p.block_valid_mask_offset = int_allocator.aligned_alloc_offset( + sizeof(bool) * padded_batch_size_p, 16, "batch_prefill_block_valid_mask_p"); + plan_info.plan_info_d.block_valid_mask_offset = int_allocator.aligned_alloc_offset( + sizeof(bool) * padded_batch_size_d, 16, "batch_prefill_block_valid_mask_d"); + + bool* block_valid_mask_h_p = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_p.block_valid_mask_offset); + bool* block_valid_mask_h_d = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.plan_info_d.block_valid_mask_offset); + for (uint32_t i = 0; i < padded_batch_size_p; ++i) { + block_valid_mask_h_p[i] = i < new_batch_size_p; + } + for (uint32_t i = 0; i < padded_batch_size_d; ++i) { + block_valid_mask_h_d[i] = i < new_batch_size_d; + } + } + + size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); + FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, + cudaMemcpyHostToDevice, stream)); + + return cudaSuccess; +} + } // namespace flashinfer #endif // FLASHINFER_ATTENTION_SCHEDULER_CUH_ diff --git a/include/flashinfer/page.cuh b/include/flashinfer/page.cuh index 1f5d328da8..30648b3269 100644 --- a/include/flashinfer/page.cuh +++ b/include/flashinfer/page.cuh @@ -58,6 +58,9 @@ struct paged_kv_t { // [batch_size] The start position of each request in the batch. IdType* rope_pos_offset; + // [batch_size, ] The page lenptr array, used when indptr is not contiguous (PoD implementation) + uint32_t* len_ptr; + /*! * \brief Construct an empty paged key-value cache */ @@ -74,7 +77,8 @@ struct paged_kv_t { indices(nullptr), indptr(nullptr), last_page_len(nullptr), - rope_pos_offset(nullptr) {} + rope_pos_offset(nullptr), + len_ptr(nullptr) {} /*! * \brief Construct a paged key-value cache @@ -101,7 +105,8 @@ struct paged_kv_t { indices(indices), indptr(indptr), last_page_len(last_page_len), - rope_pos_offset(rope_pos_offset) { + rope_pos_offset(rope_pos_offset), + len_ptr(nullptr) { stride_page = num_heads * page_size * head_dim; this->k_data = k_data; this->v_data = v_data; @@ -136,7 +141,8 @@ struct paged_kv_t { indices(indices), indptr(indptr), last_page_len(last_page_len), - rope_pos_offset(rope_pos_offset) { + rope_pos_offset(rope_pos_offset), + len_ptr(nullptr) { stride_page = kv_strides[0]; this->k_data = k_data; this->v_data = v_data; @@ -145,6 +151,9 @@ struct paged_kv_t { } __host__ __device__ __forceinline__ uint32_t get_length(uint32_t batch_idx) const { + if (len_ptr) { + return (len_ptr[batch_idx] - 1) * page_size + last_page_len[batch_idx]; + } if (indptr[batch_idx + 1] == indptr[batch_idx]) { return 0; } diff --git a/tests/jit_utils.py b/tests/jit_utils.py index 05357f0536..a9ee62b9d0 100644 --- a/tests/jit_utils.py +++ b/tests/jit_utils.py @@ -188,6 +188,27 @@ def jit_prefill_attention_func_args( ), ) ) + if q_dtype == torch.float16 and kv_dtype == torch.float16: + # not tested on 8bit + # load potential useful PoD modules + load_module_func_args.append( + ( + flashinfer.pod.gen_pod_module, + ( + "fa2", + q_dtype, + kv_dtype, + q_dtype, + torch.int32, + head_dim, # head_dim_qk + head_dim, # head_dim_vo + pos_encoding_mode, + use_sliding_window, + use_logits_soft_cap, + use_fp16_qk_reduction, + ), + ) + ) load_module_func_args.append( ( diff --git a/tests/test_pod_kernels.py b/tests/test_pod_kernels.py index 5ebfcad4cf..8c4ceece9f 100644 --- a/tests/test_pod_kernels.py +++ b/tests/test_pod_kernels.py @@ -14,12 +14,15 @@ limitations under the License. """ +import math +import random +from typing import Tuple + import pytest import torch -from jit_utils import jit_decode_attention_func_args, jit_prefill_attention_func_args +from jit_utils import jit_prefill_attention_func_args import flashinfer -from flashinfer.jit.attention.pytorch import gen_pod_module @pytest.fixture(autouse=True, scope="module") @@ -29,44 +32,15 @@ def warmup_jit(): else: try: flashinfer.jit.parallel_load_modules( - jit_decode_attention_func_args( + jit_prefill_attention_func_args( [torch.float16], # q_dtypes [torch.float16], # kv_dtypes - [128], # head_dims - [0], # pos_encoding_modes + [128, 256], # head_dims + [0, 1], # pos_encoding_modes [False], # use_sliding_windows + [False], # use_logits_soft_caps [False], # use_fp16_qk_reductions ) - + jit_prefill_attention_func_args( - [torch.float16], # q_dtypes - [ - torch.float16, - ], # kv_dtypes - [128], # head_dims - [0], # pos_encoding_modes - [False], # use_sliding_windows - [False], # use_logits_soft_cap - [False], # use_fp16_qk_reductions - ) - + [ - ( - gen_pod_module, - [ - torch.float16, # dtype_q - torch.float16, # dtype_kv - torch.float16, # dtype_o - 128, # head_dim - 0, # pos_encoding_mode_p - False, # use_sliding_window_p - False, # use_logits_soft_cap_p - False, # use_fp16_qk_reduction - torch.int32, # dtype_idx - 0, # pos_encoding_mode_d - False, # use_sliding_window_d - False, # use_logits_soft_cap_d - ], - ) - ] ) except Exception as e: # abort the test session if warmup fails @@ -75,215 +49,170 @@ def warmup_jit(): yield -@pytest.mark.parametrize("kv_len_p", [127, 12288]) -@pytest.mark.parametrize("qo_len_p", [127, 12288]) -@pytest.mark.parametrize("causal", [False, True]) -@pytest.mark.parametrize("batch_size_d", [1, 17, 127]) -@pytest.mark.parametrize("kv_len_d", [127, 12288]) -@pytest.mark.parametrize("page_size_d", [1, 16]) -@pytest.mark.parametrize("kv_layout_d", ["NHD"]) -@pytest.mark.parametrize("num_kv_heads", [8]) -@pytest.mark.parametrize("num_qo_heads", [8, 32]) -@pytest.mark.parametrize("head_dim", [128]) +def _gen_reqs(bsz: int, qo_len: int, seq_len: Tuple[int, int], stride: int): + random.seed(0) # for reproducibility + reqs = [] + for i in range(bsz): + if (i + 1) % stride == 0: + len_q = qo_len + else: + len_q = 1 + + len_kv = seq_len[0] + reqs.append((len_q, len_kv)) + return reqs + + +def _gen_metadata( + reqs, page_size, kv_layout, num_qo_heads, num_kv_heads, head_dim, device +): + torch.manual_seed(0) # for reproducibility + + total_qo_len = sum([r[0] for r in reqs]) + total_kv_len = sum([r[1] for r in reqs]) + + q = torch.randn( + total_qo_len, + num_qo_heads, + head_dim, + device=device, + dtype=torch.half, + ) + + kv_indptr_cpu = [0] + qo_indptr_cpu = [0] + kv_last_page_cpu = [] + for req in reqs: + kv_indptr_cpu.append(kv_indptr_cpu[-1] + math.ceil(req[1] / page_size)) + kv_last_page_cpu.append((req[1] - 1) % page_size + 1) + qo_indptr_cpu.append(qo_indptr_cpu[-1] + req[0]) + + kv_indices_cpu = list(range(kv_indptr_cpu[-1])) + kv_indices_cpu.extend([0] * 256) + + kv_indptr_cpu = torch.tensor(kv_indptr_cpu, dtype=torch.int32, device="cpu") + kv_indices_cpu = torch.tensor(kv_indices_cpu, dtype=torch.int32, device="cpu") + kv_last_page_cpu = torch.tensor(kv_last_page_cpu, dtype=torch.int32, device="cpu") + qo_indptr_cpu = torch.tensor(qo_indptr_cpu, dtype=torch.int32, device="cpu") + + if kv_layout == "HND": + kv_data = torch.randn( + len(kv_indices_cpu), + 2, + num_kv_heads, + page_size, + head_dim, + device=device, + dtype=torch.float32, + ) + else: + kv_data = torch.randn( + len(kv_indices_cpu), + 2, + page_size, + num_kv_heads, + head_dim, + device=device, + dtype=torch.float32, + ) + + return q, kv_data, kv_indptr_cpu, kv_indices_cpu, kv_last_page_cpu, qo_indptr_cpu + + +@pytest.mark.parametrize("batch_size", [12, 17, 64]) +@pytest.mark.parametrize("kv_len", [54, 511, 2042, 8911]) +@pytest.mark.parametrize("qo_len", [17, 47, 127, 577]) +@pytest.mark.parametrize("stride", [1, 2, 5, 1024]) +@pytest.mark.parametrize("page_size", [1, 16]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [4, 28]) +@pytest.mark.parametrize("head_dim", [64, 128]) +@pytest.mark.parametrize("causal", [True]) +@pytest.mark.parametrize("kv_layout", ["NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE"]) -@pytest.mark.parametrize("q_dtype", [torch.float16]) -@pytest.mark.parametrize("kv_dtype", [torch.float16]) +@pytest.mark.parametrize("use_cuda_graph", [False]) +@pytest.mark.parametrize("logits_soft_cap", [0.0]) +@pytest.mark.parametrize("return_lse", [False]) @pytest.mark.parametrize("contiguous_kv", [True]) def test_pod_with_paged_kv_cache( - # Prefill params - kv_len_p, - qo_len_p, - causal, - # Decode params - batch_size_d, - kv_len_d, - page_size_d, - kv_layout_d, - # Shared params + batch_size, + kv_len, + qo_len, + stride, + page_size, num_kv_heads, num_qo_heads, head_dim, + causal, + kv_layout, pos_encoding_mode, - q_dtype, - kv_dtype, + use_cuda_graph, + logits_soft_cap, + return_lse, contiguous_kv, ): - if causal and qo_len_p > kv_len_p: - pytest.skip("Causal prefill with qo_len_p > kv_len_p is not supported") - return_lse = False - # Prefill inputs - kv_layout_p = "NHD" - q_p = torch.randn( - qo_len_p, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16 - ) - k_p = torch.randn( - kv_len_p, num_kv_heads, head_dim, device="cuda:0", dtype=torch.float16 - ) - v_p = torch.randn( - kv_len_p, num_kv_heads, head_dim, device="cuda:0", dtype=torch.float16 - ) - # Generate prefill reference output - o_ref_p = flashinfer.prefill.single_prefill_with_kv_cache( - q_p, - k_p, - v_p, - causal=causal, - pos_encoding_mode=pos_encoding_mode, - ) - # Decode inputs - q_d = torch.randn( - batch_size_d, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16 - ) - num_pages_per_seq = (kv_len_d + page_size_d - 1) // page_size_d - total_num_pages = num_pages_per_seq * batch_size_d - if kv_layout_d == "HND": - kv_shape = [total_num_pages, 2, num_kv_heads, page_size_d, head_dim] - else: - kv_shape = [total_num_pages, 2, page_size_d, num_kv_heads, head_dim] - if not contiguous_kv: - tmp = [kv_shape[0]] - for v_d in kv_shape[1:]: - tmp.append(2) - tmp.append(v_d) - kv_shape = tmp - kv_data_fp32 = torch.randn(*kv_shape, device="cuda:0", dtype=torch.float32) - kv_data = kv_data_fp32.to(kv_dtype) - kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :] - kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :] - # actual data is stored in non-contiguous memory - assert ( - kv_data.stride(-4) - != kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1] + if qo_len > kv_len and causal: + pytest.skip("qo_len > kv_len and causal is not supported") + + reqs = _gen_reqs(batch_size, qo_len, (kv_len, kv_len + 128), stride) + ( + q, + kv_data_fp32, + kv_indptr_cpu, + kv_indices_cpu, + kv_last_page_len_cpu, + q_indptr_cpu, + ) = _gen_metadata( + reqs, page_size, kv_layout, num_qo_heads, num_kv_heads, head_dim, "cuda:0" + ) + kv_data = kv_data_fp32.half() + + workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + if not use_cuda_graph: + q_indptr_gpu = q_indptr_cpu.to(0) + kv_indptr_gpu = kv_indptr_cpu.to(0) + kv_indices_gpu = kv_indices_cpu.to(0) + kv_last_page_len_gpu = kv_last_page_len_cpu.to(0) + wrapper = flashinfer.PODWithPagedKVCacheWrapper(workspace_buffer, kv_layout) + wrapper.plan( + q_indptr_gpu, + kv_indptr_gpu, + kv_indices_gpu, + kv_last_page_len_gpu, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + causal=causal, + pos_encoding_mode=pos_encoding_mode, + logits_soft_cap=logits_soft_cap, ) - else: - kv_data_fp32 = torch.randn(*kv_shape, device="cuda:0", dtype=torch.float32) - kv_data = kv_data_fp32.to(kv_dtype) - kv_indptr_d = ( - torch.arange(0, batch_size_d + 1, device="cuda:0", dtype=torch.int32) - * num_pages_per_seq - ) - kv_indices_d = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) - kv_last_page_len = torch.full( - (batch_size_d,), - (kv_len_d - 1) % page_size_d + 1, - device="cuda:0", - dtype=torch.int32, - ) + if return_lse: + o, _ = wrapper.run(q, kv_data, return_lse=True) + else: + o = wrapper.run(q, kv_data) - # Generate decode reference output - decode_workspace_buffer = torch.empty( - 32 * 1024 * 1024, device="cuda:0", dtype=torch.int8 - ) - decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( - decode_workspace_buffer, kv_layout_d - ) - decode_wrapper.plan( - kv_indptr_d, - kv_indices_d, - kv_last_page_len, - num_qo_heads, - num_kv_heads, - head_dim, - page_size_d, - pos_encoding_mode=pos_encoding_mode, - data_type=kv_dtype, - q_data_type=q_dtype, - ) - o_ref_d = decode_wrapper.run(q_d, kv_data) + # test with pre-allocated output + o_buffer = torch.empty_like(o) + wrapper.run(q, kv_data, out=o_buffer) + torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3) - workspace_buffer = torch.empty(32 * 1024 * 1024, device="cuda:0", dtype=torch.int8) - pod_wrapper = flashinfer.PODWithPagedKVCacheWrapper( - workspace_buffer, - kv_layout_d, - ) - pod_wrapper.plan( - kv_indptr_d, - kv_indices_d, - kv_last_page_len, + wrapper_ref = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, kv_layout, backend="fa2" + ) + wrapper_ref.plan( + q_indptr_gpu, + kv_indptr_gpu, + kv_indices_gpu, + kv_last_page_len_gpu, num_qo_heads, num_kv_heads, head_dim, - page_size_d, + page_size, + causal=causal, pos_encoding_mode=pos_encoding_mode, - data_type=kv_dtype, - q_data_type=q_dtype, - ) - - o_p, o_d = pod_wrapper.run( - q_p, - k_p, - v_p, - q_d, - kv_data, - pos_encoding_mode_p=pos_encoding_mode, - causal_p=causal, - ) - # Prefill is run with batch size 1 - torch.testing.assert_close( - o_p, o_ref_p, rtol=1e-3, atol=1e-3, msg="Prefill mismatch" + logits_soft_cap=logits_soft_cap, ) - # Decode uses all batches at once. - torch.testing.assert_close( - o_d, o_ref_d, rtol=1e-3, atol=1e-3, msg="Decode mismatch" - ) - + o_ref = wrapper_ref.run(q, kv_data) -if __name__ == "__main__": - test_pod_with_paged_kv_cache( - # Prefill params - 128, - 128, - True, - # Decode params - 80, - 12288, - 16, - "NHD", - # Other shared params - 8, - 8, - 128, - "NONE", - torch.float16, - torch.float16, - True, - ) - test_pod_with_paged_kv_cache( - # Prefill params - 12288, - 12288, - True, - # Decode params - 220, - 12288, - 16, - "NHD", - # Other shared params - 4, - 16, - 128, - "NONE", - torch.float16, - torch.float16, - True, - ) - test_pod_with_paged_kv_cache( - # Prefill params - 16384, - 16384, - True, - # Decode params - 250, - 12288, - 16, - "NHD", - # Other shared params - 4, - 16, - 128, - "NONE", - torch.float16, - torch.float16, - True, - ) - print("POD test(s) passed!") + torch.testing.assert_close(o_ref, o, rtol=1e-3, atol=1e-3)