Skip to content

Commit

Permalink
feat: adding sm_scale field for all attention APIs (#145)
Browse files Browse the repository at this point in the history
Some of our attention APIs have this field and some don't, this PR add
`sm_scale` field for all attention APIs to make them consistent.
  • Loading branch information
yzh119 authored Mar 1, 2024
1 parent 660c559 commit 85d4018
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 105 deletions.
40 changes: 24 additions & 16 deletions include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <cuda/pipeline>
#include <iostream>
#include <optional>
#include <random>

#include "../cp_async.cuh"
Expand Down Expand Up @@ -816,9 +817,10 @@ cudaError_t SingleDecodeWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t seq_len,
uint32_t head_dim, QKVLayout kv_layout = QKVLayout::kNHD,
RotaryMode rotary_mode = RotaryMode::kNone,
std::optional<float> maybe_sm_scale = std::nullopt,
float rope_scale = 1.f, float rope_theta = 1e4,
cudaStream_t stream = nullptr) {
const float sm_scale = 1.f / std::sqrt(float(head_dim));
float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim)));
const float rope_rcp_scale = 1.f / rope_scale;
const float rope_rcp_theta = 1.f / rope_theta;
if (num_qo_heads % num_kv_heads != 0) {
Expand Down Expand Up @@ -1134,8 +1136,7 @@ cudaError_t BatchDecodeWithPagedKVCacheDispatched(
DTypeIn* q, IdType* q_rope_position,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse,
float rope_scale, float rope_theta, cudaStream_t stream) {
const float sm_scale = 1.f / std::sqrt(float(HEAD_DIM));
float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) {
const float rope_rcp_scale = 1.f / rope_scale;
const float rope_rcp_theta = 1.f / rope_theta;
const uint32_t num_kv_heads = paged_kv.num_heads;
Expand Down Expand Up @@ -1229,11 +1230,13 @@ cudaError_t BatchDecodeWithPagedKVCache(
DTypeIn* q, IdType* q_rope_position,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
kv_partition_info_t<IdType> kv_partition_info, DTypeOut* o, DTypeOut* tmp, float* lse,
uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone, float rope_scale = 1.f,
uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone,
std::optional<float> maybe_sm_scale = std::nullopt, float rope_scale = 1.f,
float rope_theta = 1e4, cudaStream_t stream = nullptr) {
const uint32_t num_kv_heads = paged_kv.num_heads;
const uint32_t head_dim = paged_kv.head_dim;
const uint32_t batch_size = paged_kv.batch_size;
const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim)));
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads " << num_qo_heads << " is not a multiple of num_kv_heads "
Expand All @@ -1243,12 +1246,14 @@ cudaError_t BatchDecodeWithPagedKVCache(

DISPATCH_GQA_GROUP_SIZE(
num_qo_heads / num_kv_heads, GROUP_SIZE,
{DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, {
return BatchDecodeWithPagedKVCacheDispatched<
GROUP_SIZE, HEAD_DIM, page_storage, kv_layout, ROTARY_MODE, DTypeIn,
DTypeOut, IdType>(q, q_rope_position, paged_kv, kv_partition_info, o,
tmp, lse, rope_scale, rope_theta, stream);
})})});
{DISPATCH_HEAD_DIM(
head_dim, HEAD_DIM, {DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, {
return BatchDecodeWithPagedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, page_storage,
kv_layout, ROTARY_MODE, DTypeIn, DTypeOut,
IdType>(
q, q_rope_position, paged_kv, kv_partition_info, o, tmp, lse, sm_scale, rope_scale,
rope_theta, stream);
})})});

return cudaSuccess;
}
Expand All @@ -1258,9 +1263,8 @@ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT, RotaryMod
cudaError_t BatchDecodeWithPaddedKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeIn* o,
DTypeOut* tmp, float* lse, uint32_t batch_size,
uint32_t padded_kv_len, uint32_t num_qo_heads,
float rope_scale, float rope_theta,
cudaStream_t stream) {
const float sm_scale = 1.f / std::sqrt(float(HEAD_DIM));
float sm_scale, float rope_scale,
float rope_theta, cudaStream_t stream) {
const float rope_rcp_scale = 1.f / rope_scale;
const float rope_rcp_theta = 1.f / rope_theta;
const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE;
Expand Down Expand Up @@ -1301,7 +1305,11 @@ cudaError_t BatchDecodeWithPaddedKVCache(
DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeIn* o, DTypeOut* tmp, float* lse, uint32_t batch_size,
uint32_t padded_kv_len, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
QKVLayout kv_layout = QKVLayout::kNHD, RotaryMode rotary_mode = RotaryMode::kNone,
float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) {
std::optional<float> sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4,
cudaStream_t stream = nullptr) {
if (!sm_scale.has_value()) {
sm_scale = 1.f / std::sqrt(float(head_dim));
}
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads " << num_qo_heads << " is not a multiple of num_kv_heads "
Expand All @@ -1317,8 +1325,8 @@ cudaError_t BatchDecodeWithPaddedKVCache(
rotary_mode, ROTARY_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, {
return BatchDecodeWithPaddedKVCacheDispatched<GROUP_SIZE, HEAD_DIM, KV_LAYOUT,
ROTARY_MODE, DTypeIn, DTypeOut>(
q, k, v, o, tmp, lse, batch_size, padded_kv_len, num_qo_heads, rope_scale,
rope_theta, stream);
q, k, v, o, tmp, lse, batch_size, padded_kv_len, num_qo_heads, sm_scale,
rope_scale, rope_theta, stream);
})})})});
return cudaSuccess;
}
Expand Down
62 changes: 34 additions & 28 deletions include/flashinfer/attention/prefill.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#endif
#include <cuda_runtime.h>

