Skip to content

Commit

Permalink
update for use prim2orig in minimize
Browse files Browse the repository at this point in the history
  • Loading branch information
levi131 committed Apr 18, 2022
1 parent cff965f commit 4659530
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
16 changes: 11 additions & 5 deletions python/paddle/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,17 +234,17 @@ def add_prim2orig(op, x, y):

@REGISTER_PRIM2ORIG('sub_p')
def sub_prim2orig(op, x, y):
return paddle.sub(x, y)
return paddle.subtract(x, y)


@REGISTER_PRIM2ORIG('mul_p')
def mul_prim2orig(op, x, y):
return paddle.mul(x, y)
return paddle.multiply(x, y)


@REGISTER_PRIM2ORIG('div_p')
def div_prim2orig(op, x, y):
return paddle.div(x, y)
return paddle.divide(x, y)


@REGISTER_PRIM2ORIG('sqrt_p')
Expand All @@ -259,7 +259,7 @@ def tanh_prim2orig(op, x):

@REGISTER_PRIM2ORIG('reshape_p')
def reshape_prim2orig(op, x):
y, _ = paddle.reshape(x, shape=op.attr('shape'))
y = paddle.reshape(x, shape=op.attr('shape'))
return y


Expand Down Expand Up @@ -288,7 +288,7 @@ def concat_prim2orig(op, *xs):


@REGISTER_PRIM2ORIG('reduce_p')
def reduce_prim2orig(op, *xs):
def reduce_prim2orig(op, xs):
return paddle.sum(xs, axis=op.attr('axis'), keepdim=op.attr('keepdim'))


Expand Down Expand Up @@ -339,6 +339,12 @@ def scatter_add_prim2orig(op, index_t, x, y):
x, index_t, y, axis=op.attr('axis'), reduce='add')


@REGISTER_PRIM2ORIG('fill_constant_p')
def fill_constant_prim2orig(op):
return paddle.full(
shape=op.attr('shape'), fill_value=op.attr('value'), dtype='float32')


## Register linearize rules


Expand Down
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_primops.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def test_minimize(self):
loss = paddle.norm(y, p=2)
opt = AdamOptimizer(0.01)
opt.minimize(loss)
# prim2orig(x.block, update_var_list=[loss])
prim2orig(x.block, update_var_list=[loss])

print(f'-------test_minimize: orig-------')
print(x.block)
Expand Down

0 comments on commit 4659530

Please sign in to comment.