Skip to content

Commit

Permalink
转换规则 No. 114-120 (#122)
Browse files Browse the repository at this point in the history
* Add tests

* Fix
  • Loading branch information
co63oc authored Jul 20, 2023
1 parent 3709420 commit 73527d9
Show file tree
Hide file tree
Showing 6 changed files with 618 additions and 0 deletions.
101 changes: 101 additions & 0 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -2892,6 +2892,107 @@
"input": "x"
}
},
"torch.autograd.backward": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.autograd.backward",
"args_list": [
"tensors",
"grad_tensors",
"retain_graph",
"create_graph",
"grad_variables",
"inputs"
],
"unsupport_args": [
"create_graph",
"grad_variables",
"inputs"
]
},
"torch.autograd.functional.hessian": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.incubate.autograd.Hessian",
"args_list": [
"func",
"inputs",
"create_graph",
"strict",
"vectorize",
"outer_jacobian_strategy"
],
"unsupport_args": [
"create_graph",
"strict",
"vectorize",
"outer_jacobian_strategy"
],
"kwargs_change": {
"inputs": "xs"
},
"paddle_default_kwargs": {
"is_batched": false
}
},
"torch.autograd.functional.jacobian": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.incubate.autograd.Jacobian",
"args_list": [
"func",
"inputs",
"create_graph",
"strict",
"vectorize",
"strategy"
],
"unsupport_args": [
"create_graph",
"strict",
"vectorize",
"strategy"
],
"kwargs_change": {
"inputs": "xs"
},
"paddle_default_kwargs": {
"is_batched": false
}
},
"torch.autograd.functional.jvp": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.incubate.autograd.jvp",
"args_list": [
"func",
"inputs",
"v",
"create_graph",
"strict"
],
"unsupport_args": [
"create_graph",
"strict"
],
"kwargs_change": {
"inputs": "xs"
}
},
"torch.autograd.functional.vjp": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.incubate.autograd.vjp",
"args_list": [
"func",
"inputs",
"v",
"create_graph",
"strict"
],
"unsupport_args": [
"create_graph",
"strict"
],
"kwargs_change": {
"inputs": "xs"
}
},
"torch.autograd.grad": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.grad",
Expand Down
77 changes: 77 additions & 0 deletions tests/test_autograd_backward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import textwrap

from apibase import APIBase

obj = APIBase("torch.autograd.backward")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, requires_grad=True)
y = torch.tensor([[3, 2], [3, 4]], dtype=torch.float32)
grad_tensor1 = torch.tensor([[1,2], [2, 3]], dtype=torch.float32)
grad_tensor2 = torch.tensor([[1,1], [1, 1]], dtype=torch.float32)
z1 = torch.matmul(x, y)
z2 = torch.matmul(x, y)
torch.autograd.backward([z1, z2], [grad_tensor1, grad_tensor2], True)
x.grad.requires_grad=False
result = x.grad
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, requires_grad=True)
y = torch.tensor([[3, 2], [3, 4]], dtype=torch.float32)
grad_tensor1 = torch.tensor([[1,2], [2, 3]], dtype=torch.float32)
grad_tensor2 = torch.tensor([[1,1], [1, 1]], dtype=torch.float32)
z1 = torch.matmul(x, y)
z2 = torch.matmul(x, y)
torch.autograd.backward([z1, z2], [grad_tensor1, grad_tensor2], retain_graph=False)
x.grad.requires_grad=False
result = x.grad
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, requires_grad=True)
z1 = x.sum()
torch.autograd.backward([z1])
x.grad.requires_grad=False
result = x.grad
"""
)
obj.run(pytorch_code, ["result"])
108 changes: 108 additions & 0 deletions tests/test_autograd_functional_hessian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import textwrap

from apibase import APIBase

obj = APIBase("torch.autograd.functional.hessian")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
def func(x):
return torch.sum(x * x)
x = torch.rand(2, 2)
h = torch.autograd.functional.hessian(func, x)
result = h[:]
result.requires_grad = False
result = torch.flatten(result)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
def func(x):
return 2 * torch.sum(x * x + 3 * x)
x = torch.rand(2, 2)
h = torch.autograd.functional.hessian(func, x)
result = h[:]
result.requires_grad = False
result = torch.flatten(result)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
def func(x):
return torch.sum(x)
x = torch.tensor([1.0, 2.0])
h = torch.autograd.functional.hessian(func, x)
result = h[:]
result.requires_grad = False
result = torch.flatten(result)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
def func(x):
return torch.sum(x)
x = torch.tensor([1.0, 2.0])
h = torch.autograd.functional.hessian(func, x, create_graph=True)
result = h[:]
result.requires_grad = False
result = torch.flatten(result)
"""
)
obj.run(
pytorch_code, ["result"], unsupport=True, reason="paddle unsupport create_graph"
)


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
def func(x):
return torch.sum(x)
x = torch.tensor([1.0, 2.0])
h = torch.autograd.functional.hessian(func, x, strict=False)
result = h[:]
result.requires_grad = False
result = torch.flatten(result)
"""
)
obj.run(pytorch_code, ["result"], unsupport=True, reason="paddle unsupport strict")
Loading

0 comments on commit 73527d9

Please sign in to comment.