Skip to content

Commit

Permalink
Update seqlen of kvcache before splitkv
Browse files Browse the repository at this point in the history
  • Loading branch information
rocking5566 committed Aug 21, 2024
1 parent d6aac9e commit ae24800
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions csrc/flash_attn_ck/mha_fwd_kvcache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse,
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
c10::optional<const at::Tensor> &seqlens_k_,
const at::Tensor seqlens_k,
c10::optional<const at::Tensor> &cache_batch_idx_,
c10::optional<at::Tensor> &block_table_,
c10::optional<at::Tensor> &alibi_slopes_,
Expand Down Expand Up @@ -193,8 +193,7 @@ fmha_fwd_splitkv_args get_ck_fmha_fwd_splitkv_args(bool has_lse,

args.seqstart_q_ptr = nullptr;
args.seqstart_k_ptr = nullptr;
args.seqlen_k_ptr = seqlens_k_.has_value() ?
reinterpret_cast<int *>(seqlens_k_.value().data_ptr()) : nullptr;
args.seqlen_k_ptr = seqlens_k.data_ptr();

args.seqlen_q = seqlen_q;
args.seqlen_k = seqlen_k;
Expand Down Expand Up @@ -506,7 +505,14 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz
fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config);
}

// we use splitkv even num_splits == 1
// seqlens_k_ is the seqlen of kvcache. We need to add seqlen_knew for before attention
auto append_seqlens_k = torch::empty({batch_size}, opts.dtype(torch::kInt32));
if (seqlens_k_.has_value())
append_seqlens_k = seqlens_k_.value() + seqlen_knew;
else
append_seqlens_k.fill_(seqlen_knew);

// we use splitkv even num_splits == 1, because fmha_fwd() does not support seqlen_k_ in batch mode
auto splitkv_traits =
get_ck_fmha_fwd_splitkv_traits(mask, q_dtype_str, head_size_8x, has_lse, alibi_slopes_.has_value());

Expand Down

0 comments on commit ae24800

Please sign in to comment.