Skip to content

Commit

Permalink
pass first linearize+transpose test.
Browse files Browse the repository at this point in the history
  • Loading branch information
tongxin committed Apr 13, 2022
1 parent e980c9d commit 4b99805
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
9 changes: 6 additions & 3 deletions python/paddle/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ def broadcast_transpose(op, check_dot, y_bar):
axis = list(range(bat))
keepdim = [(bat + i) for i, s in enumerate(x.shape) if s == 1]
axis += keepdim
return reduce(y_bar, axis=axis, keepdim=keepdim)
# TODO: Change it. keepdim boolean
return reduce(y_bar, axis=axis, keepdim=False)


@REGISTER_TRANSPOSE('transpose_p')
Expand Down Expand Up @@ -263,10 +264,12 @@ def reduce_transpose(op, check_dot, y_bar):
def matmul_transpose(op, check_dot, z_bar):
x, y = get_input_vars(op)
assert check_dot(x) ^ check_dot(y)
# TODO: replace it. this is hacky
axis = [1, 0] if len(x.shape) == 2 else [0, 2, 1]
if check_dot(x):
return matmul(z_bar, transpose(y)), None
return matmul(z_bar, transpose(y, axis=axis)), None
else:
return None, matmul(transpose(x), z_bar)
return None, matmul(transpose(x, axis=axis), z_bar)


@REGISTER_TRANSPOSE('slice_select_p')
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/autograd/primx.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def add_vars(self, new_vars):

def erase_dots(self, vars_to_erase):
for var in vars_to_erase:
del self.vars[id(var)]
if id(var) in self.vars:
del self.vars[id(var)]
self.dot2bar.delete_keyvars(vars_to_erase)
self.var2dot.delete_valuevars(vars_to_erase)
for var in vars_to_erase:
Expand Down

0 comments on commit 4b99805

Please sign in to comment.