Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 65 additions & 39 deletions include/flashinfer/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -137,32 +137,49 @@ __device__ __forceinline__ void update_local_state(const T* smem, const float* s
}
}

enum class SyncRange {
// synchronize the state on the blockDim.z dimension
kSyncBdz = 0U,
// synchronize the state on the blockDim.y * blockDim.z dimension
kSyncBdyBdz = 1U,
};

/*!
* \brief Synchronize the state of all warps inside a threadblock.
* \tparam vec_size A template integer indicates the vector size
* \tparam bdx A template integer indicates the block size in x dimension
* \tparam bdy A template integer indicates the block size in y dimension
* \tparam bdz A template integer indicates the block size in z dimension
* \tparam sync_range A template enum indicates the range of synchronization
* \param st The warp local state
* \param smem The pointer to shared memory buffer for o
* \param smem_md The pointer to shared memory buffer for m/d
*/
template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz>
template <uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz, SyncRange sync_range>
__device__ __forceinline__ void sync_state(state_t<vec_size>& st, float* smem, float* smem_md) {
if constexpr (bdz > 1) {
constexpr uint32_t head_dim = bdx * vec_size;
auto block = cg::this_thread_block();
uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z;
st.o.store(smem + (tz * bdy + ty) * head_dim + tx * vec_size);
smem_md[(tz * bdy + ty) * 2] = st.m;
smem_md[(tz * bdy + ty) * 2 + 1] = st.d;
block.sync();
st.init();
constexpr uint32_t head_dim = bdx * vec_size;
auto block = cg::this_thread_block();
uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z;
st.o.store(smem + (tz * bdy + ty) * head_dim + tx * vec_size);
smem_md[(tz * bdy + ty) * 2] = st.m;
smem_md[(tz * bdy + ty) * 2 + 1] = st.d;
block.sync();
st.init();
if constexpr (sync_range == SyncRange::kSyncBdz) {
#pragma unroll
for (uint32_t j = 0; j < bdz; ++j) {
float mz = smem_md[(j * bdy + ty) * 2], dz = smem_md[(j * bdy + ty) * 2 + 1];
vec_t<float, vec_size> oz;
oz.load(smem + (j * bdy + ty) * head_dim + tx * vec_size);
st.merge(oz, mz, dz);
float m = smem_md[(j * bdy + ty) * 2], d = smem_md[(j * bdy + ty) * 2 + 1];
vec_t<float, vec_size> o;
o.load(smem + (j * bdy + ty) * head_dim + tx * vec_size);
st.merge(o, m, d);
}
} else if constexpr (sync_range == SyncRange::kSyncBdyBdz) {
#pragma unroll
for (uint32_t j = 0; j < bdy * bdz; ++j) {
float m = smem_md[j * 2], d = smem_md[j * 2 + 1];
vec_t<float, vec_size> o;
o.load(smem + j * head_dim + tx * vec_size);
st.merge(o, m, d);
}
}
}
Expand Down Expand Up @@ -328,7 +345,8 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*
block.sync();

// sync local state of all warps inside a threadblock
sync_state<vec_size, bdx, bdy, bdz>(st_local, reinterpret_cast<float*>(smem), smem_md);
sync_state<vec_size, bdx, bdy, bdz, SyncRange::kSyncBdz>(st_local, reinterpret_cast<float*>(smem),
smem_md);

if constexpr (cooperative) {
// update tmp buffer
Expand All @@ -339,25 +357,30 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn*
grid.sync();

// sync global states
if (kv_chunk_idx == 0) {
state_t<vec_size> st_global;
#pragma unroll 2
for (uint32_t iter = 0; iter < ceil_div(num_kv_chunks, bdz); ++iter) {
uint32_t kv_chunk_idx = iter * bdz + tz;
if (kv_chunk_idx < num_kv_chunks) {
float2 md = *(float2*)&tmp_md[(qo_head_idx * num_kv_chunks + kv_chunk_idx) * 2];
st_local.m = md.x;
st_local.d = md.y;
st_local.o.load(tmp + (qo_head_idx * num_kv_chunks + kv_chunk_idx) * head_dim +
tx * vec_size);
st_global.merge(st_local);
#pragma unroll 1
for (uint32_t i = 0; i < ceil_div(num_qo_heads, gridDim.x * gridDim.y); ++i) {
qo_head_idx = (i * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x;
if (qo_head_idx < num_qo_heads) {
state_t<vec_size> st_global;
#pragma unroll 1
for (uint32_t iter = 0; iter < ceil_div(num_kv_chunks, bdy * bdz); ++iter) {
uint32_t kv_chunk_idx = (iter * bdz + tz) * bdy + ty;
if (kv_chunk_idx < num_kv_chunks) {
float2 md = *(float2*)&tmp_md[(qo_head_idx * num_kv_chunks + kv_chunk_idx) * 2];
st_local.m = md.x;
st_local.d = md.y;
st_local.o.load(tmp + (qo_head_idx * num_kv_chunks + kv_chunk_idx) * head_dim +
tx * vec_size);
st_global.merge(st_local);
}
}
block.sync();
// sync local state of all warps inside a threadblock
sync_state<vec_size, bdx, bdy, bdz, SyncRange::kSyncBdyBdz>(
st_global, reinterpret_cast<float*>(smem), smem_md);
st_global.normalize();
st_global.o.cast_store(o + info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size));
}
block.sync();
// sync local state of all warps inside a threadblock
sync_state<vec_size, bdx, bdy, bdz>(st_global, reinterpret_cast<float*>(smem), smem_md);
st_global.normalize();
st_global.o.cast_store(o + info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size));
}
} else {
st_local.normalize();
Expand Down Expand Up @@ -480,7 +503,8 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(DTypeIn* __restrict__ q, DTyp
block.sync();

// sync local state of all warps inside a threadblock
sync_state<vec_size, bdx, bdy, bdz>(st_local, reinterpret_cast<float*>(smem), smem_md);
sync_state<vec_size, bdx, bdy, bdz, SyncRange::kSyncBdz>(st_local, reinterpret_cast<float*>(smem),
smem_md);

st_local.normalize();
st_local.o.cast_store(o + batch_idx * num_qo_heads * head_dim +
Expand Down Expand Up @@ -696,7 +720,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
block.sync();

// sync local state of all warps inside a threadblock
sync_state<vec_size, bdx, bdy, bdz>(st, reinterpret_cast<float*>(smem), smem_md);
sync_state<vec_size, bdx, bdy, bdz, SyncRange::kSyncBdz>(st, reinterpret_cast<float*>(smem),
smem_md);

if constexpr (cooperative) {
auto grid = cg::this_grid();
Expand Down Expand Up @@ -727,7 +752,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
}
block.sync();
// sync local state of all warps inside a threadblock
sync_state<vec_size, bdx, bdy, bdz>(st_global, reinterpret_cast<float*>(smem), smem_md);
sync_state<vec_size, bdx, bdy, bdz, SyncRange::kSyncBdz>(
st_global, reinterpret_cast<float*>(smem), smem_md);
st_global.normalize();
st_global.o.cast_store(
o + (paged_kv.batch_idx_map()[batch_idx] * num_qo_heads + qo_head_idx) * head_dim +
Expand All @@ -753,7 +779,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(
* \param sizeof_dtype The size (in terms of bytes) of the input data type
*/
constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeof_dtype) {
if (group_size == 8U) {
if (group_size > 1U) {
if (sizeof_dtype == 1U) {
return 256U; // not enough registers for 512 threads
} else {
Expand Down Expand Up @@ -788,7 +814,7 @@ cudaError_t SingleDecodeWithKVCacheWorkEstimation(uint32_t& tmp_size, uint32_t&
RotaryMode rotary_mode = RotaryMode::kNone,
cudaStream_t stream = nullptr) {
const uint32_t GROUP_SIZE = num_qo_heads / num_kv_heads;
if (seq_len <= 128U / uint32_t(std::sqrt(GROUP_SIZE))) {
if (seq_len < 128U) {
tmp_size = 0;
} else {
SWITCH_GQA_GROUP_SIZE(
Expand Down Expand Up @@ -889,7 +915,7 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut
const uint32_t smem_size = 2U * num_stages_smem * bdy * tile_size_per_bdx * bdz *
head_dim * sizeof(DTypeIn) +
2U * bdy * bdz * sizeof(float);
if (seq_len <= 256 || tmp == nullptr) {
if (seq_len < 128 || tmp == nullptr) {
// no need to use cooperative kernel
auto kernel =
SingleDecodeWithKVCacheKernel<QKV_LAYOUT, /*cooperative=*/false, ROTARY_MODE,
Expand Down Expand Up @@ -931,7 +957,7 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut
&num_blocks_per_sm, kernel, num_threads, smem_size));
uint32_t max_grid_size = uint32_t(num_blocks_per_sm) * uint32_t(num_sm);
uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads;
uint32_t kv_chunk_size = max(ceil_div(seq_len, max_num_kv_chunks), 256);
uint32_t kv_chunk_size = max(ceil_div(seq_len, max_num_kv_chunks), 16);
dim3 nblks = dim3(ceil_div(seq_len, kv_chunk_size), num_kv_heads);
if (nblks.x == 0 || nblks.y == 0) {
std::cerr << "Invalid kernel configuration: nblks=(" << nblks.x << ","
Expand Down