Skip to content

Commit

Permalink
fix bug in minimize
Browse files Browse the repository at this point in the history
  • Loading branch information
levi131 committed Apr 18, 2022
1 parent fd4a3b8 commit 62cccfb
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
2 changes: 1 addition & 1 deletion python/paddle/autograd/new_adam_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def append_backward_new(loss,
if parameter_list is None:
parameter_list = program.global_block().all_parameters()
param_dot, loss_dot = ad.linearize(parameter_list, update_var_list)
param_bar, loss_bar = ad.transpose(param_dot, loss_dot)
loss_bar, param_bar = ad.transpose(loss_dot, param_dot)

if len(parameter_list) == 1:
params_and_grads = [(paramteter_list, param_bar)]
Expand Down
23 changes: 13 additions & 10 deletions python/paddle/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,28 +139,31 @@ def tanh_orig2prim(op, x):

## NOTE(lml): The second output of reshape2 Xshape, which is only used in reshape2_grad, is meanlingless in new autograd mechanism, thus we use a zero tensor instead.
@REGISTER_ORIG2PRIM('reshape2')
def reshape2_orig2prim(op, shape_t, shape_tl, x):
assert shape_t is None, 'Can not lower reshape2 into prim ops with shapetensor.'
assert shape_tl is None, 'Can not lower reshape2 into prim ops with shapetensorlist.'
# def reshape2_orig2prim(op, shape_t, shape_tl, x):
def reshape2_orig2prim(op, x):
# assert shape_t is None, 'Can not lower reshape2 into prim ops with shapetensor.'
# assert shape_tl is None, 'Can not lower reshape2 into prim ops with shapetensorlist.'
y, xshape = get_output_vars(op)
return reshape(
x, shape=y.shape), fill_const(
shape=xshape.shape, dtype=xshape.dtype, value=0.0)


@REGISTER_ORIG2PRIM('concat')
def concat_orig2prim(op, axis_t, *xs):
assert axis_t is None, 'Can not lower concat into prim ops with axistensor.'
# def concat_orig2prim(op, axis_t, *xs):
def concat_orig2prim(op, *xs):
# assert axis_t is None, 'Can not lower concat into prim ops with axistensor.'
return concat(xs, axis=op.attr('axis'))


@REGISTER_ORIG2PRIM('slice')
def slice_orig2prim(op, ends_t, ends_tl, x, starts_t, starts_tl):
# def slice_orig2prim(op, ends_t, ends_tl, x, starts_t, starts_tl):
def slice_orig2prim(op, x):

assert starts_t is None, 'Can not lower concat into prim ops with startstensor.'
assert ends_t is None, 'Can not lower concat into prim ops with endstensor.'
assert starts_tl is None, 'Can not lower concat into prim ops with startstensorlist.'
assert ends_tl is None, 'Can not lower concat into prim ops with endstensorlist.'
# assert starts_t is None, 'Can not lower concat into prim ops with startstensor.'
# assert ends_t is None, 'Can not lower concat into prim ops with endstensor.'
# assert starts_tl is None, 'Can not lower concat into prim ops with startstensorlist.'
# assert ends_tl is None, 'Can not lower concat into prim ops with endstensorlist.'
starts = op.attr('starts')
ends = op.attr('ends')
strides = [1 for _ in starts]
Expand Down

0 comments on commit 62cccfb

Please sign in to comment.