diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index a35ff8b29..15c55b3ce 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -844,7 +844,22 @@ "memory_format" ] }, - "torch.Tensor.cauchy_": {}, + "torch.Tensor.cauchy_": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.Tensor.cauchy_", + "min_input_args": 0, + "args_list": [ + "median", + "sigma", + "*", + "generator" + ], + "kwargs_change": { + "median": "loc", + "sigma":"scale", + "generator":"" + } + }, "torch.Tensor.cdouble": { "Matcher": "TensorCdoubleMatcher", "paddle_api": "paddle.Tensor.astype", @@ -1711,7 +1726,20 @@ "other": "y" } }, - "torch.Tensor.geometric_": {}, + "torch.Tensor.geometric_": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.Tensor.geometric_", + "min_input_args": 1, + "args_list": [ + "p", + "*", + "generator" + ], + "kwargs_change": { + "p": "probs", + "generator":"" + } + }, "torch.Tensor.geqrf": {}, "torch.Tensor.ger": { "Matcher": "GenericMatcher", @@ -2088,7 +2116,17 @@ "paddle_api": "paddle.Tensor.is_floating_point", "min_input_args": 0 }, - "torch.Tensor.is_inference": {}, + "torch.Tensor.is_inference": { + "Matcher": "Is_InferenceMatcher", + "min_input_args": 0 + }, + "torch.is_inference": { + "Matcher": "Is_InferenceMatcher", + "min_input_args": 1, + "args_list":[ + "input" + ] + }, "torch.Tensor.is_pinned": { "Matcher": "Is_PinnedMatcher", "min_input_args": 0 @@ -3166,7 +3204,22 @@ "Matcher": "UnchangeMatcher", "min_input_args": 0 }, - "torch.Tensor.random_": {}, + "torch.Tensor.random_": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.Tensor.uniform_", + "min_input_args": 0, + "args_list": [ + "from", + "to", + "*", + "generator" + ], + "kwargs_change": { + "from": "min", + "to": "max", + "generator": "" + } + }, "torch.Tensor.ravel": { "Matcher": "GenericMatcher", "paddle_api": "paddle.Tensor.flatten", @@ -6298,6 +6351,18 @@ ], "min_input_args": 1 }, + "torch.distributed.rpc.remote":{ + "Matcher": "RpcRemoteMatcher", + "paddle_api": "paddle.distributed.rpc.rpc_async", + "min_input_args": 2, + "args_list": [ + "to", + "func", + "args", + "kwargs", + "timeout" + ] + }, "torch.distributed.rpc.shutdown": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distributed.rpc.shutdown", @@ -6381,6 +6446,52 @@ }, "min_input_args": 2 }, + "torch.distributions.lkj_cholesky.LKJCholesky":{ + "Matcher": "LKJCholeskyMatcher", + "paddle_api": "paddle.distribution.LKJCholesky", + "min_input_args": 1, + "args_list": [ + "dim", + "concentration", + "validate_args" + ] + }, + "torch.distributions.studentT.StudentT":{ + "Matcher": "StudentTMatcher", + "paddle_api": "paddle.distribution.StudentT", + "min_input_args": 1, + "args_list": [ + "df", + "loc", + "scale", + "validate_args" + ], + "kwargs_change": { + "validate_args": "" + } + }, + "torch.distributions.transforms.PositiveDefiniteTransform":{ + "Matcher": "TransformsPositiveDefiniteTransformMatcher", + "min_input_args": 0, + "args_list": [ + "cache_size" + ], + "kwargs_change": { + "cache_size": "" + } + }, + "torch.distributions.poisson.Poisson":{ + "Matcher": "GenericMatcher", + "paddle_api": "paddle.distribution.Poisson", + "min_input_args": 1, + "args_list": [ + "rate", + "validate_args" + ], + "kwargs_change": { + "validate_args": "" + } + }, "torch.distributions.Bernoulli": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.Bernoulli", @@ -6431,6 +6542,18 @@ "total_count": "1" } }, + "torch.distributions.chi2.Chi2":{ + "Matcher": "GenericMatcher", + "paddle_api": "paddle.distribution.Chi2", + "min_input_args": 1, + "args_list": [ + "df", + "validate_args" + ], + "kwargs_change": { + "validate_args": "" + } + }, "torch.distributions.Categorical": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.Categorical", @@ -6473,6 +6596,24 @@ "cache_size": "" } }, + "torch.distributions.constraints.Constraint" : { + "Matcher": "DistributionsConstrainMatcher", + "paddle_api": "paddle.distribution.constraint.Constraint", + "abstract": true + }, + "torch.distributions.gamma.Gamma":{ + "Matcher": "GenericMatcher", + "paddle_api": "paddle.distribution.Gamma", + "min_input_args": 2, + "args_list": [ + "concentration", + "rate", + "validate_args" + ], + "kwargs_change": { + "validate_args": "" + } + }, "torch.distributions.ContinuousBernoulli": { "Matcher": "GenericMatcher", "paddle_api": "paddle.distribution.ContinuousBernoulli", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 947b3bb8f..1f5791e88 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -751,6 +751,127 @@ def get_paddle_nodes(self, args, kwargs): for i in range(1, len(new_args)): code = "{}({}, {})".format(self.get_paddle_api(), code, new_args[i]) return ast.parse(code).body + + +class StudentTMatcher(BaseMatcher): + def generate_aux_code(self): + API_TEMPLATE = textwrap.dedent( + """ + import paddle + class StudentT_Aux_Class: + def __init__(self, df, loc, scale): + self.df = paddle.to_tensor(df) + self.loc = paddle.to_tensor(loc) + self.scale = paddle.to_tensor(scale) + self.sT = paddle.distribution.StudentT(self.df, self.loc, self.scale) + def sample(self): + return paddle.reshape(self.sT.sample(), self.df.shape) + """ + ) + + return API_TEMPLATE + def generate_code(self, kwargs): + self.write_aux_code() + if "validate_args" in kwargs: + del kwargs["validate_args"] + if "loc" not in kwargs: + kwargs["loc"] = 0.1 + if "scale" not in kwargs: + kwargs["scale"] = 1.0 + kwargs = self.kwargs_to_str(kwargs) + API_TEMPLATE = textwrap.dedent( + """ + paddle_aux.StudentT_Aux_Class({}) + """ + ) + code = API_TEMPLATE.format(kwargs) + return code + + +class TransformsPositiveDefiniteTransformMatcher(BaseMatcher): + def generate_aux_code(self): + API_TEMPLATE = textwrap.dedent( + """ + import paddle + class TransformsPositiveDefiniteTransform: + def __call__(self, x): + x = x.tril(-1) + x.diagonal(axis1=-2, axis2=-1).exp().diag_embed() + return x @ x.T + + def inv(self, y): + y = paddle.linalg.cholesky(y) + return y.tril(-1) + y.diagonal(axis1=-2, axis2=-1).log().diag_embed() + """ + ) + + return API_TEMPLATE + def generate_code(self, kwargs): + self.write_aux_code() + API_TEMPLATE = textwrap.dedent( + """ + paddle_aux.TransformsPositiveDefiniteTransform() + """ + ) + return API_TEMPLATE + + +class LKJCholeskyMatcher(BaseMatcher): + def generate_aux_code(self): + API_TEMPLATE = textwrap.dedent( + """ + import paddle + class LKJCholesky_Aux_Class: + def __init__(self, dim, concentration, sample_method='onion'): + self.lkj = paddle.distribution.LKJCholesky(dim, concentration, sample_method) + def sample(self): + return paddle.unsqueeze(self.lkj.sample(), axis=0) + """ + ) + + return API_TEMPLATE + def generate_code(self, kwargs): + self.write_aux_code() + if "validate_args" in kwargs: + del kwargs["validate_args"] + kwargs = self.kwargs_to_str(kwargs) + API_TEMPLATE = textwrap.dedent( + """ + paddle_aux.LKJCholesky_Aux_Class({}) + """ + ) + code = API_TEMPLATE.format(kwargs) + return code + + + +class Is_InferenceMatcher(BaseMatcher): + def generate_code(self, kwargs): + if "input" not in kwargs: + kwargs["input"] = self.paddleClass + code = "{}.stop_gradient".format(kwargs["input"]) + return code + + +class DistributionsConstrainMatcher(BaseMatcher): + def generate_aux_code(self): + API_TEMPLATE = textwrap.dedent( + """ + import paddle + class DistributionsConstrain: + def check(self, value): + return paddle.distribution.constraint.Constraint()(value) + """ + ) + + return API_TEMPLATE + def generate_code(self, kwargs): + self.write_aux_code() + API_TEMPLATE = textwrap.dedent( + """ + paddle_aux.DistributionsConstrain() + """ + ) + return API_TEMPLATE class IInfoMatcher(BaseMatcher): @@ -5067,6 +5188,35 @@ def generate_code(self, kwargs): return code +class RpcRemoteMatcher(BaseMatcher): + def generate_aux_code(self): + CODE_TEMPLATE = textwrap.dedent( + """ + class rpc_remote: + def __init__(self, remote_obj): + self.remote = remote_obj + + def to_here(self): + return self.remote.wait() + """ + ) + return CODE_TEMPLATE + + def generate_code(self, kwargs): + self.write_aux_code() + kwargs['fn'] = kwargs.pop('func') + kwargs = self.kwargs_to_str(kwargs) + API_TEMPLATE = textwrap.dedent( + """ + paddle_aux.rpc_remote(paddle.distributed.rpc.rpc_async({})) + """ + ) + code = API_TEMPLATE.format( + kwargs + ) + return code + + class GetNumThreadsMatcher(BaseMatcher): def generate_code(self, kwargs): API_TEMPLATE = textwrap.dedent( diff --git a/tests/test_Tensor_cauchy_.py b/tests/test_Tensor_cauchy_.py new file mode 100644 index 000000000..3fcccd75b --- /dev/null +++ b/tests/test_Tensor_cauchy_.py @@ -0,0 +1,85 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.cauchy_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]).cauchy_() + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + result = input.cauchy_() + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + result = input.cauchy_(median=0) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + result = input.cauchy_(median=0, sigma=1) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + result = input.cauchy_(median=0, sigma=1, generator=None) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + result = input.cauchy_(median=0, generator=None, sigma=1) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) diff --git a/tests/test_Tensor_geometric_.py b/tests/test_Tensor_geometric_.py new file mode 100644 index 000000000..f550e3b34 --- /dev/null +++ b/tests/test_Tensor_geometric_.py @@ -0,0 +1,74 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.geometric_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]).geometric_(0.5) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + result = input.geometric_(0.5) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + result = input.geometric_(p=0.5) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + result = input.geometric_(p=0.5, generator=None) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + result = input.geometric_(generator=None, p=0.5) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) diff --git a/tests/test_Tensor_is_inference.py b/tests/test_Tensor_is_inference.py new file mode 100644 index 000000000..314c2b871 --- /dev/null +++ b/tests/test_Tensor_is_inference.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.is_inference", is_aux_api=True) + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + result = x.is_inference() + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]).is_inference() + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) diff --git a/tests/test_Tensor_random_.py b/tests/test_Tensor_random_.py new file mode 100644 index 000000000..dc706b9c5 --- /dev/null +++ b/tests/test_Tensor_random_.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.random_") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]).random_(0, 5) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + result = input.random_(0, 5) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + result = input.random_(0, to=5) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + input = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + result = input.random_(0, to=5, generator=None) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) diff --git a/tests/test_distributed_optim_DistributedOptimizer.py b/tests/test_distributed_optim_DistributedOptimizer.py new file mode 100644 index 000000000..7c4f2c62f --- /dev/null +++ b/tests/test_distributed_optim_DistributedOptimizer.py @@ -0,0 +1,56 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase +from optimizer_helper import generate_optimizer_test_code + +obj = APIBase("torch.distributed.optim.DistributedOptimizer") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import os + import torch + from torch import optim + import torch.distributed.rpc as rpc + + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '29500' + os.environ['PADDLE_MASTER_ENDPOINT'] = 'localhost:29501' + rpc.init_rpc( + "worker1", + rank=0, + world_size=1 + ) + # Forward pass. + rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) + rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) + # Optimizer. + dist_optim = torch.distributed.optim.DistributedOptimizer( + optim.SGD, + [rref1, rref2], + lr=0.05, + ) + rpc.shutdown() + + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="paddle does not support tensor in DistributedOptimizer",) diff --git a/tests/test_distributed_rpc_remote.py b/tests/test_distributed_rpc_remote.py new file mode 100644 index 000000000..2de74017a --- /dev/null +++ b/tests/test_distributed_rpc_remote.py @@ -0,0 +1,148 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.distributed.rpc.remote", is_aux_api=True) + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import os + import torch + import socket + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + start = 25000 + end = 30000 + for port in range(start, end): + try: + s.bind(('localhost', port)) + s.close() + break + except socket.error: + continue + print("port: " + str(port)) + + from torch.distributed import rpc + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = str(port) + os.environ['PADDLE_MASTER_ENDPOINT'] = 'localhost:' + str(port) + rpc.init_rpc( + "worker1", + rank=0, + world_size=1 + ) + r = rpc.remote( + "worker1", + min, + args=(2, 1) + ) + result = r.to_here() + rpc.shutdown() + """ + ) + obj.run( + pytorch_code, + ["result"] + ) + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import os + import torch + import socket + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + start = 25000 + end = 30000 + for port in range(start, end): + try: + s.bind(('localhost', port)) + s.close() + break + except socket.error: + continue + print("port: " + str(port)) + + from torch.distributed import rpc + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = str(port) + os.environ['PADDLE_MASTER_ENDPOINT'] = 'localhost:' + str(port) + rpc.init_rpc( + "worker1", + rank=0, + world_size=1 + ) + r = rpc.remote( + to="worker1", + func=min, + args=(2, 1), + kwargs=None, + timeout=-1 + ) + result = r.to_here() + rpc.shutdown() + """ + ) + obj.run( + pytorch_code, + ["result"] + ) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import os + import torch + import socket + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + start = 25000 + end = 30000 + for port in range(start, end): + try: + s.bind(('localhost', port)) + s.close() + break + except socket.error: + continue + print("port: " + str(port)) + + from torch.distributed import rpc + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = str(port) + os.environ['PADDLE_MASTER_ENDPOINT'] = 'localhost:' + str(port) + rpc.init_rpc( + "worker1", + rank=0, + world_size=1 + ) + r = rpc.remote( + to="worker1", + func=min, + args=(2, 1), + timeout=-1, + kwargs=None + ) + result = r.to_here() + rpc.shutdown() + """ + ) + obj.run( + pytorch_code, + ["result"] + ) diff --git a/tests/test_distributions_chi2_Chi2.py b/tests/test_distributions_chi2_Chi2.py new file mode 100644 index 000000000..66957ee5c --- /dev/null +++ b/tests/test_distributions_chi2_Chi2.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.distributions.chi2.Chi2") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.chi2.Chi2(x).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.chi2.Chi2(df=x).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.chi2.Chi2(df=x, validate_args=None).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.chi2.Chi2(validate_args=None, df=x).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) diff --git a/tests/test_distributions_constraints_Constraint.py b/tests/test_distributions_constraints_Constraint.py index 683ecce39..5d52a9f35 100644 --- a/tests/test_distributions_constraints_Constraint.py +++ b/tests/test_distributions_constraints_Constraint.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,19 +16,37 @@ from apibase import APIBase -obj = APIBase("torch.distributions.constraints.Constraint") +obj = APIBase("torch.distributions.constraints.Constraint", is_aux_api=True) def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - result = torch.distributions.constraints.Constraint() + try: + result = torch.distributions.constraints.Constraint().check(1) + except NotImplementedError: + result = torch.tensor(1) """ ) obj.run( pytorch_code, - ["result"], - unsupport=True, - reason="paddle does not support this function temporarily", + ["result"] + ) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + try: + con = torch.distributions.constraints.Constraint() + result = con.check(value=1) + except NotImplementedError: + result = torch.tensor(1) + """ + ) + obj.run( + pytorch_code, + ["result"] ) diff --git a/tests/test_distributions_gamma_Gamma.py b/tests/test_distributions_gamma_Gamma.py new file mode 100644 index 000000000..d1466a13a --- /dev/null +++ b/tests/test_distributions_gamma_Gamma.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.distributions.gamma.Gamma") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.gamma.Gamma(x, x).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.gamma.Gamma(concentration=x, rate=x).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.gamma.Gamma(concentration=x, rate=x, validate_args=None).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.gamma.Gamma(rate=x, concentration=x, validate_args=None).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) diff --git a/tests/test_distributions_lkj_cholesky_LKJCholesky.py b/tests/test_distributions_lkj_cholesky_LKJCholesky.py new file mode 100644 index 000000000..29ec6c09b --- /dev/null +++ b/tests/test_distributions_lkj_cholesky_LKJCholesky.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.distributions.lkj_cholesky.LKJCholesky", is_aux_api=True) + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.lkj_cholesky.LKJCholesky(3, x).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.lkj_cholesky.LKJCholesky(dim=3, concentration=x).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.lkj_cholesky.LKJCholesky(dim=3, concentration=x, validate_args=None).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.lkj_cholesky.LKJCholesky(concentration=x, dim=3, validate_args=None).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) diff --git a/tests/test_distributions_poisson_Poisson.py b/tests/test_distributions_poisson_Poisson.py new file mode 100644 index 000000000..b11597f55 --- /dev/null +++ b/tests/test_distributions_poisson_Poisson.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.distributions.poisson.Poisson") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.poisson.Poisson(x).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.poisson.Poisson(rate=x).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.poisson.Poisson(rate=x, validate_args=None).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.poisson.Poisson(validate_args=None, rate=x).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) \ No newline at end of file diff --git a/tests/test_distributions_studentT_StudentT.py b/tests/test_distributions_studentT_StudentT.py new file mode 100644 index 000000000..67dad5a76 --- /dev/null +++ b/tests/test_distributions_studentT_StudentT.py @@ -0,0 +1,75 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.distributions.studentT.StudentT", is_aux_api=True) + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.studentT.StudentT(x).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.studentT.StudentT(df=x, loc=0.1, scale=1.0).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.studentT.StudentT(df=x, loc=0.1, scale=1.0, validate_args=None).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.0]) + result = torch.distributions.studentT.StudentT(scale=1.0, loc=0.1, df=x, validate_args=None).sample() + """ + ) + obj.run( + pytorch_code, + ["result"], check_value=False + ) \ No newline at end of file diff --git a/tests/test_distributions_transforms_PositiveDefiniteTransform.py b/tests/test_distributions_transforms_PositiveDefiniteTransform.py new file mode 100644 index 000000000..27054768a --- /dev/null +++ b/tests/test_distributions_transforms_PositiveDefiniteTransform.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.distributions.transforms.PositiveDefiniteTransform", is_aux_api=True) + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 0.0], [2.0, 1.0]]) + a = torch.distributions.transforms.PositiveDefiniteTransform() + result = a(x) + result_inv = a.inv(result) + """ + ) + obj.run( + pytorch_code, + ["result", "result_inv"] + ) + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1.0, 0.0], [2.0, 1.0]]) + a = torch.distributions.transforms.PositiveDefiniteTransform(cache_size=0) + result = a(x) + result_inv = a.inv(result) + """ + ) + obj.run( + pytorch_code, + ["result", "result_inv"] + ) \ No newline at end of file diff --git a/tests/test_is_inference.py b/tests/test_is_inference.py new file mode 100644 index 000000000..2e1f12f84 --- /dev/null +++ b/tests/test_is_inference.py @@ -0,0 +1,41 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.is_inference", is_aux_api=True) + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + result = torch.is_inference(x) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False) + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]) + result = torch.is_inference(input = x) + """ + ) + obj.run(pytorch_code, ["result"], check_value=False)