|
14 | 14 |
|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
| 17 | +import inspect |
17 | 18 | import numbers |
18 | | -from typing import TYPE_CHECKING |
| 19 | +from typing import TYPE_CHECKING, Any |
| 20 | + |
| 21 | +from typing_extensions import overload |
19 | 22 |
|
20 | 23 | import paddle |
21 | 24 | from paddle import _C_ops, in_dynamic_mode |
@@ -681,34 +684,52 @@ def local_response_norm( |
681 | 684 | return res |
682 | 685 |
|
683 | 686 |
|
684 | | -@param_two_alias(["x", "input"], ["epsilon", "eps"]) |
| 687 | +@overload |
685 | 688 | def group_norm( |
686 | 689 | x: Tensor, |
687 | 690 | num_groups: int, |
| 691 | + epsilon: float = 1e-05, |
688 | 692 | weight: Tensor | None = None, |
689 | 693 | bias: Tensor | None = None, |
690 | | - epsilon: float = 1e-05, |
691 | 694 | data_format: DataLayout1D | DataLayout2D | DataLayout3D = 'NCHW', |
692 | 695 | name: str | None = None, |
693 | | -) -> Tensor: |
| 696 | +) -> Tensor: ... |
| 697 | + |
| 698 | + |
| 699 | +@overload |
| 700 | +def group_norm( |
| 701 | + input: Tensor, |
| 702 | + num_groups: int, |
| 703 | + weight: Tensor | None = None, |
| 704 | + bias: Tensor | None = None, |
| 705 | + eps: float = 1e-05, |
| 706 | +) -> Tensor: ... |
| 707 | + |
| 708 | + |
| 709 | +def group_norm(*args: Any, **kwargs: Any) -> Tensor: |
694 | 710 | """ |
695 | 711 | nn.GroupNorm is recommended. |
696 | 712 | For more information, please refer to :ref:`api_paddle_nn_GroupNorm` . |
697 | 713 |
|
698 | | - .. note:: |
699 | | - Alias Support: The parameter name ``input`` can be used as an alias for ``x``. |
700 | | - For example, ``group_norm(input=tensor_x, ...)`` is equivalent to ``group_norm(x=tensor_x, ...)``. |
| 714 | + This function has two functionalities, depending on the parameters passed: |
| 715 | +
|
| 716 | + 1. ``group_norm(Tensor input, int num_groups, Tensor weight = None, Tensor bias = None, float eps = 1e-05)``: |
| 717 | + PyTorch compatible group_norm. |
| 718 | +
|
| 719 | + 2. ``group_norm(Tensor x, int num_groups, float epsilon = 1e-05, Tensor weight = None, Tensor bias = None, |
| 720 | + DataLayout1D | DataLayout2D | DataLayout3D data_format = 'NCHW', str | None name = None)``: |
| 721 | + The original paddle.nn.functional.group_norm, see the following docs. |
701 | 722 |
|
702 | 723 | Parameters: |
703 | 724 | x(Tensor): Input Tensor with shape: attr:`(batch, num_features, *)`. |
704 | 725 | alias: ``input``. |
705 | 726 | num_groups(int): The number of groups that divided from channels. |
| 727 | + epsilon(float, optional): The small value added to the variance to prevent |
| 728 | + division by zero. Default: 1e-05. |
706 | 729 | weight(Tensor, optional): The weight Tensor of group_norm, with shape: attr:`[num_channels]`. |
707 | 730 | Default: None. |
708 | 731 | bias(Tensor, optional): The bias Tensor of group_norm, with shape: attr:`[num_channels]`. |
709 | 732 | Default: None. |
710 | | - epsilon(float, optional): The small value added to the variance to prevent |
711 | | - division by zero. Default: 1e-05. |
712 | 733 | data_format(str, optional): Specify the input data format. Support "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW". |
713 | 734 | name(str|None, optional): Name for the GroupNorm, default is None. For more information, please refer to :ref:`api_guide_Name`.. |
714 | 735 |
|
@@ -750,6 +771,43 @@ def group_norm( |
750 | 771 | [[-1.34163547, -0.44721183], |
751 | 772 | [ 0.44721183, 1.34163547]]]]) |
752 | 773 | """ |
| 774 | + |
| 775 | + len_args = len(args) |
| 776 | + if len_args + len(kwargs) < 2: |
| 777 | + raise TypeError( |
| 778 | + f"Too few arguments in the function call: {len_args}, {len(kwargs)}. Expect one of: \n" |
| 779 | + " - (Tensor input, int num_groups, Tensor weight = None, Tensor bias = None, float eps = 1e-05)\n" |
| 780 | + " - (Tensor x, int num_groups, float epsilon = 1e-05, Tensor weight = None, Tensor bias = None, " |
| 781 | + "DataLayout1D | DataLayout2D | DataLayout3D data_format = 'NCHW', str | None name = None)" |
| 782 | + ) |
| 783 | + |
| 784 | + is_origin_format = False |
| 785 | + if len_args >= 3: |
| 786 | + is_origin_format |= isinstance(args[2], float) |
| 787 | + |
| 788 | + if is_origin_format: |
| 789 | + return _group_norm_wrapper(*args, **kwargs) |
| 790 | + else: |
| 791 | + # transform params from (input,num_groups,weight,bias,eps) to (x,num_groups,epsilon,weight,bias) |
| 792 | + param_keys = ['input', 'num_groups', 'weight', 'bias', 'eps'] |
| 793 | + for idx, arg in enumerate(args): |
| 794 | + key = param_keys[idx] |
| 795 | + if key in kwargs: |
| 796 | + raise TypeError(f"got multiple values for argument '{key}'") |
| 797 | + kwargs[key] = arg |
| 798 | + return _group_norm_wrapper(**kwargs) |
| 799 | + |
| 800 | + |
| 801 | +@param_two_alias(["x", "input"], ["epsilon", "eps"]) |
| 802 | +def _group_norm_wrapper( |
| 803 | + x: Tensor, |
| 804 | + num_groups: int, |
| 805 | + epsilon: float = 1e-05, |
| 806 | + weight: Tensor | None = None, |
| 807 | + bias: Tensor | None = None, |
| 808 | + data_format: DataLayout1D | DataLayout2D | DataLayout3D = 'NCHW', |
| 809 | + name: str | None = None, |
| 810 | +) -> Tensor: |
753 | 811 | if data_format not in ['NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC']: |
754 | 812 | raise ValueError("unsupported data layout:" + data_format) |
755 | 813 |
|
@@ -800,3 +858,6 @@ def group_norm( |
800 | 858 | ) |
801 | 859 |
|
802 | 860 | return helper.append_activation(group_norm_out) |
| 861 | + |
| 862 | + |
| 863 | +group_norm.__signature__ = inspect.signature(_group_norm_wrapper) |
0 commit comments