Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Jul 21, 2023
1 parent 73527d9 commit 0b84789
Show file tree
Hide file tree
Showing 15 changed files with 621 additions and 6 deletions.
35 changes: 29 additions & 6 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -6673,7 +6673,10 @@
"torch.nn.Module.apply": {
"Matcher": "TensorUnchangeMatcher"
},
"torch.nn.Module.bfloat16": {},
"torch.nn.Module.bfloat16": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.Layer.to"
},
"torch.nn.Module.buffers": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.Layer.buffers",
Expand All @@ -6690,16 +6693,25 @@
},
"torch.nn.Module.cpu": {},
"torch.nn.Module.cuda": {},
"torch.nn.Module.double": {},
"torch.nn.Module.double": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.Layer.to"
},
"torch.nn.Module.eval": {
"Matcher": "TensorUnchangeMatcher"
},
"torch.nn.Module.float": {},
"torch.nn.Module.float": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.Layer.to"
},
"torch.nn.Module.get_buffer": {},
"torch.nn.Module.get_extra_state": {},
"torch.nn.Module.get_parameter": {},
"torch.nn.Module.get_submodule": {},
"torch.nn.Module.half": {},
"torch.nn.Module.half": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.Layer.to"
},
"torch.nn.Module.ipu": {},
"torch.nn.Module.load_state_dict": {
"Matcher": "GenericMatcher",
Expand Down Expand Up @@ -6840,13 +6852,24 @@
"keep_vars"
]
},
"torch.nn.Module.to": {},
"torch.nn.Module.to": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.Layer.to",
"args_list": [
"device",
"dtype",
"non_blocking"
]
},
"torch.nn.Module.to_empty": {},
"torch.nn.Module.train": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.Layer.train"
},
"torch.nn.Module.type": {},
"torch.nn.Module.type": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.Layer.to"
},
"torch.nn.Module.xpu": {},
"torch.nn.Module.zero_grad": {},
"torch.nn.ModuleDict": {
Expand Down
17 changes: 17 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3634,6 +3634,23 @@ def generate_code(self, kwargs):
)


class NnModuleTypeMatcher(BaseMatcher):
def get_paddle_nodes(self, args, kwargs):
kwargs = self.parse_kwargs(kwargs)
if "torch.nn.Module.float" == self.torch_api:
code = "paddle.nn.Layer.to(dtype='float32')"
elif "torch.nn.Module.double" == self.torch_api:
code = "paddle.nn.Layer.to(dtype='float64')"
elif "torch.nn.Module.half" == self.torch_api:
code = "paddle.nn.Layer.to(dtype='float16')"
elif "torch.nn.Module.bfloat16" == self.torch_api:
code = "paddle.nn.Layer.to(dtype='bfloat16')"
else:
code = "paddle.nn.Layer.to(dtype={})".format(kwargs["dst_type"])
node = ast.parse(code.strip("\n")).body
return node


class SizeAverageMatcher(BaseMatcher):
def generate_code(self, kwargs):
process_reduce_and_size_average(kwargs)
Expand Down
33 changes: 33 additions & 0 deletions tests/test_nn_Module_bfloat16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.nn.Module.bfloat16")


def _test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1., 2., 3.])
module1 = torch.nn.Module()
module1.register_buffer('buffer', x)
module1.bfloat16()
result = module1.buffer
"""
)
obj.run(pytorch_code, ["result"])
33 changes: 33 additions & 0 deletions tests/test_nn_Module_double.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.nn.Module.double")


def _test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1., 2., 3.])
module1 = torch.nn.Module()
module1.register_buffer('buffer', x)
module1.double()
result = module1.buffer
"""
)
obj.run(pytorch_code, ["result"])
33 changes: 33 additions & 0 deletions tests/test_nn_Module_float.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.nn.Module.float")


def _test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1., 2., 3.])
module1 = torch.nn.Module()
module1.register_buffer('buffer', x)
module1.float()
result = module1.buffer
"""
)
obj.run(pytorch_code, ["result"])
33 changes: 33 additions & 0 deletions tests/test_nn_Module_half.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.nn.Module.half")


def _test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1., 2., 3.])
module1 = torch.nn.Module()
module1.register_buffer('buffer', x)
module1.half()
result = module1.buffer
"""
)
obj.run(pytorch_code, ["result"])
53 changes: 53 additions & 0 deletions tests/test_nn_Module_register_full_backward_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.nn.Module.register_full_backward_hook")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def backward_after_hook(module, data_input, data_output):
print("I am function after forward function.")
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 3, 3)
self.conv2 = nn.Conv2d(3, 3, 3)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
return x
my_module = MyModule()
my_module.register_full_backward_hook(backward_after_hook)
result = None
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
)
53 changes: 53 additions & 0 deletions tests/test_nn_Module_register_full_backward_pre_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.nn.Module.register_full_backward_pre_hook")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def backward_pre_hook(module, data_input):
print("I am function before forward function.")
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 3, 3)
self.conv2 = nn.Conv2d(3, 3, 3)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
return x
my_module = MyModule()
my_module.register_full_backward_pre_hook(backward_pre_hook)
result = None
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
)
57 changes: 57 additions & 0 deletions tests/test_nn_Module_requires_grad_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.nn.Module.requires_grad_")


def _test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1., 2., 3.])
module1 = torch.nn.Module()
module1.register_buffer('buffer', x)
module1.requires_grad_(True)
result = None
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
)


def _test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1., 2., 3.])
module1 = torch.nn.Module()
module1.register_buffer('buffer', x)
module1.requires_grad_(requires_grad=True)
result = None
"""
)
obj.run(
pytorch_code,
["result"],
unsupport=True,
reason="paddle does not support this function temporarily",
)
Loading

0 comments on commit 0b84789

Please sign in to comment.