Skip to content

Commit ff20e57

Browse files
ProExpertProgVarun Sundar Rabindranath
authored andcommitted
[torch.compile] Dynamic fp8 + rms_norm fusion (vllm-project#10906)
Signed-off-by: luka <luka@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
1 parent 1d07ac8 commit ff20e57

20 files changed

+1736
-252
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ set(VLLM_EXT_SRC
196196
"csrc/quantization/gptq/q_gemm.cu"
197197
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
198198
"csrc/quantization/fp8/common.cu"
199+
"csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu"
199200
"csrc/quantization/gguf/gguf_kernel.cu"
200201
"csrc/cuda_utils_kernels.cu"
201202
"csrc/prepare_inputs/advance_step.cu"
@@ -300,7 +301,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
300301
#
301302
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
302303
# kernels for the remaining archs that are not already built for 3x.
303-
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
304+
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
304305
"7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}")
305306
# subtract out the archs that are already built for 3x
306307
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import pickle as pkl
2+
import time
3+
from dataclasses import dataclass
4+
from itertools import product
5+
from typing import Callable, Iterable, List, Optional
6+
7+
import torch
8+
import torch.utils.benchmark as TBenchmark
9+
from torch.utils.benchmark import Measurement as TMeasurement
10+
from tqdm import tqdm
11+
12+
import vllm._custom_ops as ops
13+
from vllm.model_executor.layers.layernorm import RMSNorm
14+
15+
16+
@dataclass
17+
class bench_params_t:
18+
num_tokens: int
19+
hidden_size: int
20+
add_residual: bool
21+
dtype: torch.dtype
22+
23+
def description(self):
24+
return (f'N {self.num_tokens} '
25+
f'x D {self.hidden_size} '
26+
f'x R {self.add_residual} '
27+
f'x DT {self.dtype}')
28+
29+
30+
def get_bench_params() -> List[bench_params_t]:
31+
## Test Fixtures
32+
NUM_TOKENS = [2**x for x in range(11)]
33+
HIDDEN_SIZES = list(range(1024, 8129, 1024))
34+
ADD_RESIDUAL = [True, False]
35+
DTYPES = [torch.bfloat16, torch.float]
36+
37+
combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES)
38+
bench_params = list(map(lambda x: \
39+
bench_params_t(x[0], x[1], x[2], x[3]), combinations))
40+
return bench_params
41+
42+
43+
# Reference impls
44+
def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
45+
residual: Optional[torch.Tensor],
46+
quant_dtype: torch.dtype):
47+
# Norm
48+
torch_out = None
49+
if residual is None:
50+
torch_out = rms_norm_layer.forward_cuda(x, residual)
51+
else:
52+
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)
53+
54+
# Quant
55+
torch_out, _, _ = ops.scaled_int8_quant(torch_out)
56+
57+
58+
def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor,
59+
residual: Optional[torch.Tensor],
60+
quant_dtype: torch.dtype):
61+
# Norm
62+
torch_out = None
63+
if residual is None:
64+
torch_out = rms_norm_layer.forward_cuda(x, residual)
65+
else:
66+
torch_out, _ = rms_norm_layer.forward_cuda(x, residual)
67+
68+
# Quant
69+
torch_out, _ = ops.scaled_fp8_quant(torch_out)
70+
71+
72+
def fused_impl(
73+
rms_norm_layer: RMSNorm, # this stores the weights
74+
x: torch.Tensor,
75+
residual: Optional[torch.Tensor],
76+
quant_dtype: torch.dtype):
77+
out, _ = ops.rms_norm_dynamic_per_token_quant(x,
78+
rms_norm_layer.weight,
79+
1e-6,
80+
quant_dtype,
81+
residual=residual)
82+
83+
84+
# Bench functions
85+
def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor,
86+
quant_dtype: torch.dtype, label: str, sub_label: str,
87+
fn: Callable, description: str) -> TMeasurement:
88+
89+
min_run_time = 1
90+
91+
globals = {
92+
"rms_norm_layer": rms_norm_layer,
93+
"x": x,
94+
"residual": residual,
95+
"quant_dtype": quant_dtype,
96+
"fn": fn,
97+
}
98+
return TBenchmark.Timer(
99+
stmt="fn(rms_norm_layer, x, residual, quant_dtype)",
100+
globals=globals,
101+
label=label,
102+
sub_label=sub_label,
103+
description=description,
104+
).blocked_autorange(min_run_time=min_run_time)
105+
106+
def bench(params: bench_params_t, label: str, sub_label: str) \
107+
-> Iterable[TMeasurement]:
108+
109+
# Make inputs
110+
layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype)
111+
# Make weights
112+
layer.weight.data.normal_(mean=1.0, std=0.1)
113+
# Make inputs
114+
scale = 1 / params.hidden_size
115+
x = torch.randn(params.num_tokens,
116+
params.hidden_size,
117+
dtype=params.dtype,
118+
device='cuda') * scale
119+
residual = (torch.randn_like(x) * scale).to(device='cuda') \
120+
if params.add_residual else None
121+
122+
timers = []
123+
124+
# unfused int8 impl.
125+
timers.append(
126+
bench_fn(layer, x, residual, torch.int8, label, sub_label,
127+
unfused_int8_impl, "unfused_int8_impl"))
128+
129+
# unfused fp8 impl.
130+
timers.append(
131+
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
132+
unfused_fp8_impl, "unfused_fp8_impl"))
133+
134+
# fused int8 impl.
135+
timers.append(
136+
bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl,
137+
"fused_int8_impl"))
138+
139+
# fused fp8 impl.
140+
timers.append(
141+
bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label,
142+
fused_impl, "fused_fp8_impl"))
143+
144+
print_timers(timers)
145+
146+
return timers
147+
148+
149+
# launch bench
150+
# runner
151+
def print_timers(timers: Iterable[TMeasurement]):
152+
compare = TBenchmark.Compare(timers)
153+
compare.print()
154+
155+
156+
def main():
157+
torch.set_default_device('cuda')
158+
bench_params = get_bench_params()
159+
160+
timers = []
161+
for bp in tqdm(bench_params):
162+
timers.extend(
163+
bench(bp, "rms-norm-dynamic-per-token-quant", bp.description()))
164+
print_timers(timers)
165+
166+
# pickle all the results
167+
timestamp = int(time.time())
168+
with open(f"rms_norm_dpt_quant-{timestamp}.pkl", "wb") as f:
169+
pkl.dump(timers, f)
170+
171+
172+
if __name__ == '__main__':
173+
main()

