Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
842 changes: 769 additions & 73 deletions tests/v1/generation/test_batch_invariance.py

Large diffs are not rendered by default.

346 changes: 346 additions & 0 deletions tests/v1/generation/test_rms_norm_batch_invariant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,346 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test batch-invariant RMS normalization against standard implementations.

This test compares the Triton-based batch-invariant RMS norm implementation
with the standard CUDA-based implementation to ensure numerical accuracy.
"""

import pytest
import torch

from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_norm
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.platforms import current_platform


@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
@pytest.mark.parametrize("batch_size", [1, 4, 16, 64])
@pytest.mark.parametrize("hidden_size", [512, 2048, 4096, 8192])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("eps", [1e-6, 1e-5])
def test_rms_norm_batch_invariant_vs_standard(
batch_size: int, hidden_size: int, dtype: torch.dtype, eps: float
):
"""
Compare batch-invariant Triton RMS norm against standard CUDA implementation.

Tests that the Triton-based batch-invariant RMS norm produces numerically
equivalent results to the standard CUDA implementation across various
configurations.
"""
device = torch.device("cuda")

# Create test input and weight
torch.manual_seed(42)
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
weight = torch.randn(hidden_size, dtype=dtype, device=device)

# Standard implementation (CUDA ops)
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
rms_norm_layer.weight.data = weight.clone()

standard_output = rms_norm_layer.forward_cuda(input_tensor)

# Batch-invariant implementation (Triton)
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)

# Compare outputs
# Use looser tolerance for bfloat16 due to its lower precision
if dtype == torch.bfloat16:
rtol, atol = 1e-1, 1e-1 # 10% relative tolerance for bfloat16
else:
rtol, atol = 1e-2, 1e-2 # 1% for float16/float32

torch.testing.assert_close(
triton_output,
standard_output,
rtol=rtol,
atol=atol,
msg=f"RMS norm mismatch for batch_size={batch_size}, "
f"hidden_size={hidden_size}, "
f"dtype={dtype}, eps={eps}",
)


@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
@pytest.mark.parametrize("batch_size", [1, 16, 128])
@pytest.mark.parametrize("seq_len", [1, 32, 512])
@pytest.mark.parametrize("hidden_size", [2048, 4096])
def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
"""
Test RMS norm with 3D input tensors (batch, seq_len, hidden_size).

Ensures that the batch-invariant RMS norm correctly handles multi-dimensional
inputs that are common in transformer models.
"""
device = torch.device("cuda")
dtype = torch.bfloat16
eps = 1e-6

torch.manual_seed(42)
input_tensor = torch.randn(
batch_size, seq_len, hidden_size, dtype=dtype, device=device
)
weight = torch.randn(hidden_size, dtype=dtype, device=device)

# Standard implementation
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
rms_norm_layer.weight.data = weight.clone()
standard_output = rms_norm_layer.forward_cuda(input_tensor)

# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)

# Use looser tolerance for bfloat16
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16

torch.testing.assert_close(
triton_output,
standard_output,
rtol=rtol,
atol=atol,
msg=f"RMS norm mismatch for 3D input with batch_size={batch_size}, "
f"seq_len={seq_len}, hidden_size={hidden_size}",
)


@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
def test_rms_norm_numerical_stability():
"""
Test RMS norm numerical stability with extreme values.

Ensures that both implementations handle edge cases like very small or large
values without producing NaN or Inf.
"""
device = torch.device("cuda")
dtype = torch.float16
eps = 1e-6
hidden_size = 2048

# Test cases with extreme values
test_cases = [
# Very small values
torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e-5,
# Very large values
torch.ones(4, hidden_size, dtype=dtype, device=device) * 1e4,
# Mixed small and large
torch.randn(4, hidden_size, dtype=dtype, device=device) * 100,
# Values near zero
torch.randn(4, hidden_size, dtype=dtype, device=device) * 1e-6,
]

weight = torch.ones(hidden_size, dtype=dtype, device=device)

for idx, input_tensor in enumerate(test_cases):
# Standard implementation
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
rms_norm_layer.weight.data = weight.clone()
standard_output = rms_norm_layer.forward_cuda(input_tensor)

# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)

# Check for NaN or Inf
assert not torch.isnan(standard_output).any(), (
f"Standard RMS norm produced NaN for test case {idx}"
)
assert not torch.isinf(standard_output).any(), (
f"Standard RMS norm produced Inf for test case {idx}"
)
assert not torch.isnan(triton_output).any(), (
f"Triton RMS norm produced NaN for test case {idx}"
)
assert not torch.isinf(triton_output).any(), (
f"Triton RMS norm produced Inf for test case {idx}"
)

# Compare outputs - very lenient for extreme values with float16
torch.testing.assert_close(
triton_output,
standard_output,
rtol=2e-1, # 20% tolerance for extreme values
atol=2e-1,
msg=f"RMS norm mismatch for extreme value test case {idx}",
)


@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
def test_rms_norm_formula():
"""
Test that RMS norm follows the correct mathematical formula.

