Skip to content

Commit 534c3c9

Browse files
authored
fix: Fix MLA TVM binding for the latest changes (#940)
This PR applies changes in #898 and #900 to the MLA TVM binding.
1 parent 3b07839 commit 534c3c9

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

tvm_binding/batch_mla_run.cu

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,16 @@ void BatchMLAPagedAttentionRun(DLTensor* float_workspace_buffer, DLTensor* int_w
9797
int_buffer_ptr, plan_info.merge_packed_offset_start_offset);
9898
params.merge_packed_offset_end =
9999
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_packed_offset_end_offset);
100-
params.merge_indptr =
101-
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_indptr_offset);
100+
params.merge_partial_packed_offset_start = GetPtrFromBaseOffset<IdType>(
101+
int_buffer_ptr, plan_info.merge_partial_packed_offset_start_offset);
102+
params.merge_partial_packed_offset_end = GetPtrFromBaseOffset<IdType>(
103+
int_buffer_ptr, plan_info.merge_partial_packed_offset_end_offset);
104+
params.merge_partial_stride =
105+
GetPtrFromBaseOffset<IdType>(int_buffer_ptr, plan_info.merge_partial_stride_offset);
102106
params.final_o = static_cast<DTypeO*>(o->data) + o->byte_offset / sizeof(DTypeO);
103107
params.final_lse = static_cast<float*>(lse->data) + lse->byte_offset / sizeof(float);
104108
params.partial_o =
105-
GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_o_offset);
109+
GetPtrFromBaseOffset<DTypeO>(float_buffer_ptr, plan_info.partial_o_offset);
106110
params.partial_lse =
107111
GetPtrFromBaseOffset<float>(float_buffer_ptr, plan_info.partial_lse_offset);
108112

0 commit comments

Comments
 (0)