Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support ALiBi #146

Merged
merged 9 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ num_qo_heads = 32
q = torch.randn(num_qo_heads, head_dim).half().to(0)

o = flashinfer.single_decode_with_kv_cache(q, k, v) # decode attention without RoPE on-the-fly
o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, rotary_mode="LLAMA") # decode with LLaMA style RoPE on-the-fly
o_rope_on_the_fly = flashinfer.single_decode_with_kv_cache(q, k, v, pos_encoding_mode="LLAMA") # decode with LLaMA style RoPE on-the-fly

# append attention
append_qo_len = 128
q = torch.randn(append_qo_len, num_qo_heads, head_dim).half().to(0) # append attention, the last 128 tokens in the KV-Cache are the new tokens
o = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True) # append attention without RoPE on-the-fly, apply causal mask
o_rope_on_the_fly = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, rotary_mode="LLAMA") # append attention with LLaMA style RoPE on-the-fly, apply causal mask
o_rope_on_the_fly = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=True, pos_encoding_mode="LLAMA") # append attention with LLaMA style RoPE on-the-fly, apply causal mask

# prefill attention
qo_len = 2048
Expand Down
2 changes: 1 addition & 1 deletion include/flashinfer.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
#include "flashinfer/attention.cuh"
#include "flashinfer/layout.cuh"
#include "flashinfer/page.cuh"
#include "flashinfer/rope.cuh"
#include "flashinfer/pos_enc.cuh"

#endif // FLASHINFER_CUH_
182 changes: 93 additions & 89 deletions include/flashinfer/attention/decode.cuh

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include <unordered_map>
#include <vector>

#include "../rope.cuh"
#include "../pos_enc.cuh"
#include "../utils.cuh"
#include "decode.cuh"

