Skip to content
22 changes: 12 additions & 10 deletions csrc/batch_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ cudaError_t BatchPagedAttentionPersistent(const Params params_1, const Params pa

using namespace flashinfer;

Array<int64_t> BatchPagedAttentionPlan(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Tensor page_locked_int_workspace_buffer, Tensor qo_indptr,
Tensor kv_indptr, Tensor kv_len, int64_t batch_size,
int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_o, bool causal) {
Array<int64_t> BatchPagedAttentionPlan(TensorView float_workspace_buffer,
TensorView int_workspace_buffer,
TensorView page_locked_int_workspace_buffer,
TensorView qo_indptr, TensorView kv_indptr,
TensorView kv_len, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t head_dim_o, bool causal) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer);
size_t int_workspace_size_in_bytes =
Expand All @@ -63,11 +64,12 @@ Array<int64_t> BatchPagedAttentionPlan(Tensor float_workspace_buffer, Tensor int
return Array(plan_info.ToVector());
}

void BatchPagedAttentionRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Array<int64_t> plan_info_vec, Tensor q, Tensor k_cache, Tensor v_cache,
Tensor kv_indices, Tensor o, Optional<Tensor> maybe_lse,
int64_t mask_mode_code, int64_t layout_code, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t page_size,
void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_workspace_buffer,
Array<int64_t> plan_info_vec, TensorView q, TensorView k_cache,
TensorView v_cache, TensorView kv_indices, TensorView o,
Optional<TensorView> maybe_lse, int64_t mask_mode_code,
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t page_size,
double v_scale, // must use double due to pytorch binding
double sm_scale,
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS) {
Expand Down
23 changes: 12 additions & 11 deletions csrc/batch_attention_jit_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@
using tvm::ffi::Array;
using tvm::ffi::Optional;

Array<int64_t> BatchPagedAttentionPlan(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Tensor page_locked_int_workspace_buffer, Tensor qo_indptr,
Tensor kv_indptr, Tensor kv_len, int64_t batch_size,
int64_t num_qo_heads, int64_t num_kv_heads,
int64_t head_dim_o, bool causal);
Array<int64_t> BatchPagedAttentionPlan(TensorView float_workspace_buffer,
TensorView int_workspace_buffer,
TensorView page_locked_int_workspace_buffer,
TensorView qo_indptr, TensorView kv_indptr,
TensorView kv_len, int64_t batch_size, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t head_dim_o, bool causal);

void BatchPagedAttentionRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Array<int64_t> plan_info_vec, Tensor q, Tensor k_cache, Tensor v_cache,
Tensor kv_indices, Tensor o, Optional<Tensor> maybe_lse,
int64_t mask_mode_code, int64_t layout_code, int64_t num_qo_heads,
int64_t num_kv_heads, int64_t page_size, double v_scale,
double sm_scale,
void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_workspace_buffer,
Array<int64_t> plan_info_vec, TensorView q, TensorView k_cache,
TensorView v_cache, TensorView kv_indices, TensorView o,
Optional<TensorView> maybe_lse, int64_t mask_mode_code,
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
int64_t page_size, double v_scale, double sm_scale,
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, &BatchPagedAttentionPlan);
Expand Down
20 changes: 11 additions & 9 deletions csrc/batch_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ using tvm::ffi::Array;
using tvm::ffi::Optional;

Array<int64_t> BatchDecodeWithPagedKVCachePlan(
Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Tensor page_locked_int_workspace_buffer, Tensor indptr, int64_t batch_size,
TensorView float_workspace_buffer, TensorView int_workspace_buffer,
TensorView page_locked_int_workspace_buffer, TensorView indptr, int64_t batch_size,
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph,
int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo,
Tensor empty_q_data, Tensor empty_kv_data) {
TensorView empty_q_data, TensorView empty_kv_data) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer);
size_t int_workspace_size_in_bytes =
Expand Down Expand Up @@ -78,12 +78,14 @@ Array<int64_t> BatchDecodeWithPagedKVCachePlan(
return Array(plan_info.ToVector());
}

void BatchDecodeWithPagedKVCacheRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Array<int64_t> plan_info_vec, Tensor q, Tensor paged_k_cache,
Tensor paged_v_cache, Tensor paged_kv_indptr,
Tensor paged_kv_indices, Tensor paged_kv_last_page_len,
Tensor o, Optional<Tensor> maybe_lse, int64_t kv_layout_code,
int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS) {
void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer,
TensorView int_workspace_buffer, Array<int64_t> plan_info_vec,
TensorView q, TensorView paged_k_cache,
TensorView paged_v_cache, TensorView paged_kv_indptr,
TensorView paged_kv_indices, TensorView paged_kv_last_page_len,
TensorView o, Optional<TensorView> maybe_lse,
int64_t kv_layout_code, int64_t window_left,
bool enable_pdl ADDITIONAL_FUNC_PARAMS) {
DecodePlanInfo plan_info;
plan_info.FromVector(std::vector<int64_t>(plan_info_vec.begin(), plan_info_vec.end()));
QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code);
Expand Down
20 changes: 11 additions & 9 deletions csrc/batch_decode_jit_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,20 @@ using tvm::ffi::Array;
using tvm::ffi::Optional;

