Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
Expand Down
173 changes: 173 additions & 0 deletions benchmarks/fused_kernels/layernorm_rms_benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import pickle as pkl
import time
from dataclasses import dataclass
from itertools import product
from typing import Callable, Iterable, List, Optional

import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from tqdm import tqdm

import vllm._custom_ops as ops
from vllm.model_executor.layers.layernorm import RMSNorm


@dataclass
class bench_params_t:
num_tokens: int
hidden_size: int
add_residual: bool
dtype: torch.dtype

def description(self):
return (f'N {self.num_tokens} '
f'x D {self.hidden_size} '
f'x R {self.add_residual} '
f'x DT {self.dtype}')


def get_bench_params() -> List[bench_params_t]:
## Test Fixtures
NUM_TOKENS = [2**x for x in range(11)]
HIDDEN_SIZES = list(range(1024, 8129, 1024))
ADD_RESIDUAL = [True, False]
DTYPES = [torch.bfloat16, torch.float]

combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
bench_params = list(map(lambda x: \
bench_params_t(x[0], x[1], x[2], x[3]), combinations))
return bench_params


# Reference impls
def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype):
# Norm
torch_out = None
if residual is None:
torch_out = rms_norm_layer.forward_cuda(x, residual)
else:
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)

# Quant
torch_out, _ = ops.scaled_int8_quant(torch_out)


def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype):
# Norm
torch_out = None
if residual is None:
torch_out = rms_norm_layer.forward_cuda(x, residual)
else:
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)

# Quant
torch_out, _ = ops.scaled_fp8_quant(torch_out)


def fused_impl(
rms_norm_layer: RMSNorm, # this stores the weights
x: torch.Tensor,
residual: Optional[torch.Tensor],
quant_dtype: torch.dtype):
out, _ = ops.rms_norm_dynamic_per_token_quant(x,
rms_norm_layer.weight,
1e-6,
quant_dtype,
residual=residual)


# Bench functions
def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor,
quant_dtype: torch.dtype, label: str, sub_label: str,
fn: Callable, description: str) -> TMeasurement:

min_run_time = 1

globals = {
"rms_norm_layer": rms_norm_layer,
"x": x,
"residual": residual,
"quant_dtype": quant_dtype,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(rms_norm_layer, x, residual, quant_dtype)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)

def bench(params: bench_params_t, label: str, sub_label: str) \
-> Iterable[TMeasurement]:

# Make inputs
layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
# Make weights
layer.weight.data.normal_(mean=1.0, std=0.1)
# Make inputs
scale = 1 / params.hidden_size
x = torch.randn(params.num_tokens,
params.hidden_size,
dtype=params.dtype,
device='cuda') * scale
residual = (torch.randn_like(x) * scale).to(device='cuda') \
if params.add_residual else None

timers = []

# unfused int8 impl.
timers.append(
bench_fn(layer, x, residual, torch.int8, label, sub_label,
unfused_int8_impl, "unfused_int8_impl"))

# unfused fp8 impl.
timers.append(
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
unfused_fp8_impl, "unfused_fp8_impl"))

# fused int8 impl.
timers.append(
bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl,
"fused_int8_impl"))

# fused fp8 impl.
timers.append(
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
fused_impl, "fused_fp8_impl"))

print_timers(timers)

return timers


# launch bench
# runner
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()


def main():
torch.set_default_device('cuda')
bench_params = get_bench_params()

timers = []
for bp in tqdm(bench_params):
timers.extend(
bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
print_timers(timers)

# pickle all the results
timestamp = int(time.time())
with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f:
pkl.dump(timers, f)


if __name__ == '__main__':
main()
7 changes: 7 additions & 0 deletions csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)

#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
Expand Down
8 changes: 8 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight,
void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
torch::Tensor& weight, double epsilon);

void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
torch::Tensor const& input,
torch::Tensor const& weight,
torch::Tensor& scales,
double const epsilon,
std::optional<torch::Tensor> scale_ub,
std::optional<torch::Tensor> residual);

void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include "../../dispatch_utils.h"
#include "../../reduction_utils.cuh"
#include "layernorm_utils.cuh"
#include "quant_conversions.cuh"

