Skip to content

Commit

Permalink
单测case覆盖2 (#455)
Browse files Browse the repository at this point in the history
* case add2

* case add2

* fix MaxMinMatcher

* fix testcase
  • Loading branch information
Xuxuanang authored Aug 23, 2024
1 parent 75a684c commit 30d04ee
Show file tree
Hide file tree
Showing 20 changed files with 839 additions and 2 deletions.
4 changes: 2 additions & 2 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -8218,7 +8218,7 @@
"paddle_api": "paddle.linalg.matrix_rank",
"min_input_args": 1,
"args_list": [
"A",
"input",
"tol",
"hermitian",
"*",
Expand All @@ -8228,7 +8228,7 @@
"out"
],
"kwargs_change": {
"A": "x",
"input": "x",
"atol": "tol",
"rtol": ""
}
Expand Down
1 change: 1 addition & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,7 @@ def get_paddle_nodes(self, args, kwargs):
else:
code = "{}({})".format(paddle_api, self.kwargs_to_str(new_kwargs))

self.api_mapping["args_list"] = ["input", "dim", "keepdim", "*", "out"]
return ast.parse(code).body

# the case of one tensor
Expand Down
62 changes: 62 additions & 0 deletions tests/test_autograd_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,65 @@ def test_case_7():
"""
)
obj.run(pytorch_code, ["result"])


def test_case_8():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1.1, 2.2, 3.3], requires_grad=True)
z = torch.tensor([1.1, 2.2, 3.3], requires_grad=True)
grad = torch.tensor(2.0)
y = x * x + z
result = torch.autograd.grad(outputs=[y.sum()], inputs=[x, z], grad_outputs=grad, retain_graph=True,
create_graph=False, only_inputs=True, allow_unused=True, is_grads_batched=False, materialize_grads=False)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle dose not support 'only_inputs' now!",
)


def test_case_9():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1.1, 2.2, 3.3], requires_grad=True)
z = torch.tensor([1.1, 2.2, 3.3], requires_grad=True)
grad = torch.tensor(2.0)
y = x * x + z
result = torch.autograd.grad([y.sum()], [x, z], grad, True, False, True, True, False, False)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle dose not support 'only_inputs' now!",
)


def test_case_10():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1.1, 2.2, 3.3], requires_grad=True)
z = torch.tensor([1.1, 2.2, 3.3], requires_grad=True)
grad = torch.tensor(2.0)
y = x * x + z
result = torch.autograd.grad(outputs=[y.sum()], inputs=[x, z], retain_graph=True, allow_unused=True,
create_graph=False, only_inputs=True, is_grads_batched=False, grad_outputs=grad, materialize_grads=False)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle dose not support 'only_inputs' now!",
)
40 changes: 40 additions & 0 deletions tests/test_bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,43 @@ def test_case_8():
"""
)
obj.run(pytorch_code, ["result", "out"])


def test_case_9():
pytorch_code = textwrap.dedent(
"""
import torch
a = torch.ones(3, 3)
out = torch.zeros(3, 3)
result = torch.bernoulli(generator=torch.Generator(), input=a, out=out)
"""
)
obj.run(pytorch_code, ["result", "out"])


def test_case_10():
pytorch_code = textwrap.dedent(
"""
import torch
a = torch.ones(3, 3)
out = torch.zeros(3, 3)
result = torch.bernoulli(input=a, generator=torch.Generator(), out=out)
"""
)
obj.run(pytorch_code, ["result", "out"])


