Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add llama 3.1 style rope #401

Merged
merged 13 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/api/python/rope.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
.. _apirope:

flashinfer.rope
===============

Kernels for applying rotary embeddings.

.. currentmodule:: flashinfer.rope

.. autosummary::
:toctree: _generate

apply_rope_inplace
apply_llama31_rope_inplace
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ FlashInfer is a library for Large Language Models that provides high-performance
api/python/sampling
api/python/group_gemm
api/python/norm
api/python/rope
api/python/quantization
178 changes: 145 additions & 33 deletions include/flashinfer/pos_enc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#ifndef FLASHINFER_POS_ENC_CUH_
#define FLASHINFER_POS_ENC_CUH_

#include <cmath>
#include <string>

#include "layout.cuh"
Expand Down Expand Up @@ -93,20 +94,56 @@ __device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope(
return vec;
}

template <uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType, typename IdType>
__global__ void BatchQKApplyRotaryInPlaceKernel(DType* __restrict__ q, DType* __restrict__ k,
IdType* __restrict__ indptr,
IdType* __restrict__ offsets, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t num_kv_heads,
float rope_rcp_scale, float rope_rcp_theta) {
/*!
* \brief Apply RoPE (Rotary Positional Embeddings) to x[0: head_dim] with interleave,
* return thread-local vector.
* \tparam vec_size A template integer indicates the vector size used
* in the kernel
* \tparam bdx A template integer indicates the blockDim.x
* \tparam T A template type indicates the x data type
* \param x A pointer to the start of x data
* \param freq A vector of float indicates the thread-local rope frequency
* \param offset A integer indicates the offset of the position in RoPE
*/
template <uint32_t vec_size, uint32_t bdx, typename T>
__device__ __forceinline__ vec_t<float, vec_size> vec_apply_llama_rope_interleave(
const T* x, const vec_t<float, vec_size>& freq, int32_t offset) {
vec_t<float, vec_size> vec, vec_before;
vec.cast_load(x + threadIdx.x * vec_size);
vec_before = vec;

#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
float embed = float(offset) * freq[i];
float cos, sin;
__sincosf(embed, &sin, &cos);
vec[i] = vec[i] * cos + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin;
}
return vec;
}

template <bool interleave, uint32_t head_dim, uint32_t vec_size, uint32_t bdx, typename DType,
typename IdType>
__global__ void BatchQKApplyRotaryInPlaceKernel(
DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr,
IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, float smooth_a,
float smooth_b, float rope_rcp_scale, float rope_rcp_theta) {
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
const uint32_t bdy = blockDim.y;
vec_t<float, vec_size> freq;
#pragma unroll
for (uint32_t i = 0; i < vec_size; ++i) {
freq[i] =
rope_rcp_scale *
__powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim));
if constexpr (interleave) {
freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(head_dim));
} else {
freq[i] = __powf(rope_rcp_theta,
float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim));
}

float smooth = freq[i] * smooth_a + smooth_b;
smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1]
freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i];
}

if (bx < batch_size * num_qo_heads) {
Expand All @@ -120,8 +157,13 @@ __global__ void BatchQKApplyRotaryInPlaceKernel(DType* __restrict__ q, DType* __
vec_t<float, vec_size> q_vec;
if (i * bdy + ty < seq_len) {
DType* q_ptr = q + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0,
num_qo_heads * head_dim, head_dim);
q_vec = vec_apply_llama_rope<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty);
q_stride_n, q_stride_h);
if constexpr (interleave) {
q_vec =
vec_apply_llama_rope_interleave<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty);
} else {
q_vec = vec_apply_llama_rope<vec_size, bdx>(q_ptr, freq, offset + i * bdy + ty);
}
q_vec.cast_store(q_ptr + tx * vec_size);
}
}
Expand All @@ -136,42 +178,112 @@ __global__ void BatchQKApplyRotaryInPlaceKernel(DType* __restrict__ q, DType* __
vec_t<float, vec_size> k_vec;
if (i * bdy + ty < seq_len) {
DType* k_ptr = k + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0,
num_kv_heads * head_dim, head_dim);
k_vec = vec_apply_llama_rope<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty);
k_stride_n, k_stride_h);
if constexpr (interleave) {
k_vec =
vec_apply_llama_rope_interleave<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty);
} else {
k_vec = vec_apply_llama_rope<vec_size, bdx>(k_ptr, freq, offset + i * bdy + ty);
}
k_vec.cast_store(k_ptr + tx * vec_size);
}
}
}
}

