From 3a86386b53185b1406f95f1a7d5b5449814dc656 Mon Sep 17 00:00:00 2001 From: haoyu <3253077866@qq.com> Date: Mon, 19 Aug 2024 19:49:27 +0800 Subject: [PATCH] update fused_rms_norm.py --- .../incubate/nn/functional/fused_rms_norm.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/python/paddle/incubate/nn/functional/fused_rms_norm.py b/python/paddle/incubate/nn/functional/fused_rms_norm.py index c5942135ad136..283f061f4488c 100644 --- a/python/paddle/incubate/nn/functional/fused_rms_norm.py +++ b/python/paddle/incubate/nn/functional/fused_rms_norm.py @@ -12,11 +12,49 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, overload import paddle from paddle import _C_ops from paddle.framework import LayerHelper, in_dynamic_mode, in_pir_mode +if TYPE_CHECKING: + from paddle import Tensor + + +@overload +def fused_layer_norm( + x: Tensor, + norm_weight: Tensor, + norm_bias: Tensor, + epsilon: float, + begin_norm_axis: int, + bias: Tensor | None = ..., + residual: None = ..., + quant_scale: float = ..., + quant_round_type: float = ..., + quant_max_bound: float = ..., + quant_min_bound: float = ..., +) -> Tensor: ... + + +@overload +def fused_layer_norm( + x: Tensor, + norm_weight: Tensor, + norm_bias: Tensor, + epsilon: float, + begin_norm_axis: int, + bias: Tensor | None = ..., + residual: Tensor = ..., + quant_scale: float = ..., + quant_round_type: float = ..., + quant_max_bound: float = ..., + quant_min_bound: float = ..., +) -> tuple[Tensor, Tensor]: ... + def fused_rms_norm( x,