Expand Down Expand Up @@ -81,15 +81,15 @@ class BatchDecodeHandler {
cudaError_t BeginForward(void* buffer, size_t workspace_size_in_bytes, IdType* indptr,
IdType* last_page_len, uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t head_dim, uint32_t page_size,
RotaryMode rotary_mode) {
PosEncodingMode pos_encoding_mode) {
batch_size_before_partition_ = batch_size;
uint32_t tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size;
auto work_estimation_func =
BatchDecodeWithPagedKVCacheWorkEstimation<page_storage, kv_layout, DTypeIn, DTypeOut,
IdType>;
FLASHINFER_CUDA_CALL(work_estimation_func(
tmp_size, max_grid_size, max_num_pages_per_batch, new_batch_size, batch_size, indptr,
num_qo_heads, num_kv_heads, head_dim, page_size, rotary_mode, stream_));
num_qo_heads, num_kv_heads, head_dim, page_size, pos_encoding_mode, stream_));
batch_size_after_partition_ = new_batch_size;
if (tmp_size > 0) {
AlignedAlloactor allocator(buffer, workspace_size_in_bytes);
Expand Down
298 changes: 190 additions & 108 deletions include/flashinfer/attention/prefill.cuh

Large diffs are not rendered by default.

71 changes: 36 additions & 35 deletions include/flashinfer/attention/wrapper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace flashinfer {
* \param o The output tensor.
* \param lse The logsumexp values.
* \param num_qo_heads The number of heads.
* \param rotary_mode The rotary mode.
* \param pos_encoding_mode The positional encoding mode.
* \param rope_scale The scale of rope.
* \param rope_theta The theta of rope.
* \param stream The CUDA stream.
Expand All @@ -46,9 +46,9 @@ namespace flashinfer {
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BatchDecodeWithPagedKVCacheWrapper(
BatchDecodeHandler* handler, DTypeIn* q, IdType* q_rope_position,
BatchDecodeHandler* handler, DTypeIn* q, IdType* q_offset,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
uint32_t num_qo_heads, RotaryMode rotary_mode = RotaryMode::kNone,
uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone,
std::optional<float> maybe_sm_scale = std::nullopt, float rope_scale = 1.f,
float rope_theta = 1e4, cudaStream_t stream = nullptr) {
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> new_paged_kv = paged_kv;
Expand All @@ -73,15 +73,15 @@ cudaError_t BatchDecodeWithPagedKVCacheWrapper(
throw std::runtime_error(err_msg.str());
}
return BatchDecodeWithPagedKVCache<page_storage, kv_layout, DTypeIn, DTypeOut, IdType>(
q, q_rope_position, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, rotary_mode,
q, q_offset, new_paged_kv, kv_partition_info, o, tmp, lse, num_qo_heads, pos_encoding_mode,
maybe_sm_scale, rope_scale, rope_theta, stream);
}

template <PageStorage page_storage, QKVLayout kv_layout, uint32_t GROUP_SIZE, uint32_t HEAD_DIM,
RotaryMode ROTARY_MODE, bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, typename DTypeIn,
typename DTypeOut, typename IdType>
PosEncodingMode pos_encoding_mode, bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL,
typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position,
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
float sm_scale, float rope_scale, float rope_theta, cudaStream_t stream) {
float* tmp = nullptr;
Expand All @@ -105,15 +105,15 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
num_frags_x, NUM_FRAGS_X, {DISPATCH_PAGE_SIZE(paged_kv.page_size, PAGE_SIZE, {
if constexpr (PAGE_SIZE == 0) {
return BatchPrefillWithPagedKVCacheFallbackDispatched<
page_storage, kv_layout, NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
page_storage, kv_layout, NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, pos_encoding_mode,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, q_rope_position, paged_kv, o, tmp, lse,
q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, o, tmp, lse,
num_qo_tiles, sm_scale, rope_scale, rope_theta, stream);
} else {
return BatchPrefillWithPagedKVCacheDispatched<
page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, q_rope_position, paged_kv, o, tmp, lse,
page_storage, kv_layout, NUM_FRAGS_X, PAGE_SIZE, GROUP_SIZE, HEAD_DIM,
pos_encoding_mode, ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, q_offset, paged_kv, o, tmp, lse,
num_qo_tiles, sm_scale, rope_scale, rope_theta, stream);
}
})});
Expand All @@ -123,9 +123,10 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapperDispatched(
template <PageStorage page_storage, QKVLayout kv_layout, typename DTypeIn, typename DTypeOut,
typename IdType>
cudaError_t BatchPrefillWithPagedKVCacheWrapper(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_rope_position,
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, IdType* q_offset,
paged_kv_t<page_storage, kv_layout, DTypeIn, IdType> paged_kv, DTypeOut* o, float* lse,
uint32_t num_qo_heads, bool causal = true, RotaryMode rotary_mode = RotaryMode::kNone,
uint32_t num_qo_heads, bool causal = true,
PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone,
bool allow_fp16_qk_reduction = false, std::optional<float> maybe_sm_scale = std::nullopt,
float rope_scale = 1.f, float rope_theta = 1e4, cudaStream_t stream = nullptr) {
const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(paged_kv.head_dim)));
Expand All @@ -137,25 +138,25 @@ cudaError_t BatchPrefillWithPagedKVCacheWrapper(
head_dim, HEAD_DIM,
{DISPATCH_CAUSAL(
causal, CAUSAL,
{DISPATCH_ROTARY_MODE(
rotary_mode, ROTARY_MODE,
{DISPATCH_POS_ENCODING_MODE(
pos_encoding_mode, pos_encoding_mode,
{DISPATCH_ALLOW_FP16_QK_REDUCTION(
allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {
return BatchPrefillWithPagedKVCacheWrapperDispatched<
page_storage, kv_layout, GROUP_SIZE, HEAD_DIM, ROTARY_MODE,
page_storage, kv_layout, GROUP_SIZE, HEAD_DIM, pos_encoding_mode,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
handler, q, qo_indptr, q_rope_position, paged_kv, o, lse, sm_scale,
rope_scale, rope_theta, stream);
handler, q, qo_indptr, q_offset, paged_kv, o, lse, sm_scale, rope_scale,
rope_theta, stream);
})})})})});
return cudaSuccess;
}

