@@ -37,11 +37,11 @@ using tvm::ffi::Array;
3737using tvm::ffi::Optional;
3838
3939Array<int64_t > BatchDecodeWithPagedKVCachePlan (
40- Tensor float_workspace_buffer, Tensor int_workspace_buffer,
41- Tensor page_locked_int_workspace_buffer, Tensor indptr, int64_t batch_size,
40+ TensorView float_workspace_buffer, TensorView int_workspace_buffer,
41+ TensorView page_locked_int_workspace_buffer, TensorView indptr, int64_t batch_size,
4242 int64_t num_qo_heads, int64_t num_kv_heads, int64_t page_size, bool enable_cuda_graph,
4343 int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo,
44- Tensor empty_q_data, Tensor empty_kv_data) {
44+ TensorView empty_q_data, TensorView empty_kv_data) {
4545 size_t float_workspace_size_in_bytes =
4646 float_workspace_buffer->shape [0 ] * get_element_size (float_workspace_buffer);
4747 size_t int_workspace_size_in_bytes =
@@ -78,12 +78,14 @@ Array<int64_t> BatchDecodeWithPagedKVCachePlan(
7878 return Array (plan_info.ToVector ());
7979}
8080
81- void BatchDecodeWithPagedKVCacheRun (Tensor float_workspace_buffer, Tensor int_workspace_buffer,
82- Array<int64_t > plan_info_vec, Tensor q, Tensor paged_k_cache,
83- Tensor paged_v_cache, Tensor paged_kv_indptr,
84- Tensor paged_kv_indices, Tensor paged_kv_last_page_len,
85- Tensor o, Optional<Tensor> maybe_lse, int64_t kv_layout_code,
86- int64_t window_left, bool enable_pdl ADDITIONAL_FUNC_PARAMS) {
81+ void BatchDecodeWithPagedKVCacheRun (TensorView float_workspace_buffer,
82+ TensorView int_workspace_buffer, Array<int64_t > plan_info_vec,
83+ TensorView q, TensorView paged_k_cache,
84+ TensorView paged_v_cache, TensorView paged_kv_indptr,
85+ TensorView paged_kv_indices, TensorView paged_kv_last_page_len,
86+ TensorView o, Optional<TensorView> maybe_lse,
87+ int64_t kv_layout_code, int64_t window_left,
88+ bool enable_pdl ADDITIONAL_FUNC_PARAMS) {
8789 DecodePlanInfo plan_info;
8890 plan_info.FromVector (std::vector<int64_t >(plan_info_vec.begin (), plan_info_vec.end ()));
8991 QKVLayout kv_layout = static_cast <QKVLayout>(kv_layout_code);
0 commit comments