Skip to content

Commit

Permalink
feat: sliding window attention (#406)
Browse files Browse the repository at this point in the history
As requested in #390 , this PR implements sliding window attention.

This PR also result in slight performance degration because we didn't
specialize kernels that use/not use sliding window. I believe we can
address them by landing the JIT compilation feature. I'll merge this
feature first and improve performance in later PRs.
  • Loading branch information
yzh119 authored Jul 29, 2024
1 parent 74ffba1 commit 28cffd3
Show file tree
Hide file tree
Showing 20 changed files with 586 additions and 123 deletions.
45 changes: 28 additions & 17 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ template <LogitsPostHook logits_post_hook, PosEncodingMode pos_encoding_mode, ui
__device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage_idx,
const vec_t<float, vec_size>& q_vec,
const vec_t<float, vec_size>& freq, uint32_t kv_idx_base,
uint32_t iter_base, uint32_t iter_bound,
const int32_t q_offset, float alibi_slope, float* s,
state_t<vec_size>& st, const float logits_soft_cap) {
uint32_t iter_base, uint32_t left_close_bound,
uint32_t iter_bound, const int32_t q_offset,
float alibi_slope, float* s, state_t<vec_size>& st,
const float logits_soft_cap) {
uint32_t tx = threadIdx.x, tz = threadIdx.z;
float m_prev = st.m;
#pragma unroll
Expand All @@ -100,9 +101,10 @@ __device__ __forceinline__ void compute_qk(const T* smem, uint32_t compute_stage
s[j] += math::shfl_xor_sync(s[j], offset);
}
s[j] = apply_logits_post_hook<logits_post_hook>(s[j], logits_soft_cap);
s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -5e4;
const uint32_t pos = kv_idx_base + tz * tile_size + j;
s[j] = (iter_base + tz * tile_size + j < iter_bound && pos >= left_close_bound) ? s[j] : -5e4;
if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) {
s[j] += alibi_slope * float(int(kv_idx_base + tz * tile_size + j) - q_offset);
s[j] += alibi_slope * float(int(pos) - q_offset);
}
st.m = max(st.m, s[j]);
}
Expand Down Expand Up @@ -212,9 +214,9 @@ template <LogitsPostHook logits_post_hook, bool partition_kv, PosEncodingMode po
__global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* __restrict__ k,
DTypeKV* __restrict__ v, DTypeOut* __restrict__ o,
float* __restrict__ lse, tensor_info_t info,
float logits_soft_cap, float sm_scale,
float rope_rcp_scale, float rope_rcp_theta,
uint32_t kv_chunk_size) {
int32_t window_left, float logits_soft_cap,
float sm_scale, float rope_rcp_scale,
float rope_rcp_theta, uint32_t kv_chunk_size) {
auto block = cg::this_thread_block();
auto grid = cg::this_grid();
sm_scale *=
Expand All @@ -227,6 +229,8 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
uint32_t num_qo_heads = info.num_qo_heads;
const float alibi_slope = get_alibi_slope(qo_head_idx, num_qo_heads) * math::log2e;
uint32_t seq_len = info.kv_len;
uint32_t left_close_bound =
(window_left >= 0) ? sub_if_greater_or_zero(seq_len, window_left + 1) : 0;

extern __shared__ uint8_t smem[];
DTypeKV* k_smem = (DTypeKV*)smem;
Expand Down Expand Up @@ -303,8 +307,8 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeQ* __restrict__ q, DTypeKV* _
block.sync();
compute_qk<logits_post_hook, pos_encoding_mode, vec_size, bdx, bdy * tile_size_per_bdx>(
k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, stage_idx, q_vec,
freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size,
seq_len - 1, alibi_slope, s, st_local, logits_soft_cap);
freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, left_close_bound,
kv_chunk_size, seq_len - 1, alibi_slope, s, st_local, logits_soft_cap);
block.sync();
// load k
for (uint32_t j = 0; j < tile_size_per_bdx; ++j) {
Expand Down Expand Up @@ -389,8 +393,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
float* __restrict__ lse, bool* __restrict__ block_valid_mask, float logits_soft_cap,
float sm_scale, float rope_rcp_scale, float rope_rcp_theta) {
float* __restrict__ lse, bool* __restrict__ block_valid_mask, int32_t window_left,
float logits_soft_cap, float sm_scale, float rope_rcp_scale, float rope_rcp_theta) {
auto block = cg::this_thread_block();
sm_scale *=
(logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap));
Expand All @@ -415,6 +419,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
: 0;
const uint32_t seq_len =
partition_kv ? kv_partition_info.seq_lens_before_partition[batch_idx] : kv_chunk_len;
const uint32_t left_close_bound =
(window_left >= 0) ? sub_if_greater_or_zero(seq_len, window_left + 1) : 0;
const uint32_t mapped_batch_idx =
partition_kv ? kv_partition_info.batch_idx_map[batch_idx] : batch_idx;

