|
14 | 14 | ) |
15 | 15 | from vllm.platforms import current_platform |
16 | 16 | from vllm.utils.torch_utils import direct_register_custom_op |
| 17 | +from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn |
17 | 18 |
|
18 | 19 |
|
19 | 20 | def is_rocm_aiter_rmsnorm_enabled() -> bool: |
@@ -368,8 +369,8 @@ def forward_cuda( |
368 | 369 | self._is_compiled = True |
369 | 370 | return self.forward_native(x, residual) |
370 | 371 |
|
371 | | - |
372 | | -class RMSNormGated(nn.Module): |
| 372 | +@CustomOp.register("rms_norm_gated") |
| 373 | +class RMSNormGated(CustomOp): |
373 | 374 | """RMS Normalization with optional gating. |
374 | 375 |
|
375 | 376 | This is a native PyTorch implementation that supports: |
@@ -413,7 +414,7 @@ def __init__( |
413 | 414 | def reset_parameters(self): |
414 | 415 | torch.nn.init.ones_(self.weight) |
415 | 416 |
|
416 | | - def forward(self, x: torch.Tensor, z: torch.Tensor | None = None) -> torch.Tensor: |
| 417 | + def forward_native(self, x: torch.Tensor, z: torch.Tensor | None = None) -> torch.Tensor: |
417 | 418 | """ |
418 | 419 | Native PyTorch implementation of RMS normalization with gating. |
419 | 420 |
|
@@ -453,6 +454,17 @@ def forward(self, x: torch.Tensor, z: torch.Tensor | None = None) -> torch.Tenso |
453 | 454 |
|
454 | 455 | return out |
455 | 456 |
|
| 457 | + def forward_cuda(self, x: torch.Tensor, z: torch.Tensor | None = None) -> torch.Tensor: |
| 458 | + return rmsnorm_fn( |
| 459 | + x, |
| 460 | + self.weight, |
| 461 | + self.bias, |
| 462 | + z=z, |
| 463 | + eps=self.eps, |
| 464 | + group_size=self.group_size, |
| 465 | + norm_before_gate=self.norm_before_gate, |
| 466 | + ) |
| 467 | + |
456 | 468 |
|
457 | 469 | class LayerNorm(nn.Module): |
458 | 470 | """ |
|
0 commit comments