From c554786254d79ffb4b2b0cc2a7d6ef5c379e22cd Mon Sep 17 00:00:00 2001 From: wangruting Date: Fri, 22 Sep 2023 07:33:50 +0000 Subject: [PATCH 1/3] add gather_nd --- .../op_generator/vjp_interface_gen_op_list.py | 8 +++-- paddle/fluid/primitive/codegen/gen.py | 8 +++++ test/legacy_test/test_gather_nd_op.py | 30 +++++++++---------- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py index 9998a40ec2c87..245bb562030f2 100644 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -54,7 +54,9 @@ 'fused_softmax_mask_upper_triangle', 'slice', 'transpose', - 'slice_double', + 'slice_grad', + 'gather_nd', + 'stack', ] vjp_interface_implementation_gen_op_list = [ "tanh", @@ -87,5 +89,7 @@ 'fused_softmax_mask_upper_triangle', 'slice', 'transpose', - 'slice_double', + 'slice_grad', + 'gather_nd', + 'stack', ] diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index e0eeeb10a3a4d..cfe87c30f1434 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -69,6 +69,10 @@ 'layer_norm_grad', 'embedding_grad', 'scale_grad', + 'gather_nd_grad', + 'stack_grad', + 'squeeze_grad', + 'unsqueeze_grad', ] @@ -163,6 +167,10 @@ 'uniform', 'split', 'transpose', + 'gather_nd_grad', + 'stack_grad', + 'squeeze_grad', + 'unsqueeze_grad', ] diff --git a/test/legacy_test/test_gather_nd_op.py b/test/legacy_test/test_gather_nd_op.py index dd1d996715eef..33c6e5235b248 100644 --- a/test/legacy_test/test_gather_nd_op.py +++ b/test/legacy_test/test_gather_nd_op.py @@ -55,8 +55,8 @@ def config_dtype(self): def test_check_output(self): self.check_output(check_new_ir=True) - def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=False) + def test_x(self): + self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) class TestGatherNdOpWithEmptyIndexFP16(TestGatherNdOpWithEmptyIndex): @@ -80,7 +80,7 @@ def test_check_output(self): def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=False + place, ['X'], 'Out', check_prim=True, check_new_ir=True ) @@ -117,7 +117,7 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=False) + self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) class TestGatherNdOpWithIndex1_ZeroDim(TestGatherNdOpWithIndex1): @@ -168,7 +168,7 @@ def test_check_output(self): def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=False + place, ['X'], 'Out', check_prim=True, check_new_ir=True ) @@ -205,7 +205,7 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=False) + self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) class TestGatherNdOpWithLowIndexFP16(TestGatherNdOpWithLowIndex): @@ -233,7 +233,7 @@ def test_check_grad(self): ['X'], 'Out', check_prim=True, - check_new_ir=False, + check_new_ir=True, numeric_grad_delta=0.5, ) @@ -280,7 +280,7 @@ def test_check_grad(self): ['X'], 'Out', check_prim=True, - check_new_ir=False, + check_new_ir=True, numeric_grad_delta=0.05, ) @@ -310,7 +310,7 @@ def test_check_grad(self): ['X'], 'Out', check_prim=True, - check_new_ir=False, + check_new_ir=True, numeric_grad_delta=0.5, ) @@ -345,7 +345,7 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=False) + self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) class TestGatherNdOpWithSameIndexAsXFP16(TestGatherNdOpWithSameIndexAsX): @@ -373,7 +373,7 @@ def test_check_grad(self): ['X'], 'Out', check_prim=True, - check_new_ir=False, + check_new_ir=True, numeric_grad_delta=0.5, ) @@ -410,7 +410,7 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=False) + self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) class TestGatherNdOpWithHighRankSameFP16(TestGatherNdOpWithHighRankSame): @@ -434,7 +434,7 @@ def test_check_output(self): def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=False + place, ['X'], 'Out', check_prim=True, check_new_ir=True ) @@ -471,7 +471,7 @@ def test_check_output(self): self.check_output(check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=False) + self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) class TestGatherNdOpWithHighRankDiffFP16(TestGatherNdOpWithHighRankDiff): @@ -495,7 +495,7 @@ def test_check_output(self): def test_check_grad(self): place = core.CUDAPlace(0) self.check_grad_with_place( - place, ['X'], 'Out', check_prim=True, check_new_ir=False + place, ['X'], 'Out', check_prim=True, check_new_ir=True ) From f120adbb89fda58b3e352eafc26d8d4b60983f18 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Fri, 22 Sep 2023 15:37:14 +0800 Subject: [PATCH 2/3] Update paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py --- .../fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py index 5a75587c44026..63ee7cd615163 100644 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -96,7 +96,6 @@ 'slice_grad', 'gather_nd', 'stack', - 'slice_double', 'poisson', 'gumbel_softmax', ] From ba4e5d8fd0b313100e92bb400d634c26a2af41d3 Mon Sep 17 00:00:00 2001 From: xiaoguoguo626807 <100397923+xiaoguoguo626807@users.noreply.github.com> Date: Fri, 22 Sep 2023 15:37:51 +0800 Subject: [PATCH 3/3] Update test/legacy_test/test_gather_nd_op.py --- test/legacy_test/test_gather_nd_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_gather_nd_op.py b/test/legacy_test/test_gather_nd_op.py index 33c6e5235b248..a10faff2ac1f3 100644 --- a/test/legacy_test/test_gather_nd_op.py +++ b/test/legacy_test/test_gather_nd_op.py @@ -55,7 +55,7 @@ def config_dtype(self): def test_check_output(self): self.check_output(check_new_ir=True) - def test_x(self): + def test_check_grad(self): self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True)