csrc/dispatch_utils.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,20 @@
1414
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
1515
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
1616

17+
// TODO(luka/varun): use FP8_TYPE macro after refactoring
18+
#ifndef USE_ROCM
19+
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
20+
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \
21+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
22+
#else
23+
#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
24+
AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fnuz, __VA_ARGS__) \
25+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
26+
#endif
27+
28+
#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
29+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))
30+
1731
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
1832
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
1933
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \

csrc/ops.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ void fused_add_rms_norm_static_fp8_quant(torch::Tensor& out,
6666
torch::Tensor& weight,
6767
torch::Tensor& scale, double epsilon);
6868

69+
void rms_norm_dynamic_per_token_quant(torch::Tensor& out,
70+
torch::Tensor const& input,
71+
torch::Tensor const& weight,
72+
torch::Tensor& scales,
73+
double const epsilon,
74+
std::optional<torch::Tensor> scale_ub,
75+
std::optional<torch::Tensor> residual);
76+
6977
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
7078
torch::Tensor& key, int64_t head_size,
7179
torch::Tensor& cos_sin_cache, bool is_neox);

csrc/quantization/fp8/common.cuh

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#pragma once
22

3+
#include "quantization/vectorization.cuh"
4+
35
#include <cmath>
6+
#include <c10/core/ScalarType.h>
47

58
#ifndef USE_ROCM
69
#include <c10/util/Float8_e4m3fn.h>
@@ -15,6 +18,7 @@ using FP8_TYPE = c10::Float8_e4m3fnuz;
1518
// issue when running dynamic quantization. Here use 224.0f for rocm.
1619
constexpr auto FP8_E4M3_MAX = 224.0f;
1720
#endif
21+
constexpr static auto kFp8Type = c10::CppTypeToScalarType<FP8_TYPE>::value;
1822

1923
namespace vllm {
2024

@@ -89,22 +93,6 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
8993
}
9094
}
9195

92-
template <typename scalar_t>
93-
struct __align__(8) vec4_t {
94-
scalar_t x;
95-
scalar_t y;
96-
scalar_t z;
97-
scalar_t w;
98-
};
99-
100-
typedef struct __align__(4) {
101-
FP8_TYPE x;
102-
FP8_TYPE y;
103-
FP8_TYPE z;
104-
FP8_TYPE w;
105-
}
106-
float8x4_t;
107-
10896
template <typename scalar_t>
10997
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
11098
int64_t const num_elems, int const tid,
@@ -139,10 +127,10 @@ __device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
139127
float const scale,
140128
int64_t const num_elems,
141129
int const tid, int const step) {
130+
using float8x4_t = q8x4_t<FP8_TYPE>;
142131
// Vectorized input/output to better utilize memory bandwidth.
143-
vec4_t<scalar_t> const* vectorized_in =
144-
reinterpret_cast<vec4_t<scalar_t> const*>(input);
145-
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
132+
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
133+
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
146134

147135
int64_t const num_vec_elems = num_elems >> 2;
148136

0 commit comments

Comments
 (0)