diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 6c3db1179..fe016e0fe 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -8214,7 +8214,7 @@ "paddle_api": "paddle.linalg.matrix_rank", "min_input_args": 1, "args_list": [ - "A", + "input", "tol", "hermitian", "*", @@ -8224,7 +8224,7 @@ "out" ], "kwargs_change": { - "A": "x", + "input": "x", "atol": "tol", "rtol": "" } diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 9ca2a04d4..94f3d8ffa 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -1003,6 +1003,7 @@ def get_paddle_nodes(self, args, kwargs): else: code = "{}({})".format(paddle_api, self.kwargs_to_str(new_kwargs)) + self.api_mapping["args_list"] = ["input", "dim", "keepdim", "*", "out"] return ast.parse(code).body # the case of one tensor diff --git a/tests/test_autograd_grad.py b/tests/test_autograd_grad.py index b5c926434..1588fb867 100644 --- a/tests/test_autograd_grad.py +++ b/tests/test_autograd_grad.py @@ -127,3 +127,65 @@ def test_case_7(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.1, 2.2, 3.3], requires_grad=True) + z = torch.tensor([1.1, 2.2, 3.3], requires_grad=True) + grad = torch.tensor(2.0) + y = x * x + z + + result = torch.autograd.grad(outputs=[y.sum()], inputs=[x, z], grad_outputs=grad, retain_graph=True, + create_graph=False, only_inputs=True, allow_unused=True, is_grads_batched=False, materialize_grads=False) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle dose not support 'only_inputs' now!", + ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.1, 2.2, 3.3], requires_grad=True) + z = torch.tensor([1.1, 2.2, 3.3], requires_grad=True) + grad = torch.tensor(2.0) + y = x * x + z + + result = torch.autograd.grad([y.sum()], [x, z], grad, True, False, True, True, False, False) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle dose not support 'only_inputs' now!", + ) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.1, 2.2, 3.3], requires_grad=True) + z = torch.tensor([1.1, 2.2, 3.3], requires_grad=True) + grad = torch.tensor(2.0) + y = x * x + z + + result = torch.autograd.grad(outputs=[y.sum()], inputs=[x, z], retain_graph=True, allow_unused=True, + create_graph=False, only_inputs=True, is_grads_batched=False, grad_outputs=grad, materialize_grads=False) + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle dose not support 'only_inputs' now!", + ) diff --git a/tests/test_bernoulli.py b/tests/test_bernoulli.py index 67cd499b7..f9f36788b 100644 --- a/tests/test_bernoulli.py +++ b/tests/test_bernoulli.py @@ -115,3 +115,43 @@ def test_case_8(): """ ) obj.run(pytorch_code, ["result", "out"]) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.ones(3, 3) + out = torch.zeros(3, 3) + result = torch.bernoulli(generator=torch.Generator(), input=a, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.ones(3, 3) + out = torch.zeros(3, 3) + result = torch.bernoulli(input=a, generator=torch.Generator(), out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_11(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.rand(3, 3) + result = torch.bernoulli(input=a, p=0.0, generator=torch.Generator()) + """ + ) + obj.run( + pytorch_code, + ["a", "result"], + unsupport=True, + reason="paddle not support parameter 'p' ", + ) diff --git a/tests/test_chain_matmul.py b/tests/test_chain_matmul.py index f5aca148b..017cc8636 100644 --- a/tests/test_chain_matmul.py +++ b/tests/test_chain_matmul.py @@ -69,3 +69,14 @@ def test_case_4(): unsupport=True, reason="paddle does not support variable parameter", ) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + v = torch.tensor([[3., 6, 9], [1, 3, 5], [2, 2, 2]]) + result = torch.chain_matmul(v) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_cumulative_trapezoid.py b/tests/test_cumulative_trapezoid.py index 64d0ebcd0..30fbf015f 100644 --- a/tests/test_cumulative_trapezoid.py +++ b/tests/test_cumulative_trapezoid.py @@ -84,3 +84,37 @@ def test_case_6(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + y = torch.tensor([1, 1, 1, 0, 1]).type(torch.float32) + result = torch.cumulative_trapezoid(y=y, dx=0.05, dim=0) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + y = torch.tensor([1, 1, 1, 0, 1]).type(torch.float32) + result = torch.cumulative_trapezoid(y, dx=0.05, dim=0) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + y = torch.tensor([1, 1, 1, 0, 1]).type(torch.float32) + x = torch.tensor([1, 2, 3, 0, 1]).type(torch.float32) + result = torch.cumulative_trapezoid(dim=0, y=y, x=x) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_distributions_Bernoulli.py b/tests/test_distributions_Bernoulli.py index 20406ef8f..c4af8fb11 100644 --- a/tests/test_distributions_Bernoulli.py +++ b/tests/test_distributions_Bernoulli.py @@ -89,3 +89,71 @@ def test_case_6(): """ ) obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Bernoulli(probs=torch.tensor([0.3]), logits=None, validate_args=False) + result = m.sample([100]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support logits temporarily", + ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Bernoulli(torch.tensor([0.3]), None, False) + result = m.sample([100]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support logits temporarily", + ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Bernoulli(probs=torch.tensor([0.3]), validate_args=False, logits=None) + result = m.sample([100]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support logits temporarily", + ) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Bernoulli(probs=None, validate_args=False, logits=3.5) + result = m.sample([100]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support logits temporarily", + ) diff --git a/tests/test_distributions_Categorical.py b/tests/test_distributions_Categorical.py index ae753b27a..8ce1d2533 100644 --- a/tests/test_distributions_Categorical.py +++ b/tests/test_distributions_Categorical.py @@ -78,3 +78,71 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Categorical(probs=None, logits=torch.tensor([0.25, 0.25, 0.25, 0.25]), validate_args=False) + result = m.sample([1]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support probs temporarily", + ) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Categorical(None, torch.tensor([0.25, 0.25, 0.25, 0.25]), False) + result = m.sample([1]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support probs temporarily", + ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Categorical(probs=None, validate_args=False, logits=torch.tensor([0.25, 0.25, 0.25, 0.25])) + result = m.sample([1]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support probs temporarily", + ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Categorical(probs=torch.tensor([0.25, 0.25, 0.25, 0.25]), validate_args=False,logits=None) + result = m.sample([1]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support probs temporarily", + ) diff --git a/tests/test_distributions_Geometric.py b/tests/test_distributions_Geometric.py index 3614c9875..158cc5167 100644 --- a/tests/test_distributions_Geometric.py +++ b/tests/test_distributions_Geometric.py @@ -89,3 +89,71 @@ def test_case_6(): """ ) obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Geometric(probs=torch.tensor([0.3]), logits=None, validate_args=False) + result = m.sample([100]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support logits temporarily", + ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Geometric(probs=torch.tensor([0.3]), validate_args=False, logits=None) + result = m.sample([100]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support logits temporarily", + ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Geometric(torch.tensor([0.3]), None, False) + result = m.sample([100]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support logits temporarily", + ) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Geometric(probs=None, logits=15, validate_args=True) + result = m.sample([100]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support logits temporarily", + ) diff --git a/tests/test_distributions_Multinomial.py b/tests/test_distributions_Multinomial.py index d3e3e0158..67ff84242 100644 --- a/tests/test_distributions_Multinomial.py +++ b/tests/test_distributions_Multinomial.py @@ -89,3 +89,71 @@ def test_case_6(): """ ) obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Multinomial(total_count=1, probs=torch.tensor([0.3]), logits=None, validate_args=False) + result = m.sample([100]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support logits temporarily", + ) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Multinomial(total_count=1, logits=None, probs=torch.tensor([0.3]), validate_args=False) + result = m.sample([100]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support logits temporarily", + ) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Multinomial(1, torch.tensor([0.3]), None, False) + result = m.sample([100]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support logits temporarily", + ) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Multinomial(None, torch.tensor([0.3]), 0.8, False) + result = m.sample([100]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support logits temporarily", + ) diff --git a/tests/test_dsplit.py b/tests/test_dsplit.py index 385cc94f1..ab66c9995 100644 --- a/tests/test_dsplit.py +++ b/tests/test_dsplit.py @@ -66,3 +66,51 @@ def test_case_4(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.arange(16.0).reshape(2, 2, 4) + result = torch.dsplit(input=a,sections=2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.arange(12).reshape(3, 2, 2) + result = torch.dsplit(input=a, indices=[1,1]) + if len(result) > 2: + result = (result[0], result[2]) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.arange(12).reshape(3, 2, 2) + result = torch.dsplit(indices=[1,1], input=a) + if len(result) > 2: + result = (result[0], result[2]) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.arange(16.0).reshape(2, 2, 4) + result = torch.dsplit(sections=2, input=a) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_hsplit.py b/tests/test_hsplit.py index 0ddbece41..8864b0ae3 100644 --- a/tests/test_hsplit.py +++ b/tests/test_hsplit.py @@ -62,3 +62,47 @@ def test_case_4(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.arange(16.0).reshape(4, 4) + result = torch.hsplit(input=t,sections=2) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.arange(16.0).reshape(4, 4) + result = torch.hsplit(sections=2, input=t) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.arange(12).reshape(3, 2, 2) + result = torch.hsplit(input=t, indices=[1, 2]) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + t = torch.arange(12).reshape(3, 2, 2) + result = torch.hsplit(indices=[1, 2], input=t) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_linalg_matrix_rank.py b/tests/test_linalg_matrix_rank.py index 7f2c38d4b..1c51aebc4 100644 --- a/tests/test_linalg_matrix_rank.py +++ b/tests/test_linalg_matrix_rank.py @@ -134,3 +134,143 @@ def test_case_4(): """ ) obj.run(pytorch_code, ["result"], check_dtype=False) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[[[-1.1079, -0.4803, 0.2296], + [ 0.3198, -0.2976, -0.0585], + [-1.6931, -0.3353, -0.2893]], + [[ 1.9757, -0.5959, -1.2041], + [ 0.8443, -0.4916, -1.6574], + [-0.2654, -1.0447, -0.8138]], + [[-0.4111, 1.0973, 0.2275], + [ 1.1851, 1.8233, 0.8187], + [-1.4107, -0.5473, 1.1431]], + [[ 0.0327, -0.8295, 0.0457], + [-0.6286, -0.2507, 0.7292], + [ 0.4075, -1.3918, -0.5015]]], + [[[-2.1256, 0.9310, 1.0743], + [ 1.9577, -0.1513, 0.1668], + [-0.1404, 1.6647, 0.7108]], + [[ 0.9001, 1.6930, -0.4966], + [-1.0432, -1.0742, 1.2273], + [-0.2711, -0.4740, -0.6381]], + [[-1.3099, -1.7540, 0.5443], + [ 0.3565, -2.3821, 0.8638], + [-1.3840, 0.8216, 0.2761]], + [[-0.5989, -0.4732, 1.3252], + [-0.7614, 1.0493, 0.8488], + [-0.1300, 0.1287, 0.6234]]]]) + out = torch.empty((2, 4), dtype=torch.int64) + result = torch.linalg.matrix_rank(input=A, atol=1.0, rtol=0.0, hermitian=True, out=out) + """ + ) + obj.run(pytorch_code, ["result"], check_dtype=False) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[[[-1.1079, -0.4803, 0.2296], + [ 0.3198, -0.2976, -0.0585], + [-1.6931, -0.3353, -0.2893]], + [[ 1.9757, -0.5959, -1.2041], + [ 0.8443, -0.4916, -1.6574], + [-0.2654, -1.0447, -0.8138]], + [[-0.4111, 1.0973, 0.2275], + [ 1.1851, 1.8233, 0.8187], + [-1.4107, -0.5473, 1.1431]], + [[ 0.0327, -0.8295, 0.0457], + [-0.6286, -0.2507, 0.7292], + [ 0.4075, -1.3918, -0.5015]]], + [[[-2.1256, 0.9310, 1.0743], + [ 1.9577, -0.1513, 0.1668], + [-0.1404, 1.6647, 0.7108]], + [[ 0.9001, 1.6930, -0.4966], + [-1.0432, -1.0742, 1.2273], + [-0.2711, -0.4740, -0.6381]], + [[-1.3099, -1.7540, 0.5443], + [ 0.3565, -2.3821, 0.8638], + [-1.3840, 0.8216, 0.2761]], + [[-0.5989, -0.4732, 1.3252], + [-0.7614, 1.0493, 0.8488], + [-0.1300, 0.1287, 0.6234]]]]) + out = torch.empty((2, 4), dtype=torch.int64) + result = torch.linalg.matrix_rank(hermitian=True, out=out, input=A, atol=1.0, rtol=0.0) + """ + ) + obj.run(pytorch_code, ["result"], check_dtype=False) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[[[-1.1079, -0.4803, 0.2296], + [ 0.3198, -0.2976, -0.0585], + [-1.6931, -0.3353, -0.2893]], + [[ 1.9757, -0.5959, -1.2041], + [ 0.8443, -0.4916, -1.6574], + [-0.2654, -1.0447, -0.8138]], + [[-0.4111, 1.0973, 0.2275], + [ 1.1851, 1.8233, 0.8187], + [-1.4107, -0.5473, 1.1431]], + [[ 0.0327, -0.8295, 0.0457], + [-0.6286, -0.2507, 0.7292], + [ 0.4075, -1.3918, -0.5015]]], + [[[-2.1256, 0.9310, 1.0743], + [ 1.9577, -0.1513, 0.1668], + [-0.1404, 1.6647, 0.7108]], + [[ 0.9001, 1.6930, -0.4966], + [-1.0432, -1.0742, 1.2273], + [-0.2711, -0.4740, -0.6381]], + [[-1.3099, -1.7540, 0.5443], + [ 0.3565, -2.3821, 0.8638], + [-1.3840, 0.8216, 0.2761]], + [[-0.5989, -0.4732, 1.3252], + [-0.7614, 1.0493, 0.8488], + [-0.1300, 0.1287, 0.6234]]]]) + out = torch.empty((2, 4), dtype=torch.int64) + result = torch.linalg.matrix_rank(input=A, tol=torch.tensor(1.), hermitian=True, out=out) + """ + ) + obj.run(pytorch_code, ["result"], check_dtype=False) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[[[-1.1079, -0.4803, 0.2296], + [ 0.3198, -0.2976, -0.0585], + [-1.6931, -0.3353, -0.2893]], + [[ 1.9757, -0.5959, -1.2041], + [ 0.8443, -0.4916, -1.6574], + [-0.2654, -1.0447, -0.8138]], + [[-0.4111, 1.0973, 0.2275], + [ 1.1851, 1.8233, 0.8187], + [-1.4107, -0.5473, 1.1431]], + [[ 0.0327, -0.8295, 0.0457], + [-0.6286, -0.2507, 0.7292], + [ 0.4075, -1.3918, -0.5015]]], + [[[-2.1256, 0.9310, 1.0743], + [ 1.9577, -0.1513, 0.1668], + [-0.1404, 1.6647, 0.7108]], + [[ 0.9001, 1.6930, -0.4966], + [-1.0432, -1.0742, 1.2273], + [-0.2711, -0.4740, -0.6381]], + [[-1.3099, -1.7540, 0.5443], + [ 0.3565, -2.3821, 0.8638], + [-1.3840, 0.8216, 0.2761]], + [[-0.5989, -0.4732, 1.3252], + [-0.7614, 1.0493, 0.8488], + [-0.1300, 0.1287, 0.6234]]]]) + out = torch.empty((2, 4), dtype=torch.int64) + result = torch.linalg.matrix_rank(input=A, hermitian=True, tol=torch.tensor(1.), out=out) + """ + ) + obj.run(pytorch_code, ["result"], check_dtype=False) diff --git a/tests/test_linalg_solve_triangular.py b/tests/test_linalg_solve_triangular.py index 91f5ed22a..8ef3947b0 100644 --- a/tests/test_linalg_solve_triangular.py +++ b/tests/test_linalg_solve_triangular.py @@ -79,3 +79,16 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result", "out"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + A = torch.tensor([[ 1.1527, -1.0753], [ 1.23, 0.7986]]) + B = torch.tensor([[-0.0210, 2.3513, -1.5492], [ 1.5429, 0.7403, -1.0243]]) + out = torch.tensor([]) + result = torch.linalg.solve_triangular(input=A, unitriangular=False, upper=True, left=True, B=B, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) diff --git a/tests/test_load.py b/tests/test_load.py index a664e708b..2472b0c0f 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -136,3 +136,31 @@ def test_case_9(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch + import pickle + + result = torch.tensor([0., 1., 2., 3., 4.]) + torch.save(result, 'tensor.pt', pickle_protocol=4) + result = torch.load(f='tensor.pt', map_location=torch.device('cpu'), pickle_module=pickle, weights_only=False, mmap=None) + """ + ) + obj.run(pytorch_code, unsupport=True, reason="`mmap` is not supported in paddle") + + +def test_case_11(): + pytorch_code = textwrap.dedent( + """ + import torch + import pickle + + result = torch.tensor([0., 1., 2., 3., 4.]) + torch.save(result, 'tensor.pt', pickle_protocol=4) + result = torch.load(f='tensor.pt', pickle_module=pickle, map_location=torch.device('cpu'), weights_only=False, mmap=None) + """ + ) + obj.run(pytorch_code, unsupport=True, reason="`mmap` is not supported in paddle") diff --git a/tests/test_max.py b/tests/test_max.py index 5debb5395..c947d9d32 100644 --- a/tests/test_max.py +++ b/tests/test_max.py @@ -150,3 +150,39 @@ def test_case_12(): """ ) obj.run(pytorch_code, ["result", "out"]) + + +def test_case_13(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, 2, 1], [3, 4, 6]]) + out = [torch.tensor(0), torch.tensor(1)] + result = torch.max(input=x, dim=1, keepdim=False, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_14(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1, 2, 1], [3, 4, 6]]) + out = [torch.tensor(0), torch.tensor(1)] + result = torch.max(input=x, keepdim=False, dim=1, out=out) + """ + ) + obj.run(pytorch_code, ["result", "out"]) + + +def test_case_15(): + pytorch_code = textwrap.dedent( + """ + import torch + other = torch.tensor([[1, 0, 3], [3, 4, 3]]) + out = torch.tensor([[1, 0, 3], [3, 4, 3]]) + result = torch.max(other=other, out=out, input=torch.tensor([[1, 2, 3], [3, 4, 6]])) + """ + ) + obj.run(pytorch_code, ["result", "out"]) diff --git a/tests/test_max_pool1d.py b/tests/test_max_pool1d.py index 9a5df19e8..60233a224 100644 --- a/tests/test_max_pool1d.py +++ b/tests/test_max_pool1d.py @@ -108,3 +108,29 @@ def test_case_7(): """ ) obj.run(pytorch_code, unsupport=True, reason="Not support dilation") + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[ 1.1524, 0.4714, 0.2857, 0.4586, 0.9876, 0.5487], + [-1.2533, -0.9829, -1.0981, 0.7655, 0.8541, 0.9873], + [ 0.1507, -1.1431, -2.0361, 0.2344, 0.5675, 0.1546]]]) + result = torch.max_pool1d(input, 5, 2, 2, 1, True) + """ + ) + obj.run(pytorch_code, unsupport=True, reason="Not support dilation") + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([[[ 1.1524, 0.4714, 0.2857, 0.4586, 0.9876, 0.5487], + [-1.2533, -0.9829, -1.0981, 0.7655, 0.8541, 0.9873], + [ 0.1507, -1.1431, -2.0361, 0.2344, 0.5675, 0.1546]]]) + result = torch.max_pool1d(input=input, padding=2, kernel_size=5, dilation=1, stride=2, ceil_mode=True) + """ + ) + obj.run(pytorch_code, unsupport=True, reason="Not support dilation") diff --git a/tests/test_nn_init_xavier_uniform_.py b/tests/test_nn_init_xavier_uniform_.py index 66872a2ae..f0e629f4d 100644 --- a/tests/test_nn_init_xavier_uniform_.py +++ b/tests/test_nn_init_xavier_uniform_.py @@ -53,3 +53,27 @@ def test_case_3(): """ ) obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + conv = torch.nn.Conv2d(3, 6, (3, 3)) + torch.nn.init.xavier_uniform_(conv.weight, 3.) + result = conv.weight + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + conv = torch.nn.Conv2d(3, 6, (3, 3)) + torch.nn.init.xavier_uniform_(gain=2., tensor=conv.weight) + result = conv.weight + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) diff --git a/tests/test_nn_utils_vector_to_parameters.py b/tests/test_nn_utils_vector_to_parameters.py index 0b6a08919..9237b408a 100644 --- a/tests/test_nn_utils_vector_to_parameters.py +++ b/tests/test_nn_utils_vector_to_parameters.py @@ -31,3 +31,31 @@ def test_case_1(): """ ) obj.run(pytorch_code, ["result", "b"], check_value=False) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + model = nn.Linear(10, 20) + a = torch.nn.utils.parameters_to_vector(model.parameters()) + b = torch.nn.utils.vector_to_parameters(vec=a, parameters=model.parameters()) + result = a.detach() + """ + ) + obj.run(pytorch_code, ["result", "b"], check_value=False) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + import torch.nn as nn + model = nn.Linear(10, 20) + a = torch.nn.utils.parameters_to_vector(model.parameters()) + b = torch.nn.utils.vector_to_parameters(parameters=model.parameters(), vec=a) + result = a.detach() + """ + ) + obj.run(pytorch_code, ["result", "b"], check_value=False) diff --git a/tests/test_nonzero.py b/tests/test_nonzero.py index ff3e7094d..70322fb64 100644 --- a/tests/test_nonzero.py +++ b/tests/test_nonzero.py @@ -86,3 +86,33 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0.6, 0.0, 0.0, 0.0], + [0.0, 0.4, 0.0, 0.0], + [0.0, 0.0, 1.2, 0.0], + [0.0, 0.0, 0.0,-0.4]]) + out = torch.tensor([1], dtype=torch.int64) + result = torch.nonzero(input=x, out=out, as_tuple=False) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0.6, 0.0, 0.0, 0.0], + [0.0, 0.4, 0.0, 0.0], + [0.0, 0.0, 1.2, 0.0], + [0.0, 0.0, 0.0,-0.4]]) + out = torch.tensor([1], dtype=torch.int64) + result = torch.nonzero(as_tuple=False, input=x, out=out) + """ + ) + obj.run(pytorch_code, ["result"])