Skip to content

Commit

Permalink
【Hackathon 7th No.39】为 Paddle 代码转换工具新增 API 转换规则(第 6 组) (#477)
Browse files Browse the repository at this point in the history
* 【Hackathon 7th No.39】为 Paddle 代码转换工具新增 API 转换规则(第 6 组)

* add skip

* update Softmin

* update

* update

* fix

* fix codestyle check

* fix

* fix

* fix
  • Loading branch information
Asthestarsfalll authored Nov 27, 2024
1 parent 785f24b commit 3c73e2e
Show file tree
Hide file tree
Showing 15 changed files with 1,766 additions and 5 deletions.
186 changes: 185 additions & 1 deletion paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -10132,6 +10132,20 @@
"output_size"
]
},
"torch.nn.AdaptiveLogSoftmaxWithLoss": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.AdaptiveLogSoftmaxWithLoss",
"min_input_args": 3,
"args_list": [
"in_features",
"n_classes",
"cutoffs",
"div_value",
"head_bias",
"device",
"dtype"
]
},
"torch.nn.AdaptiveMaxPool1d": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.AdaptiveMaxPool1D",
Expand Down Expand Up @@ -10345,6 +10359,17 @@
"groups"
]
},
"torch.nn.CircularPad3d": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.Pad3D",
"min_input_args": 1,
"args_list": [
"padding"
],
"paddle_default_kwargs": {
"mode": "'circular'"
}
},
"torch.nn.ConstantPad1d": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.Pad1D",
Expand Down Expand Up @@ -10981,6 +11006,28 @@
],
"min_input_args": 0
},
"torch.nn.LPPool1d": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.LPPool1D",
"min_input_args": 2,
"args_list": [
"norm_type",
"kernel_size",
"stride",
"ceil_mode"
]
},
"torch.nn.LPPool2d": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.LPPool2D",
"min_input_args": 2,
"args_list": [
"norm_type",
"kernel_size",
"stride",
"ceil_mode"
]
},
"torch.nn.LSTM": {
"Matcher": "RNNMatcher",
"paddle_api": "paddle.nn.LSTM",
Expand Down Expand Up @@ -11948,7 +11995,7 @@
"dim": "axis"
},
"paddle_default_kwargs": {
"axis": 0
"axis": null
}
},
"torch.nn.Softmax2d": {
Expand All @@ -11959,6 +12006,20 @@
},
"min_input_args": 0
},
"torch.nn.Softmin": {
"Matcher": "SoftminMatcher",
"paddle_api": "paddle.nn.Softmax",
"args_list": [
"dim"
],
"kwargs_change": {
"dim": "axis"
},
"paddle_default_kwargs": {
"axis": null
},
"min_input_args": 0
},
"torch.nn.Softplus": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.Softplus",
Expand Down Expand Up @@ -13162,6 +13223,36 @@
"input": "x"
}
},
"torch.nn.functional.lp_pool1d": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.lp_pool1d",
"min_input_args": 2,
"args_list": [
"input",
"norm_type",
"kernel_size",
"stride",
"ceil_mode"
],
"kwargs_change": {
"input": "x"
}
},
"torch.nn.functional.lp_pool2d": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.lp_pool2d",
"min_input_args": 2,
"args_list": [
"input",
"norm_type",
"kernel_size",
"stride",
"ceil_mode"
],
"kwargs_change": {
"input": "x"
}
},
"torch.nn.functional.margin_ranking_loss": {
"Matcher": "SizeAverageMatcher",
"paddle_api": "paddle.nn.functional.margin_ranking_loss",
Expand Down Expand Up @@ -13751,6 +13842,19 @@
"input": "x"
}
},
"torch.nn.functional.threshold_": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.functional.thresholded_relu_",
"min_input_args": 3,
"args_list": [
"input",
"threshold",
"value"
],
"kwargs_change": {
"input": "x"
}
},
"torch.nn.functional.triplet_margin_loss": {
"Matcher": "SizeAverageMatcher",
"paddle_api": "paddle.nn.functional.triplet_margin_loss",
Expand Down Expand Up @@ -14142,6 +14246,19 @@
},
"min_input_args": 1
},
"torch.nn.utils.parametrizations.weight_norm": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.utils.weight_norm",
"args_list": [
"module",
"name",
"dim"
],
"kwargs_change": {
"module": "layer"
},
"min_input_args": 1
},
"torch.nn.utils.remove_weight_norm": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nn.utils.remove_weight_norm",
Expand Down Expand Up @@ -14519,6 +14636,40 @@
"lr": "learning_rate"
}
},
"torch.optim.NAdam": {
"Matcher": "OptimAdamMatcher",
"paddle_api": "paddle.optimizer.NAdam",
"min_input_args": 1,
"args_list": [
"params",
"lr",
"betas",
"eps",
"weight_decay",
"momentum_decay",
"decoupled_weight_decay",
"*",
"foreach",
"maximize",
"capturable",
"differentiable"
],
"unsupport_args": [
"decoupled_weight_decay",
"foreach",
"maximize",
"capturable",
"differentiable"
],
"kwargs_change": {
"params": "parameters",
"lr": "learning_rate",
"eps": "epsilon"
},
"paddle_default_kwargs": {
"weight_decay": 0.0
}
},
"torch.optim.Optimizer": {
"Matcher": "OptimOptimizerMatcher",
"paddle_api": "paddle.optimizer.Optimizer",
Expand Down Expand Up @@ -14559,6 +14710,39 @@
"torch.optim.Optimizer.zero_grad": {
"min_input_args": 0
},
"torch.optim.RAdam": {
"Matcher": "OptimAdamMatcher",
"paddle_api": "paddle.optimizer.RAdam",
"min_input_args": 1,
"args_list": [
"params",
"lr",
"betas",
"eps",
"weight_decay",
"decoupled_weight_decay",
"*",
"foreach",
"maximize",
"capturable",
"differentiable"
],
"unsupport_args": [
"decoupled_weight_decay",
"foreach",
"maximize",
"capturable",
"differentiable"
],
"kwargs_change": {
"params": "parameters",
"lr": "learning_rate",
"eps": "epsilon"
},
"paddle_default_kwargs": {
"weight_decay": 0.0
}
},
"torch.optim.RMSprop": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.optimizer.RMSProp",
Expand Down
26 changes: 26 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4281,6 +4281,32 @@ def generate_code(self, kwargs):
return GenericMatcher.generate_code(self, kwargs)


class SoftminMatcher(BaseMatcher):
def generate_code(self, kwargs):
self.paddle_api = "paddle_aux.Softmin"
self.write_aux_code()
return GenericMatcher.generate_code(self, kwargs)

def generate_aux_code(self):
CODE_TEMPLATE = textwrap.dedent(
"""
def _get_softmax_dim(axis: int) -> int:
if axis == 0 or axis == 1 or axis == 3:
ret = 0
else:
ret = 1
return ret

class Softmin(paddle.nn.Softmax):
def forward(self, x):
if self._axis is None:
return paddle.nn.functional.softmax(-x, _get_softmax_dim(x.ndim))
return paddle.nn.functional.softmax(-x, self._axis)
"""
)
return CODE_TEMPLATE


class OptimOptimizerMatcher(BaseMatcher):
def generate_code(self, kwargs):
code = "paddle.optimizer.Optimizer(parameters={}, **{})".format(
Expand Down
99 changes: 99 additions & 0 deletions tests/test_nn_AdaptiveLogSoftmaxWithLoss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.nn.AdaptiveLogSoftmaxWithLoss")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[ 0.9368637 , -0.0361056 , -0.98917043, 0.06605113, 1.5254455 ],
[-1.0518035 , -1.0024613 , 0.18699688, -0.35807893, 0.25628588],
[-0.900478 , -0.41495147, 0.84707606, -1.7883497 , 1.3243382 ]])
target = torch.tensor([1, 1, 1])
asfm = torch.nn.AdaptiveLogSoftmaxWithLoss(5, 4, [2])
out, loss = asfm(input,target)
"""
)
obj.run(pytorch_code, ["out", "loss"], check_value=False)


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[ 0.9368637 , -0.0361056 , -0.98917043, 0.06605113, 1.5254455 ],
[-1.0518035 , -1.0024613 , 0.18699688, -0.35807893, 0.25628588],
[-0.900478 , -0.41495147, 0.84707606, -1.7883497 , 1.3243382 ]])
target = torch.tensor([1, 1, 1])
asfm = torch.nn.AdaptiveLogSoftmaxWithLoss(5, 4, [3], div_value=2.0)
out, loss = asfm(input,target)
"""
)
obj.run(pytorch_code, ["out", "loss"], check_value=False)


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[ 0.9368637 , -0.0361056 , -0.98917043, 0.06605113, 1.5254455 ],
[-1.0518035 , -1.0024613 , 0.18699688, -0.35807893, 0.25628588],
[-0.900478 , -0.41495147, 0.84707606, -1.7883497 , 1.3243382 ]])
target = torch.tensor([1, 1, 1])
asfm = torch.nn.AdaptiveLogSoftmaxWithLoss(5, 4, [1], div_value=3.8, head_bias=True)
out, loss = asfm(input,target)
"""
)
obj.run(pytorch_code, ["out", "loss"], check_value=False)


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[ 0.9368637 , -0.0361056 , -0.98917043, 0.06605113, 1.5254455 ],
[-1.0518035 , -1.0024613 , 0.18699688, -0.35807893, 0.25628588],
[-0.900478 , -0.41495147, 0.84707606, -1.7883497 , 1.3243382 ]])
target = torch.tensor([1, 1, 1])
asfm = torch.nn.AdaptiveLogSoftmaxWithLoss(in_features=5, n_classes=8, cutoffs=[5], div_value=3.8, head_bias=True)
out, loss = asfm(input,target)
"""
)
obj.run(pytorch_code, ["out", "loss"], check_value=False)


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
input = torch.tensor([[ 0.9368637 , -0.0361056 , -0.98917043, 0.06605113, 1.5254455 ],
[-1.0518035 , -1.0024613 , 0.18699688, -0.35807893, 0.25628588],
[-0.900478 , -0.41495147, 0.84707606, -1.7883497 , 1.3243382 ]])
target = torch.tensor([1, 1, 1])
asfm = torch.nn.AdaptiveLogSoftmaxWithLoss(n_classes=8, in_features=5, div_value=3.8, cutoffs=[5], head_bias=True)
out, loss = asfm(input,target)
"""
)
obj.run(pytorch_code, ["out", "loss"], check_value=False)
Loading

0 comments on commit 3c73e2e

Please sign in to comment.