Skip to content

Commit 70d2048

Browse files
committed
fix pre-commit
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
1 parent 04182dd commit 70d2048

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

vllm/model_executor/layers/layernorm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
rms_norm_batch_invariant,
1313
vllm_is_batch_invariant,
1414
)
15+
from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn
1516
from vllm.platforms import current_platform
1617
from vllm.utils.torch_utils import direct_register_custom_op
17-
from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn
1818

1919

2020
def is_rocm_aiter_rmsnorm_enabled() -> bool:
@@ -369,6 +369,7 @@ def forward_cuda(
369369
self._is_compiled = True
370370
return self.forward_native(x, residual)
371371

372+
372373
@CustomOp.register("rms_norm_gated")
373374
class RMSNormGated(CustomOp):
374375
"""RMS Normalization with optional gating.
@@ -414,7 +415,9 @@ def __init__(
414415
def reset_parameters(self):
415416
torch.nn.init.ones_(self.weight)
416417

417-
def forward_native(self, x: torch.Tensor, z: torch.Tensor | None = None) -> torch.Tensor:
418+
def forward_native(
419+
self, x: torch.Tensor, z: torch.Tensor | None = None
420+
) -> torch.Tensor:
418421
"""
419422
Native PyTorch implementation of RMS normalization with gating.
420423
@@ -454,7 +457,9 @@ def forward_native(self, x: torch.Tensor, z: torch.Tensor | None = None) -> torc
454457

455458
return out
456459

457-
def forward_cuda(self, x: torch.Tensor, z: torch.Tensor | None = None) -> torch.Tensor:
460+
def forward_cuda(
461+
self, x: torch.Tensor, z: torch.Tensor | None = None
462+
) -> torch.Tensor:
458463
return rmsnorm_fn(
459464
x,
460465
self.weight,

0 commit comments

Comments
 (0)