From 27ebdaebb387f508f2aabb380de62d9d8bd00bfc Mon Sep 17 00:00:00 2001 From: co63oc Date: Mon, 31 Jul 2023 16:57:24 +0800 Subject: [PATCH] =?UTF-8?q?=E8=BD=AC=E6=8D=A2=E8=A7=84=E5=88=99=20No.=2018?= =?UTF-8?q?1-184=20(#195)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add tests * Fix --- paconvert/api_alias_mapping.json | 25 ++++ paconvert/api_mapping.json | 129 +++++++++++++----- ...lli.py => test_distributions_Bernoulli.py} | 17 ++- ...eta_Beta.py => test_distributions_Beta.py} | 14 +- ...mial.py => test_distributions_Binomial.py} | 22 ++- ...l.py => test_distributions_Categorical.py} | 17 ++- ...Cauchy.py => test_distributions_Cauchy.py} | 6 +- tests/test_distributions_ComposeTransform.py | 71 ++++++++++ ...test_distributions_ContinuousBernoulli.py} | 23 +++- ...let.py => test_distributions_Dirichlet.py} | 6 +- ...l.py => test_distributions_Exponential.py} | 23 +++- ...> test_distributions_ExponentialFamily.py} | 16 ++- ...ric.py => test_distributions_Geometric.py} | 17 ++- ...Gumbel.py => test_distributions_Gumbel.py} | 6 +- ...t.py => test_distributions_Independent.py} | 24 +++- ...place.py => test_distributions_Laplace.py} | 6 +- ...mal.py => test_distributions_LogNormal.py} | 6 +- ...l.py => test_distributions_Multinomial.py} | 17 ++- ...Normal.py => test_distributions_Normal.py} | 6 +- ...> test_distributions_OneHotCategorical.py} | 6 +- ...=> test_distributions_SigmoidTransform.py} | 17 +-- ...=> test_distributions_SoftmaxTransform.py} | 16 ++- tests/test_distributions_StackTransform.py | 71 ++++++++++ ...t_distributions_StickBreakingTransform.py} | 16 ++- tests/test_distributions_Transform.py | 64 +++++++++ ...t_distributions_TransformedDistribution.py | 58 ++++++++ tests/test_distributions_Uniform.py | 52 +++++++ 27 files changed, 649 insertions(+), 102 deletions(-) rename tests/{test_distributions_bernoulli_Bernoulli.py => test_distributions_Bernoulli.py} (75%) rename tests/{test_distributions_beta_Beta.py => test_distributions_Beta.py} (71%) rename tests/{test_distributions_binomial_Binomial.py => test_distributions_Binomial.py} (73%) rename tests/{test_distributions_categorical_Categorical.py => test_distributions_Categorical.py} (74%) rename tests/{test_distributions_cauchy_Cauchy.py => test_distributions_Cauchy.py} (85%) create mode 100644 tests/test_distributions_ComposeTransform.py rename tests/{test_distributions_continuous_bernoulli_ContinuousBernoulli.py => test_distributions_ContinuousBernoulli.py} (73%) rename tests/{test_distributions_dirichlet_Dirichlet.py => test_distributions_Dirichlet.py} (86%) rename tests/{test_distributions_exponential_Exponential.py => test_distributions_Exponential.py} (74%) rename tests/{test_distributions_exp_family_ExponentialFamily.py => test_distributions_ExponentialFamily.py} (75%) rename tests/{test_distributions_geometric_Geometric.py => test_distributions_Geometric.py} (75%) rename tests/{test_distributions_gumbel_Gumbel.py => test_distributions_Gumbel.py} (85%) rename tests/{test_distributions_independent_Independent.py => test_distributions_Independent.py} (60%) rename tests/{test_distributions_laplace_Laplace.py => test_distributions_Laplace.py} (84%) rename tests/{test_distributions_log_normal_LogNormal.py => test_distributions_LogNormal.py} (84%) rename tests/{test_distributions_multinomial_Multinomial.py => test_distributions_Multinomial.py} (74%) rename tests/{test_distributions_normal_Normal.py => test_distributions_Normal.py} (85%) rename tests/{test_distributions_one_hot_categorical_OneHotCategorical.py => test_distributions_OneHotCategorical.py} (83%) rename tests/{test_distributions_transforms_StackTransform.py => test_distributions_SigmoidTransform.py} (67%) rename tests/{test_distributions_transforms_SoftmaxTransform.py => test_distributions_SoftmaxTransform.py} (75%) create mode 100644 tests/test_distributions_StackTransform.py rename tests/{test_distributions_transforms_StickBreakingTransform.py => test_distributions_StickBreakingTransform.py} (74%) create mode 100644 tests/test_distributions_Transform.py create mode 100644 tests/test_distributions_TransformedDistribution.py create mode 100644 tests/test_distributions_Uniform.py diff --git a/paconvert/api_alias_mapping.json b/paconvert/api_alias_mapping.json index 199106dc4..e52062a78 100644 --- a/paconvert/api_alias_mapping.json +++ b/paconvert/api_alias_mapping.json @@ -1,4 +1,29 @@ { + "torch.distributions.bernoulli.Bernoulli": "torch.distributions.Bernoulli", + "torch.distributions.beta.Beta": "torch.distributions.Beta", + "torch.distributions.binomial.Binomial": "torch.distributions.Binomial", + "torch.distributions.categorical.Categorical": "torch.distributions.Categorical", + "torch.distributions.cauchy.Cauchy": "torch.distributions.Cauchy", + "torch.distributions.continuous_bernoulli.ContinuousBernoulli": "torch.distributions.ContinuousBernoulli", + "torch.distributions.dirichlet.Dirichlet": "torch.distributions.Dirichlet", + "torch.distributions.exp_family.ExponentialFamily": "torch.distributions.ExponentialFamily", + "torch.distributions.exponential.Exponential": "torch.distributions.Exponential", + "torch.distributions.geometric.Geometric": "torch.distributions.Geometric", + "torch.distributions.gumbel.Gumbel": "torch.distributions.Gumbel", + "torch.distributions.independent.Independent": "torch.distributions.Independent", + "torch.distributions.laplace.Laplace": "torch.distributions.Laplace", + "torch.distributions.log_normal.LogNormal": "torch.distributions.LogNormal", + "torch.distributions.multinomial.Multinomial": "torch.distributions.Multinomial", + "torch.distributions.normal.Normal": "torch.distributions.Normal", + "torch.distributions.one_hot_categorical.OneHotCategorical": "torch.distributions.OneHotCategorical", + "torch.distributions.transformed_distribution.TransformedDistribution": "torch.distributions.TransformedDistribution", + "torch.distributions.transforms.ComposeTransform": "torch.distributions.ComposeTransform", + "torch.distributions.transforms.SigmoidTransform": "torch.distributions.SigmoidTransform", + "torch.distributions.transforms.SoftmaxTransform": "torch.distributions.SoftmaxTransform", + "torch.distributions.transforms.StackTransform": "torch.distributions.StackTransform", + "torch.distributions.transforms.StickBreakingTransform": "torch.distributions.StickBreakingTransform", + "torch.distributions.transforms.Transform": "torch.distributions.Transform", + "torch.distributions.uniform.Uniform": "torch.distributions.Uniform", "torch.nn.modules.GroupNorm": "torch.nn.GroupNorm", "torch.nn.modules.activation.ReLU": "torch.nn.ReLU", "torch.nn.modules.conv.Conv2d": "torch.nn.Conv2d", diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index a22f5aeda..658dbbe00 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -3948,7 +3948,7 @@ "pg_options" ] }, - "torch.distributions.bernoulli.Bernoulli": { + "torch.distributions.Bernoulli": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.Bernoulli", "args_list": [ @@ -3963,7 +3963,7 @@ "logits" ] }, - "torch.distributions.beta.Beta": { + "torch.distributions.Beta": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.Beta", "args_list": [ @@ -3976,7 +3976,7 @@ "concentration0": "beta" } }, - "torch.distributions.categorical.Categorical": { + "torch.distributions.Categorical": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.Categorical", "args_list": [ @@ -3991,7 +3991,7 @@ "probs" ] }, - "torch.distributions.cauchy.Cauchy": { + "torch.distributions.Cauchy": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.Cauchy", "args_list": [ @@ -4003,7 +4003,19 @@ "validate_args": "" } }, - "torch.distributions.dirichlet.Dirichlet": { + "torch.distributions.ComposeTransform": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.distribution.ChainTransform", + "args_list": [ + "parts", + "cache_size" + ], + "kwargs_change": { + "parts": "transforms", + "cache_size": "" + } + }, + "torch.distributions.Dirichlet": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.Dirichlet", "args_list": [ @@ -4014,7 +4026,7 @@ "validate_args": "" } }, - "torch.distributions.exp_family.ExponentialFamily": { + "torch.distributions.ExponentialFamily": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.ExponentialFamily", "args_list": [ @@ -4026,7 +4038,7 @@ "validate_args": "" } }, - "torch.distributions.geometric.Geometric": { + "torch.distributions.Geometric": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.Geometric", "args_list": [ @@ -4041,7 +4053,7 @@ "logits" ] }, - "torch.distributions.gumbel.Gumbel": { + "torch.distributions.Gumbel": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.Gumbel", "args_list": [ @@ -4053,7 +4065,7 @@ "validate_args": "" } }, - "torch.distributions.independent.Independent": { + "torch.distributions.Independent": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.Independent", "args_list": [ @@ -4067,27 +4079,7 @@ "validate_args": "" } }, - "torch.distributions.kl.kl_divergence": { - "Matcher": "GenericMatcher", - "paddle_api": "paddle.distribution.kl_divergence", - "args_list": [ - "p", - "q" - ] - }, - "torch.distributions.kl.register_kl": { - "Matcher": "GenericMatcher", - "paddle_api": "paddle.distribution.kl.register_kl", - "args_list": [ - "type_p", - "type_q" - ], - "kwargs_change": { - "type_p": "cls_p", - "type_q": "cls_q" - } - }, - "torch.distributions.laplace.Laplace": { + "torch.distributions.Laplace": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.Laplace", "args_list": [ @@ -4099,7 +4091,7 @@ "validate_args": "" } }, - "torch.distributions.log_normal.LogNormal": { + "torch.distributions.LogNormal": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.LogNormal", "args_list": [ @@ -4111,7 +4103,7 @@ "validate_args": "" } }, - "torch.distributions.multinomial.Multinomial": { + "torch.distributions.Multinomial": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.Multinomial", "args_list": [ @@ -4127,7 +4119,7 @@ "logits" ] }, - "torch.distributions.normal.Normal": { + "torch.distributions.Normal": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.Normal", "args_list": [ @@ -4139,7 +4131,17 @@ "validate_args": "" } }, - "torch.distributions.transforms.SoftmaxTransform": { + "torch.distributions.SigmoidTransform": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.distribution.SigmoidTransform", + "args_list": [ + "cache_size" + ], + "kwargs_change": { + "cache_size": "" + } + }, + "torch.distributions.SoftmaxTransform": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.SoftmaxTransform", "args_list": [ @@ -4149,7 +4151,7 @@ "cache_size": "" } }, - "torch.distributions.transforms.StackTransform": { + "torch.distributions.StackTransform": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.StackTransform", "args_list": [ @@ -4163,7 +4165,7 @@ "cache_size": "" } }, - "torch.distributions.transforms.StickBreakingTransform": { + "torch.distributions.StickBreakingTransform": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.StickBreakingTransform", "args_list": [ @@ -4173,6 +4175,61 @@ "cache_size": "" } }, + "torch.distributions.Transform": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.distribution.Transform", + "args_list": [ + "cache_size" + ], + "kwargs_change": { + "cache_size": "" + } + }, + "torch.distributions.TransformedDistribution": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.distribution.TransformedDistribution", + "args_list": [ + "base_distribution", + "transforms", + "validate_args" + ], + "kwargs_change": { + "base_distribution": "base", + "validate_args": "" + } + }, + "torch.distributions.Uniform": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.distribution.Uniform", + "args_list": [ + "low", + "high", + "validate_args" + ], + "kwargs_change": { + "validate_args": "" + } + }, + "torch.distributions.kl.kl_divergence": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.distribution.kl_divergence", + "args_list": [ + "p", + "q" + ] + }, + "torch.distributions.kl.register_kl": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.distribution.kl.register_kl", + "args_list": [ + "type_p", + "type_q" + ], + "kwargs_change": { + "type_p": "cls_p", + "type_q": "cls_q" + } + }, "torch.div": { "Matcher": "DivMatcher", "args_list": [ diff --git a/tests/test_distributions_bernoulli_Bernoulli.py b/tests/test_distributions_Bernoulli.py similarity index 75% rename from tests/test_distributions_bernoulli_Bernoulli.py rename to tests/test_distributions_Bernoulli.py index 0e069dcd2..23a4b876b 100644 --- a/tests/test_distributions_bernoulli_Bernoulli.py +++ b/tests/test_distributions_Bernoulli.py @@ -16,14 +16,14 @@ from apibase import APIBase -obj = APIBase("torch.distributions.bernoulli.Bernoulli") +obj = APIBase("torch.distributions.Bernoulli") def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.bernoulli.Bernoulli(torch.tensor([0.3])) + m = torch.distributions.Bernoulli(torch.tensor([0.3])) result = m.sample([100]) """ ) @@ -34,7 +34,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.bernoulli.Bernoulli(probs=torch.tensor([0.3]), logits=None) + m = torch.distributions.Bernoulli(probs=torch.tensor([0.3]), logits=None) result = m.sample([100]) """ ) @@ -48,6 +48,17 @@ def test_case_2(): def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Bernoulli(0.3, validate_args=False) + result = m.sample([100]) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_4(): pytorch_code = textwrap.dedent( """ import torch diff --git a/tests/test_distributions_beta_Beta.py b/tests/test_distributions_Beta.py similarity index 71% rename from tests/test_distributions_beta_Beta.py rename to tests/test_distributions_Beta.py index a5cca0421..b7a0308be 100644 --- a/tests/test_distributions_beta_Beta.py +++ b/tests/test_distributions_Beta.py @@ -16,10 +16,22 @@ from apibase import APIBase -obj = APIBase("torch.distributions.beta.Beta") +obj = APIBase("torch.distributions.Beta") def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Beta(torch.tensor([0.5]), torch.tensor([0.5])) + n = torch.distributions.Beta(torch.tensor([0.3]), torch.tensor([0.7])) + result = torch.distributions.kl.kl_divergence(m, n) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): pytorch_code = textwrap.dedent( """ import torch diff --git a/tests/test_distributions_binomial_Binomial.py b/tests/test_distributions_Binomial.py similarity index 73% rename from tests/test_distributions_binomial_Binomial.py rename to tests/test_distributions_Binomial.py index eda0eddb4..de3a077ee 100644 --- a/tests/test_distributions_binomial_Binomial.py +++ b/tests/test_distributions_Binomial.py @@ -16,14 +16,14 @@ from apibase import APIBase -obj = APIBase("torch.distributions.binomial.Binomial") +obj = APIBase("torch.distributions.Binomial") def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.binomial.Binomial(100, torch.tensor([0, .2, .8, 1])) + m = torch.distributions.Binomial(100, torch.tensor([0, .2, .8, 1])) result = m.sample() """ ) @@ -39,7 +39,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.binomial.Binomial(1, probs=torch.tensor([0.3]), logits=None) + m = torch.distributions.Binomial(1, probs=torch.tensor([0.3]), logits=None) result = m.sample() """ ) @@ -52,6 +52,22 @@ def test_case_2(): def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Binomial(1, 0.3, validate_args=False) + result = m.sample() + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle does not support this function temporarily", + ) + + +def test_case_4(): pytorch_code = textwrap.dedent( """ import torch diff --git a/tests/test_distributions_categorical_Categorical.py b/tests/test_distributions_Categorical.py similarity index 74% rename from tests/test_distributions_categorical_Categorical.py rename to tests/test_distributions_Categorical.py index be757742e..e58b42423 100644 --- a/tests/test_distributions_categorical_Categorical.py +++ b/tests/test_distributions_Categorical.py @@ -16,14 +16,14 @@ from apibase import APIBase -obj = APIBase("torch.distributions.categorical.Categorical") +obj = APIBase("torch.distributions.Categorical") def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.categorical.Categorical(logits=torch.tensor([0.25, 0.25, 0.25, 0.25])) + m = torch.distributions.Categorical(logits=torch.tensor([0.25, 0.25, 0.25, 0.25])) result = m.sample([1]) """ ) @@ -34,7 +34,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.categorical.Categorical(probs=None, logits=torch.tensor([0.25, 0.25, 0.25, 0.25])) + m = torch.distributions.Categorical(probs=None, logits=torch.tensor([0.25, 0.25, 0.25, 0.25])) result = m.sample([1]) """ ) @@ -48,6 +48,17 @@ def test_case_2(): def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Categorical(logits=torch.tensor([0.25, 0.25, 0.25, 0.25]), validate_args=False) + result = m.sample([1]) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_4(): pytorch_code = textwrap.dedent( """ import torch diff --git a/tests/test_distributions_cauchy_Cauchy.py b/tests/test_distributions_Cauchy.py similarity index 85% rename from tests/test_distributions_cauchy_Cauchy.py rename to tests/test_distributions_Cauchy.py index 892ba750c..2e755d4d2 100644 --- a/tests/test_distributions_cauchy_Cauchy.py +++ b/tests/test_distributions_Cauchy.py @@ -16,14 +16,14 @@ from apibase import APIBase -obj = APIBase("torch.distributions.cauchy.Cauchy") +obj = APIBase("torch.distributions.Cauchy") def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.cauchy.Cauchy(torch.tensor([0.0]), torch.tensor([1.0])) + m = torch.distributions.Cauchy(torch.tensor([0.0]), torch.tensor([1.0])) result = m.sample([1]) """ ) @@ -34,7 +34,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.cauchy.Cauchy(loc=torch.tensor([0.0]), scale=torch.tensor([1.0])) + m = torch.distributions.Cauchy(loc=torch.tensor([0.0]), scale=torch.tensor([1.0]), validate_args=False) result = m.sample([1]) """ ) diff --git a/tests/test_distributions_ComposeTransform.py b/tests/test_distributions_ComposeTransform.py new file mode 100644 index 000000000..53881bf21 --- /dev/null +++ b/tests/test_distributions_ComposeTransform.py @@ -0,0 +1,71 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.distributions.ComposeTransform") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.ones((1, 2)) + tseq = [torch.distributions.SoftmaxTransform()] + t = torch.distributions.ComposeTransform(tseq) + result = t.forward_shape([1,2]) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.ones((1, 2)) + tseq = [torch.distributions.SoftmaxTransform()] + t = torch.distributions.ComposeTransform(tseq, cache_size=0) + result = t.forward_shape([1,2]) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.ones((2, 1)) + tseq = [torch.distributions.SoftmaxTransform()] + t = torch.distributions.ComposeTransform(parts=tseq) + result = t.forward_shape([1,2]) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.ones((2, 1)) + tseq = [torch.distributions.SoftmaxTransform()] + t = torch.distributions.transforms.ComposeTransform(parts=tseq) + result = t.forward_shape([1,2]) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_distributions_continuous_bernoulli_ContinuousBernoulli.py b/tests/test_distributions_ContinuousBernoulli.py similarity index 73% rename from tests/test_distributions_continuous_bernoulli_ContinuousBernoulli.py rename to tests/test_distributions_ContinuousBernoulli.py index b7d32c074..a0b40d781 100644 --- a/tests/test_distributions_continuous_bernoulli_ContinuousBernoulli.py +++ b/tests/test_distributions_ContinuousBernoulli.py @@ -16,14 +16,14 @@ from apibase import APIBase -obj = APIBase("torch.distributions.continuous_bernoulli.ContinuousBernoulli") +obj = APIBase("torch.distributions.ContinuousBernoulli") def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.continuous_bernoulli.ContinuousBernoulli(torch.tensor([0.3])) + m = torch.distributions.ContinuousBernoulli(torch.tensor([0.3])) result = m.sample([100]) """ ) @@ -40,7 +40,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.continuous_bernoulli.ContinuousBernoulli(probs=torch.tensor([0.3]), logits=None) + m = torch.distributions.ContinuousBernoulli(probs=torch.tensor([0.3]), logits=None) result = m.sample([100]) """ ) @@ -54,6 +54,23 @@ def test_case_2(): def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.ContinuousBernoulli(0.3, validate_args=False) + result = m.sample([100]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support this function temporarily", + ) + + +def test_case_4(): pytorch_code = textwrap.dedent( """ import torch diff --git a/tests/test_distributions_dirichlet_Dirichlet.py b/tests/test_distributions_Dirichlet.py similarity index 86% rename from tests/test_distributions_dirichlet_Dirichlet.py rename to tests/test_distributions_Dirichlet.py index 962792cd6..87ca6e2a8 100644 --- a/tests/test_distributions_dirichlet_Dirichlet.py +++ b/tests/test_distributions_Dirichlet.py @@ -16,14 +16,14 @@ from apibase import APIBase -obj = APIBase("torch.distributions.dirichlet.Dirichlet") +obj = APIBase("torch.distributions.Dirichlet") def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.dirichlet.Dirichlet(torch.tensor([0.3])) + m = torch.distributions.Dirichlet(torch.tensor([0.3])) result = m.sample([100]) """ ) @@ -34,7 +34,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.dirichlet.Dirichlet(concentration=torch.tensor([0.3])) + m = torch.distributions.Dirichlet(concentration=torch.tensor([0.3]), validate_args=False) result = m.sample([100]) """ ) diff --git a/tests/test_distributions_exponential_Exponential.py b/tests/test_distributions_Exponential.py similarity index 74% rename from tests/test_distributions_exponential_Exponential.py rename to tests/test_distributions_Exponential.py index 619af3fab..5931c3e64 100644 --- a/tests/test_distributions_exponential_Exponential.py +++ b/tests/test_distributions_Exponential.py @@ -16,14 +16,14 @@ from apibase import APIBase -obj = APIBase("torch.distributions.exponential.Exponential") +obj = APIBase("torch.distributions.Exponential") def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.exponential.Exponential(torch.tensor([1.0])) + m = torch.distributions.Exponential(torch.tensor([1.0])) result = m.sample([100]) """ ) @@ -40,7 +40,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.exponential.Exponential(rate=torch.tensor([1.0])) + m = torch.distributions.Exponential(rate=torch.tensor([1.0])) result = m.sample([100]) """ ) @@ -54,6 +54,23 @@ def test_case_2(): def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Exponential(torch.tensor([1.0]), validate_args=False) + result = m.sample([100]) + """ + ) + obj.run( + pytorch_code, + ["result"], + check_value=False, + unsupport=True, + reason="paddle does not support this function temporarily", + ) + + +def test_case_4(): pytorch_code = textwrap.dedent( """ import torch diff --git a/tests/test_distributions_exp_family_ExponentialFamily.py b/tests/test_distributions_ExponentialFamily.py similarity index 75% rename from tests/test_distributions_exp_family_ExponentialFamily.py rename to tests/test_distributions_ExponentialFamily.py index 3d21e7441..c5076f04d 100644 --- a/tests/test_distributions_exp_family_ExponentialFamily.py +++ b/tests/test_distributions_ExponentialFamily.py @@ -33,14 +33,14 @@ def compare( return False -obj = ExponentialFamilyAPIBase("torch.distributions.exp_family.ExponentialFamily") +obj = ExponentialFamilyAPIBase("torch.distributions.ExponentialFamily") def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - result = torch.distributions.exp_family.ExponentialFamily() + result = torch.distributions.ExponentialFamily() """ ) obj.run(pytorch_code, ["result"]) @@ -50,13 +50,23 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - result = torch.distributions.exp_family.ExponentialFamily(batch_shape=torch.Size([1]), event_shape=torch.Size([2])) + result = torch.distributions.ExponentialFamily(batch_shape=torch.Size([1]), event_shape=torch.Size([2])) """ ) obj.run(pytorch_code, ["result"]) def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.distributions.ExponentialFamily(batch_shape=torch.Size([1]), event_shape=torch.Size([2]), validate_args=False) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): pytorch_code = textwrap.dedent( """ import torch diff --git a/tests/test_distributions_geometric_Geometric.py b/tests/test_distributions_Geometric.py similarity index 75% rename from tests/test_distributions_geometric_Geometric.py rename to tests/test_distributions_Geometric.py index c69b3ba89..de1b67587 100644 --- a/tests/test_distributions_geometric_Geometric.py +++ b/tests/test_distributions_Geometric.py @@ -16,14 +16,14 @@ from apibase import APIBase -obj = APIBase("torch.distributions.geometric.Geometric") +obj = APIBase("torch.distributions.Geometric") def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.geometric.Geometric(torch.tensor([0.3])) + m = torch.distributions.Geometric(torch.tensor([0.3])) result = m.sample([100]) """ ) @@ -34,7 +34,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.geometric.Geometric(probs=torch.tensor([0.3]), logits=None) + m = torch.distributions.Geometric(probs=torch.tensor([0.3]), logits=None) result = m.sample([100]) """ ) @@ -48,6 +48,17 @@ def test_case_2(): def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Geometric(0.3, validate_args=False) + result = m.sample([100]) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_4(): pytorch_code = textwrap.dedent( """ import torch diff --git a/tests/test_distributions_gumbel_Gumbel.py b/tests/test_distributions_Gumbel.py similarity index 85% rename from tests/test_distributions_gumbel_Gumbel.py rename to tests/test_distributions_Gumbel.py index 7451da118..159b8a4ac 100644 --- a/tests/test_distributions_gumbel_Gumbel.py +++ b/tests/test_distributions_Gumbel.py @@ -16,14 +16,14 @@ from apibase import APIBase -obj = APIBase("torch.distributions.gumbel.Gumbel") +obj = APIBase("torch.distributions.Gumbel") def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.gumbel.Gumbel(torch.tensor([1.0]), torch.tensor([2.0])) + m = torch.distributions.Gumbel(torch.tensor([1.0]), torch.tensor([2.0])) result = m.sample([100]) """ ) @@ -34,7 +34,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.gumbel.Gumbel(loc=torch.tensor([1.0]), scale=torch.tensor([2.0])) + m = torch.distributions.Gumbel(loc=torch.tensor([1.0]), scale=torch.tensor([2.0]), validate_args=False) result = m.sample([100]) """ ) diff --git a/tests/test_distributions_independent_Independent.py b/tests/test_distributions_Independent.py similarity index 60% rename from tests/test_distributions_independent_Independent.py rename to tests/test_distributions_Independent.py index e81e74cda..9aeedc03c 100644 --- a/tests/test_distributions_independent_Independent.py +++ b/tests/test_distributions_Independent.py @@ -16,15 +16,15 @@ from apibase import APIBase -obj = APIBase("torch.distributions.independent.Independent") +obj = APIBase("torch.distributions.Independent") def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - beta = torch.distributions.beta.Beta(torch.tensor([0.5, 0.5]), torch.tensor([0.5, 0.5])) - reinterpreted_beta = torch.distributions.independent.Independent(beta, 1) + beta = torch.distributions.Beta(torch.tensor([0.5, 0.5]), torch.tensor([0.5, 0.5])) + reinterpreted_beta = torch.distributions.Independent(beta, 1) result = reinterpreted_beta.log_prob(torch.tensor([0.2, 0.2])) """ ) @@ -35,8 +35,8 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - beta = torch.distributions.beta.Beta(torch.tensor([0.5, 0.5]), torch.tensor([0.5, 0.5])) - reinterpreted_beta = torch.distributions.independent.Independent(base_distribution=beta, reinterpreted_batch_ndims=1) + beta = torch.distributions.Beta(torch.tensor([0.5, 0.5]), torch.tensor([0.5, 0.5])) + reinterpreted_beta = torch.distributions.Independent(base_distribution=beta, reinterpreted_batch_ndims=1) result = reinterpreted_beta.log_prob(torch.tensor([0.2, 0.2])) """ ) @@ -47,7 +47,19 @@ def test_case_3(): pytorch_code = textwrap.dedent( """ import torch - beta = torch.distributions.beta.Beta(torch.tensor([0.5, 0.5]), torch.tensor([0.5, 0.5])) + beta = torch.distributions.Beta(torch.tensor([0.5, 0.5]), torch.tensor([0.5, 0.5])) + reinterpreted_beta = torch.distributions.Independent(beta, 1, validate_args=False) + result = reinterpreted_beta.log_prob(torch.tensor([0.2, 0.2])) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + beta = torch.distributions.Beta(torch.tensor([0.5, 0.5]), torch.tensor([0.5, 0.5])) reinterpreted_beta = torch.distributions.independent.Independent(beta, 1, validate_args=False) result = reinterpreted_beta.log_prob(torch.tensor([0.2, 0.2])) """ diff --git a/tests/test_distributions_laplace_Laplace.py b/tests/test_distributions_Laplace.py similarity index 84% rename from tests/test_distributions_laplace_Laplace.py rename to tests/test_distributions_Laplace.py index a9dc9ad80..2e9be4fb8 100644 --- a/tests/test_distributions_laplace_Laplace.py +++ b/tests/test_distributions_Laplace.py @@ -16,14 +16,14 @@ from apibase import APIBase -obj = APIBase("torch.distributions.laplace.Laplace") +obj = APIBase("torch.distributions.Laplace") def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.laplace.Laplace(torch.tensor([0.0]), torch.tensor([1.0])) + m = torch.distributions.Laplace(torch.tensor([0.0]), torch.tensor([1.0])) result = m.sample([1]) """ ) @@ -34,7 +34,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.laplace.Laplace(loc=torch.tensor([0.0]), scale=torch.tensor([1.0])) + m = torch.distributions.Laplace(loc=torch.tensor([0.0]), scale=torch.tensor([1.0]), validate_args=False) result = m.sample([1]) """ ) diff --git a/tests/test_distributions_log_normal_LogNormal.py b/tests/test_distributions_LogNormal.py similarity index 84% rename from tests/test_distributions_log_normal_LogNormal.py rename to tests/test_distributions_LogNormal.py index 7ff74c19d..3a3106a17 100644 --- a/tests/test_distributions_log_normal_LogNormal.py +++ b/tests/test_distributions_LogNormal.py @@ -16,14 +16,14 @@ from apibase import APIBase -obj = APIBase("torch.distributions.log_normal.LogNormal") +obj = APIBase("torch.distributions.LogNormal") def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.log_normal.LogNormal(torch.tensor([0.0]), torch.tensor([1.0])) + m = torch.distributions.LogNormal(torch.tensor([0.0]), torch.tensor([1.0])) result = m.sample([1]) """ ) @@ -34,7 +34,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.log_normal.LogNormal(loc=torch.tensor([0.0]), scale=torch.tensor([1.0])) + m = torch.distributions.LogNormal(loc=torch.tensor([0.0]), scale=torch.tensor([1.0]), validate_args=False) result = m.sample([1]) """ ) diff --git a/tests/test_distributions_multinomial_Multinomial.py b/tests/test_distributions_Multinomial.py similarity index 74% rename from tests/test_distributions_multinomial_Multinomial.py rename to tests/test_distributions_Multinomial.py index 45ef6b572..4a755d3cd 100644 --- a/tests/test_distributions_multinomial_Multinomial.py +++ b/tests/test_distributions_Multinomial.py @@ -16,14 +16,14 @@ from apibase import APIBase -obj = APIBase("torch.distributions.multinomial.Multinomial") +obj = APIBase("torch.distributions.Multinomial") def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.multinomial.Multinomial(1, torch.tensor([0.3])) + m = torch.distributions.Multinomial(1, torch.tensor([0.3])) result = m.sample([100]) """ ) @@ -34,7 +34,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.multinomial.Multinomial(total_count=1, probs=torch.tensor([0.3]), logits=None) + m = torch.distributions.Multinomial(total_count=1, probs=torch.tensor([0.3]), logits=None) result = m.sample([100]) """ ) @@ -48,6 +48,17 @@ def test_case_2(): def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Multinomial(1, torch.tensor([0.3]), validate_args=False) + result = m.sample([100]) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_4(): pytorch_code = textwrap.dedent( """ import torch diff --git a/tests/test_distributions_normal_Normal.py b/tests/test_distributions_Normal.py similarity index 85% rename from tests/test_distributions_normal_Normal.py rename to tests/test_distributions_Normal.py index 1229c31a8..b8c724ea6 100644 --- a/tests/test_distributions_normal_Normal.py +++ b/tests/test_distributions_Normal.py @@ -16,14 +16,14 @@ from apibase import APIBase -obj = APIBase("torch.distributions.normal.Normal") +obj = APIBase("torch.distributions.Normal") def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.normal.Normal(torch.tensor([0.0]), torch.tensor([1.0])) + m = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor([1.0])) result = m.sample([1]) """ ) @@ -34,7 +34,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.normal.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([1.0])) + m = torch.distributions.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([1.0]), validate_args=False) result = m.sample([1]) """ ) diff --git a/tests/test_distributions_one_hot_categorical_OneHotCategorical.py b/tests/test_distributions_OneHotCategorical.py similarity index 83% rename from tests/test_distributions_one_hot_categorical_OneHotCategorical.py rename to tests/test_distributions_OneHotCategorical.py index 8aebe8abc..1c603338b 100644 --- a/tests/test_distributions_one_hot_categorical_OneHotCategorical.py +++ b/tests/test_distributions_OneHotCategorical.py @@ -16,14 +16,14 @@ from apibase import APIBase -obj = APIBase("torch.distributions.one_hot_categorical.OneHotCategorical") +obj = APIBase("torch.distributions.OneHotCategorical") def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.one_hot_categorical.OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) + m = torch.distributions.OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) result = m.sample() """ ) @@ -39,7 +39,7 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - m = torch.distributions.one_hot_categorical.OneHotCategorical(probs=torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]), logits=None) + m = torch.distributions.OneHotCategorical(probs=torch.tensor([ 0.25, 0.25, 0.25, 0.25 ]), logits=None, validate_args=False) result = m.sample() """ ) diff --git a/tests/test_distributions_transforms_StackTransform.py b/tests/test_distributions_SigmoidTransform.py similarity index 67% rename from tests/test_distributions_transforms_StackTransform.py rename to tests/test_distributions_SigmoidTransform.py index 81bc8557c..4737743c8 100644 --- a/tests/test_distributions_transforms_StackTransform.py +++ b/tests/test_distributions_SigmoidTransform.py @@ -16,16 +16,15 @@ from apibase import APIBase -obj = APIBase("torch.distributions.transforms.StackTransform") +obj = APIBase("torch.distributions.SigmoidTransform") def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - x = torch.ones((1, 2)) - tseq = [torch.distributions.transforms.SoftmaxTransform()] - t = torch.distributions.transforms.StackTransform(tseq) + x = torch.ones((2,3)) + t = torch.distributions.SigmoidTransform() result = t(x) """ ) @@ -36,9 +35,8 @@ def test_case_2(): pytorch_code = textwrap.dedent( """ import torch - x = torch.ones((1, 2)) - tseq = [torch.distributions.transforms.SoftmaxTransform()] - t = torch.distributions.transforms.StackTransform(tseq, cache_size=0) + x = torch.ones((2,3)) + t = torch.distributions.SigmoidTransform(cache_size=0) result = t(x) """ ) @@ -49,9 +47,8 @@ def test_case_3(): pytorch_code = textwrap.dedent( """ import torch - x = torch.ones((2, 1)) - tseq = [torch.distributions.transforms.SoftmaxTransform()] - t = torch.distributions.transforms.StackTransform(tseq, dim=1) + x = torch.ones((2,3)) + t = torch.distributions.transforms.SigmoidTransform(cache_size=0) result = t(x) """ ) diff --git a/tests/test_distributions_transforms_SoftmaxTransform.py b/tests/test_distributions_SoftmaxTransform.py similarity index 75% rename from tests/test_distributions_transforms_SoftmaxTransform.py rename to tests/test_distributions_SoftmaxTransform.py index 42c542cb3..fef4f73a4 100644 --- a/tests/test_distributions_transforms_SoftmaxTransform.py +++ b/tests/test_distributions_SoftmaxTransform.py @@ -16,7 +16,7 @@ from apibase import APIBase -obj = APIBase("torch.distributions.transforms.SoftmaxTransform") +obj = APIBase("torch.distributions.SoftmaxTransform") def test_case_1(): @@ -24,7 +24,7 @@ def test_case_1(): """ import torch x = torch.ones((2,3)) - t = torch.distributions.transforms.SoftmaxTransform() + t = torch.distributions.SoftmaxTransform() result = t(x) """ ) @@ -32,6 +32,18 @@ def test_case_1(): def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.ones((2,3)) + t = torch.distributions.SoftmaxTransform(cache_size=0) + result = t(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): pytorch_code = textwrap.dedent( """ import torch diff --git a/tests/test_distributions_StackTransform.py b/tests/test_distributions_StackTransform.py new file mode 100644 index 000000000..661d527db --- /dev/null +++ b/tests/test_distributions_StackTransform.py @@ -0,0 +1,71 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.distributions.StackTransform") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.ones((1, 2)) + tseq = [torch.distributions.SoftmaxTransform()] + t = torch.distributions.StackTransform(tseq) + result = t(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.ones((1, 2)) + tseq = [torch.distributions.SoftmaxTransform()] + t = torch.distributions.StackTransform(tseq, cache_size=0) + result = t(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.ones((2, 1)) + tseq = [torch.distributions.SoftmaxTransform()] + t = torch.distributions.StackTransform(tseq, dim=1) + result = t(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.ones((2, 1)) + tseq = [torch.distributions.SoftmaxTransform()] + t = torch.distributions.transforms.StackTransform(tseq, dim=1) + result = t(x) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_distributions_transforms_StickBreakingTransform.py b/tests/test_distributions_StickBreakingTransform.py similarity index 74% rename from tests/test_distributions_transforms_StickBreakingTransform.py rename to tests/test_distributions_StickBreakingTransform.py index ddc928f78..b311cdabc 100644 --- a/tests/test_distributions_transforms_StickBreakingTransform.py +++ b/tests/test_distributions_StickBreakingTransform.py @@ -16,7 +16,7 @@ from apibase import APIBase -obj = APIBase("torch.distributions.transforms.StickBreakingTransform") +obj = APIBase("torch.distributions.StickBreakingTransform") def test_case_1(): @@ -24,7 +24,7 @@ def test_case_1(): """ import torch x = torch.ones((2,3)) - t = torch.distributions.transforms.StickBreakingTransform() + t = torch.distributions.StickBreakingTransform() result = t(x) """ ) @@ -32,6 +32,18 @@ def test_case_1(): def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.ones((2,3)) + t = torch.distributions.StickBreakingTransform(cache_size=0) + result = t(x) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): pytorch_code = textwrap.dedent( """ import torch diff --git a/tests/test_distributions_Transform.py b/tests/test_distributions_Transform.py new file mode 100644 index 000000000..ac479010e --- /dev/null +++ b/tests/test_distributions_Transform.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +import paddle +from apibase import APIBase + + +class TransformAPIBase(APIBase): + def compare( + self, + name, + pytorch_result, + paddle_result, + check_value=True, + check_dtype=True, + check_stop_gradient=True, + ): + assert isinstance(paddle_result, paddle.distribution.transform.Transform) + + +obj = TransformAPIBase("torch.distributions.Transform") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.distributions.Transform() + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.distributions.Transform(cache_size=0) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.distributions.transforms.Transform(cache_size=0) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_distributions_TransformedDistribution.py b/tests/test_distributions_TransformedDistribution.py new file mode 100644 index 000000000..d70b6a3bb --- /dev/null +++ b/tests/test_distributions_TransformedDistribution.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.distributions.TransformedDistribution") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + base_distribution = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor([1.0])) + transforms = [torch.distributions.SigmoidTransform()] + m = torch.distributions.TransformedDistribution(base_distribution, transforms) + result = m.sample([10]) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + base_distribution = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor([1.0])) + transforms = [torch.distributions.SigmoidTransform()] + m = torch.distributions.TransformedDistribution(base_distribution, transforms=transforms, validate_args=None) + result = m.sample([10]) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + base_distribution = torch.distributions.Normal(torch.tensor([0.0]), torch.tensor([1.0])) + transforms = [torch.distributions.SigmoidTransform()] + m = torch.distributions.transformed_distribution.TransformedDistribution(base_distribution, transforms=transforms, validate_args=None) + result = m.sample([10]) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) diff --git a/tests/test_distributions_Uniform.py b/tests/test_distributions_Uniform.py new file mode 100644 index 000000000..7996aa667 --- /dev/null +++ b/tests/test_distributions_Uniform.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.distributions.Uniform") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Uniform(torch.tensor([0.0]), torch.tensor([5.0])) + result = m.sample([10]) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.Uniform(low=torch.tensor([0.0]), high=torch.tensor([5.0]), validate_args=None) + result = m.sample([10]) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + m = torch.distributions.uniform.Uniform(low=torch.tensor([0.0]), high=torch.tensor([5.0]), validate_args=None) + result = m.sample([10]) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False)