Skip to content

Commit 04182dd

Browse files
committed
Make RMSNormGated a CustomOp
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
1 parent 8499fd6 commit 04182dd

File tree

1 file changed

+15
-3
lines changed

1 file changed

+15
-3
lines changed

vllm/model_executor/layers/layernorm.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from vllm.platforms import current_platform
1616
from vllm.utils.torch_utils import direct_register_custom_op
17+
from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn
1718

1819

1920
def is_rocm_aiter_rmsnorm_enabled() -> bool:
@@ -368,8 +369,8 @@ def forward_cuda(
368369
self._is_compiled = True
369370
return self.forward_native(x, residual)
370371

371-
372-
class RMSNormGated(nn.Module):
372+
@CustomOp.register("rms_norm_gated")
373+
class RMSNormGated(CustomOp):
373374
"""RMS Normalization with optional gating.
374375
375376
This is a native PyTorch implementation that supports:
@@ -413,7 +414,7 @@ def __init__(
413414
def reset_parameters(self):
414415
torch.nn.init.ones_(self.weight)
415416

416-
def forward(self, x: torch.Tensor, z: torch.Tensor | None = None) -> torch.Tensor:
417+
def forward_native(self, x: torch.Tensor, z: torch.Tensor | None = None) -> torch.Tensor:
417418
"""
418419
Native PyTorch implementation of RMS normalization with gating.
419420
@@ -453,6 +454,17 @@ def forward(self, x: torch.Tensor, z: torch.Tensor | None = None) -> torch.Tenso
453454

454455
return out
455456

457+
def forward_cuda(self, x: torch.Tensor, z: torch.Tensor | None = None) -> torch.Tensor:
458+
return rmsnorm_fn(
459+
x,
460+
self.weight,
461+
self.bias,
462+
z=z,
463+
eps=self.eps,
464+
group_size=self.group_size,
465+
norm_before_gate=self.norm_before_gate,
466+
)
467+
456468

457469
class LayerNorm(nn.Module):
458470
"""

0 commit comments

Comments
 (0)