From d1acc7749b18b98e3fe2590952d4981b5c43ae78 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Fri, 22 Apr 2022 04:16:16 +0000 Subject: [PATCH] Fix transpose and UT --- python/paddle/autograd/primx.py | 2 +- .../unittests/test_autogard_transform.py | 32 ++++++++----------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/python/paddle/autograd/primx.py b/python/paddle/autograd/primx.py index 2e292b0b04271..4a47dbeac303a 100644 --- a/python/paddle/autograd/primx.py +++ b/python/paddle/autograd/primx.py @@ -360,7 +360,7 @@ def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False): self.add_vars_rec(ins_bar_rec) ins_bar = flatten(ins_bar_rec) - ins = get_input_vars(op) + ins = flatten(op_position_inputs(op)) assert len(ins) == len(ins_bar) for dot, bar in zip(ins, ins_bar): diff --git a/python/paddle/fluid/tests/unittests/test_autogard_transform.py b/python/paddle/fluid/tests/unittests/test_autogard_transform.py index 350a38bb468d9..cef2c2017b31e 100644 --- a/python/paddle/fluid/tests/unittests/test_autogard_transform.py +++ b/python/paddle/fluid/tests/unittests/test_autogard_transform.py @@ -117,8 +117,6 @@ def test_run(self): for k, v in self.ys_shape_map.items(): self.assertEqual(flatten_ys_dot[k].shape, v) - import pdb - pdb.set_trace() # Test transpose ys_bar, xs_bar = ad.transpose(ys_dot, xs_dot, retain_fwd=False) transpose_ops = [op.type for op in self.main_program.block(0).ops] @@ -132,8 +130,6 @@ def test_run(self): for k, v in self.ys_shape_map.items(): self.assertEqual(flatten_ys_bar[k].shape, v) - import pdb - pdb.set_trace() # Test prim2orig prim2orig(block=self.main_program.block(0)) prim2orig_ops = [op.type for op in self.main_program.block(0).ops] @@ -172,9 +168,9 @@ def init_data(self): # linearized op 'reshape_p', 'mul_p', - 'mul_p', - 'add_p', - 'add_p', + # 'mul_p', # JVP rules handle `None` input, some op will not be appended + # 'add_p', + # 'add_p', 'matmul_p', 'matmul_p', 'add_p' @@ -191,7 +187,7 @@ def init_data(self): 'matmul_p', 'transpose_p', 'matmul_p', - 'mul_p', + # 'mul_p', 'reshape_p', ] @@ -211,7 +207,7 @@ def init_data(self): 'matmul_v2', 'transpose2', 'matmul_v2', - 'elementwise_mul', + # 'elementwise_mul', 'reshape2', ] @@ -265,7 +261,7 @@ def init_data(self): 'mul_p', 'add_p', 'reduce_p', - 'fill_constant_p', # 'sqrt_p', Will not new sqrt_p op when apply JVP for sqrt_p + 'fill_constant_p', # 'sqrt_p', Will not append sqrt_p op when apply JVP for sqrt_p 'mul_p', 'div_p', 'broadcast_p', @@ -299,20 +295,20 @@ def init_data(self): 'split_p', 'fill_constant_p', 'scatter_add_p', + 'add_p', # The output of the op is used by multiple subsequent ops 'add_p', - 'add_p' # The output of the op is used by multiple subsequent ops ] self.prim2orig_ops = [ - 'broadcast', 'elementwise_add', 'reshape2', 'elementwise_mul', - 'reduce', 'sqrt', 'broadcast', 'elementwise_sub', 'concat', + 'expand_v2', 'elementwise_add', 'reshape2', 'elementwise_mul', + 'reduce_sum', 'sqrt', 'expand_v2', 'elementwise_sub', 'concat', 'gather', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', - 'elementwise_mul', 'reduce', 'reshape2', 'reshape2', - 'elementwise_mul', 'elementwise_mul', 'reshape2', 'broadcast', - 'elementwise_div', 'reduce', 'reshape2', 'fill_constant', - 'elementwise_sub', 'split', 'fill_constant', 'scatter_add', - 'elementwise_add', 'elementwise_add' + 'elementwise_mul', 'reduce_sum', 'reshape2', 'reshape2', + 'elementwise_mul', 'elementwise_mul', 'reshape2', 'expand_v2', + 'elementwise_div', 'reduce_sum', 'reshape2', 'fill_constant', + 'elementwise_sub', 'split', 'fill_constant', 'fill_any_like', + 'elementwise_add', 'scatter', 'elementwise_add', 'elementwise_add' ]