diff --git a/python/paddle/incubate/nn/functional/fused_layer_norm.py b/python/paddle/incubate/nn/functional/fused_layer_norm.py index 6929fca78b6a7..e8ccbf3d02764 100644 --- a/python/paddle/incubate/nn/functional/fused_layer_norm.py +++ b/python/paddle/incubate/nn/functional/fused_layer_norm.py @@ -12,11 +12,51 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, overload import paddle from paddle import _C_ops from paddle.framework import LayerHelper, in_dynamic_mode, in_pir_mode +if TYPE_CHECKING: + from paddle import Tensor + + +@overload +def fused_layer_norm( + x: Tensor, + norm_weight: Tensor, + norm_bias: Tensor, + epsilon: float, + residual_alpha: float = ..., + begin_norm_axis: int = ..., + bias: Tensor | None = ..., + residual: None = ..., + quant_scale: float = ..., + quant_round_type: float = ..., + quant_max_bound: float = ..., + quant_min_bound: float = ..., +) -> Tensor: ... + + +@overload +def fused_layer_norm( + x: Tensor, + norm_weight: Tensor, + norm_bias: Tensor, + epsilon: float, + residual_alpha: float = ..., + begin_norm_axis: int = ..., + bias: Tensor | None = ..., + residual: Tensor = ..., + quant_scale: float = ..., + quant_round_type: float = ..., + quant_max_bound: float = ..., + quant_min_bound: float = ..., +) -> tuple[Tensor, Tensor]: ... + def fused_layer_norm( x,