Skip to content

Commit

Permalink
转换规则 No. 181-184 (#195)
Browse files Browse the repository at this point in the history
* Add tests

* Fix
  • Loading branch information
co63oc authored Jul 31, 2023
1 parent cbc21e1 commit 27ebdae
Show file tree
Hide file tree
Showing 27 changed files with 649 additions and 102 deletions.
25 changes: 25 additions & 0 deletions paconvert/api_alias_mapping.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,29 @@
{
"torch.distributions.bernoulli.Bernoulli": "torch.distributions.Bernoulli",
"torch.distributions.beta.Beta": "torch.distributions.Beta",
"torch.distributions.binomial.Binomial": "torch.distributions.Binomial",
"torch.distributions.categorical.Categorical": "torch.distributions.Categorical",
"torch.distributions.cauchy.Cauchy": "torch.distributions.Cauchy",
"torch.distributions.continuous_bernoulli.ContinuousBernoulli": "torch.distributions.ContinuousBernoulli",
"torch.distributions.dirichlet.Dirichlet": "torch.distributions.Dirichlet",
"torch.distributions.exp_family.ExponentialFamily": "torch.distributions.ExponentialFamily",
"torch.distributions.exponential.Exponential": "torch.distributions.Exponential",
"torch.distributions.geometric.Geometric": "torch.distributions.Geometric",
"torch.distributions.gumbel.Gumbel": "torch.distributions.Gumbel",
"torch.distributions.independent.Independent": "torch.distributions.Independent",
"torch.distributions.laplace.Laplace": "torch.distributions.Laplace",
"torch.distributions.log_normal.LogNormal": "torch.distributions.LogNormal",
"torch.distributions.multinomial.Multinomial": "torch.distributions.Multinomial",
"torch.distributions.normal.Normal": "torch.distributions.Normal",
"torch.distributions.one_hot_categorical.OneHotCategorical": "torch.distributions.OneHotCategorical",
"torch.distributions.transformed_distribution.TransformedDistribution": "torch.distributions.TransformedDistribution",
"torch.distributions.transforms.ComposeTransform": "torch.distributions.ComposeTransform",
"torch.distributions.transforms.SigmoidTransform": "torch.distributions.SigmoidTransform",
"torch.distributions.transforms.SoftmaxTransform": "torch.distributions.SoftmaxTransform",
"torch.distributions.transforms.StackTransform": "torch.distributions.StackTransform",
"torch.distributions.transforms.StickBreakingTransform": "torch.distributions.StickBreakingTransform",
"torch.distributions.transforms.Transform": "torch.distributions.Transform",
"torch.distributions.uniform.Uniform": "torch.distributions.Uniform",
"torch.nn.modules.GroupNorm": "torch.nn.GroupNorm",
"torch.nn.modules.activation.ReLU": "torch.nn.ReLU",
"torch.nn.modules.conv.Conv2d": "torch.nn.Conv2d",
Expand Down
129 changes: 93 additions & 36 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -3948,7 +3948,7 @@
"pg_options"
]
},
"torch.distributions.bernoulli.Bernoulli": {
"torch.distributions.Bernoulli": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Bernoulli",
"args_list": [
Expand All @@ -3963,7 +3963,7 @@
"logits"
]
},
"torch.distributions.beta.Beta": {
"torch.distributions.Beta": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Beta",
"args_list": [
Expand All @@ -3976,7 +3976,7 @@
"concentration0": "beta"
}
},
"torch.distributions.categorical.Categorical": {
"torch.distributions.Categorical": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Categorical",
"args_list": [
Expand All @@ -3991,7 +3991,7 @@
"probs"
]
},
"torch.distributions.cauchy.Cauchy": {
"torch.distributions.Cauchy": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Cauchy",
"args_list": [
Expand All @@ -4003,7 +4003,19 @@
"validate_args": ""
}
},
"torch.distributions.dirichlet.Dirichlet": {
"torch.distributions.ComposeTransform": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.ChainTransform",
"args_list": [
"parts",
"cache_size"
],
"kwargs_change": {
"parts": "transforms",
"cache_size": ""
}
},
"torch.distributions.Dirichlet": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Dirichlet",
"args_list": [
Expand All @@ -4014,7 +4026,7 @@
"validate_args": ""
}
},
"torch.distributions.exp_family.ExponentialFamily": {
"torch.distributions.ExponentialFamily": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.ExponentialFamily",
"args_list": [
Expand All @@ -4026,7 +4038,7 @@
"validate_args": ""
}
},
"torch.distributions.geometric.Geometric": {
"torch.distributions.Geometric": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Geometric",
"args_list": [
Expand All @@ -4041,7 +4053,7 @@
"logits"
]
},
"torch.distributions.gumbel.Gumbel": {
"torch.distributions.Gumbel": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Gumbel",
"args_list": [
Expand All @@ -4053,7 +4065,7 @@
"validate_args": ""
}
},
"torch.distributions.independent.Independent": {
"torch.distributions.Independent": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Independent",
"args_list": [
Expand All @@ -4067,27 +4079,7 @@
"validate_args": ""
}
},
"torch.distributions.kl.kl_divergence": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.kl_divergence",
"args_list": [
"p",
"q"
]
},
"torch.distributions.kl.register_kl": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.kl.register_kl",
"args_list": [
"type_p",
"type_q"
],
"kwargs_change": {
"type_p": "cls_p",
"type_q": "cls_q"
}
},
"torch.distributions.laplace.Laplace": {
"torch.distributions.Laplace": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Laplace",
"args_list": [
Expand All @@ -4099,7 +4091,7 @@
"validate_args": ""
}
},
"torch.distributions.log_normal.LogNormal": {
"torch.distributions.LogNormal": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.LogNormal",
"args_list": [
Expand All @@ -4111,7 +4103,7 @@
"validate_args": ""
}
},
"torch.distributions.multinomial.Multinomial": {
"torch.distributions.Multinomial": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Multinomial",
"args_list": [
Expand All @@ -4127,7 +4119,7 @@
"logits"
]
},
"torch.distributions.normal.Normal": {
"torch.distributions.Normal": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Normal",
"args_list": [
Expand All @@ -4139,7 +4131,17 @@
"validate_args": ""
}
},
"torch.distributions.transforms.SoftmaxTransform": {
"torch.distributions.SigmoidTransform": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.SigmoidTransform",
"args_list": [
"cache_size"
],
"kwargs_change": {
"cache_size": ""
}
},
"torch.distributions.SoftmaxTransform": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.SoftmaxTransform",
"args_list": [
Expand All @@ -4149,7 +4151,7 @@
"cache_size": ""
}
},
"torch.distributions.transforms.StackTransform": {
"torch.distributions.StackTransform": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.StackTransform",
"args_list": [
Expand All @@ -4163,7 +4165,7 @@
"cache_size": ""
}
},
"torch.distributions.transforms.StickBreakingTransform": {
"torch.distributions.StickBreakingTransform": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.StickBreakingTransform",
"args_list": [
Expand All @@ -4173,6 +4175,61 @@
"cache_size": ""
}
},
"torch.distributions.Transform": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Transform",
"args_list": [
"cache_size"
],
"kwargs_change": {
"cache_size": ""
}
},
"torch.distributions.TransformedDistribution": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.TransformedDistribution",
"args_list": [
"base_distribution",
"transforms",
"validate_args"
],
"kwargs_change": {
"base_distribution": "base",
"validate_args": ""
}
},
"torch.distributions.Uniform": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Uniform",
"args_list": [
"low",
"high",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
}
},
"torch.distributions.kl.kl_divergence": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.kl_divergence",
"args_list": [
"p",
"q"
]
},
"torch.distributions.kl.register_kl": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.kl.register_kl",
"args_list": [
"type_p",
"type_q"
],
"kwargs_change": {
"type_p": "cls_p",
"type_q": "cls_q"
}
},
"torch.div": {
"Matcher": "DivMatcher",
"args_list": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

from apibase import APIBase

obj = APIBase("torch.distributions.bernoulli.Bernoulli")
obj = APIBase("torch.distributions.Bernoulli")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.bernoulli.Bernoulli(torch.tensor([0.3]))
m = torch.distributions.Bernoulli(torch.tensor([0.3]))
result = m.sample([100])
"""
)
Expand All @@ -34,7 +34,7 @@ def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.bernoulli.Bernoulli(probs=torch.tensor([0.3]), logits=None)
m = torch.distributions.Bernoulli(probs=torch.tensor([0.3]), logits=None)
result = m.sample([100])
"""
)
Expand All @@ -48,6 +48,17 @@ def test_case_2():


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Bernoulli(0.3, validate_args=False)
result = m.sample([100])
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,22 @@

from apibase import APIBase

obj = APIBase("torch.distributions.beta.Beta")
obj = APIBase("torch.distributions.Beta")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Beta(torch.tensor([0.5]), torch.tensor([0.5]))
n = torch.distributions.Beta(torch.tensor([0.3]), torch.tensor([0.7]))
result = torch.distributions.kl.kl_divergence(m, n)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

from apibase import APIBase

obj = APIBase("torch.distributions.binomial.Binomial")
obj = APIBase("torch.distributions.Binomial")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.binomial.Binomial(100, torch.tensor([0, .2, .8, 1]))
m = torch.distributions.Binomial(100, torch.tensor([0, .2, .8, 1]))
result = m.sample()
"""
)
Expand All @@ -39,7 +39,7 @@ def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.binomial.Binomial(1, probs=torch.tensor([0.3]), logits=None)
m = torch.distributions.Binomial(1, probs=torch.tensor([0.3]), logits=None)
result = m.sample()
"""
)
Expand All @@ -52,6 +52,22 @@ def test_case_2():


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
m = torch.distributions.Binomial(1, 0.3, validate_args=False)
result = m.sample()
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
)


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
Expand Down
Loading

0 comments on commit 27ebdae

Please sign in to comment.