Verifies: output = input / sqrt(mean(input^2) + eps) * weight
"""
device = torch.device("cuda")
dtype = torch.float32 # Use float32 for higher precision in formula check
eps = 1e-6
hidden_size = 1024

torch.manual_seed(42)
input_tensor = torch.randn(8, hidden_size, dtype=dtype, device=device)
weight = torch.randn(hidden_size, dtype=dtype, device=device)

# Compute expected output using the formula
variance = (input_tensor.pow(2).mean(dim=-1, keepdim=True)).to(dtype)
expected_output = input_tensor * torch.rsqrt(variance + eps) * weight

# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)

# Compare against formula
torch.testing.assert_close(
triton_output,
expected_output,
rtol=1e-4,
atol=1e-4,
msg="Triton RMS norm doesn't match expected formula",
)


@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
@pytest.mark.parametrize("hidden_size", [128, 1024, 4096, 16384])
def test_rms_norm_different_hidden_sizes(hidden_size: int):
"""
Test RMS norm with various hidden sizes to ensure block size handling.

The Triton kernel uses a fixed BLOCK_SIZE=1024, so this tests that it
correctly handles hidden sizes both smaller and larger than the block size.
"""
device = torch.device("cuda")
dtype = torch.bfloat16
eps = 1e-6
batch_size = 16

torch.manual_seed(42)
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
weight = torch.randn(hidden_size, dtype=dtype, device=device)

# Standard implementation
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
rms_norm_layer.weight.data = weight.clone()
standard_output = rms_norm_layer.forward_cuda(input_tensor)

# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)

# Use looser tolerance for bfloat16
rtol, atol = 1e-1, 1e-1 # 10% tolerance for bfloat16

torch.testing.assert_close(
triton_output,
standard_output,
rtol=rtol,
atol=atol,
msg=f"RMS norm mismatch for hidden_size={hidden_size}",
)


@pytest.mark.skipif(
not current_platform.has_device_capability(90),
reason="Batch invariance tests only supported on Hopper (SM90)",
)
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Requires CUDA for RMS norm kernels"
)
def test_rms_norm_determinism():
"""
Test that batch-invariant RMS norm produces deterministic results.

Runs the same input through the kernel multiple times and verifies
identical outputs.
"""
device = torch.device("cuda")
dtype = torch.bfloat16
eps = 1e-6
hidden_size = 4096
batch_size = 32

torch.manual_seed(42)
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
weight = torch.randn(hidden_size, dtype=dtype, device=device)

# Run multiple times
outputs = []
for _ in range(5):
output = triton_rms_norm(input_tensor.clone(), weight, eps=eps)
outputs.append(output)

# All outputs should be identical
reference = outputs[0]
for idx, output in enumerate(outputs[1:], start=1):
torch.testing.assert_close(
output,
reference,
rtol=0.0,
atol=0.0,
msg=f"RMS norm not deterministic: run {idx} differs from reference",
)


if __name__ == "__main__":
# Run a quick smoke test
print("Running quick smoke test of RMS norm implementations...")

device = torch.device("cuda")
batch_size = 8
hidden_size = 4096
dtype = torch.bfloat16
eps = 1e-6

torch.manual_seed(42)
input_tensor = torch.randn(batch_size, hidden_size, dtype=dtype, device=device)
weight = torch.randn(hidden_size, dtype=dtype, device=device)

# Standard implementation
rms_norm_layer = RMSNorm(hidden_size, eps=eps, dtype=dtype).to(device)
rms_norm_layer.weight.data = weight.clone()
standard_output = rms_norm_layer.forward_cuda(input_tensor)

# Batch-invariant implementation
triton_output = triton_rms_norm(input_tensor, weight, eps=eps)

# Compare
max_diff = (triton_output - standard_output).abs().max().item()
mean_diff = (triton_output - standard_output).abs().mean().item()

print(f"Max difference: {max_diff:.6e}")
print(f"Mean difference: {mean_diff:.6e}")
print(f"Standard output sample: {standard_output[0, :5].tolist()}")
print(f"Triton output sample: {triton_output[0, :5].tolist()}")

if max_diff < 1e-3:
print("✓ Smoke test passed!")
else:
print("✗ Smoke test failed - differences too large")
4 changes: 3 additions & 1 deletion vllm/compilation/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import hashlib
import inspect
import os
import pickle
from unittest.mock import patch

Expand Down Expand Up @@ -168,7 +169,8 @@ def _compute_code_hash(files: set[str]) -> str:
)
file_contents = {}
for filepath in files:
if filepath == "<string>":
# Skip files that don't exist (e.g., <string>, <frozen modules>, etc.)
if not os.path.isfile(filepath):
file_contents[filepath] = ""
else:
with open(filepath) as f:
Expand Down
7 changes: 7 additions & 0 deletions vllm/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from vllm.config.scheduler import RunnerType
from vllm.config.utils import assert_hashable, config, getattr_iter
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
vllm_kernel_override_batch_invariant,
)
from vllm.platforms import current_platform
from vllm.transformers_utils.config import (
ConfigFormat,
Expand Down Expand Up @@ -419,6 +422,10 @@ def __post_init__(
skip_mm_profiling: bool | None,
video_pruning_rate: float | None,
) -> None:
# Enable batch invariance settings if requested
if vllm_kernel_override_batch_invariant():
self.enforce_eager = True

# Set the default seed to 0 in V1.
# NOTE(woosuk): In V0, we set the default seed to None because the
# driver worker shares the same process as the user process, and thus
Expand Down
Loading