template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT, RotaryMode ROTARY_MODE,
bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL, typename DTypeIn, typename DTypeOut,
typename IdType>
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, QKVLayout KV_LAYOUT,
PosEncodingMode pos_encoding_mode, bool ALLOW_FP16_QK_REDUCTION, bool CAUSAL,
typename DTypeIn, typename DTypeOut, typename IdType>
cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(
BatchPrefillHandler* handler, DTypeIn* q, IdType* qo_indptr, DTypeIn* k, DTypeIn* v,
IdType* kv_indptr, IdType* q_rope_position, IdType* k_rope_pos_offset, DTypeOut* o, float* lse,
IdType* kv_indptr, IdType* q_offset, IdType* k_rope_pos_offset, DTypeOut* o, float* lse,
const uint32_t batch_size, const uint32_t num_kv_heads, const float sm_scale,
const float rope_scale, const float rope_theta, cudaStream_t stream) {
float* tmp = nullptr;
Expand All @@ -177,11 +178,11 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapperDispatched(

DISPATCH_NUM_FRAGS_X(num_frags_x, NUM_FRAGS_X, {
return BatchPrefillWithRaggedKVCacheDispatched<NUM_FRAGS_X, GROUP_SIZE, HEAD_DIM, KV_LAYOUT,
ROTARY_MODE, ALLOW_FP16_QK_REDUCTION, CAUSAL,
DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, q_rope_position,
k_rope_pos_offset, o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale,
rope_scale, rope_theta, stream);
pos_encoding_mode, ALLOW_FP16_QK_REDUCTION,
CAUSAL, DTypeIn, DTypeOut, IdType>(
q, request_indices, tile_indices, qo_indptr, k, v, kv_indptr, q_offset, k_rope_pos_offset,
o, tmp, lse, batch_size, num_qo_tiles, num_kv_heads, sm_scale, rope_scale, rope_theta,
stream);
});
return cudaSuccess;
}
Expand All @@ -192,9 +193,9 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper(
IdType* kv_indptr, DTypeOut* o, float* lse, const uint32_t batch_size,
const uint32_t num_qo_heads, const uint32_t num_kv_heads, const uint32_t head_dim,
bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD,
RotaryMode rotary_mode = RotaryMode::kNone, bool allow_fp16_qk_reduction = false,
std::optional<float> maybe_sm_scale = std::nullopt, const float rope_scale = 1.f,
const float rope_theta = 1e4, cudaStream_t stream = nullptr) {
PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone,
bool allow_fp16_qk_reduction = false, std::optional<float> maybe_sm_scale = std::nullopt,
const float rope_scale = 1.f, const float rope_theta = 1e4, cudaStream_t stream = nullptr) {
const float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim)));
DISPATCH_LAYOUT(
kv_layout, KV_LAYOUT,
Expand All @@ -204,14 +205,14 @@ cudaError_t BatchPrefillWithRaggedKVCacheWrapper(
head_dim, HEAD_DIM,
{DISPATCH_CAUSAL(
causal, CAUSAL,
{DISPATCH_ROTARY_MODE(
rotary_mode, ROTARY_MODE,
{DISPATCH_POS_ENCODING_MODE(
pos_encoding_mode, pos_encoding_mode,
{DISPATCH_ALLOW_FP16_QK_REDUCTION(
allow_fp16_qk_reduction, ALLOW_FP16_QK_REDUCTION, {
return BatchPrefillWithRaggedKVCacheWrapperDispatched<
GROUP_SIZE, HEAD_DIM, KV_LAYOUT, ROTARY_MODE,
GROUP_SIZE, HEAD_DIM, KV_LAYOUT, pos_encoding_mode,
ALLOW_FP16_QK_REDUCTION, CAUSAL, DTypeIn, DTypeOut, IdType>(
handler, q, qo_indptr, k, v, kv_indptr, /*q_rope_position=*/nullptr,
handler, q, qo_indptr, k, v, kv_indptr, /*q_offset=*/nullptr,
/*k_rope_pos_offset=*/nullptr, o, lse, batch_size, num_kv_heads,
sm_scale, rope_scale, rope_theta, stream);
})})})})})});
Expand Down
34 changes: 22 additions & 12 deletions include/flashinfer/rope.cuh → include/flashinfer/pos_enc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_ROPE_CUH_
#define FLASHINFER_ROPE_CUH_
#ifndef FLASHINFER_POS_ENC_CUH_
#define FLASHINFER_POS_ENC_CUH_

