Skip to content

Commit

Permalink
refactor: remove page_size from template parameters for prefill ker…
Browse files Browse the repository at this point in the history
…nels (#306)

Similar to #301 , in this PR we remove `page_size` from template
parameters so that we can support any `page_size` for prefill kernels
(previously we only support something like 1,4,8,16), as well as reduce
binary size and accelerate compilation time.
  • Loading branch information
yzh119 authored Jun 15, 2024
1 parent 955dfc5 commit 82fd8c7
Show file tree
Hide file tree
Showing 19 changed files with 146 additions and 241 deletions.
57 changes: 26 additions & 31 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ flashinfer_option(FLASHINFER_TVM_SOURCE_DIR "The path to tvm for building tvm bi

# The following configurations can impact the binary
# size of the generated library
flashinfer_option(FLASHINFER_GEN_PAGE_SIZES "Prefill page sizes to enable" 1 16 32)
flashinfer_option(FLASHINFER_GEN_HEAD_DIMS "Head dims to enable" 64 128 256)
flashinfer_option(FLASHINFER_GEN_KV_LAYOUTS "KV layouts to enable" 0 1)
flashinfer_option(FLASHINFER_GEN_LOGITS_POST_HOOKS "Logits post hooks" 0 1)
Expand Down Expand Up @@ -80,7 +79,6 @@ if(FLASHINFER_ENABLE_BF16)
endif(FLASHINFER_ENABLE_BF16)

# generate kernel inst
set (PAGE_SIZES ${FLASHINFER_GEN_PAGE_SIZES})
set (HEAD_DIMS ${FLASHINFER_GEN_HEAD_DIMS})
set (LOGITS_POST_HOOKS ${FLASHINFER_GEN_LOGITS_POST_HOOKS})
set (KV_LAYOUTS ${FLASHINFER_GEN_KV_LAYOUTS})
Expand All @@ -103,7 +101,6 @@ if(FLASHINFER_ENABLE_BF16)
endif(FLASHINFER_ENABLE_BF16)

# log options
message(STATUS "FLASHINFER_PAGE_SIZES=${PAGE_SIZES}")
message(STATUS "FLASHINFER_HEAD_DIMS=${HEAD_DIMS}")
message(STATUS "FLASHINFER_KV_LAYOUTS=${KV_LAYOUTS}")
message(STATUS "FLASHINFER_POS_ENCODING_MODES=${POS_ENCODING_MODES}")
Expand All @@ -115,7 +112,7 @@ file(MAKE_DIRECTORY ${PROJECT_SOURCE_DIR}/src/generated)
set(dispatch_inc_file ${PROJECT_SOURCE_DIR}/src/dispatch.inc)
add_custom_command(
OUTPUT ${dispatch_inc_file}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --page_sizes ${FLASHINFER_GEN_PAGE_SIZES} --logits_post_hooks ${LOGITS_POST_HOOKS} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py --path ${PROJECT_SOURCE_DIR}/src/dispatch.inc --head_dims ${HEAD_DIMS} --logits_post_hooks ${LOGITS_POST_HOOKS} --kv_layouts ${KV_LAYOUTS} --pos_encoding_modes ${POS_ENCODING_MODES} --allow_fp16_qk_reductions ${ALLOW_FP16_QK_REDUCTIONS} --mask_modes ${MASK_MODES}
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_dispatch_inc.py
COMMENT "Generating additional source file ${generated_dispatch_inc}"
VERBATIM
Expand Down Expand Up @@ -249,33 +246,31 @@ foreach(head_dim IN LISTS HEAD_DIMS)
endforeach(head_dim)

# batch paged prefill kernel inst generation
foreach(page_size IN LISTS PAGE_SIZES)
foreach(head_dim IN LISTS HEAD_DIMS)
foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS)
foreach(kv_layout IN LISTS KV_LAYOUTS)
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
foreach(mask_mode IN LISTS MASK_MODES)
foreach(dtype IN LISTS PREFILL_DTYPES)
foreach(idtype IN LISTS IDTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_page_${page_size}_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src})
endforeach(idtype)
endforeach(dtype)
endforeach(mask_mode)
endforeach(allow_fp16_qk_reduction)
endforeach(pos_encoding_mode)
endforeach(kv_layout)
endforeach(logits_post_hook)
endforeach(head_dim)
endforeach(page_size)
foreach(head_dim IN LISTS HEAD_DIMS)
foreach(logits_post_hook IN LISTS LOGITS_POST_HOOKS)
foreach(kv_layout IN LISTS KV_LAYOUTS)
foreach(pos_encoding_mode IN LISTS POS_ENCODING_MODES)
foreach(allow_fp16_qk_reduction IN LISTS ALLOW_FP16_QK_REDUCTIONS)
foreach(mask_mode IN LISTS MASK_MODES)
foreach(dtype IN LISTS PREFILL_DTYPES)
foreach(idtype IN LISTS IDTYPES)
set(generated_kernel_src ${PROJECT_SOURCE_DIR}/src/generated/batch_paged_prefill_head_${head_dim}_logitshook_${logits_post_hook}_layout_${kv_layout}_posenc_${pos_encoding_mode}_fp16qkred_${allow_fp16_qk_reduction}_mask_${mask_mode}_dtypein_${dtype}_dtypeout_${dtype}_idtype_${idtype}.cu)
add_custom_command(
OUTPUT ${generated_kernel_src}
COMMAND ${Python3_EXECUTABLE} ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py ${generated_kernel_src}
DEPENDS ${PROJECT_SOURCE_DIR}/python/generate_batch_paged_prefill_inst.py
COMMENT "Generating additional source file ${generated_kernel_src}"
VERBATIM
)
list(APPEND batch_paged_prefill_kernels_src ${generated_kernel_src})
endforeach(idtype)
endforeach(dtype)
endforeach(mask_mode)
endforeach(allow_fp16_qk_reduction)
endforeach(pos_encoding_mode)
endforeach(kv_layout)
endforeach(logits_post_hook)
endforeach(head_dim)

# batch ragged prefill kernel inst generation
foreach(head_dim IN LISTS HEAD_DIMS)
Expand Down
1 change: 0 additions & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ set(FLASHINFER_DISTRIBUTED ON)
# The following configurations can impact the binary
# size of the generated library
set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0)
set(FLASHINFER_GEN_PAGE_SIZES 1 16 32)
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2)
Expand Down
20 changes: 9 additions & 11 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -601,9 +601,10 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
static_assert(num_stages_smem <= bdx);
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = paged_kv.protective_get_k_ptr(
cur_page_indptr_begin + (((j * bdz + tz) * bdy + ty) * bdx + tx) / paged_kv.page_size,
kv_head_idx, (((j * bdz + tz) * bdy + ty) * bdx + tx) % paged_kv.page_size, 0, last_indptr);
uint32_t q, r;
paged_kv.page_size.divmod(((j * bdz + tz) * bdy + ty) * bdx + tx, q, r);
k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] =
paged_kv.protective_get_k_ptr(cur_page_indptr_begin + q, kv_head_idx, r, 0, last_indptr);
}
block.sync();

