|
| 1 | +/* |
| 2 | + * Copyright (c) 2025 by FlashInfer team. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | +#include <flashinfer/attention/mask.cuh> |
| 17 | +#include <flashinfer/attention/scheduler.cuh> |
| 18 | +#include <flashinfer/layout.cuh> |
| 19 | +#include <flashinfer/pos_enc.cuh> |
| 20 | +#include <optional> |
| 21 | + |
| 22 | +#include "batch_attention_config.inc" |
| 23 | +#include "pytorch_conversion_utils.h" |
| 24 | +#include "pytorch_extension_utils.h" |
| 25 | + |
| 26 | +namespace flashinfer { |
| 27 | + |
| 28 | +template <uint32_t CTA_TILE_Q_1, uint32_t CTA_TILE_Q_2, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, |
| 29 | + MaskMode MASK_MODE, typename AttentionVariant, typename Params> |
| 30 | +cudaError_t BatchPagedAttentionPersistent(const Params params_1, const Params params_2, |
| 31 | + const uint32_t num_blks_x, const uint32_t num_blks_y, |
| 32 | + const cudaStream_t stream); |
| 33 | +} // namespace flashinfer |
| 34 | + |
| 35 | +using namespace flashinfer; |
| 36 | + |
| 37 | +at::Tensor BatchPagedAttentionPlan(at::Tensor float_workspace_buffer, |
| 38 | + at::Tensor int_workspace_buffer, |
| 39 | + at::Tensor page_locked_int_workspace_buffer, |
| 40 | + at::Tensor qo_indptr, at::Tensor kv_indptr, at::Tensor kv_len, |
| 41 | + int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, |
| 42 | + int64_t head_dim_o, bool causal) { |
| 43 | + size_t float_workspace_size_in_bytes = |
| 44 | + float_workspace_buffer.size(0) * float_workspace_buffer.element_size(); |
| 45 | + size_t int_workspace_size_in_bytes = |
| 46 | + int_workspace_buffer.size(0) * int_workspace_buffer.element_size(); |
| 47 | + |
| 48 | + HolisticPlanInfo<2> plan_info; |
| 49 | + |
| 50 | + const c10::cuda::OptionalCUDAGuard device_guard(float_workspace_buffer.device()); |
| 51 | + const cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); |
| 52 | + |
| 53 | + cudaError_t status = TwoStageHolisticPlan<IdType>( |
| 54 | + float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, |
| 55 | + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), |
| 56 | + int_workspace_size_in_bytes, plan_info, qo_indptr.data_ptr<IdType>(), |
| 57 | + kv_indptr.data_ptr<IdType>(), kv_len.data_ptr<IdType>(), batch_size, num_qo_heads, |
| 58 | + num_kv_heads, head_dim_o, causal, stream); |
| 59 | + |
| 60 | + TORCH_CHECK(status == cudaSuccess, |
| 61 | + "Failed to plan persistent paged attention, error: ", cudaGetErrorString(status)); |
| 62 | + |
| 63 | + return vec_to_tensor(plan_info.ToVector()); |
| 64 | +} |
| 65 | + |
| 66 | +void BatchPagedAttentionRun(at::Tensor float_workspace_buffer, at::Tensor int_workspace_buffer, |
| 67 | + at::Tensor plan_info_vec, at::Tensor q, at::Tensor k_cache, |
| 68 | + at::Tensor v_cache, at::Tensor kv_indices, at::Tensor o, |
| 69 | + std::optional<at::Tensor> maybe_lse, int64_t mask_mode_code, |
| 70 | + int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads, |
| 71 | + int64_t page_size, double sm_scale ADDITIONAL_FUNC_PARAMS) { |
| 72 | + HolisticPlanInfo<2> plan_info; |
| 73 | + plan_info.FromVector(tensor_to_vec(plan_info_vec)); |
| 74 | + |
| 75 | + auto device = q.device(); |
| 76 | + |
| 77 | + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); |
| 78 | + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); |
| 79 | + |
| 80 | + const MaskMode mask_mode = static_cast<MaskMode>(mask_mode_code); |
| 81 | + |
| 82 | + auto q_scalar_type = q.scalar_type(); |
| 83 | + auto kv_scalar_type = k_cache.scalar_type(); |
| 84 | + |
| 85 | + // NOTE (Yilong): assume both q and o are NHD |
| 86 | + unsigned int q_stride_n = q.stride(0); |
| 87 | + unsigned int q_stride_h = q.stride(1); |
| 88 | + |
| 89 | + // layout only constraint paged KV |
| 90 | + const QKVLayout kv_layout = static_cast<QKVLayout>(layout_code); |
| 91 | + unsigned int k_stride_page = k_cache.stride(0); |
| 92 | + unsigned int v_stride_page = v_cache.stride(0); |
| 93 | + unsigned int k_stride_n, k_stride_h, v_stride_n, v_stride_h; |
| 94 | + if (kv_layout == QKVLayout::kNHD) { |
| 95 | + k_stride_h = k_cache.stride(2); |
| 96 | + k_stride_n = k_cache.stride(1); |
| 97 | + v_stride_h = v_cache.stride(2); |
| 98 | + v_stride_n = v_cache.stride(1); |
| 99 | + } else { |
| 100 | + k_stride_h = k_cache.stride(1); |
| 101 | + k_stride_n = k_cache.stride(2); |
| 102 | + v_stride_h = v_cache.stride(1); |
| 103 | + v_stride_n = v_cache.stride(2); |
| 104 | + } |
| 105 | + |
| 106 | + const c10::cuda::OptionalCUDAGuard device_guard(device); |
| 107 | + const cudaStream_t stream = c10::cuda::getCurrentCUDAStream(); |
| 108 | + |
| 109 | + DISPATCH_context( |
| 110 | + DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, |
| 111 | + AttentionVariant, PersistentParams, [&] { |
| 112 | + PersistentParams params[2]; |
| 113 | + |
| 114 | + for (int i = 0; i < 2; i++) { |
| 115 | + params[i].q = static_cast<DTypeQ*>(q.data_ptr()); |
| 116 | + params[i].k = static_cast<DTypeKV*>(k_cache.data_ptr()); |
| 117 | + params[i].v = static_cast<DTypeKV*>(v_cache.data_ptr()); |
| 118 | + |
| 119 | + params[i].q_indptr = |
| 120 | + GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].q_indptr_offset); |
| 121 | + params[i].kv_indptr = |
| 122 | + GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_indptr_offset); |
| 123 | + params[i].partial_indptr = GetPtrFromBaseOffset<IdType>( |
| 124 | + int_buffer_ptr, plan_info.tasks[i].partial_indptr_offset); |
| 125 | + params[i].kv_indices = static_cast<int*>(kv_indices.data_ptr()); |
| 126 | + params[i].q_len = |
| 127 | + GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].q_len_offset); |
| 128 | + params[i].kv_len = |
| 129 | + GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_len_offset); |
| 130 | + params[i].q_start = |
| 131 | + GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].q_start_offset); |
| 132 | + params[i].kv_start = |
| 133 | + GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_start_offset); |
| 134 | + params[i].kv_end = |
| 135 | + GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_end_offset); |
| 136 | + params[i].kv_head_idx_arr = |
| 137 | + GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].kv_head_idx_offset); |
| 138 | + params[i].work_indptr = |
| 139 | + GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].work_indptr_offset); |
| 140 | + params[i].len_kv_chunk = |
| 141 | + GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.tasks[i].len_kv_chunk_offset); |
| 142 | + |
| 143 | + params[i].final_o = static_cast<DTypeO*>(o.data_ptr()); |
| 144 | + params[i].final_lse = |
| 145 | + maybe_lse.has_value() ? static_cast<float*>(maybe_lse->data_ptr()) : nullptr; |
| 146 | + params[i].partial_o = |
| 147 | + GetPtrFromBaseOffset<DTypeO>(float_buffer_ptr, plan_info.partial_o_offset); |
| 148 | + params[i].partial_lse = |
| 149 | + GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_lse_offset); |
| 150 | + |
| 151 | + // for state reduction |
| 152 | + params[i].merge_indptr = |
| 153 | + GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_indptr_offset); |
| 154 | + params[i].merge_o_indices = |
| 155 | + GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_o_indices_offset); |
| 156 | + params[i].num_packed_qo_len = |
| 157 | + GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.num_qo_len_offset); |
| 158 | + |
| 159 | + params[i].num_kv_heads = num_kv_heads; |
| 160 | + params[i].gqa_group_size = uint_fastdiv(num_qo_heads / num_kv_heads); |
| 161 | + params[i].page_size = uint_fastdiv(page_size); |
| 162 | + |
| 163 | + params[i].q_stride_n = q_stride_n; |
| 164 | + params[i].q_stride_h = q_stride_h; |
| 165 | + params[i].k_stride_page = k_stride_page; |
| 166 | + params[i].k_stride_h = k_stride_h; |
| 167 | + params[i].k_stride_n = k_stride_n; |
| 168 | + params[i].v_stride_page = v_stride_page; |
| 169 | + params[i].v_stride_h = v_stride_h; |
| 170 | + params[i].v_stride_n = v_stride_n; |
| 171 | + |
| 172 | + params[i].sm_scale = sm_scale; |
| 173 | + |
| 174 | + ADDITIONAL_PARAMS_SETTER |
| 175 | + } |
| 176 | + |
| 177 | + cudaError_t status = BatchPagedAttentionPersistent<128, 16, HEAD_DIM_QK, HEAD_DIM_VO, |
| 178 | + MASK_MODE, AttentionVariant>( |
| 179 | + params[0], params[1], plan_info.num_blks_x, plan_info.num_blks_y, stream); |
| 180 | + TORCH_CHECK(status == cudaSuccess, "Failed to run persistent paged attention, error: ", |
| 181 | + cudaGetErrorString(status)); |
| 182 | + return true; |
| 183 | + }); |
| 184 | +} |
0 commit comments