diff --git a/python/paddle/base/dygraph/math_op_patch.py b/python/paddle/base/dygraph/math_op_patch.py index 172f73bf7f531..3f7b7a40ffa46 100644 --- a/python/paddle/base/dygraph/math_op_patch.py +++ b/python/paddle/base/dygraph/math_op_patch.py @@ -167,9 +167,7 @@ def _size_(var): def _T_(var): if len(var.shape) == 1: return var - perm = [] - for i in range(len(var.shape)): - perm.insert(0, i) + perm = list(reversed(range(len(var.shape)))) out = _C_ops.transpose(var, perm) return out diff --git a/python/paddle/jit/dy2static/transformers/basic_api_transformer.py b/python/paddle/jit/dy2static/transformers/basic_api_transformer.py index 0902a3558b2b0..01b831706cceb 100644 --- a/python/paddle/jit/dy2static/transformers/basic_api_transformer.py +++ b/python/paddle/jit/dy2static/transformers/basic_api_transformer.py @@ -152,6 +152,8 @@ def visit_Call(self, node): Can't convert name of function call, bacause this will affect CallTransformer. """ node.args = [self.visit(arg) for arg in node.args] + for keyword in node.keywords: + keyword.value = self.visit(keyword.value) node.func = self.visit(node.func) return node diff --git a/python/paddle/pir/math_op_patch.py b/python/paddle/pir/math_op_patch.py index 45f8917bf04de..74cb7157c6f24 100644 --- a/python/paddle/pir/math_op_patch.py +++ b/python/paddle/pir/math_op_patch.py @@ -405,6 +405,35 @@ def _size_(self): """ return paddle.numel(self) + @property + def _T_(self): + """ + + Permute current Value with its dimensions reversed. + + If `n` is the dimensions of `x` , `x.T` is equivalent to `x.transpose([n-1, n-2, ..., 0])`. + + Examples: + .. code-block:: python + + >>> import paddle + >>> paddle.enable_static() + + >>> x = paddle.ones(shape=[2, 3, 5]) + >>> x_T = x.T + + >>> exe = paddle.static.Executor() + >>> x_T_np = exe.run(paddle.static.default_main_program(), fetch_list=[x_T])[0] + >>> print(x_T_np.shape) + (5, 3, 2) + + """ + if len(self.shape) == 1: + return self + perm = list(reversed(range(len(self.shape)))) + + return _C_ops.transpose(self, perm) + def clone(self): """ Returns a new static Value, which is the clone of the original static @@ -511,6 +540,7 @@ def value_hash(self): ('ndim', _ndim), ('astype', astype), ('size', _size_), + ('T', _T_), ('clone', clone), ('clear_gradient', clear_gradient), ('append', append), diff --git a/test/dygraph_to_static/test_load_transformer.py b/test/dygraph_to_static/test_load_transformer.py index 6698ba7ef6075..80652734e933e 100644 --- a/test/dygraph_to_static/test_load_transformer.py +++ b/test/dygraph_to_static/test_load_transformer.py @@ -71,5 +71,27 @@ def func(x): np.testing.assert_allclose(output_dy.numpy(), output_st.numpy()) +class LoadInCallKwargsNet(paddle.nn.Layer): + def __init__(self): + super().__init__() + self.extra_inputs = [] + + def forward(self, x): + for i in range(len(self.extra_inputs)): + x = paddle.nn.functional.linear(weight=self.extra_inputs[i].T, x=x) + return x + + +class TestLoadInCallKwargs(Dy2StTestBase): + @test_legacy_and_pt_and_pir + def test_name_load_nograd(self): + net = LoadInCallKwargsNet() + x = paddle.rand([10, 10]) + net.extra_inputs.append(paddle.rand([10, 10])) + output_st = paddle.jit.to_static(net)(x) + output_dy = net(x) + np.testing.assert_allclose(output_dy.numpy(), output_st.numpy()) + + if __name__ == "__main__": unittest.main() diff --git a/test/legacy_test/test_math_op_patch_pir.py b/test/legacy_test/test_math_op_patch_pir.py index dc2fe9abed1a9..dbd57c1999115 100644 --- a/test/legacy_test/test_math_op_patch_pir.py +++ b/test/legacy_test/test_math_op_patch_pir.py @@ -450,6 +450,20 @@ def test_size(self): (output_x,) = exe.run(main_program, fetch_list=[x.size]) self.assertEqual(output_x, 24) + def test_T(self): + with paddle.pir_utils.IrGuard(): + for ndim in range(5): + # shape is [], [1], [1, 2], [1, 2, 3], [1, 2, 3, 4] + shape = list(range(1, ndim + 1)) + out_shape = list(reversed(shape)) + main_program, exe, program_guard = new_program() + with program_guard: + x = paddle.rand(shape, dtype="float32") + x_T = x.T + self.assertEqual(x_T.shape, out_shape) + (output_x,) = exe.run(main_program, fetch_list=[x_T]) + self.assertEqual(output_x.shape, tuple(out_shape)) + def test_hash_error(self): with paddle.pir_utils.IrGuard(): _, _, program_guard = new_program()