From fc374b8f962d9816c133fdf0d63dab39763b3d38 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Tue, 2 Jan 2024 14:58:30 +0000 Subject: [PATCH 1/3] [Dy2St] Fix NameloadJstTransformer missing transform call kwargs --- python/paddle/base/dygraph/math_op_patch.py | 4 +-- .../transformers/basic_api_transformer.py | 2 ++ python/paddle/pir/math_op_patch.py | 30 +++++++++++++++++++ .../test_load_transformer.py | 22 ++++++++++++++ test/legacy_test/test_math_op_patch_pir.py | 10 +++++++ 5 files changed, 65 insertions(+), 3 deletions(-) diff --git a/python/paddle/base/dygraph/math_op_patch.py b/python/paddle/base/dygraph/math_op_patch.py index 172f73bf7f531..3f7b7a40ffa46 100644 --- a/python/paddle/base/dygraph/math_op_patch.py +++ b/python/paddle/base/dygraph/math_op_patch.py @@ -167,9 +167,7 @@ def _size_(var): def _T_(var): if len(var.shape) == 1: return var - perm = [] - for i in range(len(var.shape)): - perm.insert(0, i) + perm = list(reversed(range(len(var.shape)))) out = _C_ops.transpose(var, perm) return out diff --git a/python/paddle/jit/dy2static/transformers/basic_api_transformer.py b/python/paddle/jit/dy2static/transformers/basic_api_transformer.py index 0902a3558b2b0..01b831706cceb 100644 --- a/python/paddle/jit/dy2static/transformers/basic_api_transformer.py +++ b/python/paddle/jit/dy2static/transformers/basic_api_transformer.py @@ -152,6 +152,8 @@ def visit_Call(self, node): Can't convert name of function call, bacause this will affect CallTransformer. """ node.args = [self.visit(arg) for arg in node.args] + for keyword in node.keywords: + keyword.value = self.visit(keyword.value) node.func = self.visit(node.func) return node diff --git a/python/paddle/pir/math_op_patch.py b/python/paddle/pir/math_op_patch.py index 45f8917bf04de..7728e94e288d8 100644 --- a/python/paddle/pir/math_op_patch.py +++ b/python/paddle/pir/math_op_patch.py @@ -405,6 +405,35 @@ def _size_(self): """ return paddle.numel(self) + @property + def _T_(self): + """ + + Permute current Variable with its dimensions reversed. + + If `n` is the dimensions of `x` , `x.T` is equivalent to `x.transpose([n-1, n-2, ..., 0])`. + + Examples: + .. code-block:: python + + >>> import paddle + >>> paddle.enable_static() + + >>> x = paddle.ones(shape=[2, 3, 5]) + >>> x_T = x.T + + >>> exe = paddle.static.Executor() + >>> x_T_np = exe.run(paddle.static.default_main_program(), fetch_list=[x_T])[0] + >>> print(x_T_np.shape) + (5, 3, 2) + + """ + if len(self.shape) == 1: + return self + perm = list(reversed(range(len(self.shape)))) + + return _C_ops.transpose(self, perm) + def clone(self): """ Returns a new static Value, which is the clone of the original static @@ -511,6 +540,7 @@ def value_hash(self): ('ndim', _ndim), ('astype', astype), ('size', _size_), + ('T', _T_), ('clone', clone), ('clear_gradient', clear_gradient), ('append', append), diff --git a/test/dygraph_to_static/test_load_transformer.py b/test/dygraph_to_static/test_load_transformer.py index 6698ba7ef6075..80652734e933e 100644 --- a/test/dygraph_to_static/test_load_transformer.py +++ b/test/dygraph_to_static/test_load_transformer.py @@ -71,5 +71,27 @@ def func(x): np.testing.assert_allclose(output_dy.numpy(), output_st.numpy()) +class LoadInCallKwargsNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.extra_inputs = [] + + def forward(self, x): + for i in range(len(self.extra_inputs)): + x = paddle.nn.functional.linear(weight=self.extra_inputs[i].T, x=x) + return x + + +class TestLoadInCallKwargs(Dy2StTestBase): + @test_legacy_and_pt_and_pir + def test_name_load_nograd(self): + net = LoadInCallKwargsNet() + x = paddle.rand([10, 10]) + net.extra_inputs.append(paddle.rand([10, 10])) + output_st = paddle.jit.to_static(net)(x) + output_dy = net(x) + np.testing.assert_allclose(output_dy.numpy(), output_st.numpy()) + + if __name__ == "__main__": unittest.main() diff --git a/test/legacy_test/test_math_op_patch_pir.py b/test/legacy_test/test_math_op_patch_pir.py index dc2fe9abed1a9..6dd274229b39a 100644 --- a/test/legacy_test/test_math_op_patch_pir.py +++ b/test/legacy_test/test_math_op_patch_pir.py @@ -450,6 +450,16 @@ def test_size(self): (output_x,) = exe.run(main_program, fetch_list=[x.size]) self.assertEqual(output_x, 24) + def test_T(self): + with paddle.pir_utils.IrGuard(): + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.assign(np.random.rand(2, 3, 4).astype("float32")) + x_T = x.T + self.assertEqual(x_T.shape, [4, 3, 2]) + (output_x,) = exe.run(main_program, fetch_list=[x_T]) + self.assertEqual(output_x.shape, (4, 3, 2)) + def test_hash_error(self): with paddle.pir_utils.IrGuard(): _, _, program_guard = new_program() From c39f2b6bba7611df053722add37e52367483ea6f Mon Sep 17 00:00:00 2001 From: Nyakku Shigure Date: Wed, 3 Jan 2024 10:22:30 +0800 Subject: [PATCH 2/3] `Variable` -> `Value` Co-authored-by: gouzil <66515297+gouzil@users.noreply.github.com> --- python/paddle/pir/math_op_patch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/pir/math_op_patch.py b/python/paddle/pir/math_op_patch.py index 7728e94e288d8..74cb7157c6f24 100644 --- a/python/paddle/pir/math_op_patch.py +++ b/python/paddle/pir/math_op_patch.py @@ -409,7 +409,7 @@ def _size_(self): def _T_(self): """ - Permute current Variable with its dimensions reversed. + Permute current Value with its dimensions reversed. If `n` is the dimensions of `x` , `x.T` is equivalent to `x.transpose([n-1, n-2, ..., 0])`. From 0e861a32d4c6c0b129bbb66721c66cfeb457007a Mon Sep 17 00:00:00 2001 From: SigureMo Date: Thu, 4 Jan 2024 02:44:37 +0000 Subject: [PATCH 3/3] add more cases for test_T --- test/legacy_test/test_math_op_patch_pir.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/test/legacy_test/test_math_op_patch_pir.py b/test/legacy_test/test_math_op_patch_pir.py index 6dd274229b39a..dbd57c1999115 100644 --- a/test/legacy_test/test_math_op_patch_pir.py +++ b/test/legacy_test/test_math_op_patch_pir.py @@ -452,13 +452,17 @@ def test_size(self): def test_T(self): with paddle.pir_utils.IrGuard(): - main_program, exe, program_guard = new_program() - with program_guard: - x = paddle.assign(np.random.rand(2, 3, 4).astype("float32")) - x_T = x.T - self.assertEqual(x_T.shape, [4, 3, 2]) - (output_x,) = exe.run(main_program, fetch_list=[x_T]) - self.assertEqual(output_x.shape, (4, 3, 2)) + for ndim in range(5): + # shape is [], [1], [1, 2], [1, 2, 3], [1, 2, 3, 4] + shape = list(range(1, ndim + 1)) + out_shape = list(reversed(shape)) + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.rand(shape, dtype="float32") + x_T = x.T + self.assertEqual(x_T.shape, out_shape) + (output_x,) = exe.run(main_program, fetch_list=[x_T]) + self.assertEqual(output_x.shape, tuple(out_shape)) def test_hash_error(self): with paddle.pir_utils.IrGuard():