Skip to content

Commit 82e64c7

Browse files
[PERF] [Qwen3-next] Speed up gated RMSNorm (#26207)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 4ca2040 commit 82e64c7

File tree

2 files changed

+475
-33
lines changed

2 files changed

+475
-33
lines changed
Lines changed: 388 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,388 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
import torch
6+
import torch.nn.functional as F
7+
8+
from vllm.model_executor.layers.fla.ops.layernorm_guard import (
9+
layer_norm_fwd,
10+
layernorm_fn,
11+
rms_norm_ref,
12+
)
13+
from vllm.platforms import current_platform
14+
15+
16+
def layer_norm_ref(
17+
x,
18+
weight,
19+
bias,
20+
z=None,
21+
eps=1e-6,
22+
group_size=None,
23+
norm_before_gate=True,
24+
is_rms_norm=False,
25+
):
26+
"""Reference implementation for both layer norm and RMS norm."""
27+
if is_rms_norm:
28+
# Use the imported rms_norm_ref for RMS norm cases
29+
return rms_norm_ref(
30+
x,
31+
weight,
32+
bias,
33+
z=z,
34+
eps=eps,
35+
group_size=group_size,
36+
norm_before_gate=norm_before_gate,
37+
upcast=True,
38+
)
39+
40+
# Layer norm implementation
41+
dtype = x.dtype
42+
x = x.float()
43+
weight = weight.float()
44+
bias = bias.float() if bias is not None else None
45+
z = z.float() if z is not None else None
46+
47+
if z is not None and not norm_before_gate:
48+
x = x * F.silu(z)
49+
50+
if group_size is None:
51+
# Layer norm: subtract mean
52+
mean = x.mean(dim=-1, keepdim=True)
53+
var = ((x - mean).square()).mean(dim=-1, keepdim=True)
54+
rstd = 1 / torch.sqrt(var + eps)
55+
out = (x - mean) * rstd * weight
56+
if bias is not None:
57+
out = out + bias
58+
else:
59+
# Group norm
60+
from einops import rearrange
61+
62+
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
63+
mean = x_group.mean(dim=-1, keepdim=True)
64+
var = ((x_group - mean).square()).mean(dim=-1, keepdim=True)
65+
rstd = 1 / torch.sqrt(var + eps)
66+
x_group = (x_group - mean) * rstd
67+
out = rearrange(x_group, "... g d -> ... (g d)") * weight
68+
if bias is not None:
69+
out = out + bias
70+
71+
if z is not None and norm_before_gate:
72+
out *= F.silu(z)
73+
74+
return out.to(dtype)
75+
76+
77+
DTYPES = [torch.bfloat16, torch.float32]
78+
# Test various M sizes to ensure rows_per_block logic works correctly
79+
NUM_TOKENS = [
80+
1,
81+
7,
82+
16,
83+
63,
84+
128,
85+
256,
86+
512,
87+
1024,
88+
2048,
89+
4096,
90+
5789,
91+
8189,
92+
8191,
93+
16383,
94+
32767,
95+
]
96+
HIDDEN_SIZES = [64, 128, 256, 1024]
97+
GROUP_SIZES = [None, 64, 128] # None means full hidden size
98+
NORM_BEFORE_GATE = [True, False]
99+
IS_RMS_NORM = [True, False]
100+
SEEDS = [0, 42]
101+
102+
103+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
104+
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
105+
@pytest.mark.parametrize("dtype", DTYPES)
106+
@pytest.mark.parametrize("seed", SEEDS)
107+
@pytest.mark.parametrize("is_rms_norm", IS_RMS_NORM)
108+
@torch.inference_mode()
109+
def test_layer_norm_fwd_basic(
110+
num_tokens: int,
111+
hidden_size: int,
112+
dtype: torch.dtype,
113+
seed: int,
114+
is_rms_norm: bool,
115+
) -> None:
116+
"""Test basic layer norm forward pass without z (gate) tensor."""
117+
current_platform.seed_everything(seed)
118+
device = torch.device("cuda:0")
119+
120+
# Create inputs
121+
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
122+
weight = torch.randn(hidden_size, dtype=dtype, device=device)
123+
bias = None if is_rms_norm else torch.randn(hidden_size, dtype=dtype, device=device)
124+
eps = 1e-6
125+
126+
# Run the triton kernel
127+
out, mean, rstd = layer_norm_fwd(
128+
x, weight, bias, eps, z=None, is_rms_norm=is_rms_norm
129+
)
130+
131+
# Run reference implementation
132+
ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=is_rms_norm)
133+
134+
# Check outputs
135+
assert out.shape == x.shape
136+
assert out.dtype == x.dtype
137+
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
138+
139+
# Check mean and rstd shapes
140+
if not is_rms_norm:
141+
assert mean.shape == (num_tokens,)
142+
assert rstd.shape == (num_tokens,)
143+
144+
145+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
146+
@pytest.mark.parametrize("hidden_size", [128, 256, 1024])
147+
@pytest.mark.parametrize("dtype", DTYPES)
148+
@pytest.mark.parametrize("norm_before_gate", NORM_BEFORE_GATE)
149+
@pytest.mark.parametrize("is_rms_norm", IS_RMS_NORM)
150+
@torch.inference_mode()
151+
def test_layer_norm_fwd_with_gate(
152+
num_tokens: int,
153+
hidden_size: int,
154+
dtype: torch.dtype,
155+
norm_before_gate: bool,
156+
is_rms_norm: bool,
157+
) -> None:
158+
"""Test layer norm forward pass with z (gate) tensor."""
159+
current_platform.seed_everything(42)
160+
device = torch.device("cuda:0")
161+
162+
# Create inputs
163+
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
164+
z = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
165+
weight = torch.randn(hidden_size, dtype=dtype, device=device)
166+
bias = None if is_rms_norm else torch.randn(hidden_size, dtype=dtype, device=device)
167+
eps = 1e-6
168+
169+
# Run the triton kernel
170+
out, mean, rstd = layer_norm_fwd(
171+
x,
172+
weight,
173+
bias,
174+
eps,
175+
z=z,
176+
norm_before_gate=norm_before_gate,
177+
is_rms_norm=is_rms_norm,
178+
)
179+
180+
# Run reference implementation
181+
ref_out = layer_norm_ref(
182+
x,
183+
weight,
184+
bias,
185+
z=z,
186+
eps=eps,
187+
norm_before_gate=norm_before_gate,
188+
is_rms_norm=is_rms_norm,
189+
)
190+
191+
# Check outputs
192+
assert out.shape == x.shape
193+
assert out.dtype == x.dtype
194+
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
195+
196+
197+
@pytest.mark.parametrize("num_tokens", [128, 512])
198+
@pytest.mark.parametrize("hidden_size", [512, 1024])
199+
@pytest.mark.parametrize("group_size", [64, 128, 256])
200+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
201+
@pytest.mark.parametrize("is_rms_norm", IS_RMS_NORM)
202+
@torch.inference_mode()
203+
def test_layer_norm_fwd_with_groups(
204+
num_tokens: int,
205+
hidden_size: int,
206+
group_size: int,
207+
dtype: torch.dtype,
208+
is_rms_norm: bool,
209+
) -> None:
210+
"""Test layer norm forward pass with group normalization."""
211+
if hidden_size % group_size != 0:
212+
pytest.skip(
213+
f"hidden_size {hidden_size} not divisible by group_size {group_size}"
214+
)
215+
216+
current_platform.seed_everything(42)
217+
device = torch.device("cuda:0")
218+
219+
# Create inputs
220+
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
221+
weight = torch.randn(hidden_size, dtype=dtype, device=device)
222+
bias = None if is_rms_norm else torch.randn(hidden_size, dtype=dtype, device=device)
223+
eps = 1e-6
224+
225+
ngroups = hidden_size // group_size
226+
227+
# Run the triton kernel
228+
out, mean, rstd = layer_norm_fwd(
229+
x, weight, bias, eps, z=None, group_size=group_size, is_rms_norm=is_rms_norm
230+
)
231+
232+
# Run reference implementation
233+
ref_out = layer_norm_ref(
234+
x, weight, bias, z=None, eps=eps, group_size=group_size, is_rms_norm=is_rms_norm
235+
)
236+
237+
# Check outputs
238+
assert out.shape == x.shape
239+
assert out.dtype == x.dtype
240+
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
241+
242+
# Check mean and rstd shapes for groups
243+
if not is_rms_norm:
244+
assert mean.shape == (ngroups * num_tokens,)
245+
assert rstd.shape == (ngroups * num_tokens,)
246+
247+
248+
@pytest.mark.parametrize("num_tokens", [7, 63, 128, 513, 1024, 2049])
249+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
250+
@torch.inference_mode()
251+
def test_layer_norm_rows_per_block(
252+
num_tokens: int,
253+
dtype: torch.dtype,
254+
) -> None:
255+
"""Test that rows_per_block logic works correctly for various M sizes."""
256+
current_platform.seed_everything(42)
257+
device = torch.device("cuda:0")
258+
hidden_size = 1024
259+
260+
# Create inputs
261+
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
262+
weight = torch.randn(hidden_size, dtype=dtype, device=device)
263+
bias = torch.randn(hidden_size, dtype=dtype, device=device)
264+
eps = 1e-6
265+
266+
# Run the triton kernel
267+
out, mean, rstd = layer_norm_fwd(x, weight, bias, eps, z=None, is_rms_norm=False)
268+
269+
# Run reference implementation
270+
ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=False)
271+
272+
# Check outputs
273+
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
274+
275+
276+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
277+
@torch.inference_mode()
278+
def test_strided_input(dtype: torch.dtype) -> None:
279+
"""Test that the kernel handles non-contiguous (strided)
280+
inputs correctly."""
281+
current_platform.seed_everything(42)
282+
device = torch.device("cuda:0")
283+
num_tokens = 128
284+
hidden_size = 1024
285+
286+
# Create a larger tensor and take a strided slice
287+
x_large = torch.randn(num_tokens, hidden_size * 2, dtype=dtype, device=device)
288+
x = x_large[:, :hidden_size]
289+
290+
# Make it contiguous for the kernel
291+
x_contiguous = x.contiguous()
292+
293+
weight = torch.randn(hidden_size, dtype=dtype, device=device)
294+
bias = torch.randn(hidden_size, dtype=dtype, device=device)
295+
eps = 1e-6
296+
297+
# Run the triton kernel with contiguous input
298+
out, mean, rstd = layer_norm_fwd(
299+
x_contiguous, weight, bias, eps, z=None, is_rms_norm=False
300+
)
301+
302+
# Run reference implementation
303+
ref_out = layer_norm_ref(
304+
x_contiguous, weight, bias, z=None, eps=eps, is_rms_norm=False
305+
)
306+
307+
# Check outputs
308+
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
309+
310+
311+
@pytest.mark.parametrize("num_tokens", [1, 128, 2048])
312+
@pytest.mark.parametrize("hidden_size", [768, 4096])
313+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
314+
@torch.inference_mode()
315+
def test_output_buffer_provided(
316+
num_tokens: int,
317+
hidden_size: int,
318+
dtype: torch.dtype,
319+
) -> None:
320+
"""Test that the kernel works when an output buffer is provided."""
321+
current_platform.seed_everything(42)
322+
device = torch.device("cuda:0")
323+
324+
# Create inputs
325+
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
326+
weight = torch.randn(hidden_size, dtype=dtype, device=device)
327+
bias = torch.randn(hidden_size, dtype=dtype, device=device)
328+
eps = 1e-6
329+
330+
# Pre-allocate output buffer
331+
out_buffer = torch.empty_like(x)
332+
333+
# Run the triton kernel with provided output
334+
out, mean, rstd = layer_norm_fwd(
335+
x, weight, bias, eps, z=None, out=out_buffer, is_rms_norm=False
336+
)
337+
338+
# Check that the provided buffer was used
339+
assert out.data_ptr() == out_buffer.data_ptr()
340+
341+
# Run reference implementation
342+
ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=False)
343+
344+
# Check outputs
345+
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
346+
347+
348+
@pytest.mark.parametrize(
349+
"shape",
350+
[
351+
(4, 16, 1024), # 3D tensor
352+
(2, 8, 512, 256), # 4D tensor
353+
],
354+
)
355+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
356+
@torch.inference_mode()
357+
def test_multidimensional_input(
358+
shape: tuple,
359+
dtype: torch.dtype,
360+
) -> None:
361+
"""Test that the autograd function handles multidimensional inputs."""
362+
current_platform.seed_everything(42)
363+
device = torch.device("cuda:0")
364+
hidden_size = shape[-1]
365+
366+
# Create inputs
367+
x = torch.randn(*shape, dtype=dtype, device=device)
368+
weight = torch.randn(hidden_size, dtype=dtype, device=device)
369+
bias = torch.randn(hidden_size, dtype=dtype, device=device)
370+
eps = 1e-6
371+
372+
# Run through autograd function
373+
out = layernorm_fn(x, weight, bias, z=None, eps=eps)
374+
375+
# Run reference implementation
376+
ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=False)
377+
378+
# Check outputs
379+
assert out.shape == x.shape
380+
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
381+
382+
383+
if __name__ == "__main__":
384+
# Run a quick smoke test
385+
test_layer_norm_fwd_basic(128, 1024, torch.float16, 42, False)
386+
test_layer_norm_fwd_with_gate(128, 1024, torch.float16, True, False)
387+
test_layer_norm_rows_per_block(513, torch.float16)
388+
print("All smoke tests passed!")

0 commit comments

Comments
 (0)