-
Couldn't load subscription status.
- Fork 100
Support CP with query length larger than 1 #93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
b8792e9 to
54be252
Compare
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
hopper/flash_api.cpp
Outdated
| bool const packgqa_override = params.arch >= 90 && (params.h / params.h_k) == 8 && | ||
| params.is_local && | ||
| bool const packgqa_override = params.arch >= 90 && (params.h / params.h_k) == 8 && | ||
| params.is_local && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mind removing the unrelated formatting changes? trying to stay as close to upstream as possible when possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
| : std::max(n_block_min, | ||
| cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN)); | ||
| cute::ceil_div(m_idx_max + | ||
| params.cp_world_size * seqlen_k - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use cp_tot_seqlen_k to skip the mul here? should branch in the non-cp case to save the mul?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could make cp_tot_seqlen_k == seqlen_k in the params.cp_world_size == 1 case
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
hopper/seqlen.h
Outdated
| , cp_world_size(cp_world_size) | ||
| , cp_tot_seqlen_k(cp_tot_seqused_k == nullptr | ||
| ? 0 | ||
| : cp_tot_seqused_k[bidb]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work! left a few comments but its looking really good!
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR implements the causal mask for interleave context parallelism to allow query length > 1.
The solution follows the discussion between @LucasWilkinson , @youkaichao , and @youzhedian on slack.
key illustration made by @LucasWilkinson :
In the DCP case, the k/v tokens are distributed in an interleaved fashion, see vllm-project/vllm#23734.

Therefore we have 0,2,4 kv on rank0 and 1,3,5 kv on rank1 in the example above. The mask shape is no longer a bottom right triangle.
This requires FA to be aware of cp world size and cp rank, in order to determine the causal mask.
The block tiling implementation also needs to be updated. As illustrated below, we now needs to process block tile (0,1) in CP case, while it can be skipped previously in normal case.
Tests
Added and passed unit tests for CP.