diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index 5dca8dcb..f42cb3f9 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -25,6 +25,7 @@ #include #include +#include #include #include "../cp_async.cuh" @@ -816,9 +817,10 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t seq_len, uint32_t head_dim, QKVLayout kv_layout = QKVLayout::kNHD, RotaryMode rotary_mode = RotaryMode::kNone, + std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { - const float sm_scale = 1.f / std::sqrt(float(head_dim)); + float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); const float rope_rcp_scale = 1.f / rope_scale; const float rope_rcp_theta = 1.f / rope_theta; if (num_qo_heads % num_kv_heads != 0) { @@ -1134,8 +1136,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched( DTypeIn* q, IdType* q_rope_position, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse, - float rope_scale, float rope_theta, cudaStream_t stream) { - const float sm_scale = 1.f / std::sqrt(float(HEAD_DIM)); + float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { const float rope_rcp_scale = 1.f / rope_scale; const float rope_rcp_theta = 1.f / rope_theta; const uint32_t num_kv_heads = paged_kv.num_heads; @@ -1229,11 +1230,13 @@ cudaError_t BatchDecodeWithPagedKVCache( DTypeIn* q, IdType* q_rope_position, paged_kv_t paged_kv, kv_partition_info_t kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse, - uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone, float rope_scale = 1.f, + uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone, + std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t head_dim = paged_kv.head_dim; const uint32_t batch_size = paged_kv.batch_size; + const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads " << num_qo_heads << " is not a multiple of num_kv_heads " @@ -1243,12 +1246,14 @@ cudaError_t BatchDecodeWithPagedKVCache( DISPATCH_GQA_GROUP_SIZE( num_qo_heads / num_kv_heads, GROUP_SIZE, - {DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, { - return BatchDecodeWithPagedKVCacheDispatched< - GROUP_SIZE, HEAD_DIM, page_storage, kv_layout, ROTARY_MODE, DTypeIn, - DTypeOut, IdType>(q, q_rope_position, paged_kv, kv_partition_info, o, - tmp, lse, rope_scale, rope_theta, stream); - })})}); + {DISPATCH_HEAD_DIM( + head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, { + return BatchDecodeWithPagedKVCacheDispatched( + q, q_rope_position, paged_kv, kv_partition_info, o, tmp, lse, sm_scale, rope_scale, + rope_theta, stream); + })})}); return cudaSuccess; } @@ -1258,9 +1263,8 @@ template sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, + cudaStream_t stream = nullptr) { + if (!sm_scale.has_value()) { + sm_scale = 1.f / std::sqrt(float(head_dim)); + } if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; err_msg << "num_qo_heads " << num_qo_heads << " is not a multiple of num_kv_heads " @@ -1317,8 +1325,8 @@ cudaError_t BatchDecodeWithPaddedKVCache( rotary_mode, ROTARY_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, { return BatchDecodeWithPaddedKVCacheDispatched( - q, k, v, o, tmp, lse, batch_size, padded_kv_len, num_qo_heads, rope_scale, - rope_theta, stream); + q, k, v, o, tmp, lse, batch_size, padded_kv_len, num_qo_heads, sm_scale, + rope_scale, rope_theta, stream); })})})}); return cudaSuccess; } diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index f488c084..4a17d066 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -23,6 +23,7 @@ #endif #include +#include #include #include "../cp_async.cuh" @@ -1563,9 +1564,9 @@ template cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp, float* lse, uint32_t num_kv_heads, - uint32_t qo_len, uint32_t kv_len, float rope_scale, - float rope_theta, cudaStream_t stream) { - const float sm_scale = 1.f / std::sqrt(float(HEAD_DIM)); + uint32_t qo_len, uint32_t kv_len, float sm_scale, + float rope_scale, float rope_theta, + cudaStream_t stream) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); if (kv_len < qo_len && CAUSAL) { @@ -1714,14 +1715,14 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* * \return status Indicates whether CUDA calls are successful */ template -cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp, - float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, - uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, - bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, - RotaryMode rotary_mode = RotaryMode::kNone, - bool allow_fp16_qk_reduction = false, float rope_scale = 1.f, - float rope_theta = 1e4, cudaStream_t stream = nullptr) { +cudaError_t SinglePrefillWithKVCache( + DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, bool causal = true, + QKVLayout kv_layout = QKVLayout::kNHD, RotaryMode rotary_mode = RotaryMode::kNone, + bool allow_fp16_qk_reduction = false, std::optional maybe_sm_scale = std::nullopt, + float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { const uint32_t group_size = num_qo_heads / num_kv_heads; + const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); DISPATCH_ALLOW_FP16_QK_REDUCTION( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {DISPATCH_GQA_GROUP_SIZE( @@ -1734,9 +1735,9 @@ cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOu rotary_mode, ROTARY_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, { SinglePrefillWithKVCacheDispatched(q, k, v, o, tmp, lse, - num_kv_heads, qo_len, kv_len, - rope_scale, rope_theta, stream); + CAUSAL>( + q, k, v, o, tmp, lse, num_kv_heads, qo_len, kv_len, sm_scale, + rope_scale, rope_theta, stream); })})})})})}); return cudaSuccess; } @@ -1748,9 +1749,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched( DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float* tmp, float* lse, const uint32_t batch_size, const uint32_t num_qo_tiles, - const uint32_t num_kv_heads, const float rope_scale, const float rope_theta, - cudaStream_t stream = nullptr) { - const float sm_scale = 1.f / std::sqrt(float(HEAD_DIM)); + const uint32_t num_kv_heads, const float sm_scale, const float rope_scale, + const float rope_theta, cudaStream_t stream = nullptr) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); constexpr uint32_t num_warps = 4; @@ -1826,8 +1826,10 @@ cudaError_t BatchPrefillWithRaggedKVCache( const uint32_t batch_size, const uint32_t num_qo_heads, const uint32_t num_kv_heads, const uint32_t head_dim, bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, RotaryMode rotary_mode = RotaryMode::kNone, bool allow_fp16_qk_reduction = false, - const float rope_scale = 1.f, const float rope_theta = 1e4, cudaStream_t stream = nullptr) { + std::optional maybe_sm_scale = std::nullopt, const float rope_scale = 1.f, + const float rope_theta = 1e4, cudaStream_t stream = nullptr) { const uint32_t group_size = num_qo_heads / num_kv_heads; + const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); uint32_t num_frags_x, num_qo_tiles; std::vector request_indices_h, tile_indices_h; @@ -1865,7 +1867,8 @@ cudaError_t BatchPrefillWithRaggedKVCache( ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( q, request_indices_d, tile_indices_d, qo_indptr, k, v, kv_indptr, q_rope_position, k_rope_pos_offset, o, tmp, lse, batch_size, - num_qo_tiles, num_kv_heads, rope_scale, rope_theta, stream); + num_qo_tiles, num_kv_heads, sm_scale, rope_scale, rope_theta, + stream); })})})})})})}); FLASHINFER_CUDA_CALL(cudaFreeAsync(request_indices_d, stream)); @@ -1901,12 +1904,14 @@ template paged_kv, - DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float rope_scale = 1.f, + DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, + std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { constexpr QKVLayout KV_LAYOUT = QKVLayout::kNHD; const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t head_dim = paged_kv.head_dim; const uint32_t batch_size = paged_kv.batch_size; + const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); std::vector kv_indptr_h(paged_kv.batch_size + 1); @@ -1929,8 +1934,8 @@ cudaError_t BatchPrefillWithPagedKVCacheFallbackDispatched( ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( q, request_indices, tile_indices, qo_indptr, keys, values, kv_indptr, q_rope_position, - paged_kv.rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, rope_scale, - rope_theta, stream); + paged_kv.rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale, + rope_scale, rope_theta, stream); FLASHINFER_CUDA_CALL(cudaFreeAsync(keys, stream)); FLASHINFER_CUDA_CALL(cudaFreeAsync(values, stream)); @@ -1946,9 +1951,8 @@ template paged_kv, - DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float rope_scale, float rope_theta, - cudaStream_t stream) { - const float sm_scale = 1.f / std::sqrt(float(paged_kv.head_dim)); + DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float sm_scale, float rope_scale, + float rope_theta, cudaStream_t stream) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); constexpr uint32_t num_warps = 4; @@ -2022,11 +2026,13 @@ cudaError_t BatchPrefillWithPagedKVCache( paged_kv_t paged_kv, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_heads, bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone, bool allow_fp16_qk_reduction = false, - float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { + std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, + float rope_theta = 1e4, cudaStream_t stream = nullptr) { const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t head_dim = paged_kv.head_dim; const uint32_t batch_size = paged_kv.batch_size; const uint32_t group_size = num_qo_heads / num_kv_heads; + const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); uint32_t num_frags_x, num_qo_tiles; std::vector request_indices_h, tile_indices_h; @@ -2068,8 +2074,8 @@ cudaError_t BatchPrefillWithPagedKVCache( ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(q, request_indices_d, tile_indices_d, qo_indptr, q_rope_position, paged_kv, o, - tmp, lse, num_qo_tiles, rope_scale, - rope_theta, stream); + tmp, lse, num_qo_tiles, sm_scale, + rope_scale, rope_theta, stream); } else { return BatchPrefillWithPagedKVCacheDispatched< page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, @@ -2077,7 +2083,7 @@ cudaError_t BatchPrefillWithPagedKVCache( DTypeIn, DTypeOut, IdType>( q, request_indices_d, tile_indices_d, qo_indptr, q_rope_position, paged_kv, o, tmp, lse, num_qo_tiles, - rope_scale, rope_theta, stream); + sm_scale, rope_scale, rope_theta, stream); } }) diff --git a/include/flashinfer/attention/wrapper.cuh b/include/flashinfer/attention/wrapper.cuh index f644ec6b..73d0fc73 100644 --- a/include/flashinfer/attention/wrapper.cuh +++ b/include/flashinfer/attention/wrapper.cuh @@ -48,7 +48,8 @@ template paged_kv, DTypeOut* o, float* lse, - uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone, float rope_scale = 1.f, + uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone, + std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { paged_kv_t new_paged_kv = paged_kv; kv_partition_info_t kv_partition_info; @@ -73,7 +74,7 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper( } return BatchDecodeWithPagedKVCache( q, q_rope_position, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, rotary_mode, - rope_scale, rope_theta, stream); + maybe_sm_scale, rope_scale, rope_theta, stream); } template paged_kv, DTypeOut* o, float* lse, - float rope_scale, float rope_theta, cudaStream_t stream) { + float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { float* tmp = nullptr; IdType* request_indices = nullptr; IdType* tile_indices = nullptr; @@ -107,13 +108,13 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched( page_storage, kv_layout, NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( q, request_indices, tile_indices, qo_indptr, q_rope_position, paged_kv, o, tmp, lse, - num_qo_tiles, rope_scale, rope_theta, stream); + num_qo_tiles, sm_scale, rope_scale, rope_theta, stream); } else { return BatchPrefillWithPagedKVCacheDispatched< page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( q, request_indices, tile_indices, qo_indptr, q_rope_position, paged_kv, o, tmp, lse, - num_qo_tiles, rope_scale, rope_theta, stream); + num_qo_tiles, sm_scale, rope_scale, rope_theta, stream); } })}); return cudaSuccess; @@ -125,8 +126,9 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position, paged_kv_t paged_kv, DTypeOut* o, float* lse, uint32_t num_qo_heads, bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone, - bool allow_fp16_qk_reduction = false, float rope_scale = 1.f, float rope_theta = 1e4, - cudaStream_t stream = nullptr) { + bool allow_fp16_qk_reduction = false, std::optional maybe_sm_scale = std::nullopt, + float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) { + const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(paged_kv.head_dim))); const uint32_t num_kv_heads = paged_kv.num_heads; const uint32_t head_dim = paged_kv.head_dim; DISPATCH_GQA_GROUP_SIZE( @@ -142,8 +144,8 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper( return BatchPrefillWithPagedKVCacheWrapperDispatched< page_storage, kv_layout, GROUP_SIZE, HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( - handler, q, qo_indptr, q_rope_position, paged_kv, o, lse, rope_scale, - rope_theta, stream); + handler, q, qo_indptr, q_rope_position, paged_kv, o, lse, sm_scale, + rope_scale, rope_theta, stream); })})})})}); return cudaSuccess; } @@ -154,8 +156,8 @@ template ( q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, q_rope_position, - k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, rope_scale, - rope_theta, stream); + k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale, + rope_scale, rope_theta, stream); }); return cudaSuccess; } @@ -191,7 +193,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( const uint32_t num_qo_heads, const uint32_t num_kv_heads, const uint32_t head_dim, bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, RotaryMode rotary_mode = RotaryMode::kNone, bool allow_fp16_qk_reduction = false, - const float rope_scale = 1.f, const float rope_theta = 1e4, cudaStream_t stream = nullptr) { + std::optional maybe_sm_scale = std::nullopt, const float rope_scale = 1.f, + const float rope_theta = 1e4, cudaStream_t stream = nullptr) { + const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); DISPATCH_LAYOUT( kv_layout, KV_LAYOUT, {DISPATCH_GQA_GROUP_SIZE( @@ -209,7 +213,7 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper( ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>( handler, q, qo_indptr, k, v, kv_indptr, /*q_rope_position=*/nullptr, /*k_rope_pos_offset=*/nullptr, o, lse, batch_size, num_kv_heads, - rope_scale, rope_theta, stream); + sm_scale, rope_scale, rope_theta, stream); })})})})})}); return cudaSuccess; } diff --git a/python/csrc/flashinfer_decl.h b/python/csrc/flashinfer_decl.h index 5ae1c251..359d8d9f 100644 --- a/python/csrc/flashinfer_decl.h +++ b/python/csrc/flashinfer_decl.h @@ -26,7 +26,8 @@ CAUSAL, T, T, int32_t>(BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, \ int32_t* q_rope_position, \ paged_kv_t paged_kv, T* o, \ - float* lse, float rope_scale, float rope_theta, cudaStream_t stream); \ + float* lse, float sm_scale, float rope_scale, float rope_theta, \ + cudaStream_t stream); \ } #define INST_BatchPrefillRaggedWrapper(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, \ @@ -36,16 +37,17 @@ GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T, int32_t>( \ BatchPrefillHandler * handler, T* q, int32_t* qo_indptr, T* k, T* v, int32_t* kv_indptr, \ int32_t* q_rope_position, int32_t* k_rope_pos_offset, T* o, float* lse, uint32_t batch_size, \ - uint32_t num_kv_heads, float rope_scale, float rope_theta, cudaStream_t stream); \ + uint32_t num_kv_heads, float sm_scale, float rope_scale, float rope_theta, \ + cudaStream_t stream); \ } -#define INST_SinglePrefill(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, LAYOUT, \ - ROTARY_MODE) \ - namespace flashinfer { \ - template cudaError_t SinglePrefillWithKVCacheDispatched< \ - GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T>( \ - T * q, T* k, T* v, T* o, float* tmp, float* lse, uint32_t num_kv_heads, uint32_t qo_len, \ - uint32_t kv_len, float rope_scale, float rope_theta, cudaStream_t stream); \ +#define INST_SinglePrefill(T, GROUP_SIZE, HEAD_DIM, CAUSAL, ALLOW_FP16_QK_REDUCTION, LAYOUT, \ + ROTARY_MODE) \ + namespace flashinfer { \ + template cudaError_t SinglePrefillWithKVCacheDispatched< \ + GROUP_SIZE, HEAD_DIM, LAYOUT, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, T, T>( \ + T * q, T* k, T* v, T* o, float* tmp, float* lse, uint32_t num_kv_heads, uint32_t qo_len, \ + uint32_t kv_len, float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); \ } namespace flashinfer { @@ -58,8 +60,8 @@ template paged_kv, DTypeOut* o, float* lse, - float rope_scale, float rope_theta, cudaStream_t stream); + float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream); template cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp, float* lse, uint32_t num_kv_heads, - uint32_t qo_len, uint32_t kv_len, float rope_scale, - float rope_theta, cudaStream_t stream); + uint32_t qo_len, uint32_t kv_len, float sm_scale, + float rope_scale, float rope_theta, + cudaStream_t stream); } // namespace flashinfer diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index 6213b52a..b8060ce1 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -31,8 +31,8 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc std::vector single_prefill_with_kv_cache( torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, - unsigned int layout, unsigned int rotary_mode, bool allow_fp16_qk_reduction, float rope_scale, - float rope_theta, bool return_lse); + unsigned int layout, unsigned int rotary_mode, bool allow_fp16_qk_reduction, float sm_scale, + float rope_scale, float rope_theta, bool return_lse); void append_paged_kv_cache(torch::Tensor append_key, torch::Tensor append_value, torch::Tensor append_indptr, torch::Tensor kv_data, @@ -64,7 +64,8 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper { std::vector Forward(torch::Tensor q, torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, unsigned int rotary_mode, - float rope_scale, float rope_theta, bool return_lse); + float sm_scale, float rope_scale, float rope_theta, + bool return_lse); private: BatchDecodeWithPagedKVCachePyTorchWrapper(unsigned int layout) @@ -87,7 +88,8 @@ class BatchPrefillWithPagedKVCachePyTorchWrapper { torch::Tensor paged_kv_indices, torch::Tensor paged_kv_last_page_len, bool causal, unsigned int rotary_mode, bool allow_fp16_qk_reduction, - float rope_scale, float rope_theta, bool return_lse); + float sm_scale, float rope_scale, float rope_theta, + bool return_lse); private: BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout) @@ -108,7 +110,8 @@ class BatchPrefillWithRaggedKVCachePyTorchWrapper { std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, torch::Tensor k, torch::Tensor v, torch::Tensor kv_indptr, bool causal, unsigned int rotary_mode, bool allow_fp16_qk_reduction, - float rope_scale, float rope_theta, bool return_lse); + float sm_scale, float rope_scale, float rope_theta, + bool return_lse); private: BatchPrefillWithRaggedKVCachePyTorchWrapper(unsigned int layout) diff --git a/python/flashinfer/cascade.py b/python/flashinfer/cascade.py index 1c320b8e..27922237 100644 --- a/python/flashinfer/cascade.py +++ b/python/flashinfer/cascade.py @@ -272,6 +272,7 @@ def batch_decode_with_shared_prefix_padded_kv_cache( rotary_mode="NONE", kv_layout=kv_layout, allow_fp16_qk_reduction=allow_fp16_qk_reduction, + sm_scale=sm_scale, rope_scale=rope_scale, rope_theta=rope_theta, ) @@ -458,6 +459,7 @@ def forward( v_shared: torch.Tensor, unique_kv_data: torch.Tensor, allow_fp16_qk_reduction=False, + sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): @@ -488,6 +490,8 @@ def forward( allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). + sm_scale : Optional[float] + The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``. rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. @@ -507,6 +511,7 @@ def forward( rotary_mode="NONE", kv_layout=self._kv_layout, allow_fp16_qk_reduction=allow_fp16_qk_reduction, + sm_scale=sm_scale, rope_scale=rope_scale, rope_theta=rope_theta, ) @@ -514,6 +519,7 @@ def forward( q, unique_kv_data, rotary_mode="NONE", + sm_scale=sm_scale, rope_scale=rope_scale, rope_theta=rope_theta, ) @@ -698,6 +704,7 @@ def forward( unique_kv_data: torch.Tensor, causal: bool = True, allow_fp16_qk_reduction: bool = False, + sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): @@ -730,6 +737,8 @@ def forward( allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). + sm_scale : Optional[float] + The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``. rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. @@ -749,6 +758,7 @@ def forward( rotary_mode="NONE", kv_layout=self._kv_layout, allow_fp16_qk_reduction=allow_fp16_qk_reduction, + sm_scale=sm_scale, rope_scale=rope_scale, rope_theta=rope_theta, ) @@ -758,6 +768,7 @@ def forward( causal=causal, rotary_mode="NONE", allow_fp16_qk_reduction=allow_fp16_qk_reduction, + sm_scale=sm_scale, rope_scale=rope_scale, rope_theta=rope_theta, ) diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index 0139c76c..0742edff 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -507,6 +507,7 @@ def forward( q: torch.Tensor, paged_kv_data: torch.Tensor, rotary_mode: str = "NONE", + sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): @@ -525,6 +526,8 @@ def forward( rotary_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). + sm_scale : Optional[float] + The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``. rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. @@ -537,10 +540,14 @@ def forward( The attention output, shape: ``[batch_size, num_qo_heads, head_dim]``. """ check_rotary_mode(rotary_mode) + if sm_scale is None: + head_dim = q.shape[-1] + sm_scale = 1.0 / math.sqrt(head_dim) if rope_scale is None: rope_scale = 1.0 if rope_theta is None: rope_theta = 1e4 + paged_kv_data = expand_5d(paged_kv_data, self._kv_layout) return self._wrapper.forward( q, @@ -549,6 +556,7 @@ def forward( self._paged_kv_indices, self._paged_kv_last_page_len, getattr(RotaryMode, rotary_mode), + sm_scale, rope_scale, rope_theta, False, @@ -559,6 +567,7 @@ def forward_return_lse( q: torch.Tensor, paged_kv_data: torch.Tensor, rotary_mode: str = "NONE", + sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): @@ -578,6 +587,8 @@ def forward_return_lse( rotary_mode : str Whether to apply RoPE on-the-fly inside attention kernels, could be ``NONE`` or ``LLAMA`` (LLAMA style rotary embedding). + sm_scale : Optional[float] + The scale of softmax, if not provided, will be set to ``1 / sqrt(head_dim)``. rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. @@ -597,6 +608,9 @@ def forward_return_lse( explanation of the log-sum-exp function and attention states. """ check_rotary_mode(rotary_mode) + if sm_scale is None: + head_dim = q.shape[-1] + sm_scale = 1.0 / math.sqrt(head_dim) if rope_scale is None: rope_scale = 1.0 if rope_theta is None: @@ -609,6 +623,7 @@ def forward_return_lse( self._paged_kv_indices, self._paged_kv_last_page_len, getattr(RotaryMode, rotary_mode), + sm_scale, rope_scale, rope_theta, True, diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 1112f423..b5c082da 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -68,6 +68,7 @@ def single_prefill_with_kv_cache( kv_layout: str = "NHD", rotary_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, + sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): @@ -96,6 +97,8 @@ def single_prefill_with_kv_cache( allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to 1.0. rope_theta : Optional[float] @@ -133,6 +136,8 @@ def single_prefill_with_kv_cache( check_rotary_mode(rotary_mode) check_kv_layout(kv_layout) tmp = _get_cache_buf("single_prefill_with_kv_cache_tmp", 8 * 1024 * 1024, q.device) + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(q.size(-1)) if rope_scale is None: rope_scale = 1.0 if rope_theta is None: @@ -146,6 +151,7 @@ def single_prefill_with_kv_cache( getattr(TensorLayout, kv_layout), getattr(RotaryMode, rotary_mode), allow_fp16_qk_reduction, + sm_scale, rope_scale, rope_theta, False, @@ -160,6 +166,7 @@ def single_prefill_with_kv_cache_return_lse( kv_layout: str = "NHD", rotary_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, + sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): @@ -188,6 +195,8 @@ def single_prefill_with_kv_cache_return_lse( allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. rope_theta : Optional[float] @@ -233,6 +242,8 @@ def single_prefill_with_kv_cache_return_lse( tmp = _get_cache_buf( "single_prefill_with_kv_cache_return_lse_tmp", 8 * 1024 * 1024, q.device ) + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(q.size(-1)) if rope_scale is None: rope_scale = 1.0 if rope_theta is None: @@ -246,6 +257,7 @@ def single_prefill_with_kv_cache_return_lse( getattr(TensorLayout, kv_layout), getattr(RotaryMode, rotary_mode), allow_fp16_qk_reduction, + sm_scale, rope_scale, rope_theta, True, @@ -428,6 +440,7 @@ def forward( causal: bool = True, rotary_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, + sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): @@ -451,6 +464,9 @@ def forward( allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to + ``1.0 / sqrt(head_dim)``. rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. @@ -463,6 +479,8 @@ def forward( The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. """ check_rotary_mode(rotary_mode) + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(q.size(-1)) if rope_scale is None: rope_scale = 1.0 if rope_theta is None: @@ -478,6 +496,7 @@ def forward( causal, getattr(RotaryMode, rotary_mode), allow_fp16_qk_reduction, + sm_scale, rope_scale, rope_theta, False, @@ -490,6 +509,7 @@ def forward_return_lse( causal: bool = True, rotary_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, + sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): @@ -513,6 +533,9 @@ def forward_return_lse( allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to + ``1.0 / sqrt(head_dim)``. rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. @@ -528,6 +551,8 @@ def forward_return_lse( ``[qo_indptr[-1], num_qo_heads, head_dim]``. """ check_rotary_mode(rotary_mode) + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(q.size(-1)) if rope_scale is None: rope_scale = 1.0 if rope_theta is None: @@ -543,6 +568,7 @@ def forward_return_lse( causal, getattr(RotaryMode, rotary_mode), allow_fp16_qk_reduction, + sm_scale, rope_scale, rope_theta, True, @@ -699,6 +725,7 @@ def forward( causal: bool = True, rotary_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, + sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): @@ -721,6 +748,9 @@ def forward( allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to + ``1.0 / sqrt(head_dim)``. rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. @@ -733,6 +763,8 @@ def forward( The attention output, shape: ``[qo_indptr[-1], num_qo_heads, head_dim]``. """ check_rotary_mode(rotary_mode) + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(q.size(-1)) if rope_scale is None: rope_scale = 1.0 if rope_theta is None: @@ -746,6 +778,7 @@ def forward( causal, getattr(RotaryMode, rotary_mode), allow_fp16_qk_reduction, + sm_scale, rope_scale, rope_theta, False, @@ -759,6 +792,7 @@ def forward_return_lse( causal: bool = True, rotary_mode: str = "NONE", allow_fp16_qk_reduction: bool = False, + sm_scale: Optional[float] = None, rope_scale: Optional[float] = None, rope_theta: Optional[float] = None, ): @@ -781,6 +815,9 @@ def forward_return_lse( allow_fp16_qk_reduction : bool Whether to use f16 for qk reduction (faster at the cost of slight precision loss). + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to + ``1.0 / sqrt(head_dim)``. rope_scale : Optional[float] The scale used in RoPE interpolation, if not provided, will be set to ``1.0``. rope_theta : Optional[float] @@ -795,6 +832,8 @@ def forward_return_lse( ``[qo_indptr[-1], num_qo_heads, head_dim]``. """ check_rotary_mode(rotary_mode) + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(q.size(-1)) if rope_scale is None: rope_scale = 1.0 if rope_theta is None: @@ -808,6 +847,7 @@ def forward_return_lse( causal, getattr(RotaryMode, rotary_mode), allow_fp16_qk_reduction, + sm_scale, rope_scale, rope_theta, True, diff --git a/src/bench_single_decode.cu b/src/bench_single_decode.cu index 5bbe8755..d157e6ee 100644 --- a/src/bench_single_decode.cu +++ b/src/bench_single_decode.cu @@ -49,8 +49,10 @@ void bench_flashinfer_single_decode(nvbench::state& state) { thrust::raw_pointer_cast(Q.data()), thrust::raw_pointer_cast(K.data()), thrust::raw_pointer_cast(V.data()), thrust::raw_pointer_cast(O.data()), cooperative ? thrust::raw_pointer_cast(tmp.data()) : nullptr, num_qo_heads, num_kv_heads, - seq_len, head_dim, QKVLayout(kv_layout), RotaryMode(rotary_mode), 1.f, 1e4, - launch.get_stream()); + seq_len, head_dim, QKVLayout(kv_layout), RotaryMode(rotary_mode), + /*maybe_sm_scale=*/std::nullopt, + /*rope_scale=*/1.f, + /*rope_theta=*/1e4, launch.get_stream()); if (status != cudaSuccess) { state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); } @@ -91,7 +93,10 @@ void bench_flashinfer_single_decode_with_prefill(nvbench::state& state) { /*qo_len=*/1, /*kv_len=*/seq_len, head_dim, /*causal=*/false, QKVLayout(kv_layout), RotaryMode(rotary_mode), - /*allow_fp16_qk_reduction=*/false, 1.f, 1e4, launch.get_stream()); + /*allow_fp16_qk_reduction=*/false, + /*maybe_sm_scale=*/std::nullopt, + /*rope_scale=*/1.f, + /*rope_theta=*/1e4, launch.get_stream()); if (status != cudaSuccess) { state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); } diff --git a/src/bench_single_prefill.cu b/src/bench_single_prefill.cu index 78766124..14466583 100644 --- a/src/bench_single_prefill.cu +++ b/src/bench_single_prefill.cu @@ -58,8 +58,10 @@ void bench_flashinfer_single_prefill(nvbench::state& state) { thrust::raw_pointer_cast(V.data()), thrust::raw_pointer_cast(O.data()), /*tmp=*/cooperative ? thrust::raw_pointer_cast(tmp.data()) : nullptr, /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, - QKVLayout(kv_layout), RotaryMode(rotary_mode), allow_fp16_qk_reduction, 1.f, 1e4, - launch.get_stream()); + QKVLayout(kv_layout), RotaryMode(rotary_mode), allow_fp16_qk_reduction, + /*maybe_sm_scale=*/std::nullopt, + /*rope_scale=*/1.f, + /*rope_theta=*/1e4, launch.get_stream()); if (status != cudaSuccess) { state.skip("CUDA error: " + std::string(cudaGetErrorString(status))); } diff --git a/src/tvm_wrapper.cu b/src/tvm_wrapper.cu index d96f33a7..17899bd4 100644 --- a/src/tvm_wrapper.cu +++ b/src/tvm_wrapper.cu @@ -58,6 +58,7 @@ cudaError_t _SinglePrefillWithKVCacheNoLSE( CHECK(head_dim == 128) << "The head dimension must be 128"; CHECK(kv_layout == QKVLayout::kNHD) << "The KV layout must be NHD"; const uint32_t group_size = num_qo_heads / num_kv_heads; + const float sm_scale = 1.0f / std::sqrt(static_cast(head_dim)); DISPATCH_ALLOW_FP16_QK_REDUCTION( allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, @@ -68,7 +69,7 @@ cudaError_t _SinglePrefillWithKVCacheNoLSE( GROUP_SIZE, /*head_dim=*/128, /*layout=*/QKVLayout::kNHD, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL>( q, k, v, o, tmp, /*lse=*/nullptr, num_kv_heads, qo_len, kv_len, - rope_scale, rope_theta, stream); + sm_scale, rope_scale, rope_theta, stream); })})})}); return cudaSuccess; } @@ -186,9 +187,8 @@ template paged_kv, DTypeOut* o, float* lse, - uint32_t num_qo_heads, bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone, - bool allow_fp16_qk_reduction = false, float rope_scale = 1.f, float rope_theta = 1e4, - cudaStream_t stream = nullptr) { + uint32_t num_qo_heads, bool causal, RotaryMode rotary_mode, bool allow_fp16_qk_reduction, + float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) { CHECK(lse != nullptr) << "The lse buffer must be provided"; CHECK(allow_fp16_qk_reduction == false) << "The fp16 qk reduction is not supported"; CHECK(paged_kv.head_dim == 128) << "The head dimension must be 128"; @@ -202,8 +202,8 @@ cudaError_t _BatchPrefillWithPagedKVCacheWrapper( return BatchPrefillWithPagedKVCacheWrapperDispatched< page_storage, kv_layout, GROUP_SIZE, /*head_dim=*/128, ROTARY_MODE, /*allow_fp16_qk_reduction=*/false, CAUSAL, DTypeIn, DTypeOut, IdType>( - handler, q, qo_indptr, q_rope_position, paged_kv, o, lse, rope_scale, - rope_theta, stream); + handler, q, qo_indptr, q_rope_position, paged_kv, o, lse, sm_scale, + rope_scale, rope_theta, stream); })})}); return cudaSuccess; } @@ -218,10 +218,10 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q DLTensor* q_rope_position, // DLTensor* output, // DLTensor* lse, // - int64_t causal = 1, // - int64_t rotary_mode = 0, // - double rope_scale = 1.0f, // - double rope_theta = 1e4, + int64_t causal, // + int64_t rotary_mode, // + double rope_scale, // + double rope_theta, double attn_score_scaling_factor = 1.0f) { CHECK(handler_id < max_num_handlers) << "The handler id must be less than " << max_num_handlers; CHECK_EQ(q_data->device.device_type, kDLCUDA) << "The device of q_data must be CUDA."; @@ -239,7 +239,6 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q CHECK_EQ(qo_indptr->device.device_type, kDLCUDA) << "The device of qo_indptr matrix must be CUDA."; CHECK_EQ(output->device.device_type, kDLCUDA) << "The device of output must be CUDA."; - CHECK_EQ(attn_score_scaling_factor, 1.0f) << "The attention score scaling factor must be 1.0."; int32_t dev_id = q_data->device.device_id; CHECK_EQ(pages->device.device_id, dev_id); @@ -295,6 +294,7 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q constexpr PageStorage page_storage = PageStorage::kIndices; constexpr QKVLayout kv_layout = QKVLayout::kHND; + const float sm_scale = attn_score_scaling_factor / std::sqrt(static_cast(nfeat)); DISPATCH_TVM_CUDA_DTYPE( pages->dtype, dtype_in, @@ -315,7 +315,8 @@ void _FlashInferAttentionPrefillWithPagedKVCache(int64_t handler_id, DLTensor* q static_cast(output->data), /*lse=*/static_cast(lse->data), nhead_qo, /*causal=*/causal, RotaryMode(rotary_mode), /*allow_fp16_qk_reduction=*/false, - rope_scale, rope_theta, 0); + sm_scale, rope_scale, rope_theta, + /*stream=*/0); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); } @@ -373,7 +374,6 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ CHECK_EQ(k_rope_pos_offset->device.device_type, kDLCUDA) << "The device of k_rope_pos_offset matrix must be CUDA."; CHECK_EQ(output->device.device_type, kDLCUDA) << "The device of output must be CUDA."; - CHECK_EQ(attn_score_scaling_factor, 1.0f) << "The attention score scaling factor must be 1.0."; int32_t dev_id = q_data->device.device_id; CHECK_EQ(pages->device.device_id, dev_id); @@ -425,6 +425,7 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ constexpr PageStorage page_storage = PageStorage::kIndices; constexpr QKVLayout kv_layout = QKVLayout::kHND; + const float sm_scale = attn_score_scaling_factor / std::sqrt(static_cast(nfeat)); DISPATCH_TVM_CUDA_DTYPE( pages->dtype, dtype_in, @@ -441,8 +442,9 @@ void _FlashInferAttentionDecodeWithPagedKVCache(int64_t handler_id, DLTensor* q_ &batch_decode_handlers[handler_id], static_cast(q_data->data), static_cast(q_rope_position->data), cache, static_cast(output->data), - /*lse=*/static_cast(lse->data), nhead_qo, RotaryMode(rotary_mode), - rope_scale, rope_theta, 0); + /*lse=*/static_cast(lse->data), nhead_qo, RotaryMode(rotary_mode), sm_scale, + rope_scale, rope_theta, + /*stream=*/0); if (status != cudaSuccess) { LOG(FATAL) << "FlashInfer CUDA kernel error " << cudaGetErrorString(status); } @@ -491,13 +493,14 @@ cudaError_t _BatchPrefillWithRaggedKVCacheWrapper( BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v, IdType* kv_indptr, IdType* q_rope_position_map, IdType* k_rope_pos_offset, DTypeOut* o, float* lse, const uint32_t batch_size, const uint32_t num_qo_heads, const uint32_t num_kv_heads, - const uint32_t head_dim, bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, - RotaryMode rotary_mode = RotaryMode::kNone, bool allow_fp16_qk_reduction = false, - const float rope_scale = 1.f, const float rope_theta = 1e4, cudaStream_t stream = nullptr) { + const uint32_t head_dim, bool causal, QKVLayout kv_layout, RotaryMode rotary_mode, + bool allow_fp16_qk_reduction, const float sm_scale, const float rope_scale, + const float rope_theta, cudaStream_t stream) { CHECK(lse != nullptr) << "The lse buffer must be provided"; CHECK(head_dim == 128) << "The head dimension must be 128"; CHECK(kv_layout == QKVLayout::kNHD) << "The layout must be NHD"; CHECK(allow_fp16_qk_reduction == false) << "The fp16 qk reduction is not supported"; + DISPATCH_GQA_GROUP_SIZE( num_qo_heads / num_kv_heads, GROUP_SIZE, {DISPATCH_CAUSAL(causal, CAUSAL, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, { @@ -505,8 +508,8 @@ cudaError_t _BatchPrefillWithRaggedKVCacheWrapper( GROUP_SIZE, /*head_dim=*/128, /*layout=*/QKVLayout::kNHD, ROTARY_MODE, /*allow_fp16_qk_reduction=*/false, CAUSAL, DTypeIn, DTypeOut, IdType>( handler, q, qo_indptr, k, v, kv_indptr, q_rope_position_map, - k_rope_pos_offset, o, lse, batch_size, num_kv_heads, rope_scale, - rope_theta, stream); + k_rope_pos_offset, o, lse, batch_size, num_kv_heads, sm_scale, + rope_scale, rope_theta, stream); })})}); return cudaSuccess; } @@ -527,7 +530,6 @@ void _FlashInferAttentionPrefillWithRaggedKVCache( << "The device of q_rope_position_map must be CUDA."; CHECK_EQ(k_rope_pos_offset->device.device_type, kDLCUDA) << "The device of k_rope_pos_offset must be CUDA."; - CHECK_EQ(attn_score_scaling_factor, 1.0f) << "The attention score scaling factor must be 1.0."; int dev_id = q_data->device.device_id; CHECK_EQ(qo_indptr->device.device_id, dev_id); @@ -580,6 +582,8 @@ void _FlashInferAttentionPrefillWithRaggedKVCache( CHECK_EQ(k_rope_pos_offset->ndim, 1); CHECK_EQ(k_rope_pos_offset->shape[0], batch_size); + const float sm_scale = attn_score_scaling_factor / std::sqrt(static_cast(nfeat)); + DISPATCH_TVM_CUDA_DTYPE( q_data->dtype, dtype_in, {DISPATCH_TVM_CUDA_DTYPE( @@ -594,7 +598,8 @@ void _FlashInferAttentionPrefillWithRaggedKVCache( static_cast(output->data), /*lse=*/static_cast(lse->data), batch_size, nhead_qo, nhead_kv, nfeat, /*causal=*/bool(causal), QKVLayout::kNHD, RotaryMode(rotary_mode), - /*allow_fp16_qk_reduction=*/false, rope_scale, rope_theta, 0); + /*allow_fp16_qk_reduction=*/false, sm_scale, rope_scale, rope_theta, + /*sm_scale=*/0); })})}) }