Skip to content

Commit 0bb8d2e

Browse files
committed
update: compatible with old parameter order
1 parent 3b98c51 commit 0bb8d2e

File tree

3 files changed

+170
-34
lines changed

3 files changed

+170
-34
lines changed

python/paddle/nn/functional/norm.py

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414

1515
from __future__ import annotations
1616

17+
import inspect
1718
import numbers
18-
from typing import TYPE_CHECKING
19+
from typing import TYPE_CHECKING, Any
20+
21+
from typing_extensions import overload
1922

2023
import paddle
2124
from paddle import _C_ops, in_dynamic_mode
@@ -681,34 +684,52 @@ def local_response_norm(
681684
return res
682685

683686

684-
@param_two_alias(["x", "input"], ["epsilon", "eps"])
687+
@overload
685688
def group_norm(
686689
x: Tensor,
687690
num_groups: int,
691+
epsilon: float = 1e-05,
688692
weight: Tensor | None = None,
689693
bias: Tensor | None = None,
690-
epsilon: float = 1e-05,
691694
data_format: DataLayout1D | DataLayout2D | DataLayout3D = 'NCHW',
692695
name: str | None = None,
693-
) -> Tensor:
696+
) -> Tensor: ...
697+
698+
699+
@overload
700+
def group_norm(
701+
input: Tensor,
702+
num_groups: int,
703+
weight: Tensor | None = None,
704+
bias: Tensor | None = None,
705+
eps: float = 1e-05,
706+
) -> Tensor: ...
707+
708+
709+
def group_norm(*args: Any, **kwargs: Any) -> Tensor:
694710
"""
695711
nn.GroupNorm is recommended.
696712
For more information, please refer to :ref:`api_paddle_nn_GroupNorm` .
697713
698-
.. note::
699-
Alias Support: The parameter name ``input`` can be used as an alias for ``x``.
700-
For example, ``group_norm(input=tensor_x, ...)`` is equivalent to ``group_norm(x=tensor_x, ...)``.
714+
This function has two functionalities, depending on the parameters passed:
715+
716+
1. ``group_norm(Tensor input, int num_groups, Tensor weight = None, Tensor bias = None, float eps = 1e-05)``:
717+
PyTorch compatible group_norm.
718+
719+
2. ``group_norm(Tensor x, int num_groups, float epsilon = 1e-05, Tensor weight = None, Tensor bias = None,
720+
DataLayout1D | DataLayout2D | DataLayout3D data_format = 'NCHW', str | None name = None)``:
721+
The original paddle.nn.functional.group_norm, see the following docs.
701722
702723
Parameters:
703724
x(Tensor): Input Tensor with shape: attr:`(batch, num_features, *)`.
704725
alias: ``input``.
705726
num_groups(int): The number of groups that divided from channels.
727+
epsilon(float, optional): The small value added to the variance to prevent
728+
division by zero. Default: 1e-05.
706729
weight(Tensor, optional): The weight Tensor of group_norm, with shape: attr:`[num_channels]`.
707730
Default: None.
708731
bias(Tensor, optional): The bias Tensor of group_norm, with shape: attr:`[num_channels]`.
709732
Default: None.
710-
epsilon(float, optional): The small value added to the variance to prevent
711-
division by zero. Default: 1e-05.
712733
data_format(str, optional): Specify the input data format. Support "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW".
713734
name(str|None, optional): Name for the GroupNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..
714735
@@ -750,6 +771,43 @@ def group_norm(
750771
[[-1.34163547, -0.44721183],
751772
[ 0.44721183, 1.34163547]]]])
752773
"""
774+
775+
len_args = len(args)
776+
if len_args + len(kwargs) < 2:
777+
raise TypeError(
778+
f"Too few arguments in the function call: {len_args}, {len(kwargs)}. Expect one of: \n"
779+
" - (Tensor input, int num_groups, Tensor weight = None, Tensor bias = None, float eps = 1e-05)\n"
780+
" - (Tensor x, int num_groups, float epsilon = 1e-05, Tensor weight = None, Tensor bias = None, "
781+
"DataLayout1D | DataLayout2D | DataLayout3D data_format = 'NCHW', str | None name = None)"
782+
)
783+
784+
is_origin_format = False
785+
if len_args >= 3:
786+
is_origin_format |= isinstance(args[2], float)
787+
788+
if is_origin_format:
789+
return _group_norm_wrapper(*args, **kwargs)
790+
else:
791+
# transform params from (input,num_groups,weight,bias,eps) to (x,num_groups,epsilon,weight,bias)
792+
param_keys = ['input', 'num_groups', 'weight', 'bias', 'eps']
793+
for idx, arg in enumerate(args):
794+
key = param_keys[idx]
795+
if key in kwargs:
796+
raise TypeError(f"got multiple values for argument '{key}'")
797+
kwargs[key] = arg
798+
return _group_norm_wrapper(**kwargs)
799+
800+
801+
@param_two_alias(["x", "input"], ["epsilon", "eps"])
802+
def _group_norm_wrapper(
803+
x: Tensor,
804+
num_groups: int,
805+
epsilon: float = 1e-05,
806+
weight: Tensor | None = None,
807+
bias: Tensor | None = None,
808+
data_format: DataLayout1D | DataLayout2D | DataLayout3D = 'NCHW',
809+
name: str | None = None,
810+
) -> Tensor:
753811
if data_format not in ['NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC']:
754812
raise ValueError("unsupported data layout:" + data_format)
755813

