Skip to content

Commit 4b55b26

Browse files
cyx-6yzh119
andauthored
use ffi::TensorView instead of ffi::Tensor (#1844)
<!-- .github/pull_request_template.md --> ## 📌 Description This PR migrates the arguments from `ffi::Tensor` instead of `ffi::TensorView`. - `ffi::Tensor`: Owning tensor friendly to `torch.Tensor`. - `ffi::TensorView`: Non-owning tensor, equivalent to `dltensor*`, which is commonly used in JAX/XLA. So we do such migration for broader support. <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Zihao Ye <expye@outlook.com>
1 parent 8276d03 commit 4b55b26

File tree

89 files changed

+925
-833
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

89 files changed

+925
-833
lines changed

csrc/batch_attention.cu

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,12 @@ cudaError_t BatchPagedAttentionPersistent(const Params params_1, const Params pa
3535

3636
using namespace flashinfer;
3737

38-
Array<int64_t> BatchPagedAttentionPlan(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
39-
Tensor page_locked_int_workspace_buffer, Tensor qo_indptr,
40-
Tensor kv_indptr, Tensor kv_len, int64_t batch_size,
41-
int64_t num_qo_heads, int64_t num_kv_heads,
42-
int64_t head_dim_o, bool causal) {
38+
Array<int64_t> BatchPagedAttentionPlan(TensorView float_workspace_buffer,
39+
TensorView int_workspace_buffer,
40+
TensorView page_locked_int_workspace_buffer,
41+
TensorView qo_indptr, TensorView kv_indptr,
42+
TensorView kv_len, int64_t batch_size, int64_t num_qo_heads,
43+
int64_t num_kv_heads, int64_t head_dim_o, bool causal) {
4344
size_t float_workspace_size_in_bytes =
4445
float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer);
4546
size_t int_workspace_size_in_bytes =
@@ -63,11 +64,12 @@ Array<int64_t> BatchPagedAttentionPlan(Tensor float_workspace_buffer, Tensor int
6364
return Array(plan_info.ToVector());
6465
}
6566

66-
void BatchPagedAttentionRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
67-
Array<int64_t> plan_info_vec, Tensor q, Tensor k_cache, Tensor v_cache,
68-
Tensor kv_indices, Tensor o, Optional<Tensor> maybe_lse,
69-
int64_t mask_mode_code, int64_t layout_code, int64_t num_qo_heads,
70-
int64_t num_kv_heads, int64_t page_size,
67+
void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_workspace_buffer,
68+
Array<int64_t> plan_info_vec, TensorView q, TensorView k_cache,
69+
TensorView v_cache, TensorView kv_indices, TensorView o,
70+
Optional<TensorView> maybe_lse, int64_t mask_mode_code,
71+
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
72+
int64_t page_size,
7173
double v_scale, // must use double due to pytorch binding
7274
double sm_scale,
7375
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS) {

csrc/batch_attention_jit_binding.cu

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,19 @@
1919
using tvm::ffi::Array;
2020
using tvm::ffi::Optional;
2121

22-
Array<int64_t> BatchPagedAttentionPlan(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
23-
Tensor page_locked_int_workspace_buffer, Tensor qo_indptr,
24-
Tensor kv_indptr, Tensor kv_len, int64_t batch_size,
25-
int64_t num_qo_heads, int64_t num_kv_heads,
26-
int64_t head_dim_o, bool causal);
22+
Array<int64_t> BatchPagedAttentionPlan(TensorView float_workspace_buffer,
23+
TensorView int_workspace_buffer,
24+
TensorView page_locked_int_workspace_buffer,
25+
TensorView qo_indptr, TensorView kv_indptr,
26+
TensorView kv_len, int64_t batch_size, int64_t num_qo_heads,
27+
int64_t num_kv_heads, int64_t head_dim_o, bool causal);
2728

28-
void BatchPagedAttentionRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
29-
Array<int64_t> plan_info_vec, Tensor q, Tensor k_cache, Tensor v_cache,
30-
Tensor kv_indices, Tensor o, Optional<Tensor> maybe_lse,
31-
int64_t mask_mode_code, int64_t layout_code, int64_t num_qo_heads,
32-
int64_t num_kv_heads, int64_t page_size, double v_scale,
33-
double sm_scale,
29+
void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_workspace_buffer,
30+
Array<int64_t> plan_info_vec, TensorView q, TensorView k_cache,
31+
TensorView v_cache, TensorView kv_indices, TensorView o,
32+
Optional<TensorView> maybe_lse, int64_t mask_mode_code,
33+
int64_t layout_code, int64_t num_qo_heads, int64_t num_kv_heads,
34+
int64_t page_size, double v_scale, double sm_scale,
3435
double logits_soft_cap ADDITIONAL_FUNC_PARAMS PROFILER_FUNC_PARAMS);
3536

3637
TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, &BatchPagedAttentionPlan);

csrc/batch_decode.cu

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ using tvm::ffi::Array;
3737
using tvm::ffi::Optional;
3838

3939
Array<int64_t> BatchDecodeWithPagedKVCachePlan(
40-
Tensor float_workspace_buffer, Tensor int_workspace_buffer,
41-
Tensor page_locked_int_workspace_buffer, Tensor indptr, int64_t batch_size,
40+
TensorView float_workspace_buffer, TensorView int_workspace_buffer,
41+
TensorView page_locked_int_workspace_buffer, TensorView indptr, int64_t batch_size,
4242
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph,
4343
int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo,
44-
Tensor empty_q_data, Tensor empty_kv_data) {
44+
TensorView empty_q_data, TensorView empty_kv_data) {
4545
size_t float_workspace_size_in_bytes =
4646
float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer);
4747
size_t int_workspace_size_in_bytes =
@@ -78,12 +78,14 @@ Array<int64_t> BatchDecodeWithPagedKVCachePlan(
7878
return Array(plan_info.ToVector());
7979
}
8080

81-
void BatchDecodeWithPagedKVCacheRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
82-
Array<int64_t> plan_info_vec, Tensor q, Tensor paged_k_cache,
83-
Tensor paged_v_cache, Tensor paged_kv_indptr,
84-
Tensor paged_kv_indices, Tensor paged_kv_last_page_len,
85-
Tensor o, Optional<Tensor> maybe_lse, int64_t kv_layout_code,
86-
int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS) {
81+
void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer,
82+
TensorView int_workspace_buffer, Array<int64_t> plan_info_vec,
83+
TensorView q, TensorView paged_k_cache,
84+
TensorView paged_v_cache, TensorView paged_kv_indptr,
85+
TensorView paged_kv_indices, TensorView paged_kv_last_page_len,
86+
TensorView o, Optional<TensorView> maybe_lse,
87+
int64_t kv_layout_code, int64_t window_left,
88+
bool enable_pdl ADDITIONAL_FUNC_PARAMS) {
8789
DecodePlanInfo plan_info;
8890
plan_info.FromVector(std::vector<int64_t>(plan_info_vec.begin(), plan_info_vec.end()));
8991
QKVLayout kv_layout = static_cast<QKVLayout>(kv_layout_code);

csrc/batch_decode_jit_binding.cu

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,20 @@ using tvm::ffi::Array;
2121
using tvm::ffi::Optional;
2222

2323
Array<int64_t> BatchDecodeWithPagedKVCachePlan(
24-
Tensor float_workspace_buffer, Tensor int_workspace_buffer,
25-
Tensor page_locked_int_workspace_buffer, Tensor indptr, int64_t batch_size,
24+
TensorView float_workspace_buffer, TensorView int_workspace_buffer,
25+
TensorView page_locked_int_workspace_buffer, TensorView indptr, int64_t batch_size,
2626
int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph,
2727
int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo,
28-
Tensor empty_q_data, Tensor empty_kv_data);
28+
TensorView empty_q_data, TensorView empty_kv_data);
2929

