Skip to content

Commit

Permalink
Merge remote-tracking branch 'tx/ad' into lml/lower2prim
Browse files Browse the repository at this point in the history
  • Loading branch information
levi131 committed Apr 19, 2022
2 parents c55a6e4 + 0d38236 commit 9a94776
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
11 changes: 7 additions & 4 deletions python/paddle/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,9 @@ def gather_transpose(op, check_dot, y_bar):
assert check_dot(x)
axis = op.attr('axis')
zeros = fill_const(0.0, x.shape, x.dtype)
return scatter_add(zeros, y_bar, indextensor, axis=axis), None
x_bar = scatter_add(zeros, y_bar, indextensor, axis=axis)
indextensor_bar = None
return x_bar, indextensor_bar


@REGISTER_TRANSPOSE('scatter_add_p')
Expand All @@ -658,6 +660,7 @@ def scatter_add_transpose(op, check_dot, z_bar):
assert check_dot(x) and check_dot(y)
axis = op.attr('axis')
zeros = fill_const(value=0.0, shape=y.shape, dtype=y.dtype)
return scatter_add(
z_bar, zeros, indextensor, axis=axis), gather(
z_bar, indextensor, axis=axis), None
x_bar = scatter_add(z_bar, zeros, indextensor, axis=axis)
y_bar = gather(z_bar, indextensor, axis=axis)
indextensor_bar = None
return x_bar, y_bar, indextensor_bar
2 changes: 2 additions & 0 deletions python/paddle/autograd/primx.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,10 @@ def erase_dots(self, vars_to_erase):
block = self.block
for var in vars_to_erase:
block.desc._remove_var(cpt.to_bytes(var.name))
del block.vars[var.name]
block._sync_with_cpp()


def var2dot_rec(self, vars, defaults=None):

if isinstance(vars, paddle.fluid.framework.Variable):
Expand Down

0 comments on commit 9a94776

Please sign in to comment.