-
Notifications
You must be signed in to change notification settings - Fork 137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CK_TILE] Add PagedAttention kernels #1387
Conversation
const auto k_dram = [&]() { | ||
if constexpr(kIsPagedKV) | ||
{ | ||
return make_k_dram(nullptr, kargs.page_block_size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, Does this mean create a global memory tensor/tile on pointer with shape (page_block_size
, hdim_q
), stride (stride_k
, 1
)? And what doese the will update this pointer if using paged-kvcache
mean? Does it mean this memory tensor will move to next k/v cache page?
Thanks for replying!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we use tensor view here to describe the global memory (shape & stride). When moving to next k/v cache page, we will have to update underlying data pointer to point to the correct page. This allows us to reuse same tile window object.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for you kind explanation. I am working on Attention on AMD gpus and occasionally find this PR, trying to learn the ck_tile.
So from my point of view, creating k_ptrs
and v_ptrs
could be unnecessary when the paged kv is enabled, beacause the navigator will update the pointer itself for a paged window?
const KDataType* k_ptr =
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
batch_offset_k;
const VDataType* v_ptr =
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
batch_offset_v;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, both of the k_ptr
& v_tpr
are only for non-paged implementations. Offset computations for paged k/v cache is in the navigator and the navigator creating lambdas.
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for | ||
// load | ||
|
||
auto k_block_tile = load_tile(k_dram_window); | ||
{ | ||
// moving k_dram_window is an in-page-block operation, so there is | ||
// no need to invoke k_page_block_navigator.move_tile_window() here. | ||
move_tile_window(k_dram_window, {0, kK0}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this k_dram_window
only move within the current page block?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, k_dram_window
only move within 1 page block, The first value of step {0, kK0}
is zero.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
28f3d73
to
9623c0c
Compare
|
||
auto v_dram_window = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From this previous commit, I understand v_dram_window
's origin is {i_n1, 0}
, I was wonderig why didn't this be the same as k_dram_widow
above which has origin {0,0}
and why it starts at {i_n1,0}
?
And why does this was deleted?
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I deleted them because we will create new tile windows for k & v in pipeline using new origins: {adjusted_seqlen_k_start, 0}
& {0, adjusted_seqlen_k_start}
respectively. So the origins written here are confusing. Btw, origin for v should be {i_n1, adjusted_seqlen_k_start}
actually. In fmha_fwd_splitkv.py, kN1
is always same as hdim_v
, hence i_n1
is always 0, so we hardcode it for now.
From output tensor point of view, we divide (seqlen_q, hdim_v) memory into small tiles (kM0, kN1). Each workgroup/tile use (i_m0
, i_n1
) to locate their position (see fmha_fwd_splitkv_tile_partitioner.hpp. So logically the first value of v origin should be i_n1
.
All the test has pass (MI200 + MI300 @ ROCm6.1) in flash attention |
@@ -616,10 +597,11 @@ struct FmhaFwdSplitKVKernel | |||
sequence<kPadSeqLenQ, kPadHeadDimQ>{}); | |||
} | |||
}(); | |||
const auto k_dram = [&]() { | |||
|
|||
const auto make_k_dram = [&](const KDataType* data, index_t height) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems this pr doesn't support group mode. So if i want to customize this for group mode, we could add a Q dram navigator to implement it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the only difference between batch & group mode is final kargs.seqlen_q
value & batch_offset_q
computation logics. In order to support group mode, you will have to setup kargs.seqstart_q_ptr
array and check its values in *_kernel.hpp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, thanks. I will have a try.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -574,6 +637,9 @@ def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]: | |||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't': | |||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not | |||
continue | |||
if pipeline.F_pagedkv == 't': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, Why should we disbale the code generation for grouped paged kv if group and batch mode's only difference is offset of q
?
* Use dictionary to config all the functions * Add init codegen logic for fmha fwd appendkv * Call HIP_CHECK_ERROR() macro to get real source info * Setup meaningfull arguments * Sync kernel name with the codegen * Add knew/vnew tensors to the kernel argument * Fix wrong K values after appending * Fix vnew append errro * Extract common logics * Fix Vnew tile dstr for row major case * Conditionally add fwd_splitkv API in fmha_fwd example * Conditionally add call to fmha_fwd_splitkv() * Remove "EXAMPLE_" prefix of cmake variables * Regsiter API handlers automatically * Early return if 0 < s_k_new is not supported * Show message if we are ignoring option * Unify CMakeLists.txt coding style * Set num_splits=1 if split-kv is not supported * Add length/stride getters for HostTensor * Add RoPE example utilities * Add reference_rotary_position_embedding() (not implemented) * Finish reference_rotary_position_embedding() impl * Fix typo of HostTensor<>::get_length() * Fix compilation errors * Fix wrong answer when interleaved=false * Fix wrong answer when interleaved=true * Append K/V in the host verification code * Simplify K appending logics * Simplify v_host_ref definition * Reduce input/output dimensions * Rename function: add "batched" prefix * Apply RoPE on host side * Rename RoPE utility function * Fix wrong tensor size * Avoid invoking deprecated method 'find_module' * Pass RoPE kernel args * Create Rotary Cos/Sin tile windows in kernel * Add compute data type alias for RoPE * Randomly generate seqlen_knew if needed * Fix seqlen_knew enabling check logic * Add minimum seqlen_k to generate compliance kvcache * Fix compilation error in debug mode * Fix wrong boundaries * Fix wrong seqlen_k for kvcache * Rename variables used in distributio encoding * Fix rotary cos/sin tensor/tile size * Add constraint to the rotary_dim option * Remove unused inner namespace * Add dram distribution for rotary_cos/rotary_sin (interleaved) * Only apply interleaved RoPE on Knew for now * Fix wrong thread starting offset * Instantiate multiple kernels for RoPE approaches * Clean-up pipeline * Fix error in RoPE host reference * Handle RoPE half-rotated logics * Support 8x rotary_dim under half-rotated RoPE * Add comment * Apply elementwise function to the loaded tiles * Unify parameter/variable naming style * Remove constness from q_ptr * Add code blocks for q_tile * Apply RoPE to q_tile * Remove debug print code in kernel * Fix wrong knew/vnew appending positions * Use better naming for tile indices * Add make_tile_window() for adding distribution only * Skip code if # of block is more than needed * Move thread locating logics into policy * Remove always true static_assert() * Rename header * Rename RotaryEmbeddingEnum * Extract rotary embedding logic out * Re-order parameters * Align naming of some tile size constants * Rename more tile size constants * Fix wrong grid size * Fix wrong shape of knew_host/vnew_host * Fix wrong index into knew_host/vnew_host * Fix wrong rotary_cos/rotary_sin memory size for Q * Extract Q/Knew vector size to helper methods * Use different rotary_cos/rotary_sin distr for Q/Knew * Update host/device specifiers * Fix wrong data type for Q rotary_cos/rotary_sin * Remove RoPEComputeDataType type alias * Shift rotary_cos/rotary_sin by cache_seqlen_k * Add comment for why I just 't' for all padding flags * Align commit message to the real comment * Fix wrong pipeline * Rename utility function * Disable host verification if API not exist * Fix wrong rope key for fp8 pipeline * Allow only apply RoPE on Q (without append KV) * Add append-kv smoke tests * Remove debug statements * Remove more debug statements * Re-arrange the 'set +x' command * Remove no-longer used method in pipeline * Add missing init code * Refine pipeline padding settings * Enlarge rotary_dim limit (8 -> 16) * Enlarge KPerThread for rotary_interleaved=false * Update rotary_dim range in smoke_test_fwd.sh * Add template argument 'kIsPagedKV' for splitkv kernels * Launch splitkv kernel if given page_block_size * Fix wrong kernel name * Fix seqlen_k_min for pre-fill case (1 -> 0) * Add copy_const<> type trait * Add another make_tile_window() * Introduce 'TileWindowNavigator' types * Simplify TileWindowNavigator interfaces * Fix tile window navigation bugs * Disable calling fmha_fwd() * Remove ununnecessary data members * Simplify more make_tile_window() overloads * Move V tile through TileWindowNavigator * Fix uneven split checking logic * Move code after decide seqlen_q/seqlen_k * Make sure we always start reading complete tile * Use 128 as minimus page_block_size * Fix wrong origin for bias * Add batch_stride_k/batch_stride_v in group mode * Unify origin * Add missing kernel arguments for group mode * Add paged-kv codegen logic for appendkv kernels * Add block_table kernel args for appendkv kernel * Add tile navigators to the appendkv kernel * Fix wrong tensor descriptor lengths * Pass re-created tile window to pipeline * Fix wrong strides for appendkv kernel * Allow transit tile_window to another page-block * Handle cross-page-block write * Donot perform write again if already in last page-block * Always add fmha_fwd() api * Add missing group mode argument * Remove debug macro usages * Rename option s_k_new to s_knew * Separate splitkv/non-splitkv args/traits * Remove fmha_fwd_dispatch() * Fix compilation errors * Remove dropout code in splitkv kernel * Allow problem types without define kHasDropout attr * Use generic lambda to init traits objects * Separate more non-splitkv & splitkv traits/args * Display more info for specific kernels * Show more detailed warning message * Rename 'max_num_blocks' to 'max_num_page_blocks' * Remove no-longer used pipeline files * Wrap code by #if directives * Move functors to the begining of validation code * Use generic lambda to init all the api traits/args * Fix wrong seqlen for kvcache * Add missing comment * Rename TileWindowNavigator to PageBlockNavigator * Only expose necessary methods (not attributes) * Re-order pipeline paremeters * Refine smoke_test_fwd.sh * Fix wrong arugment count * Make tile window directly via PageBlockNavigator * Remove unused template paremeter * Remove group mode from appendkv kernel * Fix skcheck logic * Fix wrong syntax in skcheck expr * Use meaningful options in smoke test * Remove options * Fix formatting * Fix more format * Re-organize bash functions * Pass cache_batch_idx to kernels * Support cache_batch_idx in example * Fix compilation error * Add more appendkv test * Add more case for appendkv * Fix unexisted attribute * Remove 0 < seqlen_knew constraint * Clarify the case in warning message * Remove macro checking * Force batch mode when invoking appendkv & splitkv apis * Fix mode overriding logics * Fix wrong parameter name * Randomize seqlen_k if use kvcache * Use randomized seqlen_k for kvcache * Avoid using too small rotary_cos & rotary_sin * Rename parameter * Add seqlen_q & seqlen_k rules * Add comment * Add more comments * Fix compilation errors * Fix typo in comment * Remove type argument * Avoid seqlen_k=0 for kvcache * Revert "Avoid seqlen_k=0 for kvcache" This reverts commit 21c4df8. * Fix wrong uneven split checking logics * Only randomize kvcache seqlen_k if 1 < batch * Return earlier if split is empty * Revert "Only randomize kvcache seqlen_k if 1 < batch" This reverts commit b9a4ab0. * Re-order seqlen_k_start adjustment logics * Fix compilation errors * Re-format script * Find executable from folder automatically * Fix kvcache seqlen_k generating logic * Make comment more clear * Fix wrong knew/vew appending logic on host * Add s_barrier to sync threads * Revert "Add s_barrier to sync threads" This reverts commit d3f550f. * Support only using 1 row of rotary_cos/rotary_sin * Rotate Q in different way * Unify tensor view creation logics * Fix wrong argument * Add mask to switch how we use the rotary_cos/sin * Move attr from traits to problem * Move has_mask to fmha_fwd_appendkv_args * Support use uint32_t as SAD operand in Alibi<> * Use sad_u32() in splitkv kernels * Store tensor views in PageBlockNavigator * Use stored tensor view to update tile windows * Enlarge tensor view size * Remove debug code * Fix wrong tensor view size * Wrap tensor view into PageBlockNavigator * Add DataType member to PageBlockNavigator * Remove unnecessary member functions * Refind macro use * Fix typo * Add blank line between directives and actual code * Re-format files * Remove type in comment --------- Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: rocking <ChunYu.Lai@amd.com>
fmha_fwd_appendkv()
API which runs ahead thefmha_fwd()
/fmha_fwd_splitkv()
API.fmha_fwd_appendkv()
andfmha_fwd_splitkv()
APIsThe
fmha_fwd_appendkv()
+fmha_fwd_splitkv()
combination implement the functionality ofmha_fwd_kvcache()
in FA 2.5