Expand Down Expand Up @@ -643,15 +644,12 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
if ((iter + num_stages_smem) % bdx == 0) {
#pragma unroll
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
uint32_t q, r;
paged_kv.page_size.divmod(((iter + num_stages_smem) * tile_size_per_bdx * bdy * bdz +
((j * bdz + tz) * bdy + ty) * bdx + tx),
q, r);
k_ptrs_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = paged_kv.protective_get_k_ptr(
cur_page_indptr_begin + ((iter + num_stages_smem) * tile_size_per_bdx * bdy * bdz +
((j * bdz + tz) * bdy + ty) * bdx + tx) /
paged_kv.page_size,
kv_head_idx,
((iter + num_stages_smem) * tile_size_per_bdx * bdy * bdz +
((j * bdz + tz) * bdy + ty) * bdx + tx) %
paged_kv.page_size,
0, last_indptr);
cur_page_indptr_begin + q, kv_head_idx, r, 0, last_indptr);
}
}
// compute qk
Expand Down
2 changes: 1 addition & 1 deletion include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&num_blocks_per_sm, partition_kv_kernel, num_threads, smem_size));
max_grid_size = num_blocks_per_sm * num_sm;
if (batch_size * num_kv_heads >= num_sm) {
if (batch_size * num_kv_heads >= max_grid_size) {
tmp_size = 0;
new_batch_size = batch_size;
} else {
Expand Down
109 changes: 41 additions & 68 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
#endif
#include <cuda_runtime.h>

#include <optional>
#include <tuple>

#include "../cp_async.cuh"
#include "../fastdiv.cuh"
#include "../layout.cuh"
Expand Down Expand Up @@ -175,65 +172,41 @@ __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* smem_offset, T
*smem_offset -= num_frags_z * 16 * channel_size_128b_in;
}

template <bool produce_v, uint32_t page_size, uint32_t num_warps, uint32_t num_frags_y,
uint32_t num_frags_z, PageStorage page_storage, QKVLayout kv_layout, typename DType,
typename IdType>
template <bool produce_v, uint32_t num_warps, uint32_t num_frags_y, uint32_t num_frags_z,
PageStorage page_storage, QKVLayout kv_layout, typename DType, typename IdType>
__device__ __forceinline__ void page_produce_kv(
smem_t smem, uint32_t* smem_offset,
paged_kv_t<page_storage, kv_layout, DType, IdType>& paged_kv, const uint32_t kv_idx_base,
const uint32_t page_iter_base, const uint32_t kv_len, const IdType last_indptr) {
const uint32_t packed_page_iter_base, const uint32_t kv_len, const IdType last_indptr) {
constexpr SharedMemFillMode fill_mode =
produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill;
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t channel_size_128b_in = head_dim / num_elems_per_128b<DType>();
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
const uint32_t kv_head_idx = blockIdx.z;
uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8;
if constexpr (page_size % 4 == 0) {
#pragma unroll
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps; ++i) {
const uint32_t page_iter = page_iter_base + (4 * num_warps * i + ty * 4) / page_size;
const uint32_t entry_idx = (4 * num_warps * i + ty * 4) % page_size + tx / 8;
DType* gptr =
produce_v
? paged_kv.protective_get_v_ptr(page_iter, kv_head_idx, entry_idx,
(tx % 8) * num_elems_per_128b<DType>(), last_indptr)
: paged_kv.protective_get_k_ptr(page_iter, kv_head_idx, entry_idx,
(tx % 8) * num_elems_per_128b<DType>(), last_indptr);
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
*smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j);
gptr += 8 * num_elems_per_128b<DType>();
}
kv_idx += num_warps * 4;
*smem_offset = smem.advance_offset_by_row<num_warps * 4, channel_size_128b_in>(*smem_offset) -
2 * num_frags_y;
}
*smem_offset -= num_frags_z * 16 * channel_size_128b_in;
} else {
#pragma unroll
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps; ++i) {
const uint32_t page_iter = page_iter_base + (4 * num_warps * i + ty * 4 + tx / 8) / page_size;
const uint32_t entry_idx = (4 * num_warps * i + ty * 4 + tx / 8) % page_size;
DType* gptr =
produce_v
? paged_kv.protective_get_v_ptr(page_iter, kv_head_idx, entry_idx,
(tx % 8) * num_elems_per_128b<DType>(), last_indptr)
: paged_kv.protective_get_k_ptr(page_iter, kv_head_idx, entry_idx,
(tx % 8) * num_elems_per_128b<DType>(), last_indptr);
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
*smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j);
gptr += 8 * num_elems_per_128b<DType>();
}
kv_idx += num_warps * 4;
*smem_offset = smem.advance_offset_by_row<num_warps * 4, channel_size_128b_in>(*smem_offset) -
2 * num_frags_y;
for (uint32_t i = 0; i < num_frags_z * 4 / num_warps; ++i) {
uint32_t page_iter, entry_idx;
paged_kv.page_size.divmod(packed_page_iter_base + ty * 4 + tx / 8 + 4 * num_warps * i,
page_iter, entry_idx);
DType* gptr =
produce_v
? paged_kv.protective_get_v_ptr(page_iter, kv_head_idx, entry_idx,
(tx % 8) * num_elems_per_128b<DType>(), last_indptr)
: paged_kv.protective_get_k_ptr(page_iter, kv_head_idx, entry_idx,
(tx % 8) * num_elems_per_128b<DType>(), last_indptr);
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, gptr, kv_idx < kv_len);
*smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j);
gptr += 8 * num_elems_per_128b<DType>();
}
*smem_offset -= num_frags_z * 16 * channel_size_128b_in;
kv_idx += num_warps * 4;
*smem_offset = smem.advance_offset_by_row<num_warps * 4, channel_size_128b_in>(*smem_offset) -
2 * num_frags_y;
}
*smem_offset -= num_frags_z * 16 * channel_size_128b_in;
}

