diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index 5432dcb65f0fb2..2a9dc22368ebfd 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -14,8 +14,11 @@ from __future__ import annotations +import inspect import numbers -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any + +from typing_extensions import overload import paddle from paddle import _C_ops, in_dynamic_mode @@ -681,6 +684,7 @@ def local_response_norm( return res +@overload def group_norm( x: Tensor, num_groups: int, @@ -689,16 +693,40 @@ def group_norm( bias: Tensor | None = None, data_format: DataLayout1D | DataLayout2D | DataLayout3D = 'NCHW', name: str | None = None, -) -> Tensor: +) -> Tensor: ... + + +@overload +def group_norm( + input: Tensor, + num_groups: int, + weight: Tensor | None = None, + bias: Tensor | None = None, + eps: float = 1e-05, +) -> Tensor: ... + + +def group_norm(*args: Any, **kwargs: Any) -> Tensor: """ nn.GroupNorm is recommended. For more information, please refer to :ref:`api_paddle_nn_GroupNorm` . + This function has two functionalities, depending on the parameters passed: + + 1. ``group_norm(Tensor input, int num_groups, Tensor weight = None, Tensor bias = None, float eps = 1e-05)``: + PyTorch compatible group_norm. + + 2. ``group_norm(Tensor x, int num_groups, float epsilon = 1e-05, Tensor weight = None, Tensor bias = None, + DataLayout1D | DataLayout2D | DataLayout3D data_format = 'NCHW', str | None name = None)``: + The original paddle.nn.functional.group_norm, see the following docs. + Parameters: x(Tensor): Input Tensor with shape: attr:`(batch, num_features, *)`. + alias: ``input``. num_groups(int): The number of groups that divided from channels. epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-05. + alias: ``eps``. weight(Tensor, optional): The weight Tensor of group_norm, with shape: attr:`[num_channels]`. Default: None. bias(Tensor, optional): The bias Tensor of group_norm, with shape: attr:`[num_channels]`. @@ -744,6 +772,44 @@ def group_norm( [[-1.34163547, -0.44721183], [ 0.44721183, 1.34163547]]]]) """ + + len_args = len(args) + if len_args + len(kwargs) < 2: + raise TypeError( + f"Too few arguments in the function call: {len_args}, {len(kwargs)}. Expect one of: \n" + " - (Tensor input, int num_groups, Tensor weight = None, Tensor bias = None, float eps = 1e-05)\n" + " - (Tensor x, int num_groups, float epsilon = 1e-05, Tensor weight = None, Tensor bias = None, " + "DataLayout1D | DataLayout2D | DataLayout3D data_format = 'NCHW', str | None name = None)" + ) + + def safe_set_param(key: str, value: Any): + if key in kwargs: + raise TypeError(f"got multiple values for argument '{key}'") + kwargs[key] = value + + if 'input' in kwargs: + safe_set_param('x', kwargs.pop('input')) + + if 'eps' in kwargs: + safe_set_param('epsilon', kwargs.pop('eps')) + + if len_args >= 3 and not isinstance(args[2], float): + param_keys = ["weight", "bias", "epsilon"] + for idx in range(min(len_args - 2, len(param_keys))): + safe_set_param(param_keys[idx], args[idx + 2]) + args = args[:2] + return _group_norm_wrapper(*args, **kwargs) + + +def _group_norm_wrapper( + x: Tensor, + num_groups: int, + epsilon: float = 1e-05, + weight: Tensor | None = None, + bias: Tensor | None = None, + data_format: DataLayout1D | DataLayout2D | DataLayout3D = 'NCHW', + name: str | None = None, +) -> Tensor: if data_format not in ['NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC']: raise ValueError("unsupported data layout:" + data_format) @@ -794,3 +860,6 @@ def group_norm( ) return helper.append_activation(group_norm_out) + + +group_norm.__signature__ = inspect.signature(_group_norm_wrapper) diff --git a/test/legacy_test/test_group_norm_op_v2.py b/test/legacy_test/test_group_norm_op_v2.py index 19b0057c50dfec..e959fa343a985e 100644 --- a/test/legacy_test/test_group_norm_op_v2.py +++ b/test/legacy_test/test_group_norm_op_v2.py @@ -618,5 +618,130 @@ def test_group_norm_cpu_with_optional_grad_nhwc(self): np.testing.assert_equal(dx.numpy(), dx_ref.numpy()) +class TestGroupNormParam(unittest.TestCase): + def setUp(self): + self.x_tensor = paddle.randn([2, 6, 4, 4], dtype='float32') + self.weight_tensor = paddle.randn([6], dtype='float32') + self.bias_tensor = paddle.randn([6], dtype='float32') + + def test_alias_input_for_x(self): + """test parameter alias input/x""" + out_with_input = paddle.nn.functional.group_norm( + input=self.x_tensor, + num_groups=3, + weight=self.weight_tensor, + bias=self.bias_tensor, + eps=1e-5, + ) + out_with_x = paddle.nn.functional.group_norm( + x=self.x_tensor, + num_groups=3, + weight=self.weight_tensor, + bias=self.bias_tensor, + eps=1e-5, + ) + + np.testing.assert_array_equal( + out_with_input.numpy(), out_with_x.numpy() + ) + + def test_params_consistency(self): + """test both paddle and torch formats works.""" + out_old = paddle.nn.functional.group_norm( + self.x_tensor, + 3, + 1e-5, + weight=self.weight_tensor, + bias=self.bias_tensor, + ) + + out_new = paddle.nn.functional.group_norm( + x=self.x_tensor, + num_groups=3, + weight=self.weight_tensor, + bias=self.bias_tensor, + eps=1e-5, + ) + + np.testing.assert_array_equal(out_old.numpy(), out_new.numpy()) + + def test_params_1(self): + """test all args with torch format""" + try: + out = paddle.nn.functional.group_norm( + self.x_tensor, + 3, + self.weight_tensor, + self.bias_tensor, + 1e-5, + ) + self.assertTrue(True, "Function call succeeded without error") + except Exception as e: + self.fail(f"Function raised an unexpected exception: {e}") + + def test_params_2(self): + """test all kwargs with torch format""" + try: + out = paddle.nn.functional.group_norm( + input=self.x_tensor, + num_groups=3, + weight=self.weight_tensor, + bias=self.bias_tensor, + epsilon=1e-5, + ) + self.assertTrue(True, "Function call succeeded without error") + except Exception as e: + self.fail(f"Function raised an unexpected exception: {e}") + + def test_params_3(self): + """test of passing both args and kwargs parameters""" + try: + out1 = paddle.nn.functional.group_norm( + self.x_tensor, + 3, + weight=self.weight_tensor, + bias=self.bias_tensor, + epsilon=1e-5, + ) + out2 = paddle.nn.functional.group_norm( + self.x_tensor, + 3, + 1e-5, + weight=self.weight_tensor, + bias=self.bias_tensor, + ) + self.assertTrue(True, "Function call succeeded without error") + except Exception as e: + self.fail(f"Function raised an unexpected exception: {e}") + + def test_params_4(self): + """test default parameters""" + try: + out1 = paddle.nn.functional.group_norm( + self.x_tensor, + 3, + self.weight_tensor, + ) + out2 = paddle.nn.functional.group_norm(self.x_tensor, 3, 1e-5) + self.assertTrue(True, "Function call succeeded without error") + except Exception as e: + self.fail(f"Function raised an unexpected exception: {e}") + + def test_params_5(self): + """test duplicate parameters""" + with self.assertRaises(TypeError): + out_1 = paddle.nn.functional.group_norm( + x=self.x_tensor, + input=self.x_tensor, + num_groups=3, + ) + with self.assertRaises(TypeError): + out_2 = paddle.nn.functional.group_norm( + self.x_tensor, + input=self.x_tensor, + num_groups=3, + ) + + if __name__ == '__main__': unittest.main()