Skip to content

Commit

Permalink
Fix scatter_add_p prim2orig
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f committed Apr 19, 2022
1 parent 28c2fca commit dc7b7d2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 28 deletions.
8 changes: 4 additions & 4 deletions python/paddle/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,10 @@ def gather_prim2orig(op, index_t, x):

@REGISTER_PRIM2ORIG('scatter_add_p')
def scatter_add_prim2orig(op, index_t, x, y):
# assert op.attr('axis') == 0
# using scatter_nd_add
return paddle.put_along_axis(
x, index_t, y, axis=op.attr('axis'), reduce='add')
assert op.attr('axis') == 0, 'Only support axis==0 currently'
zeros = paddle.zeros_like(x=x, dtype=x.dtype)
tmp = paddle.scatter(x=zeros, index=index_t, updates=y, overwrite=False)
return paddle.add(x, tmp)


@REGISTER_PRIM2ORIG('fill_constant_p')
Expand Down
44 changes: 20 additions & 24 deletions python/paddle/fluid/tests/unittests/test_prim2orig_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,30 +341,26 @@ def init_data(self):
self.out_map = {self.output['Y']: 0}


# class TestScatterAddPPrim2Orig(TestAddPPrim2Orig):
# def init_data(self):
# self.op_type = 'scatter_add_p'
# X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
# Y = paddle.static.data(name='Y', shape=[3, 5], dtype='float64')
# IndexTensor = paddle.static.data(name='IndexTensor', shape=[3], dtype='int32')

# self.input = {
# 'X': X,
# 'Y': Y,
# 'IndexTensor': IndexTensor
# }
# self.output = {
# 'Z': self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
# }
# self.attrs = {
# 'axis': 0,
# }

# self.prim2orig_args = (IndexTensor, X, Y)
# self.all_ops = ['scatter_add_p', 'expand_v2', 'expand_v2', 'put_along_axis', ]
# self.out_map = {
# self.output['Z']: 0
# }
class TestScatterAddPPrim2Orig(TestAddPPrim2Orig):
def init_data(self):
self.op_type = 'scatter_add_p'
X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
Y = paddle.static.data(name='Y', shape=[3, 5], dtype='float64')
IndexTensor = paddle.static.data(
name='IndexTensor', shape=[3], dtype='int32')

self.input = {'X': X, 'Y': Y, 'IndexTensor': IndexTensor}
self.output = {
'Z':
self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
}
self.attrs = {'axis': 0, }

self.prim2orig_args = (IndexTensor, X, Y)
self.all_ops = [
'scatter_add_p', 'fill_any_like', 'scatter', 'elementwise_add'
]
self.out_map = {self.output['Z']: 0}


class TestFillConstantPPrim2Orig(TestAddPPrim2Orig):
Expand Down

0 comments on commit dc7b7d2

Please sign in to comment.