#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \
if (interleave) { \
const bool INTERLEAVE = true; \
__VA_ARGS__ \
} else { \
const bool INTERLEAVE = false; \
__VA_ARGS__ \
}

template <typename DType, typename IdType>
cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ k,
IdType* __restrict__ indptr, IdType* __restrict__ offsets,
uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t head_dim,
float rope_scale = 1.f, float rope_theta = 1e4,
uint32_t num_kv_heads, uint32_t head_dim, size_t q_stride_n,
size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
bool interleave, float rope_scale, float rope_theta,
cudaStream_t stream = nullptr) {
float rope_rcp_scale = 1.0f / rope_scale;
float rope_rcp_theta = 1.0f / rope_theta;
float smooth_a = 0.f;
float smooth_b = 0.f;

DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
constexpr uint32_t bdx = HEAD_DIM / vec_size;
uint32_t num_threads = std::max(128U, bdx);
uint32_t bdy = num_threads / bdx;
dim3 nblks(batch_size * (num_qo_heads + num_kv_heads));
dim3 nthrs(bdx, bdy);
auto kernel =
BatchQKApplyRotaryInPlaceKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
void* args[] = {(void*)&q,
(void*)&k,
(void*)&indptr,
(void*)&offsets,
(void*)&batch_size,
(void*)&num_qo_heads,
(void*)&num_kv_heads,
(void*)&q_stride_n,
(void*)&q_stride_h,
(void*)&k_stride_n,
(void*)&k_stride_h,
(void*)&smooth_a,
(void*)&smooth_b,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
});

return cudaSuccess;
}

template <typename DType, typename IdType>
cudaError_t BatchQKApplyLlama31RotaryInPlace(
DType* __restrict__ q, DType* __restrict__ k, IdType* __restrict__ indptr,
IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h,
bool interleave, float rope_scale, float rope_theta, float low_freq_factor,
float high_freq_factor, float old_context_length, cudaStream_t stream = nullptr) {
float rope_rcp_scale = 1.0f / rope_scale;
float rope_rcp_theta = 1.0f / rope_theta;
float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor);
float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f);

DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
constexpr uint32_t bdx = HEAD_DIM / vec_size;
uint32_t num_threads = std::max(128U, bdx);
uint32_t bdy = num_threads / bdx;
dim3 nblks(batch_size * (num_qo_heads + num_kv_heads));
dim3 nthrs(bdx, bdy);
auto kernel = BatchQKApplyRotaryInPlaceKernel<HEAD_DIM, vec_size, bdx, DType, IdType>;
void* args[] = {(void*)&q,
(void*)&k,
(void*)&indptr,
(void*)&offsets,
(void*)&batch_size,
(void*)&num_qo_heads,
(void*)&num_kv_heads,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
constexpr uint32_t bdx = HEAD_DIM / vec_size;
uint32_t num_threads = std::max(128U, bdx);
uint32_t bdy = num_threads / bdx;
dim3 nblks(batch_size * (num_qo_heads + num_kv_heads));
dim3 nthrs(bdx, bdy);
auto kernel =
BatchQKApplyRotaryInPlaceKernel<INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>;
void* args[] = {(void*)&q,
(void*)&k,
(void*)&indptr,
(void*)&offsets,
(void*)&batch_size,
(void*)&num_qo_heads,
(void*)&num_kv_heads,
(void*)&q_stride_n,
(void*)&q_stride_h,
(void*)&k_stride_n,
(void*)&k_stride_h,
(void*)&smooth_a,
(void*)&smooth_b,
(void*)&rope_rcp_scale,
(void*)&rope_rcp_theta};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
});

