Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support fused add rmsnorm #419

Merged
merged 2 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 115 additions & 17 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace flashinfer {
namespace norm {

template <uint32_t VEC_SIZE, typename T>
__global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restrict__ y,
__global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T* __restrict__ output,
const uint32_t d, float eps) {
const uint32_t bx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
Expand All @@ -43,14 +43,14 @@ __global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restric
float sum_sq = 0.f;

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> x_vec;
x_vec.fill(0);
vec_t<T, VEC_SIZE> input_vec;
input_vec.fill(0);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
x_vec.load(x + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
sum_sq += float(x_vec[j]) * float(x_vec[j]);
sum_sq += float(input_vec[j]) * float(input_vec[j]);
}
}

Expand All @@ -76,36 +76,36 @@ __global__ void RMSNormKernel(T* __restrict__ x, T* __restrict__ w, T* __restric
float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> x_vec;
vec_t<T, VEC_SIZE> w_vec;
vec_t<T, VEC_SIZE> y_vec;
x_vec.fill(0);
w_vec.fill(0);
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<T, VEC_SIZE> output_vec;
input_vec.fill(0);
weight_vec.fill(0);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
x_vec.load(x + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
w_vec.load(w + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
y_vec[j] = float(x_vec[j]) * rms_rcp * float(w_vec[j]);
output_vec[j] = float(input_vec[j]) * rms_rcp * float(weight_vec[j]);
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
y_vec.store(y + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
output_vec.store(output + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}
}

template <typename T>
cudaError_t RMSNorm(T* x, T* w, T* y, uint32_t batch_size, uint32_t d, float eps = 1e-5,
cudaStream_t stream = 0) {
cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
float eps = 1e-5, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = ceil_div(block_size, 32);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
void* args[] = {&x, &w, &y, &d, &eps};
void* args[] = {&input, &weight, &output, &d, &eps};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = RMSNormKernel<VEC_SIZE, T>;
Expand All @@ -114,6 +114,104 @@ cudaError_t RMSNorm(T* x, T* w, T* y, uint32_t batch_size, uint32_t d, float eps
return cudaSuccess;
}

template <uint32_t VEC_SIZE, typename T>
__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual,
T* __restrict__ weight, const uint32_t d, float eps) {
const uint32_t bx = blockIdx.x;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
constexpr uint32_t warp_size = 32;
const uint32_t num_warps = blockDim.y;
const uint32_t thread_id = tx + ty * warp_size;
const uint32_t num_threads = num_warps * warp_size;
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads);
extern __shared__ float smem[];

float sum_sq = 0.f;

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
input_vec.fill(0);
vec_t<T, VEC_SIZE> residual_vec;
residual_vec.fill(0);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
float x = float(input_vec[j]);
x += float(residual_vec[j]);
sum_sq += x * x;
residual_vec[j] = (T)x;
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}

// first, warp reduce sum
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}

smem[ty] = sum_sq;
__syncthreads();
// then, cross warp reduce sum using only the first warp
if (ty == 0) {
sum_sq = (tx < num_warps) ? smem[tx] : 0.f;
#pragma unroll
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) {
sum_sq += math::shfl_xor_sync(sum_sq, offset);
}
smem[0] = sum_sq;
}
__syncthreads();

float rms_rcp = math::rsqrt(smem[0] / float(d) + eps);

for (uint32_t i = 0; i < rounds; i++) {
vec_t<T, VEC_SIZE> input_vec;
vec_t<T, VEC_SIZE> weight_vec;
vec_t<T, VEC_SIZE> residual_vec;
input_vec.fill(0);
weight_vec.fill(0);
residual_vec.fill(0);
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
#pragma unroll
for (uint32_t j = 0; j < VEC_SIZE; j++) {
input_vec[j] = float(residual_vec[j]) * rms_rcp * float(weight_vec[j]);
}
if ((i * num_threads + thread_id) * VEC_SIZE < d) {
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE);
}
}
}

template <typename T>
cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
float eps = 1e-5, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = ceil_div(block_size, 32);
dim3 nblks(batch_size);
dim3 nthrs(32, num_warps);
const uint32_t smem_size = num_warps * sizeof(float);
void* args[] = {&input, &residual, &weight, &d, &eps};

DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, {
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>;
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
});

return cudaSuccess;
}

} // namespace norm

} // namespace flashinfer
Expand Down
1 change: 1 addition & 0 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("chain_speculative_sampling", &chain_speculative_sampling,
"Speculative sampling from sequence of probabilities");
m.def("rmsnorm", &rmsnorm, "Root mean square normalization");
m.def("fused_add_rmsnorm", &fused_add_rmsnorm, "Fused add root mean square normalization");
m.def("apply_rope_inplace", &apply_rope_inplace, "Apply RoPE in-place");
m.def("apply_llama31_rope_inplace", &apply_llama31_rope_inplace,
"Apply Llama 3.1 style RoPE in-place");
Expand Down
5 changes: 4 additions & 1 deletion python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ torch::Tensor chain_speculative_sampling(torch::Tensor draft_probs, torch::Tenso
torch::Tensor uniform_samples, torch::Tensor target_probs,
bool deterministic);

torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps);
torch::Tensor rmsnorm(torch::Tensor input, torch::Tensor weight, double eps);

void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight,
double eps);

void apply_rope_inplace(torch::Tensor q, torch::Tensor k, torch::Tensor indptr,
torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta);
Expand Down
62 changes: 46 additions & 16 deletions python/csrc/norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,56 @@

using namespace flashinfer;

