Skip to content

Commit

Permalink
Fix elementwise_sub orig2prim
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f committed Apr 21, 2022
1 parent d123705 commit 1441a7e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
6 changes: 3 additions & 3 deletions python/paddle/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,14 +235,14 @@ def index_select_orig2prim(op, index_t, x):
def elementwise_sub_orig2prim(op, x, y):
if x.shape != y.shape:
y = broadcast(y, shape=x.shape)
if op.attr('Scale_x'):
if op.attr('Scale_x') - 1.0 > 1e-5:
tmp = fill_const(shape=x.shape, dtype=x.dtype, value=op.attr('Scale_x'))
x = mul(x, tmp)
if op.attr('Scale_y'):
if op.attr('Scale_y') - 1.0 > 1e-5:
tmp = fill_const(shape=y.shape, dtype=y.dtype, value=op.attr('Scale_y'))
y = mul(y, tmp)
z = sub(x, y)
if op.attr('Scale_out'):
if op.attr('Scale_out') - 1.0 > 1e-5:
tmp = fill_const(
shape=z.shape, dtype=z.dtype, value=op.attr('Scale_out'))
z = mul(z, tmp)
Expand Down
5 changes: 1 addition & 4 deletions python/paddle/fluid/tests/unittests/test_orig2prig_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,10 +257,7 @@ def init_data(self):
self.orig2prim_args = (
X,
Y, )
self.all_ops = [
'elementwise_sub', 'broadcast_p', 'fill_constant_p', 'mul_p',
'fill_constant_p', 'mul_p', 'sub_p', 'fill_constant_p', 'mul_p'
]
self.all_ops = ['elementwise_sub', 'broadcast_p', 'sub_p']
self.out_map = {0: self.output['Out']}


Expand Down

0 comments on commit 1441a7e

Please sign in to comment.