@@ -800,3 +858,6 @@ def group_norm(
800858
)
801859

802860
return helper.append_activation(group_norm_out)
861+
862+
863+
group_norm.__signature__ = inspect.signature(_group_norm_wrapper)

python/paddle/nn/layer/norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,9 +562,9 @@ def forward(self, input: Tensor) -> Tensor:
562562
return group_norm(
563563
input,
564564
self._num_groups,
565+
self._epsilon,
565566
self.weight,
566567
self.bias,
567-
self._epsilon,
568568
self._data_format,
569569
)
570570

test/legacy_test/test_group_norm_op_v2.py

Lines changed: 99 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -620,54 +620,129 @@ def test_group_norm_cpu_with_optional_grad_nhwc(self):
620620

621621
class TestGroupNormParam(unittest.TestCase):
622622
def setUp(self):
623-
np.random.seed(42)
624-
self.x_np = np.random.randn(2, 6, 4, 4).astype('float32')
625-
self.weight_np = np.random.randn(6).astype('float32')
626-
self.bias_np = np.random.randn(6).astype('float32')
623+
self.x_tensor = paddle.randn([2, 6, 4, 4], dtype='float32')
624+
self.weight_tensor = paddle.randn([6], dtype='float32')
625+
self.bias_tensor = paddle.randn([6], dtype='float32')
627626

628627
def test_alias_input_for_x(self):
629628
"""test parameter alias input/x"""
630-
x_tensor = paddle.to_tensor(self.x_np)
631-
weight_tensor = paddle.to_tensor(self.weight_np)
632-
bias_tensor = paddle.to_tensor(self.bias_np)
633-
634629
out_with_input = paddle.nn.functional.group_norm(
635-
input=x_tensor,
630+
input=self.x_tensor,
636631
num_groups=3,
637-
weight=weight_tensor,
638-
bias=bias_tensor,
632+
weight=self.weight_tensor,
633+
bias=self.bias_tensor,
639634
eps=1e-5,
640635
)
641636
out_with_x = paddle.nn.functional.group_norm(
642-
x=x_tensor,
637+
x=self.x_tensor,
643638
num_groups=3,
644-
weight=weight_tensor,
645-
bias=bias_tensor,
639+
weight=self.weight_tensor,
640+
bias=self.bias_tensor,
646641
eps=1e-5,
647642
)
648643

649644
np.testing.assert_array_equal(
650645
out_with_input.numpy(), out_with_x.numpy()
651646
)
652647