Expand Down Expand Up @@ -521,8 +527,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
freq,
(paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) +
cur_chunk_start + iter * tile_size_per_bdx * bdy * bdz,
iter * tile_size_per_bdx * bdy * bdz, kv_chunk_len, q_offset_val, alibi_slope, s, st,
logits_soft_cap);
iter * tile_size_per_bdx * bdy * bdz, left_close_bound, kv_chunk_len, q_offset_val,
alibi_slope, s, st, logits_soft_cap);
block.sync();

#pragma unroll
Expand Down Expand Up @@ -627,8 +633,9 @@ template <uint32_t HEAD_DIM, LogitsPostHook LOGITS_POST_HOOK, PosEncodingMode PO
cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeOut* o,
DTypeOut* tmp, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t seq_len,
QKVLayout kv_layout, float logits_soft_cap,
float sm_scale, float rope_scale, float rope_theta,
QKVLayout kv_layout, int32_t window_left,
float logits_soft_cap, float sm_scale,
float rope_scale, float rope_theta,
cudaStream_t stream) {
const float rope_rcp_scale = 1.f / rope_scale;
const float rope_rcp_theta = 1.f / rope_theta;
Expand Down Expand Up @@ -664,6 +671,7 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
(void*)&o,
(void*)&lse,
(void*)&info,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
Expand Down Expand Up @@ -704,6 +712,7 @@ cudaError_t SingleDecodeWithKVCacheDispatched(DTypeQ* q, DTypeKV* k, DTypeKV* v,
(void*)&tmp,
(void*)&tmp_lse,
(void*)&info,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
Expand All @@ -724,7 +733,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
DTypeQ* q, IdType* q_offset, paged_kv_t<page_storage, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp_v, float* tmp_s,
float* lse, bool* block_valid_mask, uint32_t padded_batch_size, uint32_t num_qo_heads,
float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
int32_t window_left, float logits_soft_cap, float sm_scale, float rope_scale, float rope_theta,
cudaStream_t stream) {
const float rope_rcp_scale = 1.f / rope_scale;
const float rope_rcp_theta = 1.f / rope_theta;
Expand Down Expand Up @@ -761,6 +770,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
(void*)&o,
(void*)&lse,
(void*)&block_valid_mask,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
Expand All @@ -782,6 +792,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
(void*)&tmp_v,
(void*)&tmp_s,
(void*)&block_valid_mask,
(void*)&window_left,
(void*)&logits_soft_cap,
(void*)&sm_scale,
(void*)&rope_rcp_scale,
Expand Down
4 changes: 2 additions & 2 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
DTypeQ* __restrict__ q, IdType* __restrict__ q_offset,
paged_kv_t<page_storage, DTypeKV, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* __restrict__ o,
float* __restrict__ lse, bool* __restrict__ block_valid_mask, float logits_soft_cap,
float sm_scale, float rope_rcp_scale, float rope_rcp_theta);
float* __restrict__ lse, bool* __restrict__ block_valid_mask, int maybe_window_left,
float logits_soft_cap, float sm_scale, float rope_rcp_scale, float rope_rcp_theta);

/*!
* \brief Compute the maximum number of pages per batch and the new batch size
Expand Down
Loading

0 comments on commit 28cffd3

Please sign in to comment.