def test_case_11():
pytorch_code = textwrap.dedent(
"""
import torch
a = torch.rand(3, 3)
result = torch.bernoulli(input=a, p=0.0, generator=torch.Generator())
"""
)
obj.run(
pytorch_code,
["a", "result"],
unsupport=True,
reason="paddle not support parameter 'p' ",
)
11 changes: 11 additions & 0 deletions tests/test_chain_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,14 @@ def test_case_4():
unsupport=True,
reason="paddle does not support variable parameter",
)


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
v = torch.tensor([[3., 6, 9], [1, 3, 5], [2, 2, 2]])
result = torch.chain_matmul(v)
"""
)
obj.run(pytorch_code, ["result"])
34 changes: 34 additions & 0 deletions tests/test_cumulative_trapezoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,37 @@ def test_case_6():
"""
)
obj.run(pytorch_code, ["result"])


def test_case_7():
pytorch_code = textwrap.dedent(
"""
import torch
y = torch.tensor([1, 1, 1, 0, 1]).type(torch.float32)
result = torch.cumulative_trapezoid(y=y, dx=0.05, dim=0)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_8():
pytorch_code = textwrap.dedent(
"""
import torch
y = torch.tensor([1, 1, 1, 0, 1]).type(torch.float32)
result = torch.cumulative_trapezoid(y, dx=0.05, dim=0)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_9():
pytorch_code = textwrap.dedent(
"""
import torch
y = torch.tensor([1, 1, 1, 0, 1]).type(torch.float32)
x = torch.tensor([1, 2, 3, 0, 1]).type(torch.float32)
result = torch.cumulative_trapezoid(dim=0, y=y, x=x)
"""
)
obj.run(pytorch_code, ["result"])
68 changes: 68 additions & 0 deletions tests/test_distributions_Bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,71 @@ def test_case_6():
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_7():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Bernoulli(probs=torch.tensor([0.3]), logits=None, validate_args=False)
result = m.sample([100])
"""
)
obj.run(
pytorch_code,
["result"],
check_value=False,
unsupport=True,
reason="paddle does not support logits temporarily",
)


def test_case_8():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Bernoulli(torch.tensor([0.3]), None, False)
result = m.sample([100])
"""
)
obj.run(
pytorch_code,
["result"],
check_value=False,
unsupport=True,
reason="paddle does not support logits temporarily",
)


def test_case_9():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Bernoulli(probs=torch.tensor([0.3]), validate_args=False, logits=None)
result = m.sample([100])
"""
)
obj.run(
pytorch_code,
["result"],
check_value=False,
unsupport=True,
reason="paddle does not support logits temporarily",
)


def test_case_10():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Bernoulli(probs=None, validate_args=False, logits=3.5)
result = m.sample([100])
"""
)
obj.run(
pytorch_code,
["result"],
check_value=False,
unsupport=True,
reason="paddle does not support logits temporarily",
)
68 changes: 68 additions & 0 deletions tests/test_distributions_Categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,71 @@ def test_case_5():
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_6():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Categorical(probs=None, logits=torch.tensor([0.25, 0.25, 0.25, 0.25]), validate_args=False)
result = m.sample([1])
"""
)
obj.run(
pytorch_code,
["result"],
check_value=False,
unsupport=True,
reason="paddle does not support probs temporarily",
)


def test_case_7():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Categorical(None, torch.tensor([0.25, 0.25, 0.25, 0.25]), False)
result = m.sample([1])
"""
)
obj.run(
pytorch_code,
["result"],
check_value=False,
unsupport=True,
reason="paddle does not support probs temporarily",
)


def test_case_8():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Categorical(probs=None, validate_args=False, logits=torch.tensor([0.25, 0.25, 0.25, 0.25]))
result = m.sample([1])
"""
)
obj.run(
pytorch_code,
["result"],
check_value=False,
unsupport=True,
reason="paddle does not support probs temporarily",
)


def test_case_9():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Categorical(probs=torch.tensor([0.25, 0.25, 0.25, 0.25]), validate_args=False,logits=None)
result = m.sample([1])
"""
)
obj.run(
pytorch_code,
["result"],
check_value=False,
unsupport=True,
reason="paddle does not support probs temporarily",
)
Loading

0 comments on commit 30d04ee

Please sign in to comment.