template <uint32_t num_frags_y>
Expand Down Expand Up @@ -1342,10 +1315,10 @@ __global__ void BatchPrefillWithRaggedKVCacheKernel(
}
}

template <LogitsPostHook logits_post_hook, uint32_t page_size, MaskMode mask_mode,
PosEncodingMode pos_encoding_mode, uint32_t num_frags_x, uint32_t num_frags_y,
uint32_t num_frags_z, uint32_t num_warps, PageStorage page_storage, QKVLayout kv_layout,
typename DTypeIn, typename DTypeQKAccum, typename DTypeOut, typename IdType>
template <LogitsPostHook logits_post_hook, MaskMode mask_mode, PosEncodingMode pos_encoding_mode,
uint32_t num_frags_x, uint32_t num_frags_y, uint32_t num_frags_z, uint32_t num_warps,
PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeQKAccum,
typename DTypeOut, typename IdType>
__global__ void BatchPrefillWithPagedKVCacheKernel(
IdType* __restrict__ request_indices, IdType* __restrict__ tile_indices,
DTypeIn* __restrict__ q, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
Expand Down Expand Up @@ -1448,12 +1421,12 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
smem_t::get_permuted_offset<channel_size_128b_in>(ty * 4 + tx / 8, tx % 8);
const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size];