Array<int64_t> BatchDecodeWithPagedKVCachePlan(
Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Tensor page_locked_int_workspace_buffer, Tensor indptr, int64_t batch_size,
TensorView float_workspace_buffer, TensorView int_workspace_buffer,
TensorView page_locked_int_workspace_buffer, TensorView indptr, int64_t batch_size,
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph,
int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo,
Tensor empty_q_data, Tensor empty_kv_data);
TensorView empty_q_data, TensorView empty_kv_data);

void BatchDecodeWithPagedKVCacheRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Array<int64_t> plan_info_vec, Tensor q, Tensor paged_k_cache,
Tensor paged_v_cache, Tensor paged_kv_indptr,
Tensor paged_kv_indices, Tensor paged_kv_last_page_len,
Tensor o, Optional<Tensor> maybe_lse, int64_t kv_layout_code,
int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS);
void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer,
TensorView int_workspace_buffer, Array<int64_t> plan_info_vec,
TensorView q, TensorView paged_k_cache,
TensorView paged_v_cache, TensorView paged_kv_indptr,
TensorView paged_kv_indices, TensorView paged_kv_last_page_len,
TensorView o, Optional<TensorView> maybe_lse,
int64_t kv_layout_code, int64_t window_left,
bool enable_pdl ADDITIONAL_FUNC_PARAMS);