30-
void BatchDecodeWithPagedKVCacheRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
31-
Array<int64_t> plan_info_vec, Tensor q, Tensor paged_k_cache,
32-
Tensor paged_v_cache, Tensor paged_kv_indptr,
33-
Tensor paged_kv_indices, Tensor paged_kv_last_page_len,
34-
Tensor o, Optional<Tensor> maybe_lse, int64_t kv_layout_code,
35-
int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS);
30+
void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer,
31+
TensorView int_workspace_buffer, Array<int64_t> plan_info_vec,
32+
TensorView q, TensorView paged_k_cache,
33+
TensorView paged_v_cache, TensorView paged_kv_indptr,
34+
TensorView paged_kv_indices, TensorView paged_kv_last_page_len,
35+
TensorView o, Optional<TensorView> maybe_lse,
36+
int64_t kv_layout_code, int64_t window_left,
37+
bool enable_pdl ADDITIONAL_FUNC_PARAMS);
3638

3739
// Batched decode with paged KV-Cache plan
3840
TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, BatchDecodeWithPagedKVCachePlan);

csrc/batch_decode_mla_binding.cu

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,20 @@
55
using tvm::ffi::Array;
66
using tvm::ffi::Optional;
77

8-
Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(Tensor float_workspace_buffer,
9-
Tensor int_workspace_buffer,
10-
Tensor page_locked_int_workspace_buffer,
11-
Tensor indptr, int64_t batch_size,
8+
Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(TensorView float_workspace_buffer,
9+
TensorView int_workspace_buffer,
10+
TensorView page_locked_int_workspace_buffer,
11+
TensorView indptr, int64_t batch_size,
1212
int64_t num_qo_heads, int64_t page_size,
1313
bool enable_cuda_graph);
1414

15-
void BatchDecodeWithPagedKVCacheRunMLA(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
16-
Array<int64_t> plan_info_vec, Tensor q_nope, Tensor q_pe,
17-
Tensor paged_ckv_cache, Tensor paged_kpe_cache,
18-
Tensor paged_kv_indptr, Tensor paged_kv_indices,
19-
Tensor paged_kv_last_page_len, Tensor o, double sm_scale,
20-
int64_t window_left, double logits_soft_cap,
21-
double rope_scale, double rope_theta,
22-
Optional<Tensor> maybe_lse, bool enable_pdl);
15+
void BatchDecodeWithPagedKVCacheRunMLA(
16+
TensorView float_workspace_buffer, TensorView int_workspace_buffer,
17+
Array<int64_t> plan_info_vec, TensorView q_nope, TensorView q_pe, TensorView paged_ckv_cache,
18+
TensorView paged_kpe_cache, TensorView paged_kv_indptr, TensorView paged_kv_indices,
19+
TensorView paged_kv_last_page_len, TensorView o, double sm_scale, int64_t window_left,
20+
double logits_soft_cap, double rope_scale, double rope_theta, Optional<TensorView> maybe_lse,
21+
bool enable_pdl);
2322

2423
TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, BatchDecodeWithPagedKVCachePlanMLA);
2524
TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, BatchDecodeWithPagedKVCacheRunMLA);

