Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][C-73] Add type annotations for python/paddle/incubate/nn/functional/fused_layer_norm.py #67159

Merged
merged 8 commits into from
Aug 18, 2024
40 changes: 40 additions & 0 deletions python/paddle/incubate/nn/functional/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, overload

import paddle
from paddle import _C_ops
from paddle.framework import LayerHelper, in_dynamic_mode, in_pir_mode

if TYPE_CHECKING:
from paddle import Tensor


@overload
def fused_layer_norm(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

根据 residual 使用 overload

from __future__ import annotations

from typing import TYPE_CHECKING, overload

import paddle
from paddle import _C_ops
from paddle.framework import LayerHelper, in_dynamic_mode, in_pir_mode

if TYPE_CHECKING:
    from paddle import Tensor

@overload
def fused_layer_norm(
    x: Tensor,
    norm_weight: Tensor,
    norm_bias: Tensor,
    epsilon: float,
    residual_alpha: float = ...,
    begin_norm_axis: int = ...,
    bias: Tensor | None = ...,
    residual: None = ...,
    quant_scale: float = ...,
    quant_round_type: float = ...,
    quant_max_bound: float = ...,
    quant_min_bound: float = ...,
) -> Tensor: ...

@overload
def fused_layer_norm(
    x: Tensor,
    norm_weight: Tensor,
    norm_bias: Tensor,
    epsilon: float,
    residual_alpha: float = ...,
    begin_norm_axis: int = ...,
    bias: Tensor | None = ...,
    residual: Tensor = ...,
    quant_scale: float = ...,
    quant_round_type: float = ...,
    quant_max_bound: float = ...,
    quant_min_bound: float = ...,
) -> tuple[Tensor, Tensor]: ...

def fused_layer_norm(
    x,
    norm_weight,
    norm_bias,
    epsilon,
    residual_alpha=1.0,
    begin_norm_axis=1,
    bias=None,
    residual=None,
    quant_scale=-1,
    quant_round_type=0,
    quant_max_bound=0,
    quant_min_bound=0,
):
    r"""
    Apply Fused LayerNorm kernel. Also support LayerNorm(bias + residual_alpha * residual + x) fused pattern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,感谢

x: Tensor,
norm_weight: Tensor,
norm_bias: Tensor,
epsilon: float,
residual_alpha: float = ...,
begin_norm_axis: int = ...,
bias: Tensor | None = ...,
residual: None = ...,
quant_scale: float = ...,
quant_round_type: float = ...,
quant_max_bound: float = ...,
quant_min_bound: float = ...,
) -> Tensor: ...


@overload
def fused_layer_norm(
x: Tensor,
norm_weight: Tensor,
norm_bias: Tensor,
epsilon: float,
residual_alpha: float = ...,
begin_norm_axis: int = ...,
bias: Tensor | None = ...,
residual: Tensor = ...,
quant_scale: float = ...,
quant_round_type: float = ...,
quant_max_bound: float = ...,
quant_min_bound: float = ...,
) -> tuple[Tensor, Tensor]: ...


def fused_layer_norm(
x,
Expand Down