// Batched decode with paged KV-Cache plan
TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, BatchDecodeWithPagedKVCachePlan);
Expand Down
23 changes: 11 additions & 12 deletions csrc/batch_decode_mla_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,20 @@
using tvm::ffi::Array;
using tvm::ffi::Optional;

Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(Tensor float_workspace_buffer,
Tensor int_workspace_buffer,
Tensor page_locked_int_workspace_buffer,
Tensor indptr, int64_t batch_size,
Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(TensorView float_workspace_buffer,
TensorView int_workspace_buffer,
TensorView page_locked_int_workspace_buffer,
TensorView indptr, int64_t batch_size,
int64_t num_qo_heads, int64_t page_size,
bool enable_cuda_graph);

void BatchDecodeWithPagedKVCacheRunMLA(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Array<int64_t> plan_info_vec, Tensor q_nope, Tensor q_pe,
Tensor paged_ckv_cache, Tensor paged_kpe_cache,
Tensor paged_kv_indptr, Tensor paged_kv_indices,
Tensor paged_kv_last_page_len, Tensor o, double sm_scale,
int64_t window_left, double logits_soft_cap,
double rope_scale, double rope_theta,
Optional<Tensor> maybe_lse, bool enable_pdl);
void BatchDecodeWithPagedKVCacheRunMLA(
TensorView float_workspace_buffer, TensorView int_workspace_buffer,
Array<int64_t> plan_info_vec, TensorView q_nope, TensorView q_pe, TensorView paged_ckv_cache,
TensorView paged_kpe_cache, TensorView paged_kv_indptr, TensorView paged_kv_indices,
TensorView paged_kv_last_page_len, TensorView o, double sm_scale, int64_t window_left,
double logits_soft_cap, double rope_scale, double rope_theta, Optional<TensorView> maybe_lse,
bool enable_pdl);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, BatchDecodeWithPagedKVCachePlanMLA);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, BatchDecodeWithPagedKVCacheRunMLA);
20 changes: 11 additions & 9 deletions csrc/batch_decode_mla_cute_sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ using namespace flashinfer;
using tvm::ffi::Array;
using tvm::ffi::Optional;

Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(ffi::Tensor float_workspace_buffer,
ffi::Tensor int_workspace_buffer,
ffi::Tensor page_locked_int_workspace_buffer,
ffi::Tensor indptr, int64_t batch_size,
Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(ffi::TensorView float_workspace_buffer,
ffi::TensorView int_workspace_buffer,
ffi::TensorView page_locked_int_workspace_buffer,
ffi::TensorView indptr, int64_t batch_size,
int64_t num_qo_heads, int64_t page_size,
bool enable_cuda_graph) {
size_t float_workspace_size_in_bytes =
Expand Down Expand Up @@ -43,11 +43,13 @@ Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(ffi::Tensor float_workspace_bu
}

void BatchDecodeWithPagedKVCacheRunMLA(
ffi::Tensor float_workspace_buffer, ffi::Tensor int_workspace_buffer,
Array<int64_t> plan_info_vec, ffi::Tensor q_nope, ffi::Tensor q_pe, ffi::Tensor paged_ckv_cache,
ffi::Tensor paged_kpe_cache, ffi::Tensor paged_kv_indptr, ffi::Tensor paged_kv_indices,
ffi::Tensor paged_kv_last_page_len, ffi::Tensor o, double sm_scale, int64_t window_left,
double logits_soft_cap, double rope_scale, double rope_theta, Optional<ffi::Tensor> maybe_lse,
ffi::TensorView float_workspace_buffer, ffi::TensorView int_workspace_buffer,
Array<int64_t> plan_info_vec, ffi::TensorView q_nope, ffi::TensorView q_pe,
ffi::TensorView paged_ckv_cache, ffi::TensorView paged_kpe_cache,
ffi::TensorView paged_kv_indptr, ffi::TensorView paged_kv_indices,
ffi::TensorView paged_kv_last_page_len, ffi::TensorView o, double sm_scale, int64_t window_left,
double logits_soft_cap, double rope_scale, double rope_theta,
Optional<ffi::TensorView> maybe_lse,
bool enable_pdl // fake placeholder, sm80 does not support pdl
) {
DecodePlanInfo plan_info;
Expand Down
8 changes: 4 additions & 4 deletions csrc/batch_decode_mla_plan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ using namespace flashinfer;

using tvm::ffi::Array;

Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(Tensor float_workspace_buffer,
Tensor int_workspace_buffer,
Tensor page_locked_int_workspace_buffer,
Tensor indptr, int64_t batch_size,
Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(TensorView float_workspace_buffer,
TensorView int_workspace_buffer,
TensorView page_locked_int_workspace_buffer,
TensorView indptr, int64_t batch_size,
int64_t num_qo_heads, int64_t page_size,
bool enable_cuda_graph) {
cudaSetDevice(float_workspace_buffer->device.device_id);
Expand Down
15 changes: 7 additions & 8 deletions csrc/batch_decode_mla_run.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@ using namespace flashinfer;
using tvm::ffi::Array;
using tvm::ffi::Optional;

void BatchDecodeWithPagedKVCacheRunMLA(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Array<int64_t> plan_info_vec, Tensor q_nope, Tensor q_pe,
Tensor paged_ckv_cache, Tensor paged_kpe_cache,
Tensor paged_kv_indptr, Tensor paged_kv_indices,
Tensor paged_kv_last_page_len, Tensor o, double sm_scale,
int64_t window_left, double logits_soft_cap,
double rope_scale, double rope_theta,
Optional<Tensor> maybe_lse, bool enable_pdl) {
void BatchDecodeWithPagedKVCacheRunMLA(
TensorView float_workspace_buffer, TensorView int_workspace_buffer,
Array<int64_t> plan_info_vec, TensorView q_nope, TensorView q_pe, TensorView paged_ckv_cache,
TensorView paged_kpe_cache, TensorView paged_kv_indptr, TensorView paged_kv_indices,
TensorView paged_kv_last_page_len, TensorView o, double sm_scale, int64_t window_left,
double logits_soft_cap, double rope_scale, double rope_theta, Optional<TensorView> maybe_lse,
bool enable_pdl) {
DecodePlanInfo plan_info;
plan_info.FromVector(std::vector<int64_t>(plan_info_vec.begin(), plan_info_vec.end()));

Expand Down
19 changes: 10 additions & 9 deletions csrc/batch_mla_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@
using tvm::ffi::Array;
using tvm::ffi::Optional;

Array<int64_t> BatchMLAPagedAttentionPlan(Tensor float_workspace_buffer,
Tensor int_workspace_buffer,
Tensor page_locked_int_workspace_buffer, Tensor qo_indptr,
Tensor kv_indptr, Tensor kv_len, int64_t num_heads,
int64_t head_dim_o, bool causal);
Array<int64_t> BatchMLAPagedAttentionPlan(TensorView float_workspace_buffer,
TensorView int_workspace_buffer,
TensorView page_locked_int_workspace_buffer,
TensorView qo_indptr, TensorView kv_indptr,
TensorView kv_len, int64_t num_heads, int64_t head_dim_o,
bool causal);

void BatchMLAPagedAttentionRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
Array<int64_t> plan_info_vec, Tensor q_nope, Tensor q_pe,
Tensor ckv_cache, Tensor kpe_cache, Tensor kv_indices, Tensor o,
Optional<Tensor> maybe_lse, int64_t mask_mode_code,
void BatchMLAPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_workspace_buffer,
Array<int64_t> plan_info_vec, TensorView q_nope, TensorView q_pe,
TensorView ckv_cache, TensorView kpe_cache, TensorView kv_indices,
TensorView o, Optional<TensorView> maybe_lse, int64_t mask_mode_code,
int64_t num_heads, int64_t page_size, double sm_scale);

TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, BatchMLAPagedAttentionPlan);
Expand Down
11 changes: 6 additions & 5 deletions csrc/batch_mla_plan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ using namespace flashinfer;

using tvm::ffi::Array;

Array<int64_t> BatchMLAPagedAttentionPlan(Tensor float_workspace_buffer,
Tensor int_workspace_buffer,
Tensor page_locked_int_workspace_buffer, Tensor qo_indptr,
Tensor kv_indptr, Tensor kv_len, int64_t num_heads,
int64_t head_dim_o, bool causal) {
Array<int64_t> BatchMLAPagedAttentionPlan(TensorView float_workspace_buffer,
TensorView int_workspace_buffer,
TensorView page_locked_int_workspace_buffer,
TensorView qo_indptr, TensorView kv_indptr,
TensorView kv_len, int64_t num_heads, int64_t head_dim_o,
bool causal) {
size_t float_workspace_size_in_bytes =
float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer);
size_t int_workspace_size_in_bytes =
Expand Down
Loading