Support RoPE position info in batch prefill/decode kernels #69
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds q/k position information to batch prefill/decode kernels. More specifically, the kernel now accepts two additional arrays:
q_rope_position
with shape(total_q_len,)
, denoting the in-sequence position of each position in the input q.k_rope_pos_offset
with shape(num_sequence,)
, denoting the start position of each sequence in k.These two arrays helps on-the-fly calculate RoPE in multi-level cases.
Tests
test_batch_prefill
andtest_batch_decode
can pass. Performance is not validated yet. Per discussion with Zihao, this change is not very likely to incur significant perf regression.