653-
def test_param_order(self):
654-
"""test order of parameters"""
655-
x_tensor = paddle.to_tensor(self.x_np)
656-
weight_tensor = paddle.to_tensor(self.weight_np)
657-
bias_tensor = paddle.to_tensor(self.bias_np)
648+
def test_params_consistency(self):
649+
"""test both paddle and torch formats works."""
650+
out_old = paddle.nn.functional.group_norm(
651+
self.x_tensor,
652+
3,
653+
1e-5,
654+
weight=self.weight_tensor,
655+
bias=self.bias_tensor,
656+
)
657+
658+
out_new = paddle.nn.functional.group_norm(
659+
x=self.x_tensor,
660+
num_groups=3,
661+
weight=self.weight_tensor,
662+
bias=self.bias_tensor,
663+
eps=1e-5,
664+
)
665+
666+
np.testing.assert_array_equal(out_old.numpy(), out_new.numpy())
658667

668+
def test_params_1(self):
669+
"""test all args with torch format"""
659670
try:
660671
out = paddle.nn.functional.group_norm(
661-
x_tensor, # x
662-
3, # num_groups
663-
weight_tensor, # weight
664-
bias_tensor, # bias
665-
1e-5, # epsilon
672+
self.x_tensor,
673+
3,
674+
self.weight_tensor,
675+
self.bias_tensor,
676+
1e-5,
666677
)
667678
self.assertTrue(True, "Function call succeeded without error")
668679
except Exception as e:
669680
self.fail(f"Function raised an unexpected exception: {e}")
670681

682+
def test_params_2(self):
683+
"""test all kwargs with torch format"""
684+
try:
685+
out = paddle.nn.functional.group_norm(
686+
input=self.x_tensor,
687+
num_groups=3,
688+
weight=self.weight_tensor, # weight
689+
bias=self.bias_tensor, # bias
690+
epsilon=1e-5,
691+
)
692+
self.assertTrue(True, "Function call succeeded without error")
693+
except Exception as e:
694+
self.fail(f"Function raised an unexpected exception: {e}")
695+
696+
def test_params_3(self):
697+
"""test of passing both args and kwargs parameters"""
698+
try:
699+
out1 = paddle.nn.functional.group_norm(
700+
self.x_tensor,
701+
3,
702+
weight=self.weight_tensor,
703+
bias=self.bias_tensor,
704+
epsilon=1e-5,
705+
)
706+
out2 = paddle.nn.functional.group_norm(
707+
self.x_tensor,
708+
3,
709+
1e-5,
710+
weight=self.weight_tensor,
711+
bias=self.bias_tensor,
712+
)
713+
self.assertTrue(True, "Function call succeeded without error")
714+
except Exception as e:
715+
self.fail(f"Function raised an unexpected exception: {e}")
716+
717+
def test_params_4(self):
718+
"""test default parameters"""
719+
try:
720+
out1 = paddle.nn.functional.group_norm(
721+
self.x_tensor,
722+
3,
723+
weight=self.weight_tensor,
724+
bias=self.bias_tensor,
725+
)
726+
out2 = paddle.nn.functional.group_norm(self.x_tensor, 3, 1e-5)
727+
self.assertTrue(True, "Function call succeeded without error")
728+
except Exception as e:
729+
self.fail(f"Function raised an unexpected exception: {e}")
730+
731+
def test_params_5(self):
732+
"""test duplicate parameters"""
733+
with self.assertRaises(TypeError):
734+
out_1 = paddle.nn.functional.group_norm(
735+
x=self.x_tensor,
736+
input=self.x_tensor,
737+
num_groups=3,
738+
)
739+
with self.assertRaises(TypeError):
740+
out_2 = paddle.nn.functional.group_norm(
741+
self.x_tensor,
742+
input=self.x_tensor,
743+
num_groups=3,
744+
)
745+
671746

672747
if __name__ == '__main__':
673748
unittest.main()

0 commit comments

Comments
 (0)