Skip to content

Commit 1a76e10

Browse files
bwastiyewentao256
authored andcommitted
Deepseek-v3 Batch Invariant on 8xH100 (vllm-project#26609)
Signed-off-by: Bram Wasti <bwasti@meta.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 20f38c0 commit 1a76e10

File tree

21 files changed

+1567
-102
lines changed

21 files changed

+1567
-102
lines changed

tests/v1/generation/test_batch_invariance.py

Lines changed: 769 additions & 73 deletions
Large diffs are not rendered by default.
Lines changed: 346 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,346 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Test batch-invariant RMS normalization against standard implementations.
5+
6+
This test compares the Triton-based batch-invariant RMS norm implementation
7+
with the standard CUDA-based implementation to ensure numerical accuracy.
8+
"""
9+
10+
import pytest
11+
import torch
12+
13+
from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
14+
from vllm.model_executor.layers.layernorm import RMSNorm
15+
from vllm.platforms import current_platform
16+
17+
18+
@pytest.mark.skipif(
19+
not current_platform.has_device_capability(90),
20+
reason="Batch invariance tests only supported on Hopper (SM90)",
21+
)
22+
@pytest.mark.skipif(
23+
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
24+
)
25+
@pytest.mark.parametrize("batch_size", [1, 4, 16, 64])
26+
@pytest.mark.parametrize("hidden_size", [512, 2048, 4096, 8192])
27+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
28+
@pytest.mark.parametrize("eps", [1e-6, 1e-5])
29+
def test_rms_norm_batch_invariant_vs_standard(
30+
batch_size: int, hidden_size: int, dtype: torch.dtype, eps: float
31+
):
32+
"""
33+
Compare batch-invariant Triton RMS norm against standard CUDA implementation.
34+
35+
Tests that the Triton-based batch-invariant RMS norm produces numerically
36+
equivalent results to the standard CUDA implementation across various
37+
configurations.
38+
"""
39+
device = torch.device("cuda")
40+
41+
# Create test input and weight
42+
torch.manual_seed(42)
43+
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
44+
weight = torch.randn(hidden_size, dtype=dtype, device=device)
45+
46+
# Standard implementation (CUDA ops)
47+
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
48+
rms_norm_layer.weight.data = weight.clone()
49+
50+
standard_output = rms_norm_layer.forward_cuda(input_tensor)
51+
52+
# Batch-invariant implementation (Triton)
53+
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
54+
55+
# Compare outputs
56+
# Use looser tolerance for bfloat16 due to its lower precision
57+
if dtype == torch.bfloat16:
58+
rtol, atol = 1e-1, 1e-1 # 10% relative tolerance for bfloat16
59+
else:
60+
rtol, atol = 1e-2, 1e-2 # 1% for float16/float32
61+
62+
torch.testing.assert_close(
63+
triton_output,
64+
standard_output,
65+
rtol=rtol,
66+
atol=atol,
67+
msg=f"RMS norm mismatch for batch_size={batch_size}, "
68+
f"hidden_size={hidden_size}, "
69+
f"dtype={dtype}, eps={eps}",
70+
)
71+
72+
73+
@pytest.mark.skipif(
74+
not current_platform.has_device_capability(90),
75+
reason="Batch invariance tests only supported on Hopper (SM90)",
76+
)
77+
@pytest.mark.skipif(
78+
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
79+
)
80+
@pytest.mark.parametrize("batch_size", [1, 16, 128])
81+
@pytest.mark.parametrize("seq_len", [1, 32, 512])
82+
@pytest.mark.parametrize("hidden_size", [2048, 4096])
83+
def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
84+
"""
85+
Test RMS norm with 3D input tensors (batch, seq_len, hidden_size).
86+
87+
Ensures that the batch-invariant RMS norm correctly handles multi-dimensional
88+
inputs that are common in transformer models.
89+
"""
90+
device = torch.device("cuda")
91+
dtype = torch.bfloat16
92+
eps = 1e-6
93+
94+
torch.manual_seed(42)
95+
input_tensor = torch.randn(
96+
batch_size, seq_len, hidden_size, dtype=dtype, device=device
97+
)
98+
weight = torch.randn(hidden_size, dtype=dtype, device=device)
99+
100+
# Standard implementation
101+
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
102+
rms_norm_layer.weight.data = weight.clone()
103+
standard_output = rms_norm_layer.forward_cuda(input_tensor)
104+
105+
# Batch-invariant implementation
106+
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
107+
108+
# Use looser tolerance for bfloat16
109+
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
110+
111+
torch.testing.assert_close(
112+
triton_output,
113+
standard_output,
114+
rtol=rtol,
115+
atol=atol,
116+
msg=f"RMS norm mismatch for 3D input with batch_size={batch_size}, "
117+
f"seq_len={seq_len}, hidden_size={hidden_size}",
118+
)
119+
120+
121+
@pytest.mark.skipif(
122+
not current_platform.has_device_capability(90),
123+
reason="Batch invariance tests only supported on Hopper (SM90)",
124+
)
125+
@pytest.mark.skipif(
126+
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
127+
)
128+
def test_rms_norm_numerical_stability():
129+
"""
130+
Test RMS norm numerical stability with extreme values.
131+
132+
Ensures that both implementations handle edge cases like very small or large
133+
values without producing NaN or Inf.
134+
"""
135+
device = torch.device("cuda")
136+
dtype = torch.float16
137+
eps = 1e-6
138+
hidden_size = 2048
139+
140+
# Test cases with extreme values
141+
test_cases = [
142+
# Very small values
143+
torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e-5,
144+
# Very large values
145+
torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e4,
146+
# Mixed small and large
147+
torch.randn(4, hidden_size, dtype=dtype, device=device) * 100,
148+
# Values near zero
149+
torch.randn(4, hidden_size, dtype=dtype, device=device) * 1e-6,
150+
]
151+
152+
weight = torch.ones(hidden_size, dtype=dtype, device=device)
153+
154+
for idx, input_tensor in enumerate(test_cases):
155+
# Standard implementation
156+
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
157+
rms_norm_layer.weight.data = weight.clone()
158+
standard_output = rms_norm_layer.forward_cuda(input_tensor)
159+
160+
# Batch-invariant implementation
161+
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
162+
163+
# Check for NaN or Inf
164+
assert not torch.isnan(standard_output).any(), (
165+
f"Standard RMS norm produced NaN for test case {idx}"
166+
)
167+
assert not torch.isinf(standard_output).any(), (
168+
f"Standard RMS norm produced Inf for test case {idx}"
169+
)
170+
assert not torch.isnan(triton_output).any(), (
171+
f"Triton RMS norm produced NaN for test case {idx}"
172+
)
173+
assert not torch.isinf(triton_output).any(), (
174+
f"Triton RMS norm produced Inf for test case {idx}"
175+
)
176+
177+
# Compare outputs - very lenient for extreme values with float16
178+
torch.testing.assert_close(
179+
triton_output,
180+
standard_output,
181+
rtol=2e-1, # 20% tolerance for extreme values
182+
atol=2e-1,
183+
msg=f"RMS norm mismatch for extreme value test case {idx}",
184+
)
185+
186+
187+
@pytest.mark.skipif(
188+
not current_platform.has_device_capability(90),
189+
reason="Batch invariance tests only supported on Hopper (SM90)",
190+
)
191+
@pytest.mark.skipif(
192+
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
193+
)
194+
def test_rms_norm_formula():
195+
"""
196+
Test that RMS norm follows the correct mathematical formula.
197+
198+
Verifies: output = input / sqrt(mean(input^2) + eps) * weight
199+
"""
200+
device = torch.device("cuda")
201+
dtype = torch.float32 # Use float32 for higher precision in formula check
202+
eps = 1e-6
203+
hidden_size = 1024
204+
205+
torch.manual_seed(42)
206+
input_tensor = torch.randn(8, hidden_size, dtype=dtype, device=device)
207+
weight = torch.randn(hidden_size, dtype=dtype, device=device)
208+
209+
# Compute expected output using the formula
210+
variance = (input_tensor.pow(2).mean(dim=-1, keepdim=True)).to(dtype)
211+
expected_output = input_tensor * torch.rsqrt(variance + eps) * weight
212+
213+
# Batch-invariant implementation
214+
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
215+
216+
# Compare against formula
217+
torch.testing.assert_close(
218+
triton_output,
219+
expected_output,
220+
rtol=1e-4,
221+
atol=1e-4,
222+
msg="Triton RMS norm doesn't match expected formula",
223+
)
224+
225+
226+
@pytest.mark.skipif(
227+
not current_platform.has_device_capability(90),
228+
reason="Batch invariance tests only supported on Hopper (SM90)",
229+
)
230+
@pytest.mark.skipif(
231+
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
232+
)
233+
@pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384])
234+
def test_rms_norm_different_hidden_sizes(hidden_size: int):
235+
"""
236+
Test RMS norm with various hidden sizes to ensure block size handling.
237+
238+
The Triton kernel uses a fixed BLOCK_SIZE=1024, so this tests that it
239+
correctly handles hidden sizes both smaller and larger than the block size.
240+
"""
241+
device = torch.device("cuda")
242+
dtype = torch.bfloat16
243+
eps = 1e-6
244+
batch_size = 16
245+
246+
torch.manual_seed(42)
247+
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
248+
weight = torch.randn(hidden_size, dtype=dtype, device=device)
249+
250+
# Standard implementation
251+
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
252+
rms_norm_layer.weight.data = weight.clone()
253+
standard_output = rms_norm_layer.forward_cuda(input_tensor)
254+
255+
# Batch-invariant implementation
256+
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
257+
258+
# Use looser tolerance for bfloat16
259+
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16
260+
261+
torch.testing.assert_close(
262+
triton_output,
263+
standard_output,
264+
rtol=rtol,
265+
atol=atol,
266+
msg=f"RMS norm mismatch for hidden_size={hidden_size}",
267+
)
268+
269+
270+
@pytest.mark.skipif(
271+
not current_platform.has_device_capability(90),
272+
reason="Batch invariance tests only supported on Hopper (SM90)",
273+
)
274+
@pytest.mark.skipif(
275+
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
276+
)
277+
def test_rms_norm_determinism():
278+
"""
279+
Test that batch-invariant RMS norm produces deterministic results.
280+
281+
Runs the same input through the kernel multiple times and verifies
282+
identical outputs.
283+
"""
284+
device = torch.device("cuda")
285+
dtype = torch.bfloat16
286+
eps = 1e-6
287+
hidden_size = 4096
288+
batch_size = 32
289+
290+
torch.manual_seed(42)
291+
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
292+
weight = torch.randn(hidden_size, dtype=dtype, device=device)
293+
294+
# Run multiple times
295+
outputs = []
296+
for _ in range(5):
297+
output = triton_rms_norm(input_tensor.clone(), weight, eps=eps)
298+
outputs.append(output)
299+
300+
# All outputs should be identical
301+
reference = outputs[0]
302+
for idx, output in enumerate(outputs[1:], start=1):
303+
torch.testing.assert_close(
304+
output,
305+
reference,
306+
rtol=0.0,
307+
atol=0.0,
308+
msg=f"RMS norm not deterministic: run {idx} differs from reference",
309+
)
310+
311+
312+
if __name__ == "__main__":
313+
# Run a quick smoke test
314+
print("Running quick smoke test of RMS norm implementations...")
315+
316+
device = torch.device("cuda")
317+
batch_size = 8
318+
hidden_size = 4096
319+
dtype = torch.bfloat16
320+
eps = 1e-6
321+
322+
torch.manual_seed(42)
323+
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
324+
weight = torch.randn(hidden_size, dtype=dtype, device=device)
325+
326+
# Standard implementation
327+
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
328+
rms_norm_layer.weight.data = weight.clone()
329+
standard_output = rms_norm_layer.forward_cuda(input_tensor)
330+
331+
# Batch-invariant implementation
332+
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)
333+
334+
# Compare
335+
max_diff = (triton_output - standard_output).abs().max().item()
336+
mean_diff = (triton_output - standard_output).abs().mean().item()
337+
338+
print(f"Max difference: {max_diff:.6e}")
339+
print(f"Mean difference: {mean_diff:.6e}")
340+
print(f"Standard output sample: {standard_output[0, :5].tolist()}")
341+
print(f"Triton output sample: {triton_output[0, :5].tolist()}")
342+
343+
if max_diff < 1e-3:
344+
print("✓ Smoke test passed!")
345+
else:
346+
print("✗ Smoke test failed - differences too large")

vllm/compilation/caching.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import hashlib
55
import inspect
6+
import os
67
import pickle
78
from unittest.mock import patch
89

@@ -168,7 +169,8 @@ def _compute_code_hash(files: set[str]) -> str:
168169
)
169170
file_contents = {}
170171
for filepath in files:
171-
if filepath == "<string>":
172+
# Skip files that don't exist (e.g., <string>, <frozen modules>, etc.)
173+
if not os.path.isfile(filepath):
172174
file_contents[filepath] = ""
173175
else:
174176
with open(filepath) as f:

vllm/config/model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from vllm.config.scheduler import RunnerType
2121
from vllm.config.utils import assert_hashable, config, getattr_iter
2222
from vllm.logger import init_logger
23+
from vllm.model_executor.layers.batch_invariant import (
24+
vllm_kernel_override_batch_invariant,
25+
)
2326
from vllm.platforms import current_platform
2427
from vllm.transformers_utils.config import (
2528
ConfigFormat,
@@ -419,6 +422,10 @@ def __post_init__(
419422
skip_mm_profiling: bool | None,
420423
video_pruning_rate: float | None,
421424
) -> None:
425+
# Enable batch invariance settings if requested
426+
if vllm_kernel_override_batch_invariant():
427+
self.enforce_eager = True
428+
422429
# Set the default seed to 0 in V1.
423430
# NOTE(woosuk): In V0, we set the default seed to None because the
424431
# driver worker shares the same process as the user process, and thus

0 commit comments

Comments
 (0)