return cudaSuccess;
Expand Down
3 changes: 3 additions & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("chain_speculative_sampling", &chain_speculative_sampling,
"Speculative sampling from sequence of probabilities");
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place");
m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace,
"Apply Llama 3.1 style RoPE in-place");
m.def("packbits", &packbits, "GPU packbits operator");
m.def("segment_packbits", &segment_packbits, "GPU segment packbits operator");
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
Expand Down
8 changes: 8 additions & 0 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso

torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps);

void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta);

void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta, float low_freq_factor, float high_freq_factor,
float old_context_length);

torch::Tensor packbits(torch::Tensor x, const std::string& bitorder);

torch::Tensor segment_packbits(torch::Tensor x, torch::Tensor input_indptr,
Expand Down
105 changes: 105 additions & 0 deletions python/csrc/rope.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/pos_enc.cuh>

#include "flashinfer_ops.h"
#include "pytorch_extension_utils.h"

using namespace flashinfer;

void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_INPUT(indptr);
CHECK_INPUT(offsets);

auto device = q.device();
CHECK_EQ(k.device(), device);
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
CHECK_DIM(3, k); // k: (nnz, H_K, D)
CHECK_DIM(1, indptr); // indptr: (B + 1)
CHECK_DIM(1, offsets); // offsets: (B)
CHECK_EQ(q.size(0), k.size(0));
CHECK_EQ(q.size(2), k.size(2));
unsigned int num_qo_heads = q.size(1);
unsigned int num_kv_heads = k.size(1);
unsigned int head_dim = q.size(2);
unsigned int batch_size = offsets.size(0);
CHECK_EQ(indptr.size(0), batch_size + 1);
size_t q_stride_n = q.stride(0);
size_t q_stride_h = q.stride(1);
size_t k_stride_n = k.stride(0);
size_t k_stride_h = k.stride(1);
indptr = indptr.to(torch::kInt32);
offsets = offsets.to(torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyRotaryInPlace(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<int32_t*>(indptr.data_ptr()), static_cast<int32_t*>(offsets.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n,
k_stride_h, interleave, rope_scale, rope_theta, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchQKApplyRotaryInPlace failed with error code " +
std::string(cudaGetErrorString(status)));
return true;
});
}

void apply_llama31_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale,
float rope_theta, float low_freq_factor, float high_freq_factor,
float old_context_length) {
CHECK_CUDA(q); // not necessarily contiguous
CHECK_CUDA(k); // not necessarily contiguous
CHECK_INPUT(indptr);
CHECK_INPUT(offsets);

auto device = q.device();
CHECK_EQ(k.device(), device);
CHECK_DIM(3, q); // q: (nnz, H_Q, D)
CHECK_DIM(3, k); // k: (nnz, H_K, D)
CHECK_DIM(1, indptr); // indptr: (B + 1)
CHECK_DIM(1, offsets); // offsets: (B)
CHECK_EQ(q.size(0), k.size(0));
CHECK_EQ(q.size(2), k.size(2));
unsigned int num_qo_heads = q.size(1);
unsigned int num_kv_heads = k.size(1);
unsigned int head_dim = q.size(2);
unsigned int batch_size = offsets.size(0);
CHECK_EQ(indptr.size(0), batch_size + 1);
size_t q_stride_n = q.stride(0);
size_t q_stride_h = q.stride(1);
size_t k_stride_n = k.stride(0);
size_t k_stride_h = k.stride(1);
indptr = indptr.to(torch::kInt32);
offsets = offsets.to(torch::kInt32);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] {
cudaError_t status = BatchQKApplyLlama31RotaryInPlace(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(k.data_ptr()),
static_cast<int32_t*>(indptr.data_ptr()), static_cast<int32_t*>(offsets.data_ptr()),
batch_size, num_qo_heads, num_kv_heads, head_dim, q_stride_n, q_stride_h, k_stride_n,
k_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor,
old_context_length, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "BatchQKApplyLlama31RotaryInPlace failed with error code " +
std::string(cudaGetErrorString(status)));
return true;
});
}
1 change: 1 addition & 0 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
chain_speculative_sampling,
)
from .norm import rmsnorm
from .rope import apply_rope_inplace, apply_llama31_rope_inplace
from .group_gemm import SegmentGEMMWrapper
from .quantization import packbits, segment_packbits

Expand Down
Loading