|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +import pytest |
| 5 | +import torch |
| 6 | +import torch.nn.functional as F |
| 7 | + |
| 8 | +from vllm.model_executor.layers.fla.ops.layernorm_guard import ( |
| 9 | + layer_norm_fwd, |
| 10 | + layernorm_fn, |
| 11 | + rms_norm_ref, |
| 12 | +) |
| 13 | +from vllm.platforms import current_platform |
| 14 | + |
| 15 | + |
| 16 | +def layer_norm_ref( |
| 17 | + x, |
| 18 | + weight, |
| 19 | + bias, |
| 20 | + z=None, |
| 21 | + eps=1e-6, |
| 22 | + group_size=None, |
| 23 | + norm_before_gate=True, |
| 24 | + is_rms_norm=False, |
| 25 | +): |
| 26 | + """Reference implementation for both layer norm and RMS norm.""" |
| 27 | + if is_rms_norm: |
| 28 | + # Use the imported rms_norm_ref for RMS norm cases |
| 29 | + return rms_norm_ref( |
| 30 | + x, |
| 31 | + weight, |
| 32 | + bias, |
| 33 | + z=z, |
| 34 | + eps=eps, |
| 35 | + group_size=group_size, |
| 36 | + norm_before_gate=norm_before_gate, |
| 37 | + upcast=True, |
| 38 | + ) |
| 39 | + |
| 40 | + # Layer norm implementation |
| 41 | + dtype = x.dtype |
| 42 | + x = x.float() |
| 43 | + weight = weight.float() |
| 44 | + bias = bias.float() if bias is not None else None |
| 45 | + z = z.float() if z is not None else None |
| 46 | + |
| 47 | + if z is not None and not norm_before_gate: |
| 48 | + x = x * F.silu(z) |
| 49 | + |
| 50 | + if group_size is None: |
| 51 | + # Layer norm: subtract mean |
| 52 | + mean = x.mean(dim=-1, keepdim=True) |
| 53 | + var = ((x - mean).square()).mean(dim=-1, keepdim=True) |
| 54 | + rstd = 1 / torch.sqrt(var + eps) |
| 55 | + out = (x - mean) * rstd * weight |
| 56 | + if bias is not None: |
| 57 | + out = out + bias |
| 58 | + else: |
| 59 | + # Group norm |
| 60 | + from einops import rearrange |
| 61 | + |
| 62 | + x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) |
| 63 | + mean = x_group.mean(dim=-1, keepdim=True) |
| 64 | + var = ((x_group - mean).square()).mean(dim=-1, keepdim=True) |
| 65 | + rstd = 1 / torch.sqrt(var + eps) |
| 66 | + x_group = (x_group - mean) * rstd |
| 67 | + out = rearrange(x_group, "... g d -> ... (g d)") * weight |
| 68 | + if bias is not None: |
| 69 | + out = out + bias |
| 70 | + |
| 71 | + if z is not None and norm_before_gate: |
| 72 | + out *= F.silu(z) |
| 73 | + |
| 74 | + return out.to(dtype) |
| 75 | + |
| 76 | + |
| 77 | +DTYPES = [torch.bfloat16, torch.float32] |
| 78 | +# Test various M sizes to ensure rows_per_block logic works correctly |
| 79 | +NUM_TOKENS = [ |
| 80 | + 1, |
| 81 | + 7, |
| 82 | + 16, |
| 83 | + 63, |
| 84 | + 128, |
| 85 | + 256, |
| 86 | + 512, |
| 87 | + 1024, |
| 88 | + 2048, |
| 89 | + 4096, |
| 90 | + 5789, |
| 91 | + 8189, |
| 92 | + 8191, |
| 93 | + 16383, |
| 94 | + 32767, |
| 95 | +] |
| 96 | +HIDDEN_SIZES = [64, 128, 256, 1024] |
| 97 | +GROUP_SIZES = [None, 64, 128] # None means full hidden size |
| 98 | +NORM_BEFORE_GATE = [True, False] |
| 99 | +IS_RMS_NORM = [True, False] |
| 100 | +SEEDS = [0, 42] |
| 101 | + |
| 102 | + |
| 103 | +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| 104 | +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) |
| 105 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 106 | +@pytest.mark.parametrize("seed", SEEDS) |
| 107 | +@pytest.mark.parametrize("is_rms_norm", IS_RMS_NORM) |
| 108 | +@torch.inference_mode() |
| 109 | +def test_layer_norm_fwd_basic( |
| 110 | + num_tokens: int, |
| 111 | + hidden_size: int, |
| 112 | + dtype: torch.dtype, |
| 113 | + seed: int, |
| 114 | + is_rms_norm: bool, |
| 115 | +) -> None: |
| 116 | + """Test basic layer norm forward pass without z (gate) tensor.""" |
| 117 | + current_platform.seed_everything(seed) |
| 118 | + device = torch.device("cuda:0") |
| 119 | + |
| 120 | + # Create inputs |
| 121 | + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) |
| 122 | + weight = torch.randn(hidden_size, dtype=dtype, device=device) |
| 123 | + bias = None if is_rms_norm else torch.randn(hidden_size, dtype=dtype, device=device) |
| 124 | + eps = 1e-6 |
| 125 | + |
| 126 | + # Run the triton kernel |
| 127 | + out, mean, rstd = layer_norm_fwd( |
| 128 | + x, weight, bias, eps, z=None, is_rms_norm=is_rms_norm |
| 129 | + ) |
| 130 | + |
| 131 | + # Run reference implementation |
| 132 | + ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=is_rms_norm) |
| 133 | + |
| 134 | + # Check outputs |
| 135 | + assert out.shape == x.shape |
| 136 | + assert out.dtype == x.dtype |
| 137 | + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) |
| 138 | + |
| 139 | + # Check mean and rstd shapes |
| 140 | + if not is_rms_norm: |
| 141 | + assert mean.shape == (num_tokens,) |
| 142 | + assert rstd.shape == (num_tokens,) |
| 143 | + |
| 144 | + |
| 145 | +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) |
| 146 | +@pytest.mark.parametrize("hidden_size", [128, 256, 1024]) |
| 147 | +@pytest.mark.parametrize("dtype", DTYPES) |
| 148 | +@pytest.mark.parametrize("norm_before_gate", NORM_BEFORE_GATE) |
| 149 | +@pytest.mark.parametrize("is_rms_norm", IS_RMS_NORM) |
| 150 | +@torch.inference_mode() |
| 151 | +def test_layer_norm_fwd_with_gate( |
| 152 | + num_tokens: int, |
| 153 | + hidden_size: int, |
| 154 | + dtype: torch.dtype, |
| 155 | + norm_before_gate: bool, |
| 156 | + is_rms_norm: bool, |
| 157 | +) -> None: |
| 158 | + """Test layer norm forward pass with z (gate) tensor.""" |
| 159 | + current_platform.seed_everything(42) |
| 160 | + device = torch.device("cuda:0") |
| 161 | + |
| 162 | + # Create inputs |
| 163 | + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) |
| 164 | + z = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) |
| 165 | + weight = torch.randn(hidden_size, dtype=dtype, device=device) |
| 166 | + bias = None if is_rms_norm else torch.randn(hidden_size, dtype=dtype, device=device) |
| 167 | + eps = 1e-6 |
| 168 | + |
| 169 | + # Run the triton kernel |
| 170 | + out, mean, rstd = layer_norm_fwd( |
| 171 | + x, |
| 172 | + weight, |
| 173 | + bias, |
| 174 | + eps, |
| 175 | + z=z, |
| 176 | + norm_before_gate=norm_before_gate, |
| 177 | + is_rms_norm=is_rms_norm, |
| 178 | + ) |
| 179 | + |
| 180 | + # Run reference implementation |
| 181 | + ref_out = layer_norm_ref( |
| 182 | + x, |
| 183 | + weight, |
| 184 | + bias, |
| 185 | + z=z, |
| 186 | + eps=eps, |
| 187 | + norm_before_gate=norm_before_gate, |
| 188 | + is_rms_norm=is_rms_norm, |
| 189 | + ) |
| 190 | + |
| 191 | + # Check outputs |
| 192 | + assert out.shape == x.shape |
| 193 | + assert out.dtype == x.dtype |
| 194 | + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) |
| 195 | + |
| 196 | + |
| 197 | +@pytest.mark.parametrize("num_tokens", [128, 512]) |
| 198 | +@pytest.mark.parametrize("hidden_size", [512, 1024]) |
| 199 | +@pytest.mark.parametrize("group_size", [64, 128, 256]) |
| 200 | +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
| 201 | +@pytest.mark.parametrize("is_rms_norm", IS_RMS_NORM) |
| 202 | +@torch.inference_mode() |
| 203 | +def test_layer_norm_fwd_with_groups( |
| 204 | + num_tokens: int, |
| 205 | + hidden_size: int, |
| 206 | + group_size: int, |
| 207 | + dtype: torch.dtype, |
| 208 | + is_rms_norm: bool, |
| 209 | +) -> None: |
| 210 | + """Test layer norm forward pass with group normalization.""" |
| 211 | + if hidden_size % group_size != 0: |
| 212 | + pytest.skip( |
| 213 | + f"hidden_size {hidden_size} not divisible by group_size {group_size}" |
| 214 | + ) |
| 215 | + |
| 216 | + current_platform.seed_everything(42) |
| 217 | + device = torch.device("cuda:0") |
| 218 | + |
| 219 | + # Create inputs |
| 220 | + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) |
| 221 | + weight = torch.randn(hidden_size, dtype=dtype, device=device) |
| 222 | + bias = None if is_rms_norm else torch.randn(hidden_size, dtype=dtype, device=device) |
| 223 | + eps = 1e-6 |
| 224 | + |
| 225 | + ngroups = hidden_size // group_size |
| 226 | + |
| 227 | + # Run the triton kernel |
| 228 | + out, mean, rstd = layer_norm_fwd( |
| 229 | + x, weight, bias, eps, z=None, group_size=group_size, is_rms_norm=is_rms_norm |
| 230 | + ) |
| 231 | + |
| 232 | + # Run reference implementation |
| 233 | + ref_out = layer_norm_ref( |
| 234 | + x, weight, bias, z=None, eps=eps, group_size=group_size, is_rms_norm=is_rms_norm |
| 235 | + ) |
| 236 | + |
| 237 | + # Check outputs |
| 238 | + assert out.shape == x.shape |
| 239 | + assert out.dtype == x.dtype |
| 240 | + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) |
| 241 | + |
| 242 | + # Check mean and rstd shapes for groups |
| 243 | + if not is_rms_norm: |
| 244 | + assert mean.shape == (ngroups * num_tokens,) |
| 245 | + assert rstd.shape == (ngroups * num_tokens,) |
| 246 | + |
| 247 | + |
| 248 | +@pytest.mark.parametrize("num_tokens", [7, 63, 128, 513, 1024, 2049]) |
| 249 | +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
| 250 | +@torch.inference_mode() |
| 251 | +def test_layer_norm_rows_per_block( |
| 252 | + num_tokens: int, |
| 253 | + dtype: torch.dtype, |
| 254 | +) -> None: |
| 255 | + """Test that rows_per_block logic works correctly for various M sizes.""" |
| 256 | + current_platform.seed_everything(42) |
| 257 | + device = torch.device("cuda:0") |
| 258 | + hidden_size = 1024 |
| 259 | + |
| 260 | + # Create inputs |
| 261 | + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) |
| 262 | + weight = torch.randn(hidden_size, dtype=dtype, device=device) |
| 263 | + bias = torch.randn(hidden_size, dtype=dtype, device=device) |
| 264 | + eps = 1e-6 |
| 265 | + |
| 266 | + # Run the triton kernel |
| 267 | + out, mean, rstd = layer_norm_fwd(x, weight, bias, eps, z=None, is_rms_norm=False) |
| 268 | + |
| 269 | + # Run reference implementation |
| 270 | + ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=False) |
| 271 | + |
| 272 | + # Check outputs |
| 273 | + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) |
| 274 | + |
| 275 | + |
| 276 | +@pytest.mark.parametrize("dtype", [torch.bfloat16]) |
| 277 | +@torch.inference_mode() |
| 278 | +def test_strided_input(dtype: torch.dtype) -> None: |
| 279 | + """Test that the kernel handles non-contiguous (strided) |
| 280 | + inputs correctly.""" |
| 281 | + current_platform.seed_everything(42) |
| 282 | + device = torch.device("cuda:0") |
| 283 | + num_tokens = 128 |
| 284 | + hidden_size = 1024 |
| 285 | + |
| 286 | + # Create a larger tensor and take a strided slice |
| 287 | + x_large = torch.randn(num_tokens, hidden_size * 2, dtype=dtype, device=device) |
| 288 | + x = x_large[:, :hidden_size] |
| 289 | + |
| 290 | + # Make it contiguous for the kernel |
| 291 | + x_contiguous = x.contiguous() |
| 292 | + |
| 293 | + weight = torch.randn(hidden_size, dtype=dtype, device=device) |
| 294 | + bias = torch.randn(hidden_size, dtype=dtype, device=device) |
| 295 | + eps = 1e-6 |
| 296 | + |
| 297 | + # Run the triton kernel with contiguous input |
| 298 | + out, mean, rstd = layer_norm_fwd( |
| 299 | + x_contiguous, weight, bias, eps, z=None, is_rms_norm=False |
| 300 | + ) |
| 301 | + |
| 302 | + # Run reference implementation |
| 303 | + ref_out = layer_norm_ref( |
| 304 | + x_contiguous, weight, bias, z=None, eps=eps, is_rms_norm=False |
| 305 | + ) |
| 306 | + |
| 307 | + # Check outputs |
| 308 | + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) |
| 309 | + |
| 310 | + |
| 311 | +@pytest.mark.parametrize("num_tokens", [1, 128, 2048]) |
| 312 | +@pytest.mark.parametrize("hidden_size", [768, 4096]) |
| 313 | +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
| 314 | +@torch.inference_mode() |
| 315 | +def test_output_buffer_provided( |
| 316 | + num_tokens: int, |
| 317 | + hidden_size: int, |
| 318 | + dtype: torch.dtype, |
| 319 | +) -> None: |
| 320 | + """Test that the kernel works when an output buffer is provided.""" |
| 321 | + current_platform.seed_everything(42) |
| 322 | + device = torch.device("cuda:0") |
| 323 | + |
| 324 | + # Create inputs |
| 325 | + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) |
| 326 | + weight = torch.randn(hidden_size, dtype=dtype, device=device) |
| 327 | + bias = torch.randn(hidden_size, dtype=dtype, device=device) |
| 328 | + eps = 1e-6 |
| 329 | + |
| 330 | + # Pre-allocate output buffer |
| 331 | + out_buffer = torch.empty_like(x) |
| 332 | + |
| 333 | + # Run the triton kernel with provided output |
| 334 | + out, mean, rstd = layer_norm_fwd( |
| 335 | + x, weight, bias, eps, z=None, out=out_buffer, is_rms_norm=False |
| 336 | + ) |
| 337 | + |
| 338 | + # Check that the provided buffer was used |
| 339 | + assert out.data_ptr() == out_buffer.data_ptr() |
| 340 | + |
| 341 | + # Run reference implementation |
| 342 | + ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=False) |
| 343 | + |
| 344 | + # Check outputs |
| 345 | + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) |
| 346 | + |
| 347 | + |
| 348 | +@pytest.mark.parametrize( |
| 349 | + "shape", |
| 350 | + [ |
| 351 | + (4, 16, 1024), # 3D tensor |
| 352 | + (2, 8, 512, 256), # 4D tensor |
| 353 | + ], |
| 354 | +) |
| 355 | +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
| 356 | +@torch.inference_mode() |
| 357 | +def test_multidimensional_input( |
| 358 | + shape: tuple, |
| 359 | + dtype: torch.dtype, |
| 360 | +) -> None: |
| 361 | + """Test that the autograd function handles multidimensional inputs.""" |
| 362 | + current_platform.seed_everything(42) |
| 363 | + device = torch.device("cuda:0") |
| 364 | + hidden_size = shape[-1] |
| 365 | + |
| 366 | + # Create inputs |
| 367 | + x = torch.randn(*shape, dtype=dtype, device=device) |
| 368 | + weight = torch.randn(hidden_size, dtype=dtype, device=device) |
| 369 | + bias = torch.randn(hidden_size, dtype=dtype, device=device) |
| 370 | + eps = 1e-6 |
| 371 | + |
| 372 | + # Run through autograd function |
| 373 | + out = layernorm_fn(x, weight, bias, z=None, eps=eps) |
| 374 | + |
| 375 | + # Run reference implementation |
| 376 | + ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=False) |
| 377 | + |
| 378 | + # Check outputs |
| 379 | + assert out.shape == x.shape |
| 380 | + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) |
| 381 | + |
| 382 | + |
| 383 | +if __name__ == "__main__": |
| 384 | + # Run a quick smoke test |
| 385 | + test_layer_norm_fwd_basic(128, 1024, torch.float16, 42, False) |
| 386 | + test_layer_norm_fwd_with_gate(128, 1024, torch.float16, True, False) |
| 387 | + test_layer_norm_rows_per_block(513, torch.float16) |
| 388 | + print("All smoke tests passed!") |
0 commit comments