Skip to content

Commit

Permalink
add gather_nd (PaddlePaddle#57640)
Browse files Browse the repository at this point in the history
* add gather_nd

* Update paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py

* Update test/legacy_test/test_gather_nd_op.py
  • Loading branch information
xiaoguoguo626807 authored and Frida-a committed Oct 14, 2023
1 parent 75678ea commit ba15a12
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@
'fused_softmax_mask_upper_triangle',
'slice',
'transpose',
'slice_double',
'slice_grad',
'gather_nd',
'stack',
'poisson',
'gumbel_softmax',
'tril',
Expand Down Expand Up @@ -93,7 +95,9 @@
'fused_softmax_mask_upper_triangle',
'slice',
'transpose',
'slice_double',
'slice_grad',
'gather_nd',
'stack',
'poisson',
'gumbel_softmax',
'tril',
Expand Down
8 changes: 8 additions & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@
'layer_norm_grad',
'embedding_grad',
'scale_grad',
'gather_nd_grad',
'stack_grad',
'squeeze_grad',
'unsqueeze_grad',
'poisson_grad',
'gumbel_softmax_grad',
]
Expand Down Expand Up @@ -173,6 +177,10 @@
'gumbel_softmax_grad',
'split',
'transpose',
'gather_nd_grad',
'stack_grad',
'squeeze_grad',
'unsqueeze_grad',
]


Expand Down
28 changes: 14 additions & 14 deletions test/legacy_test/test_gather_nd_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_check_output(self):
self.check_output(check_new_ir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=False)
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True)


class TestGatherNdOpWithEmptyIndexFP16(TestGatherNdOpWithEmptyIndex):
Expand All @@ -80,7 +80,7 @@ def test_check_output(self):
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', check_prim=True, check_new_ir=False
place, ['X'], 'Out', check_prim=True, check_new_ir=True
)


Expand Down Expand Up @@ -117,7 +117,7 @@ def test_check_output(self):
self.check_output(check_new_ir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=False)
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True)


class TestGatherNdOpWithIndex1_ZeroDim(TestGatherNdOpWithIndex1):
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_check_output(self):
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', check_prim=True, check_new_ir=False
place, ['X'], 'Out', check_prim=True, check_new_ir=True
)


Expand Down Expand Up @@ -205,7 +205,7 @@ def test_check_output(self):
self.check_output(check_new_ir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=False)
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True)


class TestGatherNdOpWithLowIndexFP16(TestGatherNdOpWithLowIndex):
Expand Down Expand Up @@ -233,7 +233,7 @@ def test_check_grad(self):
['X'],
'Out',
check_prim=True,
check_new_ir=False,
check_new_ir=True,
numeric_grad_delta=0.5,
)

Expand Down Expand Up @@ -280,7 +280,7 @@ def test_check_grad(self):
['X'],
'Out',
check_prim=True,
check_new_ir=False,
check_new_ir=True,
numeric_grad_delta=0.05,
)

Expand Down Expand Up @@ -310,7 +310,7 @@ def test_check_grad(self):
['X'],
'Out',
check_prim=True,
check_new_ir=False,
check_new_ir=True,
numeric_grad_delta=0.5,
)

Expand Down Expand Up @@ -345,7 +345,7 @@ def test_check_output(self):
self.check_output(check_new_ir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=False)
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True)


class TestGatherNdOpWithSameIndexAsXFP16(TestGatherNdOpWithSameIndexAsX):
Expand Down Expand Up @@ -373,7 +373,7 @@ def test_check_grad(self):
['X'],
'Out',
check_prim=True,
check_new_ir=False,
check_new_ir=True,
numeric_grad_delta=0.5,
)

Expand Down Expand Up @@ -410,7 +410,7 @@ def test_check_output(self):
self.check_output(check_new_ir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=False)
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True)


class TestGatherNdOpWithHighRankSameFP16(TestGatherNdOpWithHighRankSame):
Expand All @@ -434,7 +434,7 @@ def test_check_output(self):
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', check_prim=True, check_new_ir=False
place, ['X'], 'Out', check_prim=True, check_new_ir=True
)


Expand Down Expand Up @@ -471,7 +471,7 @@ def test_check_output(self):
self.check_output(check_new_ir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=False)
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True)


class TestGatherNdOpWithHighRankDiffFP16(TestGatherNdOpWithHighRankDiff):
Expand All @@ -495,7 +495,7 @@ def test_check_output(self):
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', check_prim=True, check_new_ir=False
place, ['X'], 'Out', check_prim=True, check_new_ir=True
)


Expand Down

0 comments on commit ba15a12

Please sign in to comment.