Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

单测case覆盖2 #455

Merged
merged 4 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -8214,7 +8214,7 @@
"paddle_api": "paddle.linalg.matrix_rank",
"min_input_args": 1,
"args_list": [
"A",
"input",
"tol",
"hermitian",
"*",
Expand All @@ -8224,7 +8224,7 @@
"out"
],
"kwargs_change": {
"A": "x",
"input": "x",
"atol": "tol",
"rtol": ""
}
Expand Down
1 change: 1 addition & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,7 @@ def get_paddle_nodes(self, args, kwargs):
else:
code = "{}({})".format(paddle_api, self.kwargs_to_str(new_kwargs))

self.api_mapping["args_list"] = ["input", "dim", "keepdim", "*", "out"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个已经组装代码完成了,改这个应该没什么意义吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个如果不改的话会影响到后面的testcase。

return ast.parse(code).body

# the case of one tensor
Expand Down
62 changes: 62 additions & 0 deletions tests/test_autograd_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,65 @@ def test_case_7():
"""
)
obj.run(pytorch_code, ["result"])


def test_case_8():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1.1, 2.2, 3.3], requires_grad=True)
z = torch.tensor([1.1, 2.2, 3.3], requires_grad=True)
grad = torch.tensor(2.0)
y = x * x + z

result = torch.autograd.grad(outputs=[y.sum()], inputs=[x, z], grad_outputs=grad, retain_graph=True,
create_graph=False, only_inputs=True, allow_unused=True, is_grads_batched=False, materialize_grads=False)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle dose not support 'only_inputs' now!",
)


def test_case_9():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1.1, 2.2, 3.3], requires_grad=True)
z = torch.tensor([1.1, 2.2, 3.3], requires_grad=True)
grad = torch.tensor(2.0)
y = x * x + z

result = torch.autograd.grad([y.sum()], [x, z], grad, True, False, True, True, False, False)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle dose not support 'only_inputs' now!",
)


def test_case_10():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1.1, 2.2, 3.3], requires_grad=True)
z = torch.tensor([1.1, 2.2, 3.3], requires_grad=True)
grad = torch.tensor(2.0)
y = x * x + z

result = torch.autograd.grad(outputs=[y.sum()], inputs=[x, z], retain_graph=True, allow_unused=True,
create_graph=False, only_inputs=True, is_grads_batched=False, grad_outputs=grad, materialize_grads=False)
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle dose not support 'only_inputs' now!",
)
40 changes: 40 additions & 0 deletions tests/test_bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,43 @@ def test_case_8():
"""
)
obj.run(pytorch_code, ["result", "out"])


def test_case_9():
pytorch_code = textwrap.dedent(
"""
import torch
a = torch.ones(3, 3)
out = torch.zeros(3, 3)
result = torch.bernoulli(generator=torch.Generator(), input=a, out=out)
"""
)
obj.run(pytorch_code, ["result", "out"])


def test_case_10():
pytorch_code = textwrap.dedent(
"""
import torch
a = torch.ones(3, 3)
out = torch.zeros(3, 3)
result = torch.bernoulli(input=a, generator=torch.Generator(), out=out)
"""
)
obj.run(pytorch_code, ["result", "out"])


def test_case_11():
pytorch_code = textwrap.dedent(
"""
import torch
a = torch.rand(3, 3)
result = torch.bernoulli(input=a, p=0.0, generator=torch.Generator())
"""
)
obj.run(
pytorch_code,
["a", "result"],
unsupport=True,
reason="paddle not support parameter 'p' ",
)
11 changes: 11 additions & 0 deletions tests/test_chain_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,14 @@ def test_case_4():
unsupport=True,
reason="paddle does not support variable parameter",
)


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
v = torch.tensor([[3., 6, 9], [1, 3, 5], [2, 2, 2]])
result = torch.chain_matmul(v)
"""
)
obj.run(pytorch_code, ["result"])
34 changes: 34 additions & 0 deletions tests/test_cumulative_trapezoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,37 @@ def test_case_6():
"""
)
obj.run(pytorch_code, ["result"])


def test_case_7():
pytorch_code = textwrap.dedent(
"""
import torch
y = torch.tensor([1, 1, 1, 0, 1]).type(torch.float32)
result = torch.cumulative_trapezoid(y=y, dx=0.05, dim=0)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_8():
pytorch_code = textwrap.dedent(
"""
import torch
y = torch.tensor([1, 1, 1, 0, 1]).type(torch.float32)
result = torch.cumulative_trapezoid(y, dx=0.05, dim=0)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_9():
pytorch_code = textwrap.dedent(
"""
import torch
y = torch.tensor([1, 1, 1, 0, 1]).type(torch.float32)
x = torch.tensor([1, 2, 3, 0, 1]).type(torch.float32)
result = torch.cumulative_trapezoid(dim=0, y=y, x=x)
"""
)
obj.run(pytorch_code, ["result"])
68 changes: 68 additions & 0 deletions tests/test_distributions_Bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,71 @@ def test_case_6():
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_7():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Bernoulli(probs=torch.tensor([0.3]), logits=None, validate_args=False)
result = m.sample([100])
"""
)
obj.run(
pytorch_code,
["result"],
check_value=False,
unsupport=True,
reason="paddle does not support logits temporarily",
)


def test_case_8():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Bernoulli(torch.tensor([0.3]), None, False)
result = m.sample([100])
"""
)
obj.run(
pytorch_code,
["result"],
check_value=False,
unsupport=True,
reason="paddle does not support logits temporarily",
)


def test_case_9():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Bernoulli(probs=torch.tensor([0.3]), validate_args=False, logits=None)
Xuxuanang marked this conversation as resolved.
Show resolved Hide resolved
result = m.sample([100])
"""
)
obj.run(
pytorch_code,
["result"],
check_value=False,
unsupport=True,
reason="paddle does not support logits temporarily",
)


def test_case_10():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Bernoulli(probs=None, validate_args=False, logits=3.5)
result = m.sample([100])
"""
)
obj.run(
pytorch_code,
["result"],
check_value=False,
unsupport=True,
reason="paddle does not support logits temporarily",
)
68 changes: 68 additions & 0 deletions tests/test_distributions_Categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,71 @@ def test_case_5():
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_6():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Categorical(probs=None, logits=torch.tensor([0.25, 0.25, 0.25, 0.25]), validate_args=False)
result = m.sample([1])
"""
)
obj.run(
pytorch_code,
["result"],
check_value=False,
unsupport=True,
reason="paddle does not support probs temporarily",
)


def test_case_7():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Categorical(None, torch.tensor([0.25, 0.25, 0.25, 0.25]), False)
result = m.sample([1])
"""
)
obj.run(
pytorch_code,
["result"],
check_value=False,
unsupport=True,
reason="paddle does not support probs temporarily",
)


def test_case_8():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Categorical(probs=None, validate_args=False, logits=torch.tensor([0.25, 0.25, 0.25, 0.25]))
Xuxuanang marked this conversation as resolved.
Show resolved Hide resolved
result = m.sample([1])
"""
)
obj.run(
pytorch_code,
["result"],
check_value=False,
unsupport=True,
reason="paddle does not support probs temporarily",
)


def test_case_9():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Categorical(probs=torch.tensor([0.25, 0.25, 0.25, 0.25]), validate_args=False,logits=None)
result = m.sample([1])
"""
)
obj.run(
pytorch_code,
["result"],
check_value=False,
unsupport=True,
reason="paddle does not support probs temporarily",
)
Loading