From f236f70a7beead58276aab120b7f67aca54ae2b6 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Wed, 20 Nov 2024 03:09:27 -0800 Subject: [PATCH] refactor: rename num_frags to num_mma (#621) fragment should refer to the element hold by each thread in tensor core layout. mma is a more accurate term for a mma block. --- include/flashinfer/attention/prefill.cuh | 1024 +++++++++++----------- include/flashinfer/utils.cuh | 56 +- 2 files changed, 542 insertions(+), 538 deletions(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 4d8c62ad..8464f053 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -55,7 +55,7 @@ constexpr uint32_t get_num_warps_kv(const uint32_t cta_tile_kv) { return 4 / get_num_warps_q(cta_tile_kv); } -constexpr uint32_t get_num_frags_q(const uint32_t cta_tile_q) { +constexpr uint32_t get_num_mma_q(const uint32_t cta_tile_q) { if (cta_tile_q > 64) { return 2; } else { @@ -66,13 +66,12 @@ constexpr uint32_t get_num_frags_q(const uint32_t cta_tile_q) { namespace { template -constexpr bool is_invalid_configuration(uint32_t NUM_FRAGS_Q, uint32_t NUM_FRAGS_D, - uint32_t NUM_FRAGS_KV, uint32_t NUM_WARPS_Q, - uint32_t NUM_WARPS_KV) { - return ((NUM_FRAGS_D < 4) || (NUM_FRAGS_D == 4 && NUM_FRAGS_KV % 2 == 1) || - (NUM_FRAGS_D > 4 && NUM_FRAGS_D % (2 * NUM_WARPS_Q) != 0) || - (NUM_FRAGS_Q * (8 * NUM_FRAGS_D + 2 * sizeof(DTypeQKAccum) * NUM_FRAGS_KV) >= 256) || - (sizeof(DTypeKV) == 1 && NUM_FRAGS_KV * 2 % NUM_WARPS_Q != 0) || +constexpr bool is_invalid_configuration(uint32_t NUM_MMA_Q, uint32_t NUM_MMA_D, uint32_t NUM_MMA_KV, + uint32_t NUM_WARPS_Q, uint32_t NUM_WARPS_KV) { + return ((NUM_MMA_D < 4) || (NUM_MMA_D == 4 && NUM_MMA_KV % 2 == 1) || + (NUM_MMA_D > 4 && NUM_MMA_D % (2 * NUM_WARPS_Q) != 0) || + (NUM_MMA_Q * (8 * NUM_MMA_D + 2 * sizeof(DTypeQKAccum) * NUM_MMA_KV) >= 256) || + (sizeof(DTypeKV) == 1 && NUM_MMA_KV * 2 % NUM_WARPS_Q != 0) || (sizeof(DTypeKV) == 1 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama)); } @@ -174,8 +173,8 @@ __device__ __forceinline__ void q_frag_apply_llama_rope_with_pos(T* x_first_half /*! * \brief Produce k/v fragments from global memory to shared memory. * \tparam fill_mode The fill mode of the shared memory. - * \tparam NUM_FRAGS_D The number of fragments in y dimension. - * \tparam NUM_FRAGS_KV The number of fragments in z dimension. + * \tparam NUM_MMA_D The number of fragments in y dimension. + * \tparam NUM_MMA_KV The number of fragments in z dimension. * \tparam num_warps The number of warps in the threadblock. * \tparam T The data type of the input tensor. * \param smem The shared memory to store kv fragments. @@ -184,24 +183,24 @@ __device__ __forceinline__ void q_frag_apply_llama_rope_with_pos(T* x_first_half * \param kv_len The length of kv tensor. */ template + uint32_t NUM_MMA_D, uint32_t NUM_MMA_KV, SwizzleMode swizzle_mode, typename T> __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* smem_offset, T** gptr, const uint32_t kv_stride_n, const uint32_t kv_idx_base, const uint32_t kv_len) { // NOTE(Zihao): for fp8, this function doesn't work for head_dim = 64 at the moment - constexpr uint32_t head_dim = NUM_FRAGS_D * 16; + constexpr uint32_t head_dim = NUM_MMA_D * 16; constexpr uint32_t num_warps = NUM_WARPS_Q * NUM_WARPS_KV; constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); const uint32_t warp_idx = get_warp_idx(), lane_idx = threadIdx.x; if constexpr (swizzle_mode == SwizzleMode::k128B) { uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; - // NOTE(Zihao): NUM_FRAGS_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_FRAGS_KV * 4 / num_warps - static_assert(NUM_FRAGS_KV * 4 % NUM_WARPS_Q == 0); + // NOTE(Zihao): NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); #pragma unroll - for (uint32_t i = 0; i < NUM_FRAGS_KV * 4 / NUM_WARPS_Q; ++i) { + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { #pragma unroll - for (uint32_t j = 0; j < NUM_FRAGS_D / (8 / sizeof(T)); ++j) { + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(T)); ++j) { smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); *gptr += 8 * num_elems_per_128b(); @@ -209,28 +208,28 @@ __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* kv_idx += num_warps * 4; *smem_offset = smem.template advance_offset_by_row(*smem_offset) - - sizeof(T) * NUM_FRAGS_D; - *gptr += num_warps * 4 * kv_stride_n - sizeof(T) * NUM_FRAGS_D * num_elems_per_128b(); + sizeof(T) * NUM_MMA_D; + *gptr += num_warps * 4 * kv_stride_n - sizeof(T) * NUM_MMA_D * num_elems_per_128b(); } - *smem_offset -= NUM_WARPS_KV * NUM_FRAGS_KV * 16 * channel_size_128b_kv; + *smem_offset -= NUM_WARPS_KV * NUM_MMA_KV * 16 * channel_size_128b_kv; } else { uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; - // NOTE(Zihao): NUM_FRAGS_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_FRAGS_KV * 2 / num_warps - static_assert(NUM_FRAGS_KV * 2 % NUM_WARPS_Q == 0); + // NOTE(Zihao): NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / num_warps + static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); #pragma unroll - for (uint32_t i = 0; i < NUM_FRAGS_KV * 2 / NUM_WARPS_Q; ++i) { + for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_row(*smem_offset); kv_idx += num_warps * 8; *gptr += num_warps * 8 * kv_stride_n; } - *smem_offset -= NUM_WARPS_KV * NUM_FRAGS_KV * 16 * channel_size_128b_kv; + *smem_offset -= NUM_WARPS_KV * NUM_MMA_KV * 16 * channel_size_128b_kv; } } -template +template __device__ __forceinline__ void page_produce_kv(smem_t smem, uint32_t* smem_offset, const paged_kv_t& paged_kv, const uint32_t kv_idx_base, const size_t* kv_offset, @@ -238,19 +237,19 @@ __device__ __forceinline__ void page_produce_kv(smem_t smem, uint3 // NOTE(Zihao): for fp8, this function doesn't work for head_dim = 64 at the moment constexpr SharedMemFillMode fill_mode = produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill; - constexpr uint32_t head_dim = NUM_FRAGS_D * 16; + constexpr uint32_t head_dim = NUM_MMA_D * 16; constexpr uint32_t num_warps = NUM_WARPS_Q * NUM_WARPS_KV; constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); const uint32_t warp_idx = get_warp_idx(), lane_idx = threadIdx.x; if constexpr (swizzle_mode == SwizzleMode::k128B) { uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; - // NOTE(Zihao): NUM_FRAGS_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_FRAGS_KV * 4 / num_warps - static_assert(NUM_FRAGS_KV * 4 % NUM_WARPS_Q == 0); + // NOTE(Zihao): NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); #pragma unroll - for (uint32_t i = 0; i < NUM_FRAGS_KV * 4 / NUM_WARPS_Q; ++i) { + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { DType* gptr = produce_v ? paged_kv.v_data + kv_offset[i] : paged_kv.k_data + kv_offset[i]; #pragma unroll - for (uint32_t j = 0; j < NUM_FRAGS_D / (8 / sizeof(DType)); ++j) { + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DType)); ++j) { smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); gptr += 8 * num_elems_per_128b(); @@ -258,73 +257,72 @@ __device__ __forceinline__ void page_produce_kv(smem_t smem, uint3 kv_idx += num_warps * 4; *smem_offset = smem.template advance_offset_by_row(*smem_offset) - - sizeof(DType) * NUM_FRAGS_D; + sizeof(DType) * NUM_MMA_D; } - *smem_offset -= NUM_WARPS_KV * NUM_FRAGS_KV * 16 * channel_size_128b_kv; + *smem_offset -= NUM_WARPS_KV * NUM_MMA_KV * 16 * channel_size_128b_kv; } else { uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; - // NOTE(Zihao): NUM_FRAGS_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_FRAGS_KV * 2 / num_warps - static_assert(NUM_FRAGS_KV * 2 % NUM_WARPS_Q == 0); + // NOTE(Zihao): NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / num_warps + static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); #pragma unroll - for (uint32_t i = 0; i < NUM_FRAGS_KV * 2 / NUM_WARPS_Q; ++i) { + for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { DType* gptr = produce_v ? paged_kv.v_data + kv_offset[i] : paged_kv.k_data + kv_offset[i]; smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); kv_idx += num_warps * 8; *smem_offset = smem.template advance_offset_by_row(*smem_offset); } - *smem_offset -= NUM_WARPS_KV * NUM_FRAGS_KV * 16 * channel_size_128b_kv; + *smem_offset -= NUM_WARPS_KV * NUM_MMA_KV * 16 * channel_size_128b_kv; } } -template +template __device__ __forceinline__ void init_rope_freq(float (*rope_freq)[4], const float log2_rope_rcp_scale, const float log2_rope_rcp_theta) { - constexpr uint32_t head_dim = NUM_FRAGS_D * 16; + constexpr uint32_t head_dim = NUM_MMA_D * 16; const uint32_t lane_idx = threadIdx.x; #pragma unroll - for (uint32_t fd = 0; fd < NUM_FRAGS_D / 2; ++fd) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D / 2; ++mma_d) { #pragma unroll for (uint32_t j = 0; j < 4; ++j) { - rope_freq[fd][j] = math::ptx_exp2( - log2_rope_rcp_scale + - log2_rope_rcp_theta * - float(2 * ((fd * 16 + (j / 2) * 8 + (lane_idx % 4) * 2 + (j % 2)) % (head_dim / 2))) / - float(head_dim)); + rope_freq[mma_d][j] = + math::ptx_exp2(log2_rope_rcp_scale + + log2_rope_rcp_theta * + float(2 * ((mma_d * 16 + (j / 2) * 8 + (lane_idx % 4) * 2 + (j % 2)) % + (head_dim / 2))) / + float(head_dim)); } } } -template -__device__ __forceinline__ void init_states(AttentionVariant variant, - float (*o_frag)[NUM_FRAGS_D][8], DTypeQKAccum (*m)[2], - float (*d)[2]) { +template +__device__ __forceinline__ void init_states(AttentionVariant variant, float (*o_frag)[NUM_MMA_D][8], + DTypeQKAccum (*m)[2], float (*d)[2]) { #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t fd = 0; fd < NUM_FRAGS_D; ++fd) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { - o_frag[fq][fd][reg_id] = 0.f; + o_frag[mma_q][mma_d][reg_id] = 0.f; } } } if constexpr (variant.use_softmax) { #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - m[fq][j] = DTypeQKAccum(-math::inf); - d[fq][j] = 1.f; + m[mma_q][j] = DTypeQKAccum(-math::inf); + d[mma_q][j] = 1.f; } } } } -template __device__ __forceinline__ void load_q_global_smem(uint32_t packed_offset, const uint32_t qo_upper_bound, @@ -332,120 +330,120 @@ __device__ __forceinline__ void load_q_global_smem(uint32_t packed_offset, const uint32_t q_stride_h, const uint_fastdiv group_size, smem_t* q_smem) { - constexpr uint32_t head_dim = NUM_FRAGS_D * 16; + constexpr uint32_t head_dim = NUM_MMA_D * 16; constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); const uint32_t lane_idx = threadIdx.x, warp_idx_x = get_warp_idx_q(); if (get_warp_idx_kv() == 0) { uint32_t q_smem_offset_w = q_smem->get_permuted_offset( - warp_idx_x * NUM_FRAGS_Q * 16 + lane_idx / 8, lane_idx % 8); + warp_idx_x * NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8); #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 4; ++j) { uint32_t q, r; - group_size.divmod(packed_offset + lane_idx / 8 + fq * 16 + j * 4, q, r); + group_size.divmod(packed_offset + lane_idx / 8 + mma_q * 16 + j * 4, q, r); const uint32_t q_idx = q; DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h; #pragma unroll - for (uint32_t fdo = 0; fdo < NUM_FRAGS_D / 4; ++fdo) { + for (uint32_t mma_do = 0; mma_do < NUM_MMA_D / 4; ++mma_do) { // load q fragment from gmem to smem q_smem->load_128b_async(q_smem_offset_w, q_ptr, q_idx < qo_upper_bound); - q_smem_offset_w = q_smem->template advance_offset_by_column<8>(q_smem_offset_w, fdo); + q_smem_offset_w = q_smem->template advance_offset_by_column<8>(q_smem_offset_w, mma_do); q_ptr += 8 * num_elems_per_128b(); } q_smem_offset_w = q_smem->template advance_offset_by_row<4, channel_size_128b_q>(q_smem_offset_w) - - 2 * NUM_FRAGS_D; + 2 * NUM_MMA_D; } } } } -template __device__ __forceinline__ void q_smem_inplace_apply_rotary( const uint32_t q_packed_idx, const uint32_t qo_len, const uint32_t kv_len, const uint_fastdiv group_size, smem_t* q_smem, uint32_t* q_smem_offset_r, float (*rope_freq)[4]) { if (get_warp_idx_kv() == 0) { - constexpr uint32_t head_dim = NUM_FRAGS_D * 16; + constexpr uint32_t head_dim = NUM_MMA_D * 16; constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); const uint32_t lane_idx = threadIdx.x; uint32_t q_frag_local[2][4]; - static_assert(NUM_FRAGS_D % 4 == 0, "NUM_FRAGS_D must be a multiple of 4"); + static_assert(NUM_MMA_D % 4 == 0, "NUM_MMA_D must be a multiple of 4"); #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; #pragma unroll - for (uint32_t fdi = 0; fdi < NUM_FRAGS_D / 2; ++fdi) { + for (uint32_t mma_di = 0; mma_di < NUM_MMA_D / 2; ++mma_di) { q_smem->ldmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); uint32_t q_smem_offset_r_last_half = - q_smem->template advance_offset_by_column(q_smem_offset_r_first_half, 0); + q_smem->template advance_offset_by_column(q_smem_offset_r_first_half, 0); q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_frag_apply_llama_rope( - (DTypeQ*)q_frag_local[0], (DTypeQ*)q_frag_local[1], rope_freq[fdi], - q_packed_idx + kv_len * group_size - qo_len * group_size + fq * 16 + lane_idx / 4, + (DTypeQ*)q_frag_local[0], (DTypeQ*)q_frag_local[1], rope_freq[mma_di], + q_packed_idx + kv_len * group_size - qo_len * group_size + mma_q * 16 + lane_idx / 4, group_size); q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); q_smem_offset_r_first_half = - q_smem->template advance_offset_by_column<2>(q_smem_offset_r_first_half, fdi); + q_smem->template advance_offset_by_column<2>(q_smem_offset_r_first_half, mma_di); } *q_smem_offset_r += 16 * channel_size_128b_q; } - *q_smem_offset_r -= NUM_FRAGS_Q * 16 * channel_size_128b_q; + *q_smem_offset_r -= NUM_MMA_Q * 16 * channel_size_128b_q; } } -template __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( const uint32_t q_packed_idx_base, const IdType* q_offset, smem_t* q_smem, const uint_fastdiv group_size, uint32_t* q_smem_offset_r, float (*rope_freq)[4]) { if (get_warp_idx_kv() == 0) { - constexpr uint32_t head_dim = NUM_FRAGS_D * 16; + constexpr uint32_t head_dim = NUM_MMA_D * 16; constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); const uint32_t lane_idx = threadIdx.x; uint32_t q_frag_local[2][4]; - static_assert(NUM_FRAGS_D % 4 == 0, "NUM_FRAGS_D must be a multiple of 4"); + static_assert(NUM_MMA_D % 4 == 0, "NUM_MMA_D must be a multiple of 4"); #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; #pragma unroll - for (uint32_t fdi = 0; fdi < NUM_FRAGS_D / 2; ++fdi) { + for (uint32_t mma_di = 0; mma_di < NUM_MMA_D / 2; ++mma_di) { q_smem->ldmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); uint32_t q_smem_offset_r_last_half = - q_smem->template advance_offset_by_column(q_smem_offset_r_first_half, 0); + q_smem->template advance_offset_by_column(q_smem_offset_r_first_half, 0); q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_frag_apply_llama_rope_with_pos( - (DTypeQ*)q_frag_local[0], (DTypeQ*)q_frag_local[1], rope_freq[fdi], - q_packed_idx_base + fq * 16 + lane_idx / 4, group_size, q_offset); + (DTypeQ*)q_frag_local[0], (DTypeQ*)q_frag_local[1], rope_freq[mma_di], + q_packed_idx_base + mma_q * 16 + lane_idx / 4, group_size, q_offset); q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); q_smem_offset_r_first_half = - q_smem->template advance_offset_by_column<2>(q_smem_offset_r_first_half, fdi); + q_smem->template advance_offset_by_column<2>(q_smem_offset_r_first_half, mma_di); } *q_smem_offset_r += 16 * channel_size_128b_q; } - *q_smem_offset_r -= NUM_FRAGS_Q * 16 * channel_size_128b_q; + *q_smem_offset_r -= NUM_MMA_Q * 16 * channel_size_128b_q; } } -template __device__ __forceinline__ void q_smem_inplace_transform( const typename AttentionVariant::ParamsT& params, AttentionVariant variant, smem_t* q_smem) { using DTypeQ = typename AttentionVariant::DTypeQ; const uint32_t warp_idx = get_warp_idx(), lane_idx = threadIdx.x; - constexpr uint32_t head_dim = NUM_FRAGS_D * 16; + constexpr uint32_t head_dim = NUM_MMA_D * 16; constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); constexpr uint32_t num_warps = NUM_WARPS_Q * NUM_WARPS_KV; #pragma unroll - for (uint32_t i = 0; i < NUM_FRAGS_Q * head_dim / (NUM_WARPS_KV * 16); ++i) { + for (uint32_t i = 0; i < NUM_MMA_Q * head_dim / (NUM_WARPS_KV * 16); ++i) { vec_t tmp; tmp.load((DTypeQ*)(q_smem->base) + (i * num_warps + warp_idx) * 256 + lane_idx * 8); #pragma unroll @@ -456,18 +454,18 @@ __device__ __forceinline__ void q_smem_inplace_transform( } } -template __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_idx_base, smem_t* k_smem, uint32_t* k_smem_offset_r, float (*rope_freq)[4]) { static_assert(sizeof(DTypeKV) == 2); - constexpr uint32_t head_dim = NUM_FRAGS_D * 16; + constexpr uint32_t head_dim = NUM_MMA_D * 16; constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); uint32_t k_frag_local[2][4]; const uint32_t lane_idx = threadIdx.x; - if constexpr (NUM_FRAGS_D == 4 && NUM_WARPS_Q == 4) { + if constexpr (NUM_MMA_D == 4 && NUM_WARPS_Q == 4) { static_assert(NUM_WARPS_KV == 1); const uint32_t warp_idx = get_warp_idx_q(); // horizontal-axis: y @@ -476,96 +474,94 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_id // | 1-16 | 16-32 | 32-48 | 48-64 | // | 1-16 | warp_idx=0 | warp_idx=1 | warp_idx=0 | warp_idx=1 | // | 16-32 | warp_idx=2 | warp_idx=3 | warp_idx=2 | warp_idx=3 | - static_assert(NUM_FRAGS_KV % 2 == 0, - "when NUM_FRAGS_D == 4, NUM_FRAGS_KV must be a multiple of 2"); + static_assert(NUM_MMA_KV % 2 == 0, "when NUM_MMA_D == 4, NUM_MMA_KV must be a multiple of 2"); uint32_t kv_idx = kv_idx_base + (warp_idx / 2) * 16 + lane_idx / 4; *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) + (warp_idx / 2) * 16 * channel_size_128b_kv; #pragma unroll - for (uint32_t i = 0; i < NUM_FRAGS_KV / 2; ++i) { - // uint32_t fkv = warp_idx / 2 + i * 2; + for (uint32_t i = 0; i < NUM_MMA_KV / 2; ++i) { + // uint32_t mma_kv = warp_idx / 2 + i * 2; uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; - uint32_t fdi = (warp_idx % 2); + uint32_t mma_di = (warp_idx % 2); k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); uint32_t k_smem_offset_r_last_half = k_smem->template advance_offset_by_column<4>(k_smem_offset_r_first_half, 0); k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); k_frag_apply_llama_rope((DTypeKV*)k_frag_local[0], (DTypeKV*)k_frag_local[1], - rope_freq[fdi], kv_idx); + rope_freq[mma_di], kv_idx); k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); *k_smem_offset_r += 32 * channel_size_128b_kv; kv_idx += 32; } *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) - - ((warp_idx / 2) + NUM_FRAGS_KV) * 16 * channel_size_128b_kv; + ((warp_idx / 2) + NUM_MMA_KV) * 16 * channel_size_128b_kv; } else { const uint32_t warp_idx_x = get_warp_idx_q(), warp_idx_z = get_warp_idx_kv(); - static_assert(NUM_FRAGS_D % (2 * NUM_WARPS_Q) == 0); + static_assert(NUM_MMA_D % (2 * NUM_WARPS_Q) == 0); // horizontal axis: y // vertical axis: z // | (warp_idx_z, warp_idx_x) | 1-16 | 16-32 | 32-48 | 48-64 | ... - // | 1-16*NUM_FRAGS_KV | (0, 0) | (0, 1) | (0, 2) | (0, 3) | ... - // | 16*NUM_FRAGS_KV-32*NUM_FRAGS_KV | (1, 0) | (1, 1) | (1, 2) | (1, 3) | ... + // | 1-16*NUM_MMA_KV | (0, 0) | (0, 1) | (0, 2) | (0, 3) | ... + // | 16*NUM_MMA_KV-32*NUM_MMA_KV | (1, 0) | (1, 1) | (1, 2) | (1, 3) | ... // ... - uint32_t kv_idx = kv_idx_base + (warp_idx_z * NUM_FRAGS_KV * 16) + lane_idx / 4; + uint32_t kv_idx = kv_idx_base + (warp_idx_z * NUM_MMA_KV * 16) + lane_idx / 4; *k_smem_offset_r = *k_smem_offset_r ^ (0x2 * warp_idx_x); #pragma unroll - for (uint32_t i = 0; i < NUM_FRAGS_KV; ++i) { + for (uint32_t i = 0; i < NUM_MMA_KV; ++i) { uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; #pragma unroll - for (uint32_t j = 0; j < NUM_FRAGS_D / (2 * NUM_WARPS_Q); ++j) { - uint32_t fdi = warp_idx_x + j * NUM_WARPS_Q; + for (uint32_t j = 0; j < NUM_MMA_D / (2 * NUM_WARPS_Q); ++j) { + uint32_t mma_di = warp_idx_x + j * NUM_WARPS_Q; k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); uint32_t k_smem_offset_r_last_half = - k_smem->template advance_offset_by_column(k_smem_offset_r_first_half, 0); + k_smem->template advance_offset_by_column(k_smem_offset_r_first_half, 0); k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); k_frag_apply_llama_rope((DTypeKV*)k_frag_local[0], (DTypeKV*)k_frag_local[1], - rope_freq[fdi], kv_idx); + rope_freq[mma_di], kv_idx); k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); k_smem_offset_r_first_half = k_smem->template advance_offset_by_column<2 * NUM_WARPS_Q>( - k_smem_offset_r_first_half, fdi); + k_smem_offset_r_first_half, mma_di); } *k_smem_offset_r += 16 * channel_size_128b_kv; kv_idx += 16; } *k_smem_offset_r = - (*k_smem_offset_r ^ (0x2 * warp_idx_x)) - NUM_FRAGS_KV * 16 * channel_size_128b_kv; + (*k_smem_offset_r ^ (0x2 * warp_idx_x)) - NUM_MMA_KV * 16 * channel_size_128b_kv; } } -template +template __device__ __forceinline__ void compute_qk(smem_t* q_smem, uint32_t* q_smem_offset_r, smem_t* k_smem, uint32_t* k_smem_offset_r, - DTypeQKAccum (*s_frag)[NUM_FRAGS_KV][8]) { - constexpr uint32_t head_dim = NUM_FRAGS_D * 16; + DTypeQKAccum (*s_frag)[NUM_MMA_KV][8]) { + constexpr uint32_t head_dim = NUM_MMA_D * 16; constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); - uint32_t a_frag[NUM_FRAGS_Q][4], b_frag[4]; + uint32_t a_frag[NUM_MMA_Q][4], b_frag[4]; // compute q*k^T #pragma unroll - for (uint32_t fd = 0; fd < NUM_FRAGS_D; ++fd) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { - q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fq]); + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[mma_q]); *q_smem_offset_r = q_smem->template advance_offset_by_row<16, channel_size_128b_q>(*q_smem_offset_r); } - *q_smem_offset_r = q_smem->template advance_offset_by_column<2>(*q_smem_offset_r, fd) - - NUM_FRAGS_Q * 16 * channel_size_128b_q; + *q_smem_offset_r = q_smem->template advance_offset_by_column<2>(*q_smem_offset_r, mma_d) - + NUM_MMA_Q * 16 * channel_size_128b_q; #pragma unroll - for (uint32_t fkv = 0; fkv < NUM_FRAGS_KV; ++fkv) { + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { if constexpr (sizeof(DTypeKV) == 1) { uint32_t b_frag_f8[2]; - if (fd % 2 == 0) { + if (mma_d % 2 == 0) { k_smem->ldmatrix_m8n8x4_left_half(*k_smem_offset_r, b_frag_f8); } else { k_smem->ldmatrix_m8n8x4_right_half(*k_smem_offset_r, b_frag_f8); @@ -580,76 +576,80 @@ __device__ __forceinline__ void compute_qk(smem_t* q_smem, k_smem->template advance_offset_by_row<16, channel_size_128b_kv>(*k_smem_offset_r); #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { if constexpr (std::is_same_v) { - if (fd == 0) { - mma::mma_sync_m16n16k16_row_col_f16f16f32(s_frag[fq][fkv], - a_frag[fq], b_frag); + if (mma_d == 0) { + mma::mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); } else { - mma::mma_sync_m16n16k16_row_col_f16f16f32(s_frag[fq][fkv], a_frag[fq], b_frag); + mma::mma_sync_m16n16k16_row_col_f16f16f32(s_frag[mma_q][mma_kv], a_frag[mma_q], + b_frag); } } else if (std::is_same_v) { - if (fd == 0) { - mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)s_frag[fq][fkv], - a_frag[fq], b_frag); + if (mma_d == 0) { + mma::mma_sync_m16n16k16_row_col_f16f16f16( + (uint32_t*)s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); } else { - mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)s_frag[fq][fkv], a_frag[fq], - b_frag); + mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)s_frag[mma_q][mma_kv], + a_frag[mma_q], b_frag); } } } } if constexpr (sizeof(DTypeKV) == 1) { - if (fd % 2 == 1) { - *k_smem_offset_r = k_smem->template advance_offset_by_column<2>(*k_smem_offset_r, fd / 2); + if (mma_d % 2 == 1) { + *k_smem_offset_r = + k_smem->template advance_offset_by_column<2>(*k_smem_offset_r, mma_d / 2); } - *k_smem_offset_r -= NUM_FRAGS_KV * 16 * channel_size_128b_kv; + *k_smem_offset_r -= NUM_MMA_KV * 16 * channel_size_128b_kv; } else { - *k_smem_offset_r = k_smem->template advance_offset_by_column<2>(*k_smem_offset_r, fd) - - NUM_FRAGS_KV * 16 * channel_size_128b_kv; + *k_smem_offset_r = k_smem->template advance_offset_by_column<2>(*k_smem_offset_r, mma_d) - + NUM_MMA_KV * 16 * channel_size_128b_kv; } } - *q_smem_offset_r -= NUM_FRAGS_D * 2; - *k_smem_offset_r -= NUM_FRAGS_D * sizeof(DTypeKV); + *q_smem_offset_r -= NUM_MMA_D * 2; + *k_smem_offset_r -= NUM_MMA_D * sizeof(DTypeKV); } -template +template __device__ __forceinline__ void logits_transform(const typename AttentionVariant::ParamsT& params, AttentionVariant variant, const uint32_t batch_idx, const uint32_t qo_packed_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, const uint32_t kv_len, const uint_fastdiv group_size, - DTypeQKAccum (*s_frag)[NUM_FRAGS_KV][8]) { + DTypeQKAccum (*s_frag)[NUM_MMA_KV][8]) { const uint32_t lane_idx = threadIdx.x, kv_head_idx = blockIdx.z; - uint32_t q[NUM_FRAGS_Q][2], r[NUM_FRAGS_Q][2]; + uint32_t q[NUM_MMA_Q][2], r[NUM_MMA_Q][2]; #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - group_size.divmod(qo_packed_idx_base + fq * 16 + lane_idx / 4 + 8 * j, q[fq][j], r[fq][j]); + group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / 4 + 8 * j, q[mma_q][j], + r[mma_q][j]); } } #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t fkv = 0; fkv < NUM_FRAGS_KV; ++fkv) { + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { - const uint32_t q_idx = q[fq][(reg_id % 4) / 2], kv_idx = kv_idx_base + fkv * 16 + - 2 * (lane_idx % 4) + - 8 * (reg_id / 4) + reg_id % 2; - const uint32_t qo_head_idx = kv_head_idx * group_size + r[fq][(reg_id % 4) / 2]; - s_frag[fq][fkv][reg_id] = variant.LogitsTransform( - params, s_frag[fq][fkv][reg_id], batch_idx, q_idx, kv_idx, qo_head_idx, kv_head_idx); + const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + + 2 * (lane_idx % 4) + + 8 * (reg_id / 4) + reg_id % 2; + const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; + s_frag[mma_q][mma_kv][reg_id] = + variant.LogitsTransform(params, s_frag[mma_q][mma_kv][reg_id], batch_idx, q_idx, kv_idx, + qo_head_idx, kv_head_idx); } } } } -template __device__ __forceinline__ void logits_mask(const typename AttentionVariant::ParamsT& params, AttentionVariant variant, const uint32_t batch_idx, @@ -657,114 +657,122 @@ __device__ __forceinline__ void logits_mask(const typename AttentionVariant::Par const uint32_t kv_idx_base, const uint32_t qo_len, const uint32_t kv_len, const uint32_t chunk_end, const uint_fastdiv group_size, - DTypeQKAccum (*s_frag)[NUM_FRAGS_KV][8]) { + DTypeQKAccum (*s_frag)[NUM_MMA_KV][8]) { const uint32_t lane_idx = threadIdx.x, kv_head_idx = blockIdx.z; - uint32_t q[NUM_FRAGS_Q][2], r[NUM_FRAGS_Q][2]; + uint32_t q[NUM_MMA_Q][2], r[NUM_MMA_Q][2]; #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - group_size.divmod(qo_packed_idx_base + fq * 16 + lane_idx / 4 + 8 * j, q[fq][j], r[fq][j]); + group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / 4 + 8 * j, q[mma_q][j], + r[mma_q][j]); } } #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t fkv = 0; fkv < NUM_FRAGS_KV; ++fkv) { + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { - const uint32_t q_idx = q[fq][(reg_id % 4) / 2], kv_idx = kv_idx_base + fkv * 16 + - 2 * (lane_idx % 4) + - 8 * (reg_id / 4) + reg_id % 2; - const uint32_t qo_head_idx = kv_head_idx * group_size + r[fq][(reg_id % 4) / 2]; + const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + + 2 * (lane_idx % 4) + + 8 * (reg_id / 4) + reg_id % 2; + const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; const bool mask = (!(MASK_MODE == MaskMode::kCausal ? (kv_idx + qo_len > kv_len + q_idx || (kv_idx >= chunk_end)) : kv_idx >= chunk_end)) && variant.LogitsMask(params, batch_idx, q_idx, kv_idx, qo_head_idx, kv_head_idx); - s_frag[fq][fkv][reg_id] = - (mask) ? s_frag[fq][fkv][reg_id] + s_frag[mma_q][mma_kv][reg_id] = + (mask) ? s_frag[mma_q][mma_kv][reg_id] : (variant.use_softmax ? DTypeQKAccum(-math::inf) : DTypeQKAccum(0.f)); } } } } -template __device__ __forceinline__ void update_mdo_states(AttentionVariant variant, - DTypeQKAccum (*s_frag)[NUM_FRAGS_KV][8], - float (*o_frag)[NUM_FRAGS_D][8], + DTypeQKAccum (*s_frag)[NUM_MMA_KV][8], + float (*o_frag)[NUM_MMA_D][8], DTypeQKAccum (*m)[2], float (*d)[2]) { if constexpr (variant.use_softmax) { if constexpr (std::is_same_v) { #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - float m_prev = m[fq][j]; + float m_prev = m[mma_q][j]; #pragma unroll - for (uint32_t fkv = 0; fkv < NUM_FRAGS_KV; ++fkv) { - float m_local = max(max(s_frag[fq][fkv][j * 2 + 0], s_frag[fq][fkv][j * 2 + 1]), - max(s_frag[fq][fkv][j * 2 + 4], s_frag[fq][fkv][j * 2 + 5])); - m[fq][j] = max(m[fq][j], m_local); + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + float m_local = + max(max(s_frag[mma_q][mma_kv][j * 2 + 0], s_frag[mma_q][mma_kv][j * 2 + 1]), + max(s_frag[mma_q][mma_kv][j * 2 + 4], s_frag[mma_q][mma_kv][j * 2 + 5])); + m[mma_q][j] = max(m[mma_q][j], m_local); } - m[fq][j] = max(m[fq][j], math::shfl_xor_sync(m[fq][j], 0x2)); - m[fq][j] = max(m[fq][j], math::shfl_xor_sync(m[fq][j], 0x1)); + m[mma_q][j] = max(m[mma_q][j], math::shfl_xor_sync(m[mma_q][j], 0x2)); + m[mma_q][j] = max(m[mma_q][j], math::shfl_xor_sync(m[mma_q][j], 0x1)); - float o_scale = math::ptx_exp2(m_prev - m[fq][j]); - d[fq][j] *= o_scale; + float o_scale = math::ptx_exp2(m_prev - m[mma_q][j]); + d[mma_q][j] *= o_scale; #pragma unroll - for (uint32_t fd = 0; fd < NUM_FRAGS_D; ++fd) { - o_frag[fq][fd][j * 2 + 0] *= o_scale; - o_frag[fq][fd][j * 2 + 1] *= o_scale; - o_frag[fq][fd][j * 2 + 4] *= o_scale; - o_frag[fq][fd][j * 2 + 5] *= o_scale; + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { + o_frag[mma_q][mma_d][j * 2 + 0] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 1] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 4] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 5] *= o_scale; } #pragma unroll - for (uint32_t fkv = 0; fkv < NUM_FRAGS_KV; ++fkv) { - s_frag[fq][fkv][j * 2 + 0] = math::ptx_exp2(s_frag[fq][fkv][j * 2 + 0] - m[fq][j]); - s_frag[fq][fkv][j * 2 + 1] = math::ptx_exp2(s_frag[fq][fkv][j * 2 + 1] - m[fq][j]); - s_frag[fq][fkv][j * 2 + 4] = math::ptx_exp2(s_frag[fq][fkv][j * 2 + 4] - m[fq][j]); - s_frag[fq][fkv][j * 2 + 5] = math::ptx_exp2(s_frag[fq][fkv][j * 2 + 5] - m[fq][j]); + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + s_frag[mma_q][mma_kv][j * 2 + 0] = + math::ptx_exp2(s_frag[mma_q][mma_kv][j * 2 + 0] - m[mma_q][j]); + s_frag[mma_q][mma_kv][j * 2 + 1] = + math::ptx_exp2(s_frag[mma_q][mma_kv][j * 2 + 1] - m[mma_q][j]); + s_frag[mma_q][mma_kv][j * 2 + 4] = + math::ptx_exp2(s_frag[mma_q][mma_kv][j * 2 + 4] - m[mma_q][j]); + s_frag[mma_q][mma_kv][j * 2 + 5] = + math::ptx_exp2(s_frag[mma_q][mma_kv][j * 2 + 5] - m[mma_q][j]); } } } } else if constexpr (std::is_same_v) { #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { half m_prev[2]; #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - m_prev[j] = m[fq][j]; + m_prev[j] = m[mma_q][j]; #pragma unroll - for (uint32_t fkv = 0; fkv < NUM_FRAGS_KV; ++fkv) { - half2 m_local = - __hmax2(*(half2*)&s_frag[fq][fkv][j * 2], *(half2*)&s_frag[fq][fkv][j * 2 + 4]); - m[fq][j] = __hmax(m[fq][j], __hmax(m_local.x, m_local.y)); + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + half2 m_local = __hmax2(*(half2*)&s_frag[mma_q][mma_kv][j * 2], + *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4]); + m[mma_q][j] = __hmax(m[mma_q][j], __hmax(m_local.x, m_local.y)); } } - *(half2*)&m[fq] = __hmax2(*(half2*)&m[fq], math::shfl_xor_sync(*(half2*)&m[fq], 0x2)); - *(half2*)&m[fq] = __hmax2(*(half2*)&m[fq], math::shfl_xor_sync(*(half2*)&m[fq], 0x1)); + *(half2*)&m[mma_q] = + __hmax2(*(half2*)&m[mma_q], math::shfl_xor_sync(*(half2*)&m[mma_q], 0x2)); + *(half2*)&m[mma_q] = + __hmax2(*(half2*)&m[mma_q], math::shfl_xor_sync(*(half2*)&m[mma_q], 0x1)); #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - float o_scale = math::ptx_exp2(float(m_prev[j] - m[fq][j])); - d[fq][j] *= o_scale; -#pragma unroll - for (uint32_t fd = 0; fd < NUM_FRAGS_D; ++fd) { - o_frag[fq][fd][j * 2 + 0] *= o_scale; - o_frag[fq][fd][j * 2 + 1] *= o_scale; - o_frag[fq][fd][j * 2 + 4] *= o_scale; - o_frag[fq][fd][j * 2 + 5] *= o_scale; + float o_scale = math::ptx_exp2(float(m_prev[j] - m[mma_q][j])); + d[mma_q][j] *= o_scale; +#pragma unroll + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { + o_frag[mma_q][mma_d][j * 2 + 0] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 1] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 4] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 5] *= o_scale; } - half2 m2 = make_half2(m[fq][j], m[fq][j]); + half2 m2 = make_half2(m[mma_q][j], m[mma_q][j]); #pragma unroll - for (uint32_t fkv = 0; fkv < NUM_FRAGS_KV; ++fkv) { - *(half2*)&s_frag[fq][fkv][j * 2] = - math::ptx_exp2(*(half2*)&s_frag[fq][fkv][j * 2] - m2); - *(half2*)&s_frag[fq][fkv][j * 2 + 4] = - math::ptx_exp2(*(half2*)&s_frag[fq][fkv][j * 2 + 4] - m2); + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + *(half2*)&s_frag[mma_q][mma_kv][j * 2] = + math::ptx_exp2(*(half2*)&s_frag[mma_q][mma_kv][j * 2] - m2); + *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4] = + math::ptx_exp2(*(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4] - m2); } } } @@ -772,50 +780,49 @@ __device__ __forceinline__ void update_mdo_states(AttentionVariant variant, } } -template +template __device__ __forceinline__ void compute_sfm_v(AttentionVariant variant, smem_t* v_smem, uint32_t* v_smem_offset_r, - DTypeQKAccum (*s_frag)[NUM_FRAGS_KV][8], - float (*o_frag)[NUM_FRAGS_D][8], float (*d)[2]) { - constexpr uint32_t head_dim = NUM_FRAGS_D * 16; + DTypeQKAccum (*s_frag)[NUM_MMA_KV][8], + float (*o_frag)[NUM_MMA_D][8], float (*d)[2]) { + constexpr uint32_t head_dim = NUM_MMA_D * 16; constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); - DTypeQ s_frag_f16[NUM_FRAGS_Q][NUM_FRAGS_KV][8]; + DTypeQ s_frag_f16[NUM_MMA_Q][NUM_MMA_KV][8]; if constexpr (std::is_same_v) { #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t fkv = 0; fkv < NUM_FRAGS_KV; ++fkv) { - vec_cast::cast<8>(s_frag_f16[fq][fkv], s_frag[fq][fkv]); + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + vec_cast::cast<8>(s_frag_f16[mma_q][mma_kv], s_frag[mma_q][mma_kv]); } } } if constexpr (variant.use_softmax) { #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t fkv = 0; fkv < NUM_FRAGS_KV; ++fkv) { + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { if constexpr (std::is_same_v) { - mma::rowsum_f16f16f32(d[fq], s_frag_f16[fq][fkv]); + mma::rowsum_f16f16f32(d[mma_q], s_frag_f16[mma_q][mma_kv]); } else { - mma::rowsum_f16f16f32(d[fq], s_frag[fq][fkv]); + mma::rowsum_f16f16f32(d[mma_q], s_frag[mma_q][mma_kv]); } } } } #pragma unroll - for (uint32_t fkv = 0; fkv < NUM_FRAGS_KV; ++fkv) { + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { #pragma unroll - for (uint32_t fd = 0; fd < NUM_FRAGS_D; ++fd) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { uint32_t b_frag[4]; if constexpr (sizeof(DTypeKV) == 1) { uint32_t b_frag_f8[2]; - if (fd % 2 == 0) { + if (mma_d % 2 == 0) { v_smem->ldmatrix_m8n8x4_trans_left_half(*v_smem_offset_r, b_frag_f8); } else { v_smem->ldmatrix_m8n8x4_trans_right_half(*v_smem_offset_r, b_frag_f8); @@ -828,53 +835,54 @@ __device__ __forceinline__ void compute_sfm_v(AttentionVariant variant, v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); } #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { if constexpr (std::is_same_v) { mma::mma_sync_m16n16k16_row_col_f16f16f32( - o_frag[fq][fd], (uint32_t*)(s_frag_f16[fq][fkv]), b_frag); + o_frag[mma_q][mma_d], (uint32_t*)(s_frag_f16[mma_q][mma_kv]), b_frag); } else { - mma::mma_sync_m16n16k16_row_col_f16f16f32(o_frag[fq][fd], - (uint32_t*)s_frag[fq][fkv], b_frag); + mma::mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[mma_q][mma_d], (uint32_t*)s_frag[mma_q][mma_kv], b_frag); } } if constexpr (sizeof(DTypeKV) == 1) { - if (fd % 2 == 1) { - *v_smem_offset_r = v_smem->template advance_offset_by_column<2>(*v_smem_offset_r, fd / 2); + if (mma_d % 2 == 1) { + *v_smem_offset_r = + v_smem->template advance_offset_by_column<2>(*v_smem_offset_r, mma_d / 2); } } else { - *v_smem_offset_r = v_smem->template advance_offset_by_column<2>(*v_smem_offset_r, fd); + *v_smem_offset_r = v_smem->template advance_offset_by_column<2>(*v_smem_offset_r, mma_d); } } *v_smem_offset_r = v_smem->template advance_offset_by_row<16, channel_size_128b_kv>(*v_smem_offset_r) - - sizeof(DTypeKV) * NUM_FRAGS_D; + sizeof(DTypeKV) * NUM_MMA_D; } - *v_smem_offset_r -= 16 * NUM_FRAGS_KV * channel_size_128b_kv; + *v_smem_offset_r -= 16 * NUM_MMA_KV * channel_size_128b_kv; } -template -__device__ __forceinline__ void normalize_d(AttentionVariant variant, - float (*o_frag)[NUM_FRAGS_D][8], DTypeQKAccum (*m)[2], - float (*d)[2]) { +template +__device__ __forceinline__ void normalize_d(AttentionVariant variant, float (*o_frag)[NUM_MMA_D][8], + DTypeQKAccum (*m)[2], float (*d)[2]) { if constexpr (variant.use_softmax) { - float d_rcp[NUM_FRAGS_Q][2]; + float d_rcp[NUM_MMA_Q][2]; // compute reciprocal of d #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - d_rcp[fq][j] = (m[fq][j] != DTypeQKAccum(-math::inf)) ? math::ptx_rcp(d[fq][j]) : 0.f; + d_rcp[mma_q][j] = + (m[mma_q][j] != DTypeQKAccum(-math::inf)) ? math::ptx_rcp(d[mma_q][j]) : 0.f; } } #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t fd = 0; fd < NUM_FRAGS_D; ++fd) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { - o_frag[fq][fd][reg_id] = o_frag[fq][fd][reg_id] * d_rcp[fq][(reg_id % 4) / 2]; + o_frag[mma_q][mma_d][reg_id] = + o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][(reg_id % 4) / 2]; } } } @@ -884,42 +892,42 @@ __device__ __forceinline__ void normalize_d(AttentionVariant variant, /*! * \brief Synchronize the states of the MDO kernel across the threadblock along threadIdx.z. */ -template __device__ __forceinline__ void threadblock_sync_mdo_states( - AttentionVariant variant, float (*o_frag)[NUM_FRAGS_D][8], float* smem_workspace, + AttentionVariant variant, float (*o_frag)[NUM_MMA_D][8], float* smem_workspace, DTypeQKAccum (*m)[2], float (*d)[2], const uint32_t warp_idx, const uint32_t lane_idx) { // only necessary when blockDim.z > 1 if constexpr (NUM_WARPS_KV > 1) { - float2* smem_md = (float2*)(smem_workspace + NUM_FRAGS_Q * NUM_FRAGS_D * NUM_WARPS_Q * - NUM_WARPS_KV * WARP_SIZE * 8); - // o: [num_warps, NUM_FRAGS_Q, NUM_FRAGS_D, WARP_SIZE(32), 8] - // md: [num_warps, NUM_FRAGS_Q, 2, WARP_SIZE(32), 2 (m/d)] + float2* smem_md = (float2*)(smem_workspace + + NUM_MMA_Q * NUM_MMA_D * NUM_WARPS_Q * NUM_WARPS_KV * WARP_SIZE * 8); + // o: [num_warps, NUM_MMA_Q, NUM_MMA_D, WARP_SIZE(32), 8] + // md: [num_warps, NUM_MMA_Q, 2, WARP_SIZE(32), 2 (m/d)] #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t fd = 0; fd < NUM_FRAGS_D; ++fd) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { vec_t::memcpy( smem_workspace + - (((warp_idx * NUM_FRAGS_Q + fq) * NUM_FRAGS_D + fd) * WARP_SIZE + lane_idx) * 8, - o_frag[fq][fd]); + (((warp_idx * NUM_MMA_Q + mma_q) * NUM_MMA_D + mma_d) * WARP_SIZE + lane_idx) * 8, + o_frag[mma_q][mma_d]); } } if constexpr (variant.use_softmax) { #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - smem_md[((warp_idx * NUM_FRAGS_Q + fq) * 2 + j) * WARP_SIZE + lane_idx] = - make_float2(float(m[fq][j]), d[fq][j]); + smem_md[((warp_idx * NUM_MMA_Q + mma_q) * 2 + j) * WARP_SIZE + lane_idx] = + make_float2(float(m[mma_q][j]), d[mma_q][j]); } } // synchronize m,d first __syncthreads(); #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { float o_scale[2][NUM_WARPS_KV]; #pragma unroll for (uint32_t j = 0; j < 2; ++j) { @@ -927,8 +935,8 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( #pragma unroll for (uint32_t i = 0; i < NUM_WARPS_KV; ++i) { float2 md = smem_md[(((i * NUM_WARPS_Q + get_warp_idx_q()) * - NUM_FRAGS_Q + - fq) * + NUM_MMA_Q + + mma_q) * 2 + j) * WARP_SIZE + @@ -941,8 +949,8 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( #pragma unroll for (uint32_t i = 0; i < NUM_WARPS_KV; ++i) { float2 md = smem_md[(((i * NUM_WARPS_Q + get_warp_idx_q()) * - NUM_FRAGS_Q + - fq) * + NUM_MMA_Q + + mma_q) * 2 + j) * WARP_SIZE + @@ -950,92 +958,90 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( float mi = md.x; o_scale[j][i] = math::ptx_exp2(float(mi - m_new)); } - m[fq][j] = DTypeQKAccum(m_new); - d[fq][j] = d_new; + m[mma_q][j] = DTypeQKAccum(m_new); + d[mma_q][j] = d_new; } #pragma unroll - for (uint32_t fd = 0; fd < NUM_FRAGS_D; ++fd) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { vec_t o_new; o_new.fill(0.f); #pragma unroll for (uint32_t i = 0; i < NUM_WARPS_KV; ++i) { vec_t oi; - oi.load( - smem_workspace + - ((((i * NUM_WARPS_Q + get_warp_idx_q()) * NUM_FRAGS_Q + - fq) * - NUM_FRAGS_D + - fd) * - WARP_SIZE + - lane_idx) * - 8); + oi.load(smem_workspace + + ((((i * NUM_WARPS_Q + get_warp_idx_q()) * NUM_MMA_Q + + mma_q) * + NUM_MMA_D + + mma_d) * + WARP_SIZE + + lane_idx) * + 8); #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { o_new[reg_id] += oi[reg_id] * o_scale[(reg_id % 4) / 2][i]; } } - o_new.store(o_frag[fq][fd]); + o_new.store(o_frag[mma_q][mma_d]); } } } else { // synchronize m,d first __syncthreads(); #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t fd = 0; fd < NUM_FRAGS_D; ++fd) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { vec_t o_new; o_new.fill(0.f); #pragma unroll for (uint32_t i = 0; i < NUM_WARPS_KV; ++i) { vec_t oi; - oi.load( - smem_workspace + - ((((i * NUM_WARPS_Q + get_warp_idx_q()) * NUM_FRAGS_Q + - fq) * - NUM_FRAGS_D + - fd) * - WARP_SIZE + - lane_idx) * - 8); + oi.load(smem_workspace + + ((((i * NUM_WARPS_Q + get_warp_idx_q()) * NUM_MMA_Q + + mma_q) * + NUM_MMA_D + + mma_d) * + WARP_SIZE + + lane_idx) * + 8); #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { o_new[reg_id] += oi[reg_id]; } } - o_new.store(o_frag[fq][fd]); + o_new.store(o_frag[mma_q][mma_d]); } } } } } -template __device__ __forceinline__ void write_o_reg_gmem( - float (*o_frag)[NUM_FRAGS_D][8], smem_t* o_smem, DTypeO* o_ptr_base, + float (*o_frag)[NUM_MMA_D][8], smem_t* o_smem, DTypeO* o_ptr_base, const uint32_t o_packed_idx_base, const uint32_t qo_upper_bound, const uint32_t o_stride_n, const uint32_t o_stride_h, const uint_fastdiv group_size) { - constexpr uint32_t head_dim = NUM_FRAGS_D * 16; + constexpr uint32_t head_dim = NUM_MMA_D * 16; constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); const uint32_t warp_idx_x = get_warp_idx_q(); const uint32_t lane_idx = threadIdx.x; if (get_warp_idx_kv() == 0) { #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t fd = 0; fd < NUM_FRAGS_D; ++fd) { + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D; ++mma_d) { uint32_t o_frag_f16[4]; - vec_cast::cast<8>((DTypeO*)o_frag_f16, o_frag[fq][fd]); + vec_cast::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_q][mma_d]); #ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED uint32_t o_smem_offset_w = o_smem->get_permuted_offset( - (warp_idx_x * NUM_FRAGS_Q + fq) * 16 + lane_idx % 16, fd * 2 + lane_idx / 16); + (warp_idx_x * NUM_MMA_Q + mma_q) * 16 + lane_idx % 16, mma_d * 2 + lane_idx / 16); o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); #else uint32_t o_smem_offset_w = o_smem->get_permuted_offset( - (warp_idx_x * NUM_FRAGS_Q + fq) * 16 + lane_idx / 4, fd * 2); + (warp_idx_x * NUM_MMA_Q + mma_q) * 16 + lane_idx / 4, mma_d * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * channel_size_128b_out))[lane_idx % 4] = o_frag_f16[1]; @@ -1047,27 +1053,27 @@ __device__ __forceinline__ void write_o_reg_gmem( } uint32_t o_smem_offset_w = o_smem->get_permuted_offset( - warp_idx_x * NUM_FRAGS_Q * 16 + lane_idx / 8, lane_idx % 8); + warp_idx_x * NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8); #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 4; ++j) { uint32_t q, r; - group_size.divmod(o_packed_idx_base + lane_idx / 8 + fq * 16 + j * 4, q, r); + group_size.divmod(o_packed_idx_base + lane_idx / 8 + mma_q * 16 + j * 4, q, r); const uint32_t o_idx = q; DTypeO* o_ptr = o_ptr_base + q * o_stride_n + r * o_stride_h; #pragma unroll - for (uint32_t fdo = 0; fdo < NUM_FRAGS_D / 4; ++fdo) { + for (uint32_t mma_do = 0; mma_do < NUM_MMA_D / 4; ++mma_do) { if (o_idx < qo_upper_bound) { o_smem->store_128b(o_smem_offset_w, o_ptr); } o_ptr += 8 * num_elems_per_128b(); - o_smem_offset_w = o_smem->template advance_offset_by_column<8>(o_smem_offset_w, fdo); + o_smem_offset_w = o_smem->template advance_offset_by_column<8>(o_smem_offset_w, mma_do); } o_smem_offset_w = o_smem->template advance_offset_by_row<4, channel_size_128b_out>(o_smem_offset_w) - - 2 * NUM_FRAGS_D; + 2 * NUM_MMA_D; } } } @@ -1080,9 +1086,9 @@ __device__ __forceinline__ void write_o_reg_gmem( * \tparam partition_kv Whether to split kv_len into chunks. * \tparam mask_mode The mask mode used in the attention operation. * \tparam POS_ENCODING_MODE The positional encoding mode. - * \tparam NUM_FRAGS_Q The number of fragments in x dimension. - * \tparam NUM_FRAGS_D The number of fragments in y dimension. - * \tparam NUM_FRAGS_KV The number of fragments in z dimension. + * \tparam NUM_MMA_Q The number of fragments in x dimension. + * \tparam NUM_MMA_D The number of fragments in y dimension. + * \tparam NUM_MMA_KV The number of fragments in z dimension. * \tparam num_warps The number of warps in the threadblock. * \tparam DTypeQ The data type of the query tensor. * \tparam DTypeKV The data type of the key/value tensor. @@ -1098,8 +1104,8 @@ __device__ __forceinline__ void write_o_reg_gmem( * \param log2_rope_rcp_theta log2(1/(rope_theta)), where rope_theta is the theta * used in RoPE. */ -template __global__ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKVCacheKernel( @@ -1132,9 +1138,9 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV const uint32_t lane_idx = threadIdx.x, warp_idx = get_warp_idx(); const uint32_t bx = blockIdx.x, chunk_idx = blockIdx.y, kv_head_idx = blockIdx.z; const uint32_t num_kv_heads = gridDim.z, num_qo_heads = num_kv_heads * group_size; - constexpr uint32_t num_rows_per_cta = NUM_FRAGS_Q * NUM_WARPS_Q * 16; + constexpr uint32_t num_rows_per_cta = NUM_MMA_Q * NUM_WARPS_Q * 16; const tensor_info_t qkv_info(qo_len, kv_len, num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, - kv_stride_n, kv_stride_h, /*head_dim=*/NUM_FRAGS_D * 16); + kv_stride_n, kv_stride_h, /*head_dim=*/NUM_MMA_D * 16); const uint32_t num_chunks = gridDim.y; const uint32_t max_chunk_size = partition_kv ? ceil_div(kv_len, num_chunks) : kv_len; @@ -1148,26 +1154,26 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV AttentionVariant variant(params, /*batch_idx=*/0, smem); const uint32_t window_left = variant.window_left; - constexpr uint32_t head_dim = NUM_FRAGS_D * 16; + constexpr uint32_t head_dim = NUM_MMA_D * 16; constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); - DTypeQKAccum s_frag[NUM_FRAGS_Q][NUM_FRAGS_KV][8]; - alignas(16) float o_frag[NUM_FRAGS_Q][NUM_FRAGS_D][8]; - DTypeQKAccum m[NUM_FRAGS_Q][2]; - float d[NUM_FRAGS_Q][2]; - float rope_freq[NUM_FRAGS_D / 2][4]; + DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D][8]; + DTypeQKAccum m[NUM_MMA_Q][2]; + float d[NUM_MMA_Q][2]; + float rope_freq[NUM_MMA_D / 2][4]; if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { const float log2_rope_rcp_scale = params.log2_rope_rcp_scale; const float log2_rope_rcp_theta = params.log2_rope_rcp_theta; - init_rope_freq(rope_freq, log2_rope_rcp_scale, log2_rope_rcp_theta); + init_rope_freq(rope_freq, log2_rope_rcp_scale, log2_rope_rcp_theta); } - init_states(variant, o_frag, m, d); + init_states(variant, o_frag, m, d); // cooperative fetch q fragment from gmem to reg const uint32_t qo_packed_idx_base = - (bx * NUM_WARPS_Q + get_warp_idx_q()) * NUM_FRAGS_Q * 16; + (bx * NUM_WARPS_Q + get_warp_idx_q()) * NUM_MMA_Q * 16; constexpr SwizzleMode swizzle_mode_q = SwizzleMode::k128B; smem_t qo_smem(smem); DTypeQ* q_ptr_base = @@ -1182,10 +1188,10 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV (lane_idx % 8) * num_elems_per_128b()); uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( - get_warp_idx_q() * NUM_FRAGS_Q * 16 + lane_idx % 16, + get_warp_idx_q() * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); - load_q_global_smem( + load_q_global_smem( qo_packed_idx_base, qo_len, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem); cp_async::commit_group(); @@ -1193,12 +1199,12 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV block.sync(); if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { - q_smem_inplace_apply_rotary( - qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq); + q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, + &q_smem_offset_r, rope_freq); block.sync(); } - q_smem_inplace_transform( + q_smem_inplace_transform( params, variant, &qo_smem); constexpr SwizzleMode swizzle_mode_kv = @@ -1206,9 +1212,9 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV constexpr uint32_t kv_frag_rows = swizzle_mode_kv == SwizzleMode::k128B ? 4 : 8; constexpr uint32_t kv_frag_cols = swizzle_mode_kv == SwizzleMode::k128B ? 8 : 4; smem_t k_smem(smem + - (NUM_WARPS_Q * NUM_FRAGS_Q * sizeof(DTypeQ)) * 16 * head_dim), - v_smem(smem + (NUM_WARPS_Q * NUM_FRAGS_Q * sizeof(DTypeQ) + - NUM_WARPS_KV * NUM_FRAGS_KV * sizeof(DTypeKV)) * + (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ)) * 16 * head_dim), + v_smem(smem + (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ) + + NUM_WARPS_KV * NUM_MMA_KV * sizeof(DTypeKV)) * 16 * head_dim); const uint32_t num_iterations = ceil_div( @@ -1217,12 +1223,12 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV sub_if_greater_or_zero( kv_len - qo_len + ((bx + 1) * num_rows_per_cta) / group_size, chunk_start)) : chunk_size, - 16 * NUM_WARPS_KV * NUM_FRAGS_KV); + 16 * NUM_WARPS_KV * NUM_MMA_KV); const uint32_t window_iteration = ceil_div(sub_if_greater_or_zero(kv_len + (bx + 1) * num_rows_per_cta, qo_len + window_left + chunk_start), - (16 * NUM_WARPS_KV * NUM_FRAGS_KV)); + (16 * NUM_WARPS_KV * NUM_MMA_KV)); const uint32_t mask_iteration = (MASK_MODE == MaskMode::kCausal @@ -1230,7 +1236,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV sub_if_greater_or_zero(kv_len + (bx * num_rows_per_cta) / group_size - qo_len, chunk_start)) : chunk_size) / - (16 * NUM_WARPS_KV * NUM_FRAGS_KV); + (16 * NUM_WARPS_KV * NUM_MMA_KV); DTypeKV* k_ptr = k + qkv_info.get_kv_elem_offset( @@ -1242,18 +1248,18 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV (lane_idx % kv_frag_cols) * num_elems_per_128b()); uint32_t k_smem_offset_r = k_smem.get_permuted_offset( - get_warp_idx_kv() * NUM_FRAGS_KV * 16 + + get_warp_idx_kv() * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), v_smem_offset_r = v_smem.get_permuted_offset( - get_warp_idx_kv() * NUM_FRAGS_KV * 16 + lane_idx % 16, + get_warp_idx_kv() * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), kv_smem_offset_w = k_smem.get_permuted_offset( warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols); - produce_kv( + produce_kv( k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, 0, chunk_size); cp_async::commit_group(); - produce_kv( + produce_kv( v_smem, &kv_smem_offset_w, &v_ptr, kv_stride_n, 0, chunk_size); cp_async::commit_group(); @@ -1263,65 +1269,65 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV block.sync(); if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { - k_smem_inplace_apply_rotary( - chunk_start + iter * 16 * NUM_WARPS_KV * NUM_FRAGS_KV, &k_smem, &k_smem_offset_r, + chunk_start + iter * 16 * NUM_WARPS_KV * NUM_MMA_KV, &k_smem, &k_smem_offset_r, rope_freq); block.sync(); } // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); - logits_transform( + logits_transform( params, variant, /*batch_idx=*/0, qo_packed_idx_base, chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * - NUM_FRAGS_KV * 16, + NUM_MMA_KV * 16, qo_len, kv_len, group_size, s_frag); // apply mask if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { - logits_mask( + logits_mask( params, variant, /*batch_idx=*/0, qo_packed_idx_base, chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * - NUM_FRAGS_KV * 16, + NUM_MMA_KV * 16, qo_len, kv_len, chunk_end, group_size, s_frag); } // compute m,d states in online softmax - update_mdo_states(variant, s_frag, o_frag, m, d); + update_mdo_states(variant, s_frag, o_frag, m, d); block.sync(); - produce_kv( + produce_kv( k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, - (iter + 1) * 16 * NUM_WARPS_KV * NUM_FRAGS_KV, chunk_size); + (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, chunk_size); cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); // compute sfm*v - compute_sfm_v( + compute_sfm_v( variant, &v_smem, &v_smem_offset_r, s_frag, o_frag, d); block.sync(); - produce_kv(v_smem, &kv_smem_offset_w, &v_ptr, kv_stride_n, - (iter + 1) * 16 * NUM_WARPS_KV * NUM_FRAGS_KV, chunk_size); + produce_kv( + v_smem, &kv_smem_offset_w, &v_ptr, kv_stride_n, + (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, chunk_size); cp_async::commit_group(); } cp_async::wait_group<0>(); block.sync(); // threadblock synchronization - threadblock_sync_mdo_states( + threadblock_sync_mdo_states( variant, o_frag, (float*)smem, m, d, warp_idx, lane_idx); // normalize d - normalize_d(variant, o_frag, m, d); + normalize_d(variant, o_frag, m, d); // write back - write_o_reg_gmem( + write_o_reg_gmem( o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, /*o_stride_n=*/ partition_kv ? num_qo_heads * head_dim * num_chunks : num_qo_heads * head_dim, @@ -1332,20 +1338,20 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void SinglePrefillWithKV if (lse != nullptr || partition_kv) { if (get_warp_idx_kv() == 0) { #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { uint32_t q, r; - group_size.divmod(qo_packed_idx_base + lane_idx / 4 + j * 8 + fq * 16, q, r); + group_size.divmod(qo_packed_idx_base + lane_idx / 4 + j * 8 + mma_q * 16, q, r); const uint32_t qo_head_idx = kv_head_idx * group_size + r; const uint32_t qo_idx = q; if (qo_idx < qo_len) { if (partition_kv) { lse[(qo_idx * num_chunks + chunk_idx) * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[fq][j]) + float(m[fq][j]); + math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } else { lse[qo_idx * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[fq][j]) + float(m[fq][j]); + math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } } } @@ -1380,7 +1386,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params const uint32_t group_size = num_qo_heads / num_kv_heads; const uint_fastdiv group_size_fastdiv(group_size); - constexpr uint32_t NUM_FRAGS_D = HEAD_DIM / 16; + constexpr uint32_t NUM_MMA_D = HEAD_DIM / 16; uint32_t cta_tile_q = 0; int64_t unpacked_qo_len = qo_len * group_size; if (unpacked_qo_len > 64 && HEAD_DIM < 256) { @@ -1405,7 +1411,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, { constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); - constexpr uint32_t NUM_FRAGS_Q = get_num_frags_q(CTA_TILE_Q); + constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); using DTypeQKAccum = typename std::conditional, half, @@ -1421,37 +1427,37 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeQ) * 16) ? 2 : 1; const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; - const uint32_t max_num_frags_kv_reg = - (HEAD_DIM >= 128 && NUM_FRAGS_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + const uint32_t max_num_mma_kv_reg = + (HEAD_DIM >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !ALLOW_FP16_QK_REDUCTION) ? 2 - : (8 / NUM_FRAGS_Q); + : (8 / NUM_MMA_Q); // TODO(Zihao): fix the following computation - const uint32_t max_num_frags_kv_smem = - (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ)) - NUM_FRAGS_Q * NUM_WARPS_Q) / + const uint32_t max_num_mma_kv_smem = + (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ)) - NUM_MMA_Q * NUM_WARPS_Q) / (2 * NUM_WARPS_KV); - // control NUM_FRAGS_KV for maximum warp occupancy - DISPATCH_NUM_FRAGS_KV(min(max_num_frags_kv_smem, max_num_frags_kv_reg), NUM_FRAGS_KV, { + // control NUM_MMA_KV for maximum warp occupancy + DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { if constexpr (is_invalid_configuration( - NUM_FRAGS_Q, NUM_FRAGS_D, NUM_FRAGS_KV, NUM_WARPS_Q, NUM_WARPS_KV)) { + NUM_MMA_Q, NUM_MMA_D, NUM_MMA_KV, NUM_WARPS_Q, NUM_WARPS_KV)) { // Invalid configuration, skip std::ostringstream err_msg; - err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_FRAGS_Q=" << NUM_FRAGS_Q - << " NUM_FRAGS_D=" << NUM_FRAGS_D << " NUM_FRAGS_KV=" << NUM_FRAGS_KV + err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q + << " NUM_MMA_D=" << NUM_MMA_D << " NUM_MMA_KV=" << NUM_MMA_KV << " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" " and report the issue to the developers."; FLASHINFER_ERROR(err_msg.str()); } else { constexpr uint32_t num_threads = (NUM_WARPS_Q * NUM_WARPS_KV) * WARP_SIZE; - constexpr uint32_t num_rows_per_cta = NUM_FRAGS_Q * NUM_WARPS_Q * 16; - auto kernel = SinglePrefillWithKVCacheKernel; // TODO(Zihao): fix the following computation - uint32_t smem_size = (NUM_FRAGS_Q * NUM_WARPS_Q * sizeof(DTypeQ) + - NUM_FRAGS_KV * NUM_WARPS_KV * 2 * sizeof(DTypeQ)) * + uint32_t smem_size = (NUM_MMA_Q * NUM_WARPS_Q * sizeof(DTypeQ) + + NUM_MMA_KV * NUM_WARPS_KV * 2 * sizeof(DTypeQ)) * 16 * HEAD_DIM; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); @@ -1507,8 +1513,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched(typename AttentionVariant::Params return cudaSuccess; } -template __global__ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRaggedKVCacheKernel( @@ -1546,7 +1552,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag static_assert(sizeof(DTypeQ) == 2); static_assert(sizeof(DTypeO) == 2); - constexpr uint32_t head_dim = NUM_FRAGS_D * 16; + constexpr uint32_t head_dim = NUM_MMA_D * 16; const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); auto block = cg::this_thread_block(); @@ -1558,7 +1564,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag const uint32_t num_kv_heads = gridDim.z, num_qo_heads = group_size * num_kv_heads; const uint32_t request_idx = request_indices[bx], qo_tile_idx = qo_tile_indices[bx], kv_tile_idx = kv_tile_indices[bx]; - constexpr uint32_t num_rows_per_cta = NUM_FRAGS_Q * NUM_WARPS_Q * 16; + constexpr uint32_t num_rows_per_cta = NUM_MMA_Q * NUM_WARPS_Q * 16; extern __shared__ uint8_t smem[]; AttentionVariant variant(params, /*batch_idx=*/request_idx, smem); const uint32_t qo_len = variant.qo_len, kv_len = variant.kv_len, @@ -1570,7 +1576,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag partition_kv ? min((kv_tile_idx + 1) * max_chunk_size, kv_len) : kv_len; const uint32_t chunk_size = chunk_end - chunk_start; const tensor_info_t qkv_info(qo_len, kv_len, num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, - kv_stride_n, kv_stride_h, /*head_dim=*/NUM_FRAGS_D * 16); + kv_stride_n, kv_stride_h, /*head_dim=*/NUM_MMA_D * 16); const uint32_t qo_upper_bound = min(qo_len, ceil_div((qo_tile_idx + 1) * num_rows_per_cta, group_size)); @@ -1578,22 +1584,21 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); - DTypeQKAccum s_frag[NUM_FRAGS_Q][NUM_FRAGS_KV][8]; - alignas(16) float o_frag[NUM_FRAGS_Q][NUM_FRAGS_D][8]; - DTypeQKAccum m[NUM_FRAGS_Q][2]; - float d[NUM_FRAGS_Q][2]; - float rope_freq[NUM_FRAGS_D / 2][4]; + DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D][8]; + DTypeQKAccum m[NUM_MMA_Q][2]; + float d[NUM_MMA_Q][2]; + float rope_freq[NUM_MMA_D / 2][4]; if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { const float log2_rope_rcp_scale = params.log2_rope_rcp_scale; const float log2_rope_rcp_theta = params.log2_rope_rcp_theta; - init_rope_freq(rope_freq, log2_rope_rcp_scale, log2_rope_rcp_theta); + init_rope_freq(rope_freq, log2_rope_rcp_scale, log2_rope_rcp_theta); } - init_states(variant, o_frag, m, d); + init_states(variant, o_frag, m, d); const uint32_t qo_packed_idx_base = - (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q()) * NUM_FRAGS_Q * - 16; + (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q()) * NUM_MMA_Q * 16; constexpr SwizzleMode swizzle_mode_q = SwizzleMode::k128B; smem_t qo_smem(smem); @@ -1610,10 +1615,10 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag (lane_idx % 8) * num_elems_per_128b()); uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( - get_warp_idx_q() * NUM_FRAGS_Q * 16 + lane_idx % 16, + get_warp_idx_q() * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); - load_q_global_smem( + load_q_global_smem( qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem); @@ -1623,18 +1628,18 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { if (!q_offset) { - q_smem_inplace_apply_rotary( - qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq); + q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, group_size, + &qo_smem, &q_smem_offset_r, rope_freq); } else { - q_smem_inplace_apply_rotary_with_pos( qo_packed_idx_base, q_offset + q_indptr[request_idx], &qo_smem, group_size, &q_smem_offset_r, rope_freq); } block.sync(); } - q_smem_inplace_transform( + q_smem_inplace_transform( params, variant, &qo_smem); const uint32_t num_iterations = ceil_div( @@ -1644,12 +1649,12 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag kv_len - qo_len + ((qo_tile_idx + 1) * num_rows_per_cta) / group_size, chunk_start)) : chunk_size), - 16 * NUM_WARPS_KV * NUM_FRAGS_KV); + 16 * NUM_WARPS_KV * NUM_MMA_KV); const uint32_t window_iteration = ceil_div(sub_if_greater_or_zero(kv_len + (qo_tile_idx + 1) * num_rows_per_cta, qo_len + window_left + chunk_start), - (16 * NUM_WARPS_KV * NUM_FRAGS_KV)); + (16 * NUM_WARPS_KV * NUM_MMA_KV)); const uint32_t mask_iteration = (MASK_MODE == MaskMode::kCausal @@ -1657,24 +1662,24 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag kv_len + (qo_tile_idx * num_rows_per_cta) / group_size - qo_len, chunk_start)) : chunk_size) / - (16 * NUM_WARPS_KV * NUM_FRAGS_KV); + (16 * NUM_WARPS_KV * NUM_MMA_KV); constexpr SwizzleMode swizzle_mode_kv = (sizeof(DTypeKV) == 1 && head_dim == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; constexpr uint32_t kv_frag_rows = swizzle_mode_kv == SwizzleMode::k128B ? 4 : 8; constexpr uint32_t kv_frag_cols = swizzle_mode_kv == SwizzleMode::k128B ? 8 : 4; smem_t k_smem(smem + - (NUM_WARPS_Q * NUM_FRAGS_Q * sizeof(DTypeQ)) * 16 * head_dim), - v_smem(smem + (NUM_WARPS_Q * NUM_FRAGS_Q * sizeof(DTypeQ) + - NUM_WARPS_KV * NUM_FRAGS_KV * sizeof(DTypeKV)) * + (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ)) * 16 * head_dim), + v_smem(smem + (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ) + + NUM_WARPS_KV * NUM_MMA_KV * sizeof(DTypeKV)) * 16 * head_dim); uint32_t k_smem_offset_r = k_smem.get_permuted_offset( - get_warp_idx_kv() * NUM_FRAGS_KV * 16 + + get_warp_idx_kv() * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), v_smem_offset_r = v_smem.get_permuted_offset( - get_warp_idx_kv() * NUM_FRAGS_KV * 16 + lane_idx % 16, + get_warp_idx_kv() * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), kv_smem_offset_w = k_smem.get_permuted_offset( warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols); @@ -1690,10 +1695,10 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag kv_head_idx, (lane_idx % kv_frag_cols) * num_elems_per_128b()); - produce_kv( + produce_kv( k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, 0, chunk_size); cp_async::commit_group(); - produce_kv( + produce_kv( v_smem, &kv_smem_offset_w, &v_ptr, kv_stride_n, 0, chunk_size); cp_async::commit_group(); @@ -1703,68 +1708,68 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag block.sync(); if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { - k_smem_inplace_apply_rotary( (k_rope_pos_offset == nullptr ? 0 : k_rope_pos_offset[request_idx]) + chunk_start + - iter * 16 * NUM_WARPS_KV * NUM_FRAGS_KV, + iter * 16 * NUM_WARPS_KV * NUM_MMA_KV, &k_smem, &k_smem_offset_r, rope_freq); block.sync(); } // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); - logits_transform( + logits_transform( params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * - NUM_FRAGS_KV * 16, + NUM_MMA_KV * 16, qo_len, kv_len, group_size, s_frag); // apply mask if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { - logits_mask( + logits_mask( params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * - NUM_FRAGS_KV * 16, + NUM_MMA_KV * 16, qo_len, kv_len, chunk_end, group_size, s_frag); } // compute m,d states in online softmax - update_mdo_states(variant, s_frag, o_frag, m, d); + update_mdo_states(variant, s_frag, o_frag, m, d); block.sync(); - produce_kv( + produce_kv( k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, - (iter + 1) * 16 * NUM_WARPS_KV * NUM_FRAGS_KV, chunk_size); + (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, chunk_size); cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); // compute sfm*v - compute_sfm_v( + compute_sfm_v( variant, &v_smem, &v_smem_offset_r, s_frag, o_frag, d); block.sync(); - produce_kv(v_smem, &kv_smem_offset_w, &v_ptr, kv_stride_n, - (iter + 1) * 16 * NUM_WARPS_KV * NUM_FRAGS_KV, chunk_size); + produce_kv( + v_smem, &kv_smem_offset_w, &v_ptr, kv_stride_n, + (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, chunk_size); cp_async::commit_group(); } cp_async::wait_group<0>(); block.sync(); // threadblock synchronization - threadblock_sync_mdo_states( + threadblock_sync_mdo_states( variant, o_frag, (float*)smem, m, d, warp_idx, lane_idx); // normalize d - normalize_d(variant, o_frag, m, d); + normalize_d(variant, o_frag, m, d); const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; // write back - write_o_reg_gmem( + write_o_reg_gmem( o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, /*o_stride_n=*/ partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim, @@ -1775,21 +1780,21 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag if (lse != nullptr) { if (get_warp_idx_kv() == 0) { #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { uint32_t q, r; - group_size.divmod(qo_packed_idx_base + lane_idx / 4 + j * 8 + fq * 16, q, r); + group_size.divmod(qo_packed_idx_base + lane_idx / 4 + j * 8 + mma_q * 16, q, r); const uint32_t qo_head_idx = kv_head_idx * group_size + r; const uint32_t qo_idx = q; if (qo_idx < qo_len) { if (partition_kv) { lse[(o_indptr[request_idx] + qo_idx * num_kv_chunks + kv_tile_idx) * num_qo_heads + - qo_head_idx] = math::ptx_log2(d[fq][j]) + float(m[fq][j]); + qo_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } else { lse[(o_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[fq][j]) + float(m[fq][j]); + math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } } } @@ -1802,8 +1807,8 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithRag #endif } -template __global__ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPagedKVCacheKernel( @@ -1846,7 +1851,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag const uint32_t request_idx = request_indices[bx], qo_tile_idx = qo_tile_indices[bx], kv_tile_idx = kv_tile_indices[bx]; - constexpr uint32_t num_rows_per_cta = NUM_FRAGS_Q * NUM_WARPS_Q * 16; + constexpr uint32_t num_rows_per_cta = NUM_MMA_Q * NUM_WARPS_Q * 16; extern __shared__ uint8_t smem[]; AttentionVariant variant(params, /*batch_idx=*/request_idx, smem); const uint32_t qo_len = variant.qo_len, kv_len = variant.kv_len, @@ -1860,27 +1865,26 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag const uint32_t qo_upper_bound = min(qo_len, ceil_div((qo_tile_idx + 1) * num_rows_per_cta, group_size)); - constexpr uint32_t head_dim = NUM_FRAGS_D * 16; + constexpr uint32_t head_dim = NUM_MMA_D * 16; constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); - DTypeQKAccum s_frag[NUM_FRAGS_Q][NUM_FRAGS_KV][8]; - alignas(16) float o_frag[NUM_FRAGS_Q][NUM_FRAGS_D][8]; - DTypeQKAccum m[NUM_FRAGS_Q][2]; - float d[NUM_FRAGS_Q][2]; - float rope_freq[NUM_FRAGS_D / 2][4]; + DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D][8]; + DTypeQKAccum m[NUM_MMA_Q][2]; + float d[NUM_MMA_Q][2]; + float rope_freq[NUM_MMA_D / 2][4]; if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { const float log2_rope_rcp_scale = params.log2_rope_rcp_scale; const float log2_rope_rcp_theta = params.log2_rope_rcp_theta; - init_rope_freq(rope_freq, log2_rope_rcp_scale, log2_rope_rcp_theta); + init_rope_freq(rope_freq, log2_rope_rcp_scale, log2_rope_rcp_theta); } - init_states(variant, o_frag, m, d); + init_states(variant, o_frag, m, d); const uint32_t qo_packed_idx_base = - (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q()) * NUM_FRAGS_Q * - 16; + (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q()) * NUM_MMA_Q * 16; const uint32_t q_stride_n = params.q_stride_n, q_stride_h = params.q_stride_h; constexpr SwizzleMode swizzle_mode_q = SwizzleMode::k128B; smem_t qo_smem(smem); @@ -1896,10 +1900,10 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag (lane_idx % 8) * num_elems_per_128b(), num_qo_heads * head_dim, head_dim); uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( - get_warp_idx_q() * NUM_FRAGS_Q * 16 + lane_idx % 16, + get_warp_idx_q() * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); - load_q_global_smem( + load_q_global_smem( qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem); @@ -1909,18 +1913,18 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { if (q_offset == nullptr) { - q_smem_inplace_apply_rotary( - qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq); + q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, group_size, + &qo_smem, &q_smem_offset_r, rope_freq); } else { - q_smem_inplace_apply_rotary_with_pos( qo_packed_idx_base, q_offset + q_indptr[request_idx], &qo_smem, group_size, &q_smem_offset_r, rope_freq); } block.sync(); } - q_smem_inplace_transform( + q_smem_inplace_transform( params, variant, &qo_smem); constexpr SwizzleMode swizzle_mode_kv = @@ -1928,18 +1932,18 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag constexpr uint32_t kv_frag_rows = swizzle_mode_kv == SwizzleMode::k128B ? 4 : 8; constexpr uint32_t kv_frag_cols = swizzle_mode_kv == SwizzleMode::k128B ? 8 : 4; smem_t k_smem(smem + - (NUM_WARPS_Q * NUM_FRAGS_Q * sizeof(DTypeQ)) * 16 * head_dim), - v_smem(smem + (NUM_WARPS_Q * NUM_FRAGS_Q * sizeof(DTypeQ) + - NUM_WARPS_KV * NUM_FRAGS_KV * sizeof(DTypeKV)) * + (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ)) * 16 * head_dim), + v_smem(smem + (NUM_WARPS_Q * NUM_MMA_Q * sizeof(DTypeQ) + + NUM_WARPS_KV * NUM_MMA_KV * sizeof(DTypeKV)) * 16 * head_dim); - size_t kv_offset[NUM_FRAGS_KV * (swizzle_mode_kv == SwizzleMode::k128B ? 4 : 2) / NUM_WARPS_Q]; + size_t kv_offset[NUM_MMA_KV * (swizzle_mode_kv == SwizzleMode::k128B ? 4 : 2) / NUM_WARPS_Q]; uint32_t k_smem_offset_r = k_smem.get_permuted_offset( - get_warp_idx_kv() * NUM_FRAGS_KV * 16 + + get_warp_idx_kv() * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), v_smem_offset_r = v_smem.get_permuted_offset( - get_warp_idx_kv() * NUM_FRAGS_KV * 16 + lane_idx % 16, + get_warp_idx_kv() * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), kv_smem_offset_w = k_smem.get_permuted_offset( warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols); @@ -1949,7 +1953,7 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag paged_kv.indptr[request_idx] * paged_kv.page_size + chunk_start; #pragma unroll for (uint32_t i = 0; - i < NUM_FRAGS_KV * (swizzle_mode_kv == SwizzleMode::k128B ? 4 : 2) / NUM_WARPS_Q; ++i) { + i < NUM_MMA_KV * (swizzle_mode_kv == SwizzleMode::k128B ? 4 : 2) / NUM_WARPS_Q; ++i) { uint32_t page_iter, entry_idx; paged_kv.page_size.divmod(packed_page_iter_base + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols + @@ -1959,10 +1963,10 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag page_iter, kv_head_idx, entry_idx, (lane_idx % kv_frag_cols) * num_elems_per_128b(), last_indptr); } - page_produce_kv( + page_produce_kv( k_smem, &kv_smem_offset_w, paged_kv, 0, kv_offset, chunk_size); cp_async::commit_group(); - page_produce_kv( + page_produce_kv( v_smem, &kv_smem_offset_w, paged_kv, 0, kv_offset, chunk_size); cp_async::commit_group(); @@ -1973,12 +1977,12 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag kv_len - qo_len + ((qo_tile_idx + 1) * num_rows_per_cta) / group_size, chunk_start)) : chunk_size), - 16 * NUM_WARPS_KV * NUM_FRAGS_KV); + 16 * NUM_WARPS_KV * NUM_MMA_KV); const uint32_t window_iteration = ceil_div(sub_if_greater_or_zero(kv_len + (qo_tile_idx + 1) * num_rows_per_cta, qo_len + window_left + chunk_start), - (16 * NUM_WARPS_KV * NUM_FRAGS_KV)); + (16 * NUM_WARPS_KV * NUM_MMA_KV)); const uint32_t mask_iteration = (MASK_MODE == MaskMode::kCausal @@ -1986,14 +1990,14 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag kv_len + (qo_tile_idx * num_rows_per_cta) / group_size - qo_len, chunk_start)) : chunk_size) / - (16 * NUM_WARPS_KV * NUM_FRAGS_KV); + (16 * NUM_WARPS_KV * NUM_MMA_KV); #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { - packed_page_iter_base += 16 * NUM_WARPS_KV * NUM_FRAGS_KV; + packed_page_iter_base += 16 * NUM_WARPS_KV * NUM_MMA_KV; #pragma unroll for (uint32_t i = 0; - i < NUM_FRAGS_KV * (swizzle_mode_kv == SwizzleMode::k128B ? 4 : 2) / NUM_WARPS_Q; ++i) { + i < NUM_MMA_KV * (swizzle_mode_kv == SwizzleMode::k128B ? 4 : 2) / NUM_WARPS_Q; ++i) { uint32_t page_iter, entry_idx; paged_kv.page_size.divmod(packed_page_iter_base + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols + @@ -2007,51 +2011,51 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag block.sync(); if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { - k_smem_inplace_apply_rotary( (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[request_idx]) + - chunk_start + iter * 16 * NUM_WARPS_KV * NUM_FRAGS_KV, + chunk_start + iter * 16 * NUM_WARPS_KV * NUM_MMA_KV, &k_smem, &k_smem_offset_r, rope_freq); block.sync(); } // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); - logits_transform( + logits_transform( params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * - NUM_FRAGS_KV * 16, + NUM_MMA_KV * 16, qo_len, kv_len, group_size, s_frag); // apply mask if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { - logits_mask( + logits_mask( params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv()) * - NUM_FRAGS_KV * 16, + NUM_MMA_KV * 16, qo_len, kv_len, chunk_end, group_size, s_frag); } // compute m,d states in online softmax - update_mdo_states(variant, s_frag, o_frag, m, d); + update_mdo_states(variant, s_frag, o_frag, m, d); block.sync(); - page_produce_kv( - k_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * NUM_WARPS_KV * NUM_FRAGS_KV, + page_produce_kv( + k_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, kv_offset, chunk_size); cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); // compute sfm*v - compute_sfm_v( + compute_sfm_v( variant, &v_smem, &v_smem_offset_r, s_frag, o_frag, d); block.sync(); - page_produce_kv( - v_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * NUM_WARPS_KV * NUM_FRAGS_KV, + page_produce_kv( + v_smem, &kv_smem_offset_w, paged_kv, (iter + 1) * 16 * NUM_WARPS_KV * NUM_MMA_KV, kv_offset, chunk_size); cp_async::commit_group(); } @@ -2059,16 +2063,16 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag block.sync(); // threadblock synchronization - threadblock_sync_mdo_states( + threadblock_sync_mdo_states( variant, o_frag, (float*)smem, m, d, warp_idx, lane_idx); // normalize d - normalize_d(variant, o_frag, m, d); + normalize_d(variant, o_frag, m, d); const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; // write_back - write_o_reg_gmem( + write_o_reg_gmem( o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, /*o_stride_n=*/ partition_kv ? num_qo_heads * head_dim * num_kv_chunks : num_qo_heads * head_dim, @@ -2079,21 +2083,21 @@ __launch_bounds__(NUM_WARPS_Q* NUM_WARPS_KV* WARP_SIZE) void BatchPrefillWithPag if (lse != nullptr) { if (get_warp_idx_kv() == 0) { #pragma unroll - for (uint32_t fq = 0; fq < NUM_FRAGS_Q; ++fq) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { uint32_t q, r; - group_size.divmod(qo_packed_idx_base + lane_idx / 4 + j * 8 + fq * 16, q, r); + group_size.divmod(qo_packed_idx_base + lane_idx / 4 + j * 8 + mma_q * 16, q, r); const uint32_t qo_head_idx = kv_head_idx * group_size + r; const uint32_t qo_idx = q; if (qo_idx < qo_upper_bound) { if (partition_kv) { lse[(o_indptr[request_idx] + qo_idx * num_kv_chunks + kv_tile_idx) * num_qo_heads + - qo_head_idx] = math::ptx_log2(d[fq][j]) + float(m[fq][j]); + qo_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } else { lse[(o_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = - math::ptx_log2(d[fq][j]) + float(m[fq][j]); + math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } } } @@ -2118,7 +2122,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P const uint32_t num_kv_heads = params.num_kv_heads; const uint_fastdiv group_size_fastdiv(num_qo_heads / num_kv_heads); const uint32_t total_num_rows = params.total_num_rows; - constexpr uint32_t NUM_FRAGS_Q = get_num_frags_q(CTA_TILE_Q); + constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); @@ -2130,7 +2134,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P dim3 nblks(padded_batch_size, 1, num_kv_heads); dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); - constexpr uint32_t NUM_FRAGS_D = HEAD_DIM / 16; + constexpr uint32_t NUM_MMA_D = HEAD_DIM / 16; using DTypeQKAccum = typename std::conditional, half, float>::type; @@ -2145,36 +2149,36 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(typename AttentionVariant::P const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeQ) * 16) ? 2 : 1; const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; - const uint32_t max_num_frags_kv_reg = - (HEAD_DIM >= 128 && NUM_FRAGS_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + const uint32_t max_num_mma_kv_reg = + (HEAD_DIM >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !ALLOW_FP16_QK_REDUCTION) ? 2 - : (8 / NUM_FRAGS_Q); + : (8 / NUM_MMA_Q); // TODO(Zihao): fix the following computation - const uint32_t max_num_frags_kv_smem = - (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ)) - NUM_FRAGS_Q * NUM_WARPS_Q) / + const uint32_t max_num_mma_kv_smem = + (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ)) - NUM_MMA_Q * NUM_WARPS_Q) / (2 * NUM_WARPS_KV); - DISPATCH_NUM_FRAGS_KV(min(max_num_frags_kv_smem, max_num_frags_kv_reg), NUM_FRAGS_KV, { + DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { if constexpr (is_invalid_configuration( - NUM_FRAGS_Q, NUM_FRAGS_D, NUM_FRAGS_KV, NUM_WARPS_Q, NUM_WARPS_KV)) { + NUM_MMA_Q, NUM_MMA_D, NUM_MMA_KV, NUM_WARPS_Q, NUM_WARPS_KV)) { // Invalid configuration, skip std::ostringstream err_msg; - err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_FRAGS_Q=" << NUM_FRAGS_Q - << " NUM_FRAGS_D=" << NUM_FRAGS_D << " NUM_FRAGS_KV=" << NUM_FRAGS_KV + err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q + << " NUM_MMA_D=" << NUM_MMA_D << " NUM_MMA_KV=" << NUM_MMA_KV << " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" " and report the issue to the developers."; FLASHINFER_ERROR(err_msg.str()); } else { // TODO(Zihao): fix the following computation - uint32_t smem_size = (NUM_FRAGS_Q * NUM_WARPS_Q * sizeof(DTypeQ) + - NUM_FRAGS_KV * NUM_WARPS_KV * 2 * sizeof(DTypeQ)) * + uint32_t smem_size = (NUM_MMA_Q * NUM_WARPS_Q * sizeof(DTypeQ) + + NUM_MMA_KV * NUM_WARPS_KV * 2 * sizeof(DTypeQ)) * 16 * HEAD_DIM; auto kernel = - BatchPrefillWithRaggedKVCacheKernel; + BatchPrefillWithRaggedKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); if (tmp_v == nullptr) { @@ -2219,7 +2223,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa const uint32_t num_kv_heads = params.paged_kv.num_heads; const uint_fastdiv group_size_fastdiv(num_qo_heads / num_kv_heads); const uint32_t total_num_rows = params.total_num_rows; - constexpr uint32_t NUM_FRAGS_Q = get_num_frags_q(CTA_TILE_Q); + constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); @@ -2232,7 +2236,7 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa dim3 nblks(padded_batch_size, 1, num_kv_heads); dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); - constexpr uint32_t NUM_FRAGS_D = HEAD_DIM / 16; + constexpr uint32_t NUM_MMA_D = HEAD_DIM / 16; using DTypeQKAccum = typename std::conditional, half, float>::type; @@ -2247,35 +2251,35 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched(typename AttentionVariant::Pa const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeQ) * 16) ? 2 : 1; const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; - const uint32_t max_num_frags_kv_reg = - (HEAD_DIM >= 128 && NUM_FRAGS_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + const uint32_t max_num_mma_kv_reg = + (HEAD_DIM >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !ALLOW_FP16_QK_REDUCTION) ? 2 - : (8 / NUM_FRAGS_Q); + : (8 / NUM_MMA_Q); // TODO(Zihao): fix the following computation - const uint32_t max_num_frags_kv_smem = - (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ)) - NUM_FRAGS_Q * NUM_WARPS_Q) / + const uint32_t max_num_mma_kv_smem = + (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ)) - NUM_MMA_Q * NUM_WARPS_Q) / (2 * NUM_WARPS_KV); - DISPATCH_NUM_FRAGS_KV(min(max_num_frags_kv_smem, max_num_frags_kv_reg), NUM_FRAGS_KV, { + DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { if constexpr (is_invalid_configuration( - NUM_FRAGS_Q, NUM_FRAGS_D, NUM_FRAGS_KV, NUM_WARPS_Q, NUM_WARPS_KV)) { + NUM_MMA_Q, NUM_MMA_D, NUM_MMA_KV, NUM_WARPS_Q, NUM_WARPS_KV)) { // Invalid configuration, skip std::ostringstream err_msg; - err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_FRAGS_Q=" << NUM_FRAGS_Q - << " NUM_FRAGS_D=" << NUM_FRAGS_D << " NUM_FRAGS_KV=" << NUM_FRAGS_KV + err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q + << " NUM_MMA_D=" << NUM_MMA_D << " NUM_MMA_KV=" << NUM_MMA_KV << " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" " and report the issue to the developers."; FLASHINFER_ERROR(err_msg.str()); } else { // TODO(Zihao): fix the following computation - uint32_t smem_size = (NUM_FRAGS_Q * NUM_WARPS_Q * sizeof(DTypeQ) + - NUM_FRAGS_KV * NUM_WARPS_KV * 2 * sizeof(DTypeQ)) * + uint32_t smem_size = (NUM_MMA_Q * NUM_WARPS_Q * sizeof(DTypeQ) + + NUM_MMA_KV * NUM_WARPS_KV * 2 * sizeof(DTypeQ)) * 16 * HEAD_DIM; auto kernel = - BatchPrefillWithPagedKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index a4841a31..b36145fb 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -63,36 +63,36 @@ __VA_ARGS__ \ } -#define DISPATCH_NUM_FRAGS_Q(num_frags_q, NUM_FRAGS_Q, ...) \ - if (num_frags_q == 1) { \ - constexpr size_t NUM_FRAGS_Q = 1; \ - __VA_ARGS__ \ - } else if (num_frags_q == 2) { \ - constexpr size_t NUM_FRAGS_Q = 2; \ - __VA_ARGS__ \ - } else { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported num_frags_q: " << num_frags_q; \ - FLASHINFER_ERROR(err_msg.str()); \ +#define DISPATCH_NUM_MMA_Q(num_mma_q, NUM_MMA_Q, ...) \ + if (num_mma_q == 1) { \ + constexpr size_t NUM_MMA_Q = 1; \ + __VA_ARGS__ \ + } else if (num_mma_q == 2) { \ + constexpr size_t NUM_MMA_Q = 2; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported num_mma_q: " << num_mma_q; \ + FLASHINFER_ERROR(err_msg.str()); \ } -#define DISPATCH_NUM_FRAGS_KV(max_frags_kv, NUM_FRAGS_KV, ...) \ - if (max_frags_kv >= 8) { \ - constexpr size_t NUM_FRAGS_KV = 8; \ - __VA_ARGS__ \ - } else if (max_frags_kv >= 4) { \ - constexpr size_t NUM_FRAGS_KV = 4; \ - __VA_ARGS__ \ - } else if (max_frags_kv >= 2) { \ - constexpr size_t NUM_FRAGS_KV = 2; \ - __VA_ARGS__ \ - } else if (max_frags_kv >= 1) { \ - constexpr size_t NUM_FRAGS_KV = 1; \ - __VA_ARGS__ \ - } else { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported max_frags_kv: " << max_frags_kv; \ - FLASHINFER_ERROR(err_msg.str()); \ +#define DISPATCH_NUM_MMA_KV(max_mma_kv, NUM_MMA_KV, ...) \ + if (max_mma_kv >= 8) { \ + constexpr size_t NUM_MMA_KV = 8; \ + __VA_ARGS__ \ + } else if (max_mma_kv >= 4) { \ + constexpr size_t NUM_MMA_KV = 4; \ + __VA_ARGS__ \ + } else if (max_mma_kv >= 2) { \ + constexpr size_t NUM_MMA_KV = 2; \ + __VA_ARGS__ \ + } else if (max_mma_kv >= 1) { \ + constexpr size_t NUM_MMA_KV = 1; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported max_mma_kv: " << max_mma_kv; \ + FLASHINFER_ERROR(err_msg.str()); \ } #define DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, ...) \