Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass a dynamic token count to the cascade kernels #635

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2121,7 +2121,6 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P
const uint32_t num_qo_heads = params.num_qo_heads;
const uint32_t num_kv_heads = params.num_kv_heads;
const uint_fastdiv group_size_fastdiv(num_qo_heads / num_kv_heads);
const uint32_t total_num_rows = params.total_num_rows;
constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q);
constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q);
constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q);
Expand Down Expand Up @@ -2198,13 +2197,13 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
if constexpr (AttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse,
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
tmp_v, tmp_s, params.merge_indptr, o, lse, params.max_total_num_rows,
params.total_num_rows, num_qo_heads, HEAD_DIM, stream));
} else {
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o,
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
FLASHINFER_CUDA_CALL(
VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, params.max_total_num_rows,
params.total_num_rows, num_qo_heads, HEAD_DIM, stream));
}
}
}
Expand All @@ -2223,7 +2222,6 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa
const uint32_t num_qo_heads = params.num_qo_heads;
const uint32_t num_kv_heads = params.paged_kv.num_heads;
const uint_fastdiv group_size_fastdiv(num_qo_heads / num_kv_heads);
const uint32_t total_num_rows = params.total_num_rows;
constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q);
constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q);
constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q);
Expand Down Expand Up @@ -2300,13 +2298,13 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
if constexpr (AttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse,
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
tmp_v, tmp_s, params.merge_indptr, o, lse, params.max_total_num_rows,
params.total_num_rows, num_qo_heads, HEAD_DIM, stream));
} else {
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o,
total_num_rows, nullptr, num_qo_heads,
HEAD_DIM, stream));
FLASHINFER_CUDA_CALL(
VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, params.max_total_num_rows,
params.total_num_rows, num_qo_heads, HEAD_DIM, stream));
}
}
}
Expand Down
12 changes: 8 additions & 4 deletions include/flashinfer/attention/prefill_params.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ struct BatchPrefillRaggedParams {
IdType* o_indptr;
IdType* kv_chunk_size_ptr;
bool* block_valid_mask;
uint32_t total_num_rows;
uint32_t max_total_num_rows;
uint32_t* total_num_rows;
uint32_t padded_batch_size;
bool partition_kv;

Expand Down Expand Up @@ -178,7 +179,8 @@ struct BatchPrefillRaggedParams {
o_indptr(nullptr),
kv_chunk_size_ptr(nullptr),
block_valid_mask(nullptr),
total_num_rows(0),
max_total_num_rows(0),
total_num_rows(nullptr),
padded_batch_size(0),
partition_kv(false) {}

Expand Down Expand Up @@ -227,7 +229,8 @@ struct BatchPrefillPagedParams {
IdType* o_indptr;
bool* block_valid_mask;
IdType* kv_chunk_size_ptr;
uint32_t total_num_rows;
uint32_t max_total_num_rows;
uint32_t* total_num_rows;
uint32_t padded_batch_size;
bool partition_kv;

Expand Down Expand Up @@ -261,7 +264,8 @@ struct BatchPrefillPagedParams {
o_indptr(nullptr),
block_valid_mask(nullptr),
kv_chunk_size_ptr(nullptr),
total_num_rows(0),
max_total_num_rows(0),
total_num_rows(nullptr),
padded_batch_size(0),
partition_kv(false) {}

Expand Down
38 changes: 25 additions & 13 deletions include/flashinfer/attention/scheduler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@ inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, uin
struct PrefillPlanInfo {
int64_t padded_batch_size;
int64_t total_num_rows;
int64_t total_num_rows_offset;
int64_t cta_tile_q;
int64_t request_indices_offset;
int64_t qo_tile_indices_offset;
Expand All @@ -534,6 +535,7 @@ struct PrefillPlanInfo {
PrefillPlanInfo()
: padded_batch_size(0),
total_num_rows(0),
total_num_rows_offset(0),
cta_tile_q(0),
request_indices_offset(0),
qo_tile_indices_offset(0),
Expand All @@ -551,6 +553,7 @@ struct PrefillPlanInfo {
std::vector<int64_t> ToVector() const {
return {padded_batch_size,
total_num_rows,
total_num_rows_offset,
cta_tile_q,
request_indices_offset,
qo_tile_indices_offset,
Expand All @@ -567,25 +570,26 @@ struct PrefillPlanInfo {

// From std::vector<int64_t> to PrefillPlanInfo
void FromVector(const std::vector<int64_t>& vec) {
if (vec.size() != 14) {
if (vec.size() != 15) {
std::ostringstream err_msg;
err_msg << "PrefillPlanInfo::FromVector: vec.size() should be 14, but got " << vec.size();
FLASHINFER_ERROR(err_msg.str());
}
padded_batch_size = vec[0];
total_num_rows = vec[1];
cta_tile_q = vec[2];
request_indices_offset = vec[3];
qo_tile_indices_offset = vec[4];
kv_tile_indices_offset = vec[5];
merge_indptr_offset = vec[6];
o_indptr_offset = vec[7];
kv_chunk_size_ptr_offset = vec[8];
v_offset = vec[9];
s_offset = vec[10];
block_valid_mask_offset = vec[11];
enable_cuda_graph = vec[12];
split_kv = vec[13];
total_num_rows_offset = vec[2];
cta_tile_q = vec[3];
request_indices_offset = vec[4];
qo_tile_indices_offset = vec[5];
kv_tile_indices_offset = vec[6];
merge_indptr_offset = vec[7];
o_indptr_offset = vec[8];
kv_chunk_size_ptr_offset = vec[9];
v_offset = vec[10];
s_offset = vec[11];
block_valid_mask_offset = vec[12];
enable_cuda_graph = vec[13];
split_kv = vec[14];
}
};

Expand Down Expand Up @@ -640,6 +644,14 @@ inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_i
plan_info.kv_chunk_size_ptr_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr");

if (plan_info.enable_cuda_graph) {
plan_info.total_num_rows_offset =
int_allocator.aligned_alloc_offset(sizeof(uint32_t), 16, "batch_prefill_total_num_rows");
uint32_t* total_num_rows_h =
GetPtrFromBaseOffset<uint32_t>(page_locked_int_buffer, plan_info.total_num_rows_offset);
*total_num_rows_h = qo_indptr_h[batch_size];
}

IdType* request_indices_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.request_indices_offset);
IdType* qo_tile_indices_h =
Expand Down
12 changes: 10 additions & 2 deletions python/csrc/batch_prefill.cu
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,12 @@ void BatchPrefillWithRaggedKVCacheRun(
GetPtrFromBaseOffset<bool>(int_buffer_ptr, plan_info.block_valid_mask_offset);
}
}
params.total_num_rows = plan_info.total_num_rows;
params.padded_batch_size = plan_info.padded_batch_size;
params.max_total_num_rows = plan_info.total_num_rows;
if (plan_info.enable_cuda_graph) {
params.total_num_rows =
GetPtrFromBaseOffset<uint32_t>(int_buffer_ptr, plan_info.total_num_rows_offset);
}

cudaError_t status = cudaSuccess;

Expand Down Expand Up @@ -290,8 +294,12 @@ void BatchPrefillWithPagedKVCacheRun(
GetPtrFromBaseOffset<bool>(int_buffer_ptr, plan_info.block_valid_mask_offset);
}
}
params.total_num_rows = plan_info.total_num_rows;
params.padded_batch_size = plan_info.padded_batch_size;
params.max_total_num_rows = plan_info.total_num_rows;
if (plan_info.enable_cuda_graph) {
params.total_num_rows =
GetPtrFromBaseOffset<uint32_t>(int_buffer_ptr, plan_info.total_num_rows_offset);
}

cudaError_t status = cudaSuccess;

Expand Down
Loading