Skip to content

Commit

Permalink
fix dot2bar delete.
Browse files Browse the repository at this point in the history
  • Loading branch information
tongxin committed Apr 19, 2022
1 parent 72fe343 commit 3adfe6b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
11 changes: 7 additions & 4 deletions python/paddle/autograd/primx.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ def delete(self, key_var):

def delete_keyvars(self, key_vars):
for var in key_vars:
del self.tab[id(var)]
varid = id(var)
if varid in self.tab:
del self.tab[varid]

def delete_valuevars(self, value_vars):
ids = [id(v) for v in value_vars]
Expand Down Expand Up @@ -408,9 +410,10 @@ def _gradients(ys, xs, ys_bar=None):
orig2prim(block, new_vars)

ad = Transform(block)
new_xs = new_vars[:len(xs)]
new_ys = new_vars[len(xs):]
xs_dot, ys_dot = ad.linearize(new_xs, new_ys)
xs = new_vars[:len(xs)]
ys = new_vars[len(xs):]

xs_dot, ys_dot = ad.linearize(xs, ys)
ys_bar, xs_bar = ad.transpose(ys_dot, xs_dot, ys_bar)
# remove xs_dot and their constructor ops

Expand Down
7 changes: 3 additions & 4 deletions python/paddle/fluid/tests/unittests/test_primops.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,12 @@ def test_gradients_set2(self):
x = paddle.static.data('X', shape=[3, 3], dtype='float32')
y = paddle.static.data('Y', shape=[3, 3], dtype='float32')
# z = prog2(x, y)
t = paddle.multiply(x, x)
t = paddle.matmul(x, x)
z = paddle.norm(t, p=2)
orig2prim(x.block)
# x_grad, y_grad = _gradients([z], [x, y])
x_grad, y_grad = _gradients([z], [x, y])
# path, _, _ = topo_path([x, y], [x_grad, y_grad])
# print(f'-------test_gradients_set2-------')
# print(x.block)
print(x.block)

def test_gradients_set3(self):
main = paddle.static.Program()
Expand Down

0 comments on commit 3adfe6b

Please sign in to comment.