#include <string>

#include "layout.cuh"
#include "math.cuh"
#include "utils.cuh"
#include "vec_dtypes.cuh"

Expand All @@ -28,28 +29,37 @@ namespace flashinfer {
* \brief An enumeration class that defines different modes for applying RoPE
* (Rotary Positional Embeddings).
*/
enum class RotaryMode {
enum class PosEncodingMode {
// No rotary positional embeddings
kNone = 0U,
// Apply Llama-style rope.
kLlama = 1U,
kRoPELlama = 1U,
// Apply ALiBi bias
kALiBi = 2U
};

/*!
* \brief Convert RotaryMode to string
* \param rotary_mode A RotaryMode value
* \brief Convert PosEncodingMode to string
* \param pos_encoding_mode A PosEncodingMode value
*/
inline std::string RotaryModeToString(const RotaryMode& rotary_mode) {
switch (rotary_mode) {
case RotaryMode::kNone:
inline std::string PosEncodingModeToString(const PosEncodingMode& pos_encoding_mode) {
switch (pos_encoding_mode) {
case PosEncodingMode::kNone:
return "None";
case RotaryMode::kLlama:
case PosEncodingMode::kRoPELlama:
return "Llama";
case PosEncodingMode::kALiBi:
return "ALiBi";
default:
return "Unknown";
}
}

__device__ __forceinline__ float get_alibi_slope(uint32_t head_idx, uint32_t num_heads) {
// NOTE(Zihao): here we assume that num_heads is a power of 2
return math::ptx_exp2(-8. * float(head_idx + 1) / float(num_heads));
}

/*!
* \brief Apply RoPE (Rotary Positional Embeddings) to x[0: head_dim],
* return thread-local vector
Expand All @@ -63,7 +73,7 @@ inline std::string RotaryModeToString(const RotaryMode& rotary_mode) {
*/
template <uint32_t vec_size, uint32_t bdx, typename T>
__device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope(
const T* x, const vec_t<float, vec_size>& freq, uint32_t offset) {
const T* x, const vec_t<float, vec_size>& freq, int32_t offset) {
constexpr uint32_t head_dim = vec_size * bdx;
vec_t<float, vec_size> permuted_vec, vec;
vec.cast_load(x + threadIdx.x * vec_size);
Expand Down Expand Up @@ -170,4 +180,4 @@ cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__

} // namespace flashinfer

#endif // FLASHINFER_ROPE_CUH_
#endif // FLASHINFER_POS_ENC_CUH_
39 changes: 22 additions & 17 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -178,23 +178,28 @@
} \
}

#define DISPATCH_ROTARY_MODE(rotary_mode, ROTARY_MODE, ...) \
switch (rotary_mode) { \
case RotaryMode::kNone: { \
constexpr RotaryMode ROTARY_MODE = RotaryMode::kNone; \
__VA_ARGS__ \
break; \
} \
case RotaryMode::kLlama: { \
constexpr RotaryMode ROTARY_MODE = RotaryMode::kLlama; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported rotary_mode: " << int(rotary_mode); \
throw std::invalid_argument(err_msg.str()); \
} \
#define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \
switch (pos_encoding_mode) { \
case PosEncodingMode::kNone: { \
constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone; \
__VA_ARGS__ \
break; \
} \
case PosEncodingMode::kRoPELlama: { \
constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kRoPELlama; \
__VA_ARGS__ \
break; \
} \
case PosEncodingMode::kALiBi: { \
constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kALiBi; \
__VA_ARGS__ \
break; \
} \
default: { \
std::ostringstream err_msg; \
err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \
throw std::invalid_argument(err_msg.str()); \
} \
}

namespace flashinfer {
Expand Down
Loading