#include <optional>
#include <tuple>

#include "../cp_async.cuh"
Expand Down Expand Up @@ -1563,9 +1564,9 @@ template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT, RotaryMod
bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, typename DTypeIn, typename DTypeOut>
cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o,
float* tmp, float* lse, uint32_t num_kv_heads,
uint32_t qo_len, uint32_t kv_len, float rope_scale,
float rope_theta, cudaStream_t stream) {
const float sm_scale = 1.f / std::sqrt(float(HEAD_DIM));
uint32_t qo_len, uint32_t kv_len, float sm_scale,
float rope_scale, float rope_theta,
cudaStream_t stream) {
const float log2_rope_rcp_scale = -std::log2f(rope_scale);
const float log2_rope_rcp_theta = -std::log2f(rope_theta);
if (kv_len < qo_len && CAUSAL) {
Expand Down Expand Up @@ -1714,14 +1715,14 @@ cudaError_t SinglePrefillWithKVCacheDispatched(DTypeIn* q, DTypeIn* k, DTypeIn*
* \return status Indicates whether CUDA calls are successful
*/
template <typename DTypeIn, typename DTypeOut>
cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp,
float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t qo_len, uint32_t kv_len, uint32_t head_dim,
bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD,
RotaryMode rotary_mode = RotaryMode::kNone,
bool allow_fp16_qk_reduction = false, float rope_scale = 1.f,
float rope_theta = 1e4, cudaStream_t stream = nullptr) {
cudaError_t SinglePrefillWithKVCache(
DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, bool causal = true,
QKVLayout kv_layout = QKVLayout::kNHD, RotaryMode rotary_mode = RotaryMode::kNone,
bool allow_fp16_qk_reduction = false, std::optional<float> maybe_sm_scale = std::nullopt,
float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) {
const uint32_t group_size = num_qo_heads / num_kv_heads;
const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim)));
DISPATCH_ALLOW_FP16_QK_REDUCTION(
allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION,
{DISPATCH_GQA_GROUP_SIZE(
Expand All @@ -1734,9 +1735,9 @@ cudaError_t SinglePrefillWithKVCache(DTypeIn* q, DTypeIn* k, DTypeIn* v, DTypeOu
rotary_mode, ROTARY_MODE, {DISPATCH_LAYOUT(kv_layout, KV_LAYOUT, {
SinglePrefillWithKVCacheDispatched<GROUP_SIZE, HEAD_DIM, KV_LAYOUT,
ROTARY_MODE, ALLOW_FP16_QK_REDUCTION,
CAUSAL>(q, k, v, o, tmp, lse,
num_kv_heads, qo_len, kv_len,
rope_scale, rope_theta, stream);
CAUSAL>(
q, k, v, o, tmp, lse, num_kv_heads, qo_len, kv_len, sm_scale,
rope_scale, rope_theta, stream);
})})})})})});
return cudaSuccess;
}
Expand All @@ -1748,9 +1749,8 @@ cudaError_t BatchPrefillWithRaggedKVCacheDispatched(
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr, DTypeIn* k,
DTypeIn* v, IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o,
float* tmp, float* lse, const uint32_t batch_size, const uint32_t num_qo_tiles,
const uint32_t num_kv_heads, const float rope_scale, const float rope_theta,
cudaStream_t stream = nullptr) {
const float sm_scale = 1.f / std::sqrt(float(HEAD_DIM));
const uint32_t num_kv_heads, const float sm_scale, const float rope_scale,
const float rope_theta, cudaStream_t stream = nullptr) {
const float log2_rope_rcp_scale = -std::log2f(rope_scale);
const float log2_rope_rcp_theta = -std::log2f(rope_theta);
constexpr uint32_t num_warps = 4;
Expand Down Expand Up @@ -1826,8 +1826,10 @@ cudaError_t BatchPrefillWithRaggedKVCache(
const uint32_t batch_size, const uint32_t num_qo_heads, const uint32_t num_kv_heads,
const uint32_t head_dim, bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD,
RotaryMode rotary_mode = RotaryMode::kNone, bool allow_fp16_qk_reduction = false,
const float rope_scale = 1.f, const float rope_theta = 1e4, cudaStream_t stream = nullptr) {
std::optional<float> maybe_sm_scale = std::nullopt, const float rope_scale = 1.f,
const float rope_theta = 1e4, cudaStream_t stream = nullptr) {
const uint32_t group_size = num_qo_heads / num_kv_heads;
const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim)));