torch::Tensor rmsnorm(torch::Tensor x, torch::Tensor w, double eps) {
CHECK_INPUT(x);
CHECK_INPUT(w);
auto device = x.device();
CHECK_EQ(w.device(), device);
CHECK_DIM(2, x); // x: (batch_size, hidden_size)
CHECK_DIM(1, w); // w: (hidden_size)
CHECK_EQ(x.size(1), w.size(0));
unsigned int batch_size = x.size(0);
unsigned int hidden_size = x.size(1);
torch::Tensor rmsnorm(torch::Tensor input, torch::Tensor weight, double eps) {
CHECK_INPUT(input);
CHECK_INPUT(weight);
auto device = input.device();
CHECK_EQ(weight.device(), device);
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
CHECK_DIM(1, weight); // weight: (hidden_size)
CHECK_EQ(input.size(1), weight.size(0));
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
auto y = torch::empty_like(x);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(x.scalar_type(), c_type, [&] {
cudaError_t status = norm::RMSNorm(
static_cast<c_type*>(x.data_ptr()), static_cast<c_type*>(w.data_ptr()),
static_cast<c_type*>(y.data_ptr()), batch_size, hidden_size, eps, torch_current_stream);
auto output = torch::empty_like(input);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::RMSNorm(static_cast<c_type*>(input.data_ptr()),
static_cast<c_type*>(weight.data_ptr()),
static_cast<c_type*>(output.data_ptr()), batch_size,
hidden_size, eps, torch_current_stream);
TORCH_CHECK(status == cudaSuccess,
"RMSNorm failed with error code " + std::string(cudaGetErrorString(status)));
return true;
});
return y;
return output;
}

void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight,
double eps) {
CHECK_INPUT(input);
CHECK_INPUT(residual);
CHECK_INPUT(weight);
auto device = input.device();
CHECK_EQ(residual.device(), device);
CHECK_EQ(weight.device(), device);
CHECK_DIM(2, input); // input: (batch_size, hidden_size)
CHECK_DIM(2, residual); // residual: (batch_size, hidden_size)
CHECK_DIM(1, weight); // weight: (hidden_size)
CHECK_EQ(input.size(0), residual.size(0));
CHECK_EQ(input.size(1), residual.size(1));
CHECK_EQ(input.size(1), weight.size(0));
unsigned int batch_size = input.size(0);
unsigned int hidden_size = input.size(1);

cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index());
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
cudaError_t status = norm::FusedAddRMSNorm(static_cast<c_type*>(input.data_ptr()),
static_cast<c_type*>(residual.data_ptr()),
static_cast<c_type*>(weight.data_ptr()), batch_size,
hidden_size, eps, torch_current_stream);
TORCH_CHECK(status == cudaSuccess, "FusedAddRMSNorm failed with error code " +
std::string(cudaGetErrorString(status)));
return true;
});
}
55 changes: 19 additions & 36 deletions python/flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,44 +14,27 @@
limitations under the License.
"""

from .decode import (
single_decode_with_kv_cache,
BatchDecodeWithPagedKVCacheWrapper,
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
)
from .prefill import (
single_prefill_with_kv_cache,
single_prefill_with_kv_cache_return_lse,
BatchPrefillWithRaggedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
from .sparse import BlockSparseAttentionWrapper
from .cascade import (
merge_state,
merge_state_in_place,
merge_states,
BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
)
from .page import append_paged_kv_cache
from .sampling import (
sampling_from_probs,
top_p_sampling_from_probs,
top_k_sampling_from_probs,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
top_k_renorm_prob,
chain_speculative_sampling,
)
from .norm import rmsnorm
from .rope import (
apply_rope_inplace,
apply_llama31_rope_inplace,
apply_rope,
apply_llama31_rope,
)
from .cascade import (BatchDecodeWithSharedPrefixPagedKVCacheWrapper,
BatchPrefillWithSharedPrefixPagedKVCacheWrapper,
merge_state, merge_state_in_place, merge_states)
from .decode import (BatchDecodeWithPagedKVCacheWrapper,
CUDAGraphBatchDecodeWithPagedKVCacheWrapper,
single_decode_with_kv_cache)
from .group_gemm import SegmentGEMMWrapper
from .norm import fused_add_rmsnorm, rmsnorm
from .page import append_paged_kv_cache
from .prefill import (BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
single_prefill_with_kv_cache,
single_prefill_with_kv_cache_return_lse)
from .quantization import packbits, segment_packbits
from .rope import (apply_llama31_rope, apply_llama31_rope_inplace, apply_rope,
apply_rope_inplace)
from .sampling import (chain_speculative_sampling, sampling_from_probs,
top_k_renorm_prob, top_k_sampling_from_probs,
top_k_top_p_sampling_from_probs, top_p_renorm_prob,
top_p_sampling_from_probs)
from .sparse import BlockSparseAttentionWrapper

try:
from ._build_meta import __version__
Expand Down
33 changes: 27 additions & 6 deletions python/flashinfer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
try:
from . import _kernels
except ImportError as e:
import os
import logging
import os

if os.environ.get("BUILD_DOC", "0") == "1":
_kernels = None
Expand All @@ -30,21 +30,42 @@
raise e


def rmsnorm(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
def rmsnorm(
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
r"""Root mean square normalization.

Parameters
----------
x: torch.Tensor
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
w: torch.Tensor
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.

Returns
-------
y: torch.Tensor
output: torch.Tensor
Normalized tensor, shape (batch_size, hidden_size).
"""
return _kernels.rmsnorm(x, w, eps)
return _kernels.rmsnorm(input, weight, eps)


def fused_add_rmsnorm(
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
):
r"""Fused add root mean square normalization.

Parameters
----------
input: torch.Tensor
Input tensor, shape (batch_size, hidden_size).
residual: torch.Tensor
Residual tensor, shape (batch_size, hidden_size).
weight: torch.Tensor
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
"""
_kernels.fused_add_rmsnorm(input, residual, weight, eps)
Loading