diff --git a/include/flashinfer/decode.cuh b/include/flashinfer/decode.cuh index 2179ec9972..fa233783e0 100644 --- a/include/flashinfer/decode.cuh +++ b/include/flashinfer/decode.cuh @@ -137,32 +137,49 @@ __device__ __forceinline__ void update_local_state(const T* smem, const float* s } } +enum class SyncRange { + // synchronize the state on the blockDim.z dimension + kSyncBdz = 0U, + // synchronize the state on the blockDim.y * blockDim.z dimension + kSyncBdyBdz = 1U, +}; + /*! * \brief Synchronize the state of all warps inside a threadblock. * \tparam vec_size A template integer indicates the vector size * \tparam bdx A template integer indicates the block size in x dimension * \tparam bdy A template integer indicates the block size in y dimension + * \tparam bdz A template integer indicates the block size in z dimension + * \tparam sync_range A template enum indicates the range of synchronization * \param st The warp local state * \param smem The pointer to shared memory buffer for o * \param smem_md The pointer to shared memory buffer for m/d */ -template +template __device__ __forceinline__ void sync_state(state_t& st, float* smem, float* smem_md) { - if constexpr (bdz > 1) { - constexpr uint32_t head_dim = bdx * vec_size; - auto block = cg::this_thread_block(); - uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; - st.o.store(smem + (tz * bdy + ty) * head_dim + tx * vec_size); - smem_md[(tz * bdy + ty) * 2] = st.m; - smem_md[(tz * bdy + ty) * 2 + 1] = st.d; - block.sync(); - st.init(); + constexpr uint32_t head_dim = bdx * vec_size; + auto block = cg::this_thread_block(); + uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; + st.o.store(smem + (tz * bdy + ty) * head_dim + tx * vec_size); + smem_md[(tz * bdy + ty) * 2] = st.m; + smem_md[(tz * bdy + ty) * 2 + 1] = st.d; + block.sync(); + st.init(); + if constexpr (sync_range == SyncRange::kSyncBdz) { #pragma unroll for (uint32_t j = 0; j < bdz; ++j) { - float mz = smem_md[(j * bdy + ty) * 2], dz = smem_md[(j * bdy + ty) * 2 + 1]; - vec_t oz; - oz.load(smem + (j * bdy + ty) * head_dim + tx * vec_size); - st.merge(oz, mz, dz); + float m = smem_md[(j * bdy + ty) * 2], d = smem_md[(j * bdy + ty) * 2 + 1]; + vec_t o; + o.load(smem + (j * bdy + ty) * head_dim + tx * vec_size); + st.merge(o, m, d); + } + } else if constexpr (sync_range == SyncRange::kSyncBdyBdz) { +#pragma unroll + for (uint32_t j = 0; j < bdy * bdz; ++j) { + float m = smem_md[j * 2], d = smem_md[j * 2 + 1]; + vec_t o; + o.load(smem + j * head_dim + tx * vec_size); + st.merge(o, m, d); } } } @@ -328,7 +345,8 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* block.sync(); // sync local state of all warps inside a threadblock - sync_state(st_local, reinterpret_cast(smem), smem_md); + sync_state(st_local, reinterpret_cast(smem), + smem_md); if constexpr (cooperative) { // update tmp buffer @@ -339,25 +357,30 @@ __global__ void SingleDecodeWithKVCacheKernel(DTypeIn* __restrict__ q, DTypeIn* grid.sync(); // sync global states - if (kv_chunk_idx == 0) { - state_t st_global; -#pragma unroll 2 - for (uint32_t iter = 0; iter < ceil_div(num_kv_chunks, bdz); ++iter) { - uint32_t kv_chunk_idx = iter * bdz + tz; - if (kv_chunk_idx < num_kv_chunks) { - float2 md = *(float2*)&tmp_md[(qo_head_idx * num_kv_chunks + kv_chunk_idx) * 2]; - st_local.m = md.x; - st_local.d = md.y; - st_local.o.load(tmp + (qo_head_idx * num_kv_chunks + kv_chunk_idx) * head_dim + - tx * vec_size); - st_global.merge(st_local); +#pragma unroll 1 + for (uint32_t i = 0; i < ceil_div(num_qo_heads, gridDim.x * gridDim.y); ++i) { + qo_head_idx = (i * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x; + if (qo_head_idx < num_qo_heads) { + state_t st_global; +#pragma unroll 1 + for (uint32_t iter = 0; iter < ceil_div(num_kv_chunks, bdy * bdz); ++iter) { + uint32_t kv_chunk_idx = (iter * bdz + tz) * bdy + ty; + if (kv_chunk_idx < num_kv_chunks) { + float2 md = *(float2*)&tmp_md[(qo_head_idx * num_kv_chunks + kv_chunk_idx) * 2]; + st_local.m = md.x; + st_local.d = md.y; + st_local.o.load(tmp + (qo_head_idx * num_kv_chunks + kv_chunk_idx) * head_dim + + tx * vec_size); + st_global.merge(st_local); + } } + block.sync(); + // sync local state of all warps inside a threadblock + sync_state( + st_global, reinterpret_cast(smem), smem_md); + st_global.normalize(); + st_global.o.cast_store(o + info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size)); } - block.sync(); - // sync local state of all warps inside a threadblock - sync_state(st_global, reinterpret_cast(smem), smem_md); - st_global.normalize(); - st_global.o.cast_store(o + info.get_qo_elem_offset(0, qo_head_idx, tx * vec_size)); } } else { st_local.normalize(); @@ -480,7 +503,8 @@ __global__ void BatchDecodeWithPaddedKVCacheKernel(DTypeIn* __restrict__ q, DTyp block.sync(); // sync local state of all warps inside a threadblock - sync_state(st_local, reinterpret_cast(smem), smem_md); + sync_state(st_local, reinterpret_cast(smem), + smem_md); st_local.normalize(); st_local.o.cast_store(o + batch_idx * num_qo_heads * head_dim + @@ -696,7 +720,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( block.sync(); // sync local state of all warps inside a threadblock - sync_state(st, reinterpret_cast(smem), smem_md); + sync_state(st, reinterpret_cast(smem), + smem_md); if constexpr (cooperative) { auto grid = cg::this_grid(); @@ -727,7 +752,8 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( } block.sync(); // sync local state of all warps inside a threadblock - sync_state(st_global, reinterpret_cast(smem), smem_md); + sync_state( + st_global, reinterpret_cast(smem), smem_md); st_global.normalize(); st_global.o.cast_store( o + (paged_kv.batch_idx_map()[batch_idx] * num_qo_heads + qo_head_idx) * head_dim + @@ -753,7 +779,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel( * \param sizeof_dtype The size (in terms of bytes) of the input data type */ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeof_dtype) { - if (group_size == 8U) { + if (group_size > 1U) { if (sizeof_dtype == 1U) { return 256U; // not enough registers for 512 threads } else { @@ -788,7 +814,7 @@ cudaError_t SingleDecodeWithKVCacheWorkEstimation(uint32_t& tmp_size, uint32_t& RotaryMode rotary_mode = RotaryMode::kNone, cudaStream_t stream = nullptr) { const uint32_t GROUP_SIZE = num_qo_heads / num_kv_heads; - if (seq_len <= 128U / uint32_t(std::sqrt(GROUP_SIZE))) { + if (seq_len < 128U) { tmp_size = 0; } else { SWITCH_GQA_GROUP_SIZE( @@ -889,7 +915,7 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut const uint32_t smem_size = 2U * num_stages_smem * bdy * tile_size_per_bdx * bdz * head_dim * sizeof(DTypeIn) + 2U * bdy * bdz * sizeof(float); - if (seq_len <= 256 || tmp == nullptr) { + if (seq_len < 128 || tmp == nullptr) { // no need to use cooperative kernel auto kernel = SingleDecodeWithKVCacheKernel