uint32_t num_frags_x, num_qo_tiles;
std::vector<IdType> request_indices_h, tile_indices_h;
Expand Down Expand Up @@ -1865,7 +1867,8 @@ cudaError_t BatchPrefillWithRaggedKVCache(
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
q, request_indices_d, tile_indices_d, qo_indptr, k, v, kv_indptr,
q_rope_position, k_rope_pos_offset, o, tmp, lse, batch_size,
num_qo_tiles, num_kv_heads, rope_scale, rope_theta, stream);
num_qo_tiles, num_kv_heads, sm_scale, rope_scale, rope_theta,
stream);
})})})})})})});

FLASHINFER_CUDA_CALL(cudaFreeAsync(request_indices_d, stream));
Expand Down Expand Up @@ -1901,12 +1904,14 @@ template <PageStorage page_storage, QKVLayout kv_layout, uint32_t num_frags_x, u
cudaError_t BatchPrefillWithPagedKVCacheFallbackDispatched(
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr,
IdType* q_rope_position, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float rope_scale = 1.f,
DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles,
std::optional<float> maybe_sm_scale = std::nullopt, float rope_scale = 1.f,
float rope_theta = 1e4, cudaStream_t stream = nullptr) {
constexpr QKVLayout KV_LAYOUT = QKVLayout::kNHD;
const uint32_t num_kv_heads = paged_kv.num_heads;
const uint32_t head_dim = paged_kv.head_dim;
const uint32_t batch_size = paged_kv.batch_size;
const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim)));

std::vector<IdType> kv_indptr_h(paged_kv.batch_size + 1);

Expand All @@ -1929,8 +1934,8 @@ cudaError_t BatchPrefillWithPagedKVCacheFallbackDispatched(
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut,
IdType>(
q, request_indices, tile_indices, qo_indptr, keys, values, kv_indptr, q_rope_position,
paged_kv.rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, rope_scale,
rope_theta, stream);
paged_kv.rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale,
rope_scale, rope_theta, stream);

FLASHINFER_CUDA_CALL(cudaFreeAsync(keys, stream));
FLASHINFER_CUDA_CALL(cudaFreeAsync(values, stream));
Expand All @@ -1946,9 +1951,8 @@ template <PageStorage page_storage, QKVLayout kv_layout, uint32_t num_frags_x, u
cudaError_t BatchPrefillWithPagedKVCacheDispatched(
DTypeIn* q, IdType* request_indices, IdType* tile_indices, IdType* qo_indptr,
IdType* q_rope_position, paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv,
DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float rope_scale, float rope_theta,
cudaStream_t stream) {
const float sm_scale = 1.f / std::sqrt(float(paged_kv.head_dim));
DTypeOut* o, float* tmp, float* lse, uint32_t num_qo_tiles, float sm_scale, float rope_scale,
float rope_theta, cudaStream_t stream) {
const float log2_rope_rcp_scale = -std::log2f(rope_scale);
const float log2_rope_rcp_theta = -std::log2f(rope_theta);
constexpr uint32_t num_warps = 4;
Expand Down Expand Up @@ -2022,11 +2026,13 @@ cudaError_t BatchPrefillWithPagedKVCache(
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* tmp,
float* lse, uint32_t num_qo_heads, bool causal = true,
RotaryMode rotary_mode = RotaryMode::kNone, bool allow_fp16_qk_reduction = false,
float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) {
std::optional<float> maybe_sm_scale = std::nullopt, float rope_scale = 1.f,
float rope_theta = 1e4, cudaStream_t stream = nullptr) {
const uint32_t num_kv_heads = paged_kv.num_heads;
const uint32_t head_dim = paged_kv.head_dim;
const uint32_t batch_size = paged_kv.batch_size;
const uint32_t group_size = num_qo_heads / num_kv_heads;
const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim)));

uint32_t num_frags_x, num_qo_tiles;
std::vector<IdType> request_indices_h, tile_indices_h;
Expand Down Expand Up @@ -2068,16 +2074,16 @@ cudaError_t BatchPrefillWithPagedKVCache(
ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn,
DTypeOut, IdType>(q, request_indices_d, tile_indices_d,
qo_indptr, q_rope_position, paged_kv, o,
tmp, lse, num_qo_tiles, rope_scale,
rope_theta, stream);
tmp, lse, num_qo_tiles, sm_scale,
rope_scale, rope_theta, stream);
} else {
return BatchPrefillWithPagedKVCacheDispatched<
page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE,
HEAD_DIM, ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL,
DTypeIn, DTypeOut, IdType>(
q, request_indices_d, tile_indices_d, qo_indptr,
q_rope_position, paged_kv, o, tmp, lse, num_qo_tiles,
rope_scale, rope_theta, stream);
sm_scale, rope_scale, rope_theta, stream);
}
})

Expand Down
Loading

0 comments on commit 85d4018

Please sign in to comment.