csrc/batch_decode_mla_cute_sm80.cu

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ using namespace flashinfer;
1111
using tvm::ffi::Array;
1212
using tvm::ffi::Optional;
1313

14-
Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(ffi::Tensor float_workspace_buffer,
15-
ffi::Tensor int_workspace_buffer,
16-
ffi::Tensor page_locked_int_workspace_buffer,
17-
ffi::Tensor indptr, int64_t batch_size,
14+
Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(ffi::TensorView float_workspace_buffer,
15+
ffi::TensorView int_workspace_buffer,
16+
ffi::TensorView page_locked_int_workspace_buffer,
17+
ffi::TensorView indptr, int64_t batch_size,
1818
int64_t num_qo_heads, int64_t page_size,
1919
bool enable_cuda_graph) {
2020
size_t float_workspace_size_in_bytes =
@@ -43,11 +43,13 @@ Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(ffi::Tensor float_workspace_bu
4343
}
4444

4545
void BatchDecodeWithPagedKVCacheRunMLA(
46-
ffi::Tensor float_workspace_buffer, ffi::Tensor int_workspace_buffer,
47-
Array<int64_t> plan_info_vec, ffi::Tensor q_nope, ffi::Tensor q_pe, ffi::Tensor paged_ckv_cache,
48-
ffi::Tensor paged_kpe_cache, ffi::Tensor paged_kv_indptr, ffi::Tensor paged_kv_indices,
49-
ffi::Tensor paged_kv_last_page_len, ffi::Tensor o, double sm_scale, int64_t window_left,
50-
double logits_soft_cap, double rope_scale, double rope_theta, Optional<ffi::Tensor> maybe_lse,
46+
ffi::TensorView float_workspace_buffer, ffi::TensorView int_workspace_buffer,
47+
Array<int64_t> plan_info_vec, ffi::TensorView q_nope, ffi::TensorView q_pe,
48+
ffi::TensorView paged_ckv_cache, ffi::TensorView paged_kpe_cache,
49+
ffi::TensorView paged_kv_indptr, ffi::TensorView paged_kv_indices,
50+
ffi::TensorView paged_kv_last_page_len, ffi::TensorView o, double sm_scale, int64_t window_left,
51+
double logits_soft_cap, double rope_scale, double rope_theta,
52+
Optional<ffi::TensorView> maybe_lse,
5153
bool enable_pdl // fake placeholder, sm80 does not support pdl
5254
) {
5355
DecodePlanInfo plan_info;

csrc/batch_decode_mla_plan.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ using namespace flashinfer;
99

1010
using tvm::ffi::Array;
1111

12-
Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(Tensor float_workspace_buffer,
13-
Tensor int_workspace_buffer,
14-
Tensor page_locked_int_workspace_buffer,
15-
Tensor indptr, int64_t batch_size,
12+
Array<int64_t> BatchDecodeWithPagedKVCachePlanMLA(TensorView float_workspace_buffer,
13+
TensorView int_workspace_buffer,
14+
TensorView page_locked_int_workspace_buffer,
15+
TensorView indptr, int64_t batch_size,
1616
int64_t num_qo_heads, int64_t page_size,
1717
bool enable_cuda_graph) {
1818
cudaSetDevice(float_workspace_buffer->device.device_id);

csrc/batch_decode_mla_run.cu

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,13 @@ using namespace flashinfer;
1010
using tvm::ffi::Array;
1111
using tvm::ffi::Optional;
1212

13-
void BatchDecodeWithPagedKVCacheRunMLA(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
14-
Array<int64_t> plan_info_vec, Tensor q_nope, Tensor q_pe,
15-
Tensor paged_ckv_cache, Tensor paged_kpe_cache,
16-
Tensor paged_kv_indptr, Tensor paged_kv_indices,
17-
Tensor paged_kv_last_page_len, Tensor o, double sm_scale,
18-
int64_t window_left, double logits_soft_cap,
19-
double rope_scale, double rope_theta,
20-
Optional<Tensor> maybe_lse, bool enable_pdl) {
13+
void BatchDecodeWithPagedKVCacheRunMLA(
14+
TensorView float_workspace_buffer, TensorView int_workspace_buffer,
15+
Array<int64_t> plan_info_vec, TensorView q_nope, TensorView q_pe, TensorView paged_ckv_cache,
16+
TensorView paged_kpe_cache, TensorView paged_kv_indptr, TensorView paged_kv_indices,
17+
TensorView paged_kv_last_page_len, TensorView o, double sm_scale, int64_t window_left,
18+
double logits_soft_cap, double rope_scale, double rope_theta, Optional<TensorView> maybe_lse,
19+
bool enable_pdl) {
2120
DecodePlanInfo plan_info;
2221
plan_info.FromVector(std::vector<int64_t>(plan_info_vec.begin(), plan_info_vec.end()));
2322

csrc/batch_mla_binding.cu

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,17 @@
2020
using tvm::ffi::Array;
2121
using tvm::ffi::Optional;
2222

23-
Array<int64_t> BatchMLAPagedAttentionPlan(Tensor float_workspace_buffer,
24-
Tensor int_workspace_buffer,
25-
Tensor page_locked_int_workspace_buffer, Tensor qo_indptr,
26-
Tensor kv_indptr, Tensor kv_len, int64_t num_heads,
27-
int64_t head_dim_o, bool causal);
23+
Array<int64_t> BatchMLAPagedAttentionPlan(TensorView float_workspace_buffer,
24+
TensorView int_workspace_buffer,
25+
TensorView page_locked_int_workspace_buffer,
26+
TensorView qo_indptr, TensorView kv_indptr,
27+
TensorView kv_len, int64_t num_heads, int64_t head_dim_o,
28+
bool causal);
2829

29-
void BatchMLAPagedAttentionRun(Tensor float_workspace_buffer, Tensor int_workspace_buffer,
30-
Array<int64_t> plan_info_vec, Tensor q_nope, Tensor q_pe,
31-
Tensor ckv_cache, Tensor kpe_cache, Tensor kv_indices, Tensor o,
32-
Optional<Tensor> maybe_lse, int64_t mask_mode_code,
30+
void BatchMLAPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_workspace_buffer,
31+
Array<int64_t> plan_info_vec, TensorView q_nope, TensorView q_pe,
32+
TensorView ckv_cache, TensorView kpe_cache, TensorView kv_indices,
33+
TensorView o, Optional<TensorView> maybe_lse, int64_t mask_mode_code,
3334
int64_t num_heads, int64_t page_size, double sm_scale);
3435

3536
TVM_FFI_DLL_EXPORT_TYPED_FUNC(plan, BatchMLAPagedAttentionPlan);

csrc/batch_mla_plan.cu

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,12 @@ using namespace flashinfer;
2323

2424
using tvm::ffi::Array;
2525

26-
Array<int64_t> BatchMLAPagedAttentionPlan(Tensor float_workspace_buffer,
27-
Tensor int_workspace_buffer,
28-
Tensor page_locked_int_workspace_buffer, Tensor qo_indptr,
29-
Tensor kv_indptr, Tensor kv_len, int64_t num_heads,
30-
int64_t head_dim_o, bool causal) {
26+
Array<int64_t> BatchMLAPagedAttentionPlan(TensorView float_workspace_buffer,
27+
TensorView int_workspace_buffer,
28+
TensorView page_locked_int_workspace_buffer,
29+
TensorView qo_indptr, TensorView kv_indptr,
30+
TensorView kv_len, int64_t num_heads, int64_t head_dim_o,
31+
bool causal) {
3132
size_t float_workspace_size_in_bytes =
3233
float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer);
3334
size_t int_workspace_size_in_bytes =

0 commit comments

Comments
 (0)