namespace vllm {

template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
__device__ void rms_norm_dynamic_per_token_quant_vec(
scalar_out_t* __restrict__ out, // [..., hidden_size]
float* __restrict__ scales, // [num_tokens]
scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon,
float const min_scaling_factor, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) {
float rms = 0.0f;
float token_scale = 0.0f;

// Compute rms
vllm::vectorized::compute_rms<scalar_t, has_residual>(
&rms, input, hidden_size, var_epsilon, residual);

// Compute scale
vllm::vectorized::compute_dynamic_per_token_scales<scalar_t, scalar_out_t,
has_residual>(
&token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor,
hidden_size, residual);

// RMS Norm + Quant
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, true,
has_residual>(
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
} else {
// FP8 - Do not invert s_token_scale for exact match with FBGemm
vllm::vectorized::norm_and_quant<scalar_t, scalar_out_t, false,
has_residual>(
out, input, weight, rms, token_scale, hidden_size, residual);
}
}

// RMS norm + quant kernel
template <typename scalar_t, typename scalar_out_t, bool has_residual = false>
__global__ void rms_norm_dynamic_per_token_quant_kernel(
scalar_out_t* __restrict__ out, // [..., hidden_size]
float* __restrict__ scales, // [num_tokens]
scalar_t const* __restrict__ input, // [..., hidden_size]
scalar_t const* __restrict__ weight, // [hidden_size]
float const* scale_ub, float const var_epsilon,
float const min_scaling_factor, int32_t const hidden_size,
scalar_t* __restrict__ residual = nullptr) {
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
bool const can_vectorize = hidden_size % 4 == 0;

if (can_vectorize) {
return rms_norm_dynamic_per_token_quant_vec<scalar_t, scalar_out_t,
has_residual>(
out, scales, input, weight, scale_ub, var_epsilon, min_scaling_factor,
hidden_size, residual);
}

float rms = 0.0f;
float token_scale = 0.0f;

// Compute RMS
vllm::compute_rms<scalar_t, has_residual>(&rms, input, hidden_size,
var_epsilon, residual);
// Compute Scale
vllm::compute_dynamic_per_token_scales<scalar_t, scalar_out_t, has_residual>(
&token_scale, scales, input, weight, rms, scale_ub, min_scaling_factor,
hidden_size, residual);

// RMS Norm + Quant
if constexpr (std::is_same_v<scalar_out_t, int8_t>) {
vllm::norm_and_quant<scalar_t, scalar_out_t, true, has_residual>(
out, input, weight, rms, 1.0f / token_scale, hidden_size, residual);
} else {
// FP8 - Do not invert s_token_scale for exact match with FBGemm
vllm::norm_and_quant<scalar_t, scalar_out_t, false, has_residual>(
out, input, weight, rms, token_scale, hidden_size, residual);
}
}
} // namespace vllm

// Residual add + RMS norm + dynamic per token
template <typename scalar_in_t>
void rms_norm_dynamic_per_token_quant_dispatch(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& weight, // [hidden_size]
torch::Tensor& scales, // [num_tokens]
double const var_epsilon, // Variance epsilon used in norm calculation
std::optional<at::Tensor> const& scale_ub,
std::optional<at::Tensor>& residual) {
int32_t hidden_size = input.size(-1);
int32_t num_tokens = input.numel() / hidden_size;

dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

const float min_scaling_factor =
out.dtype() == torch::kInt8
? std::numeric_limits<float>::epsilon()
: 1.0f / (std::numeric_limits<c10::Float8_e4m3fn>::max() * 512.f);

if (residual.has_value()) {
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
vllm::rms_norm_dynamic_per_token_quant_kernel<scalar_in_t, scalar_t,
true>
<<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
var_epsilon, min_scaling_factor, hidden_size,
residual->data_ptr<scalar_in_t>());
});

} else {
VLLM_DISPATCH_QUANT_TYPES(
out.scalar_type(), "rms_norm_dynamic_per_token_quant_kernel", [&] {
vllm::rms_norm_dynamic_per_token_quant_kernel<scalar_in_t, scalar_t,
false>
<<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(), scales.data_ptr<float>(),
input.data_ptr<scalar_in_t>(), weight.data_ptr<scalar_in_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
var_epsilon, min_scaling_factor, hidden_size, nullptr);
});
}
}

void rms_norm_dynamic_per_token_quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor const& input, // [..., hidden_size]
torch::Tensor const& weight, // [hidden_size]
torch::Tensor& scales, // [num_tokens]
double const var_epsilon, // Variance epsilon used in norm calculation
std::optional<at::Tensor> scale_ub, std::optional<at::Tensor> residual) {
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn ||
out.dtype() == torch::kInt8);

if (scale_ub.has_value()) {
TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn);
}

VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "rms_norm_dynamic_per_token_quant_dispatch", [&] {
rms_norm_dynamic_per_token_quant_dispatch<scalar_t>(
out, input, weight, scales, var_epsilon, scale_ub, residual);
});
}
Loading