Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[KVCACHE] Improved schedule for prefill attention #17482

Merged
merged 1 commit into from
Oct 28, 2024

Conversation

krishnaraj36
Copy link
Contributor

@krishnaraj36 krishnaraj36 commented Oct 22, 2024

Improvements -
Added Tranpose to K for better Vectorization during Matmul. Improved Load Schedule.
Improved a bit more than 2x is most cases.
Llama-2 7B observation
-----------kernel----------------baseline----------optimized

  • ---batch_prefill_ragged_kv------15 ms-------------7.1 ms

This PR fixes the issue addressed in the PR #17446. The correctness issue is caused by incorrect code-generation during the unroll phase. Thus, we removed the explicit unroll and noticed little to no performance degradation.

We generated OpenCL kernels extracting the generated modules by setting num_qo_heads=28 in
https://github.qualcomm.com/gpgpu/apache-tvm/blob/85e15d494d5a42360859941cbc972c4f175c3b94/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py#L36
Original PR Codegen

int cur_L_3 = ((((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) + 1) / 7) + (((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) + 1) % 7) >> 31)) + q_indptr[(b_idx_1 + q_indptr_elem_offset)]);
if (cur_L_3 < q_indptr[((b_idx_1 + q_indptr_elem_offset) + 1)]) {
    vstore4((convert_half4((O_local[3] / ((float4)(d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)]))))), 0, output + (((((cur_L_3 * 3584) + ((convert_int(get_group_id(1))) * 896)) + ((((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) + 1) % 7) + (7 & (((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) + 1) % 7) >> 31))) * 128)) + (((convert_int(get_local_id(0))) & 15) * 8)) + 4));
}
int cur_L_4 = ((((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) - 2147483637) / 7) - -306783377) + q_indptr[(b_idx_1 + q_indptr_elem_offset)]);
if (cur_L_4 < q_indptr[((b_idx_1 + q_indptr_elem_offset) + 1)]) {
    vstore4((convert_half4((O_local[4] / ((float4)(d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)]))))), 0, output + ((((cur_L_4 * 3584) + ((convert_int(get_group_id(1))) * 896)) + (((((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + LH_start) - 2147483637) % 7) * 128)) + (((convert_int(get_local_id(0))) & 15) * 8)));
}

In the O_store block we notice large and incorrect pointer offsets were being generated during subsequent stages of unroll. This can be indirectly noted zero elements contained in the output and compute instability.

Fusing the unroll loops to unroll together doesn't seem to resolve this.

Oddly enough, the initial test case doesn't seem to trigger the issue and works as intended.

int cur_L_3 = ((((((convert_int(get_local_id(0))) >> 4) + ((LH_start + 1) >> 2)) >> 1) + q_indptr[(b_idx_1 + q_indptr_elem_offset)]) + (convert_int(get_local_id(1))));
if (cur_L_3 < q_indptr[((b_idx_1 + q_indptr_elem_offset) + 1)]) {
    vstore4((convert_half4((O_local[3] / ((float4)(d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 1)]))))), 0, output + (((((cur_L_3 * 4096) + ((convert_int(get_group_id(1))) * 1024)) + (((((((convert_int(get_local_id(0))) >> 4) * 4) + (LH_start & 7)) + 1) & 7) * 128)) + (((convert_int(get_local_id(0))) & 15) * 8)) + 4));
}
int cur_L_4 = ((((((convert_int(get_local_id(0))) >> 4) + ((LH_start + 2) >> 2)) >> 1) + q_indptr[(b_idx_1 + q_indptr_elem_offset)]) + (convert_int(get_local_id(1))));
 if (cur_L_4 < q_indptr[((b_idx_1 + q_indptr_elem_offset) + 1)]) {
    vstore4((convert_half4((O_local[4] / ((float4)(d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)], d_smem[((((convert_int(get_local_id(1))) * 8) + (((convert_int(get_local_id(0))) >> 4) * 4)) + 2)]))))), 0, output + ((((cur_L_4 * 4096) + ((convert_int(get_group_id(1))) * 1024)) + (((((((convert_int(get_local_id(0))) >> 4) * 4) + (LH_start & 7)) + 2) & 7) * 128)) + (((convert_int(get_local_id(0))) & 15) * 8)));
}

Improvements

Added Tranpose to K for better Vectorization during Matmul.
Improved Load Schedule.
Improved a bit more than 2x is most cases.
Llama-2 7B observation
-----------kernel----------------baseline----------optimized-
---batch_prefill_ragged_kv------15 ms-------------7.1 ms
@krishnaraj36
Copy link
Contributor Author

@MasterJH5574 @tqchen
We have fixed the issue raise in PR (#17466).
Can you please look at this PR.

Copy link
Contributor

@MasterJH5574 MasterJH5574 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @krishnaraj36 so much for the fix!

@MasterJH5574
Copy link
Contributor

I have also observed the “large and incorrect” pointer offset before but I didn't get time to nail down the issue. Roughly I remember it's generated by some floordiv simplification in src/tir/transforms/lower_intrin.cc.

@krishnaraj36
Copy link
Contributor Author

Thank you @krishnaraj36 so much for the fix!
@MasterJH5574
There is only one change(removed sch.unroll(xi) ) on previous commit which was reverted.

@srkreddy1238 srkreddy1238 merged commit e3e27f5 into apache:main Oct 28, 2024
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants