Skip to content

Commit

Permalink
update for matmul_v2 and reshape2 orig2prim
Browse files Browse the repository at this point in the history
  • Loading branch information
levi131 committed Apr 17, 2022
1 parent e347444 commit 84994d7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
13 changes: 10 additions & 3 deletions python/paddle/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ def trans(shape):

assert len(x.shape) < 4 and len(
y.shape) < 4, 'Do not support multi batchsize dimensions currently.'

if len(x.shape == 1):
x = broadcast(x, shape=[1, x.shape[0]])
if len(y.shape == 1):
y = broadcast(y, shape=[y.shape[0], 1])
if op.attr('trans_x'):
x = transpose(x, shape=trans(x.shape))
if op.attr('trans_y'):
Expand Down Expand Up @@ -117,13 +122,15 @@ def tanh_orig2prim(op, x):
return tanh(x)


## NOTE(lml): The second output of reshape2 Xshape, can't be described by prim ops, use paddle.shape() interface instead.
## NOTE(lml): The second output of reshape2 Xshape, which is only used in reshape2_grad, is meanlingless in new autograd mechanism, thus we use a zero tensor instead.
@REGISTER_ORIG2PRIM('reshape2')
def reshape2_orig2prim(op, shape_t, shape_tl, x):
assert shape_t is None, 'Can not lower reshape2 into prim ops with shapetensor.'
assert shape_tl is None, 'Can not lower reshape2 into prim ops with shapetensorlist.'
y, _ = get_output_vars(op)
return reshape(x, shape=y.shape), paddle.shape(x)
y, xshape = get_output_vars(op)
return reshape(
x, shape=y.shape), fill_const(
shape=xshape.shape, dtype=xshape.dtype, value=0.0)


@REGISTER_ORIG2PRIM('concat')
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/autograd/primx.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False):

return ys_bar, xs_bar


def _gradients(ys, xs, ys_bar=None):
""" A drop-in replacement of paddle.gradients for computing
the gradients of `xs` against `ys` using primitive ops based
Expand Down Expand Up @@ -383,8 +384,8 @@ def _lower(block, reverse, update_var_list):
for op_idx in reversed(ops_to_remove):
block._remove_op(op_idx)
for var_name in vars_to_remove:
# block._remove_var(var_name)
del block.vars[var_name]
block._remove_var(var_name)
# del block.vars[var_name]

if update_var_list is not None:
for i in range(len(update_var_list)):
Expand Down

0 comments on commit 84994d7

Please sign in to comment.