|
12 | 12 | rms_norm_batch_invariant, |
13 | 13 | vllm_is_batch_invariant, |
14 | 14 | ) |
| 15 | +from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn |
15 | 16 | from vllm.platforms import current_platform |
16 | 17 | from vllm.utils.torch_utils import direct_register_custom_op |
17 | | -from vllm.model_executor.layers.fla.ops.layernorm_guard import rmsnorm_fn |
18 | 18 |
|
19 | 19 |
|
20 | 20 | def is_rocm_aiter_rmsnorm_enabled() -> bool: |
@@ -369,6 +369,7 @@ def forward_cuda( |
369 | 369 | self._is_compiled = True |
370 | 370 | return self.forward_native(x, residual) |
371 | 371 |
|
| 372 | + |
372 | 373 | @CustomOp.register("rms_norm_gated") |
373 | 374 | class RMSNormGated(CustomOp): |
374 | 375 | """RMS Normalization with optional gating. |
@@ -414,7 +415,9 @@ def __init__( |
414 | 415 | def reset_parameters(self): |
415 | 416 | torch.nn.init.ones_(self.weight) |
416 | 417 |
|
417 | | - def forward_native(self, x: torch.Tensor, z: torch.Tensor | None = None) -> torch.Tensor: |
| 418 | + def forward_native( |
| 419 | + self, x: torch.Tensor, z: torch.Tensor | None = None |
| 420 | + ) -> torch.Tensor: |
418 | 421 | """ |
419 | 422 | Native PyTorch implementation of RMS normalization with gating. |
420 | 423 |
|
@@ -454,7 +457,9 @@ def forward_native(self, x: torch.Tensor, z: torch.Tensor | None = None) -> torc |
454 | 457 |
|
455 | 458 | return out |
456 | 459 |
|
457 | | - def forward_cuda(self, x: torch.Tensor, z: torch.Tensor | None = None) -> torch.Tensor: |
| 460 | + def forward_cuda( |
| 461 | + self, x: torch.Tensor, z: torch.Tensor | None = None |
| 462 | + ) -> torch.Tensor: |
458 | 463 | return rmsnorm_fn( |
459 | 464 | x, |
460 | 465 | self.weight, |
|
0 commit comments