Skip to content

Commit

Permalink
Fix transpose and UT
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f committed Apr 22, 2022
1 parent 3defd25 commit d1acc77
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 19 deletions.
2 changes: 1 addition & 1 deletion python/paddle/autograd/primx.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False):
self.add_vars_rec(ins_bar_rec)

ins_bar = flatten(ins_bar_rec)
ins = get_input_vars(op)
ins = flatten(op_position_inputs(op))
assert len(ins) == len(ins_bar)

for dot, bar in zip(ins, ins_bar):
Expand Down
32 changes: 14 additions & 18 deletions python/paddle/fluid/tests/unittests/test_autogard_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ def test_run(self):
for k, v in self.ys_shape_map.items():
self.assertEqual(flatten_ys_dot[k].shape, v)

import pdb
pdb.set_trace()
# Test transpose
ys_bar, xs_bar = ad.transpose(ys_dot, xs_dot, retain_fwd=False)
transpose_ops = [op.type for op in self.main_program.block(0).ops]
Expand All @@ -132,8 +130,6 @@ def test_run(self):
for k, v in self.ys_shape_map.items():
self.assertEqual(flatten_ys_bar[k].shape, v)

import pdb
pdb.set_trace()
# Test prim2orig
prim2orig(block=self.main_program.block(0))
prim2orig_ops = [op.type for op in self.main_program.block(0).ops]
Expand Down Expand Up @@ -172,9 +168,9 @@ def init_data(self):
# linearized op
'reshape_p',
'mul_p',
'mul_p',
'add_p',
'add_p',
# 'mul_p', # JVP rules handle `None` input, some op will not be appended
# 'add_p',
# 'add_p',
'matmul_p',
'matmul_p',
'add_p'
Expand All @@ -191,7 +187,7 @@ def init_data(self):
'matmul_p',
'transpose_p',
'matmul_p',
'mul_p',
# 'mul_p',
'reshape_p',
]

Expand All @@ -211,7 +207,7 @@ def init_data(self):
'matmul_v2',
'transpose2',
'matmul_v2',
'elementwise_mul',
# 'elementwise_mul',
'reshape2',
]

Expand Down Expand Up @@ -265,7 +261,7 @@ def init_data(self):
'mul_p',
'add_p',
'reduce_p',
'fill_constant_p', # 'sqrt_p', Will not new sqrt_p op when apply JVP for sqrt_p
'fill_constant_p', # 'sqrt_p', Will not append sqrt_p op when apply JVP for sqrt_p
'mul_p',
'div_p',
'broadcast_p',
Expand Down Expand Up @@ -299,20 +295,20 @@ def init_data(self):
'split_p',
'fill_constant_p',
'scatter_add_p',
'add_p', # The output of the op is used by multiple subsequent ops
'add_p',
'add_p' # The output of the op is used by multiple subsequent ops
]

self.prim2orig_ops = [
'broadcast', 'elementwise_add', 'reshape2', 'elementwise_mul',
'reduce', 'sqrt', 'broadcast', 'elementwise_sub', 'concat',
'expand_v2', 'elementwise_add', 'reshape2', 'elementwise_mul',
'reduce_sum', 'sqrt', 'expand_v2', 'elementwise_sub', 'concat',
'gather', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant',
'elementwise_mul', 'reduce', 'reshape2', 'reshape2',
'elementwise_mul', 'elementwise_mul', 'reshape2', 'broadcast',
'elementwise_div', 'reduce', 'reshape2', 'fill_constant',
'elementwise_sub', 'split', 'fill_constant', 'scatter_add',
'elementwise_add', 'elementwise_add'
'elementwise_mul', 'reduce_sum', 'reshape2', 'reshape2',
'elementwise_mul', 'elementwise_mul', 'reshape2', 'expand_v2',
'elementwise_div', 'reduce_sum', 'reshape2', 'fill_constant',
'elementwise_sub', 'split', 'fill_constant', 'fill_any_like',
'elementwise_add', 'scatter', 'elementwise_add', 'elementwise_add'
]


Expand Down

0 comments on commit d1acc77

Please sign in to comment.