uint32_t page_iter_base = paged_kv.indptr[request_idx];
page_produce_kv<false, page_size, num_warps, num_frags_y, num_frags_z>(
k_smem, &kv_smem_offset_w, paged_kv, 0, page_iter_base, kv_len, last_indptr);
uint32_t packed_page_iter_base = paged_kv.indptr[request_idx] * paged_kv.page_size;
page_produce_kv<false, num_warps, num_frags_y, num_frags_z>(
k_smem, &kv_smem_offset_w, paged_kv, 0, packed_page_iter_base, kv_len, last_indptr);
cp_async::commit_group();
page_produce_kv<true, page_size, num_warps, num_frags_y, num_frags_z>(
v_smem, &kv_smem_offset_w, paged_kv, 0, page_iter_base, kv_len, last_indptr);
page_produce_kv<true, num_warps, num_frags_y, num_frags_z>(
v_smem, &kv_smem_offset_w, paged_kv, 0, packed_page_iter_base, kv_len, last_indptr);
cp_async::commit_group();

const uint32_t num_iterations = ceil_div(
Expand Down Expand Up @@ -1508,10 +1481,10 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(s_frag, o_frag, m, d);

block.sync();
page_iter_base += 16 * num_frags_z / page_size;
page_produce_kv<false, page_size, num_warps, num_frags_y, num_frags_z>(
k_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * num_frags_z, page_iter_base, kv_len,
last_indptr);
packed_page_iter_base += 16 * num_frags_z;
page_produce_kv<false, num_warps, num_frags_y, num_frags_z>(
k_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * num_frags_z, packed_page_iter_base,
kv_len, last_indptr);
cp_async::commit_group();
cp_async::wait_group<1>();
block.sync();
Expand All @@ -1521,9 +1494,9 @@ __global__ void BatchPrefillWithPagedKVCacheKernel(
o_frag, d);

block.sync();
page_produce_kv<true, page_size, num_warps, num_frags_y, num_frags_z>(
v_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * num_frags_z, page_iter_base, kv_len,
last_indptr);
page_produce_kv<true, num_warps, num_frags_y, num_frags_z>(
v_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * num_frags_z, packed_page_iter_base,
kv_len, last_indptr);
cp_async::commit_group();
}
cp_async::wait_group<0>();
Expand Down Expand Up @@ -1776,7 +1749,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
return cudaSuccess;
}

template <PageStorage page_storage, uint32_t num_frags_x, uint32_t PAGE_SIZE, uint32_t HEAD_DIM,
template <PageStorage page_storage, uint32_t num_frags_x, uint32_t HEAD_DIM,
LogitsPostHook LOGITS_POST_HOOK, QKVLayout kv_layout, PosEncodingMode pos_encoding_mode,
bool ALLOW_FP16_QK_REDUCTION, MaskMode MASK_MODE, typename DTypeIn, typename DTypeOut,
typename IdType>
Expand Down Expand Up @@ -1831,8 +1804,8 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(
throw std::invalid_argument(err_msg.str());
} else {
auto kernel = BatchPrefillWithPagedKVCacheKernel<
LOGITS_POST_HOOK, PAGE_SIZE, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y,
num_frags_z, num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>;
LOGITS_POST_HOOK, MASK_MODE, pos_encoding_mode, num_frags_x, num_frags_y, num_frags_z,
num_warps, page_storage, kv_layout, DTypeIn, DTypeQKAccum, DTypeOut, IdType>;
uint32_t smem_size =
(num_frags_x * num_warps + num_frags_z * 2) * 16 * HEAD_DIM * sizeof(DTypeIn);
FLASHINFER_CUDA_CALL(
Expand Down
1 change: 1 addition & 0 deletions include/flashinfer/fastdiv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
*/
#ifndef FLASHINFER_FASTDIV_CUH_
#define FLASHINFER_FASTDIV_CUH_
#include <cstdint>

namespace flashinfer {

Expand Down
Loading

0 comments on commit 82fd8c7

Please sign in to comment.