Skip to content

Commit 39dddfe

Browse files
committed
[API compatibility] update paddle group_norm api
1 parent 79ed81e commit 39dddfe

File tree

3 files changed

+56
-4
lines changed

3 files changed

+56
-4
lines changed

python/paddle/nn/functional/norm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -681,12 +681,13 @@ def local_response_norm(
681681
return res
682682

683683

684+
@param_two_alias(["x", "input"], ["epsilon", "eps"])
684685
def group_norm(
685686
x: Tensor,
686687
num_groups: int,
687-
epsilon: float = 1e-05,
688688
weight: Tensor | None = None,
689689
bias: Tensor | None = None,
690+
epsilon: float = 1e-05,
690691
data_format: DataLayout1D | DataLayout2D | DataLayout3D = 'NCHW',
691692
name: str | None = None,
692693
) -> Tensor:
@@ -697,12 +698,12 @@ def group_norm(
697698
Parameters:
698699
x(Tensor): Input Tensor with shape: attr:`(batch, num_features, *)`.
699700
num_groups(int): The number of groups that divided from channels.
700-
epsilon(float, optional): The small value added to the variance to prevent
701-
division by zero. Default: 1e-05.
702701
weight(Tensor, optional): The weight Tensor of group_norm, with shape: attr:`[num_channels]`.
703702
Default: None.
704703
bias(Tensor, optional): The bias Tensor of group_norm, with shape: attr:`[num_channels]`.
705704
Default: None.
705+
epsilon(float, optional): The small value added to the variance to prevent
706+
division by zero. Default: 1e-05.
706707
data_format(str, optional): Specify the input data format. Support "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW".
707708
name(str|None, optional): Name for the GroupNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..
708709

python/paddle/nn/layer/norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,9 +562,9 @@ def forward(self, input: Tensor) -> Tensor:
562562
return group_norm(
563563
input,
564564
self._num_groups,
565-
self._epsilon,
566565
self.weight,
567566
self.bias,
567+
self._epsilon,
568568
self._data_format,
569569
)
570570

test/legacy_test/test_group_norm_op_v2.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,5 +618,56 @@ def test_group_norm_cpu_with_optional_grad_nhwc(self):
618618
np.testing.assert_equal(dx.numpy(), dx_ref.numpy())
619619

620620

621+
class TestGroupNormParam(unittest.TestCase):
622+
def setUp(self):
623+
np.random.seed(42)
624+
self.x_np = np.random.randn(2, 6, 4, 4).astype('float32')
625+
self.weight_np = np.random.randn(6).astype('float32')
626+
self.bias_np = np.random.randn(6).astype('float32')
627+
628+
def test_alias_input_for_x(self):
629+
"""test parameter alias input/x"""
630+
x_tensor = paddle.to_tensor(self.x_np)
631+
weight_tensor = paddle.to_tensor(self.weight_np)
632+
bias_tensor = paddle.to_tensor(self.bias_np)
633+
634+
out_with_input = paddle.nn.functional.group_norm(
635+
input=x_tensor,
636+
num_groups=3,
637+
weight=weight_tensor,
638+
bias=bias_tensor,
639+
eps=1e-5,
640+
)
641+
out_with_x = paddle.nn.functional.group_norm(
642+
x=x_tensor,
643+
num_groups=3,
644+
weight=weight_tensor,
645+
bias=bias_tensor,
646+
eps=1e-5,
647+
)
648+
649+
np.testing.assert_array_equal(
650+
out_with_input.numpy(), out_with_x.numpy()
651+
)
652+
653+
def test_param_order(self):
654+
"""test order of parameters"""
655+
x_tensor = paddle.to_tensor(self.x_np)
656+
weight_tensor = paddle.to_tensor(self.weight_np)
657+
bias_tensor = paddle.to_tensor(self.bias_np)
658+
659+
try:
660+
out = paddle.nn.functional.group_norm(
661+
x_tensor, # x
662+
3, # num_groups
663+
weight_tensor, # weight
664+
bias_tensor, # bias
665+
1e-5, # epsilon
666+
)
667+
self.assertTrue(True, "Function call succeeded without error")
668+
except Exception as e:
669+
self.fail(f"Function raised an unexpected exception: {e}")
670+
671+
621672
if __name__ == '__main__':
622673
unittest.main()

0 commit comments

Comments
 (0)