From 8024396bfb8c33e96410c1ce2ca1736684978a55 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Fri, 7 Jan 2022 19:58:15 +0800 Subject: [PATCH 01/11] add scatter mapper --- .../API_docs/ops/README.md | 2 +- .../API_docs/ops/torch.scatter.md | 75 +++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) create mode 100644 docs/pytorch_project_convertor/API_docs/ops/torch.scatter.md diff --git a/docs/pytorch_project_convertor/API_docs/ops/README.md b/docs/pytorch_project_convertor/API_docs/ops/README.md index 086c1f41a..00fc26dde 100644 --- a/docs/pytorch_project_convertor/API_docs/ops/README.md +++ b/docs/pytorch_project_convertor/API_docs/ops/README.md @@ -152,7 +152,7 @@ | 147 | [torch.matmul](https://pytorch.org/docs/stable/generated/torch.matmul.html?highlight=matmul#torch.matmul) | [paddle.matmul](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/matmul_cn.html) | [差异对比](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.matmul.md) | | 148 | [torch.mm](https://pytorch.org/docs/stable/generated/torch.mm.html?highlight=mm#torch.mm) | [paddle.matmul](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/matmul_cn.html) | [差异对比](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.mm.md) | | 149 | [torch.mv](https://pytorch.org/docs/stable/generated/torch.mv.html?highlight=mv#torch.mv) | 无对应实现 | [组合实现](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.mv.md) | - +| 150 | [torch.scatter](https://pytorch.org/docs/stable/generated/torch.scatter.html?highlight=scatter#torch.scatter) | [paddle.scatter_nd_add](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/scatter_nd_add_cn.html) | [组合实现](https://github.com/PaddlePaddle/X2Paddle/tree/develop/docs/pytorch_project_convertor/API_docs/ops/torch.scatter.md) | diff --git a/docs/pytorch_project_convertor/API_docs/ops/torch.scatter.md b/docs/pytorch_project_convertor/API_docs/ops/torch.scatter.md new file mode 100644 index 000000000..2eb4bb9a8 --- /dev/null +++ b/docs/pytorch_project_convertor/API_docs/ops/torch.scatter.md @@ -0,0 +1,75 @@ +## torch.scatter +### [torch.scatter](https://pytorch.org/docs/stable/generated/torch.scatter.html?highlight=scatter#torch.scatter) + +```python +torch.scatter(tensor, + dim, + index, + src) +``` + +### [paddle.scatter_nd_add](https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/scatter_nd_add_cn.html) + +```python +paddle.scatter_nd_add(x, + index, + updates, + name=None) +``` + +### 参数差异 +| PyTorch | PaddlePaddle | 备注 | +| ------------- | ------------ | ------------------------------------------------------ | +| tensor | x | 表示输入Tensor。 | +| dim | - | 表示在哪一个维度scatter,Paddle无此参数 | +| index | index | 输入的索引张量 | +| src | updates | 输入的更新张量 | + + + +### 功能差异 + +#### 使用方式 +因 torch.scatter 与 paddle.scatter_nd_add 差异较大,必须使用 paddle.flatten + paddle.meshgrid + paddle.scatter_nd_add 组合实现,看如下例子 + + +### 代码示例 +``` python +# PyTorch 示例: +src = torch.arange(1, 11).reshape((2, 5)) +# 输出 +# tensor([[ 1, 2, 3, 4, 5], +# [ 6, 7, 8, 9, 10]]) +index = torch.tensor([[0, 1, 2], [0, 1, 4]]) +torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) +# 输出 +# tensor([[1, 2, 3, 0, 0], +# [6, 7, 0, 0, 8], +# [0, 0, 0, 0, 0]]) +``` + +``` python +# PaddlePaddle 组合实现: +x = paddle.zeros([3, 5], dtype="int64") +updates = paddle.arange(1, 11).reshape([2,5]) +# 输出 +# Tensor(shape=[2, 5], dtype=int64, place=CUDAPlace(0), stop_gradient=True, +# [[1 , 2 , 3 , 4 , 5 ], +# [6 , 7 , 8 , 9 , 10]]) +index = paddle.to_tensor([[0, 1, 2], [0, 1, 4]]) +i, j = index.shape +grid_x , grid_y = paddle.meshgrid(paddle.arange(i), paddle.arange(j)) +# 若 PyTorch 的 dim 取 0 +# index = paddle.stack([index.flatten(), grid_y.flatten()], axis=1) +# 若 PyTorch 的 dim 取 1 +index = paddle.stack([grid_x.flatten(), index.flatten()], axis=1) +# PaddlePaddle updates 的 shape 大小必须与 index 对应 +updates_index = paddle.stack([grid_x.flatten(), grid_y.flatten()], axis=1) +updates = paddle.gather_nd(updates, index=updates_index) +paddle.scatter_nd_add(x, index, updates) +# 输出 +# Tensor(shape=[3, 5], dtype=int64, place=CUDAPlace(0), stop_gradient=True, +# [[1, 2, 3, 0, 0], +# [6, 7, 0, 0, 8], +# [0, 0, 0, 0, 0]]) +``` From a2346f043154486eb2789f7f9b11e49fcc1309db Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Wed, 12 Jan 2022 16:49:46 +0800 Subject: [PATCH 02/11] solve same param is used for multiple OPs --- .../op_mapper/onnx2paddle/opset9/opset.py | 43 +++++++++++++++---- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 6fb18a03e..bcde0a103 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -62,7 +62,6 @@ def _rename_or_remove_weight(weights, if origin_name not in weights: raise KeyError('{} not a key in {}'.format(origin_name, weights.keys())) if is_remove: - # TODO There may be problems when the same data is used as an argument to multiple OPs. # remove weight data = weights.pop(origin_name) else: @@ -182,6 +181,8 @@ def __init__(self, decoder, paddle_graph): self.weights = dict() self.nn_name2id = dict() self.done_weight_list = list() + # solve for same data is used as an argument to multiple OPs. + self.rename_mapper = dict() @print_mapping_info def directly_map(self, node, *args, **kwargs): @@ -1680,13 +1681,39 @@ def BatchNormalization(self, node): epsilon = node.get_attr('epsilon', 1e-5) c = val_x.out_shapes[0][1] - _rename_or_remove_weight(self.weights, val_scale.name, - op_name + '.weight') - _rename_or_remove_weight(self.weights, val_b.name, op_name + '.bias') - _rename_or_remove_weight(self.weights, val_var.name, - op_name + '._variance') - _rename_or_remove_weight(self.weights, val_mean.name, - op_name + '._mean') + # solved the same data is used as an argument to multiple OPs. + if val_scale.name in self.rename_mapper: + new_name = self.rename_mapper[val_scale.name] + _rename_or_remove_weight(self.weights, new_name, + op_name + '.weight', False) + else: + _rename_or_remove_weight(self.weights, val_scale.name, + op_name + '.weight') + self.rename_mapper[val_scale.name] = op_name + '.weight' + if val_b.name in self.rename_mapper: + new_name = self.rename_mapper[val_b.name] + _rename_or_remove_weight(self.weights, new_name, op_name + '.bias', + False) + else: + _rename_or_remove_weight(self.weights, val_b.name, + op_name + '.bias') + self.rename_mapper[val_b.name] = op_name + '.bias' + if val_var.name in self.rename_mapper: + new_name = self.rename_mapper[val_var.name] + _rename_or_remove_weight(self.weights, new_name, + op_name + '._variance', False) + else: + _rename_or_remove_weight(self.weights, val_var.name, + op_name + '._variance') + self.rename_mapper[val_var.name] = op_name + '._variance' + if val_mean.name in self.rename_mapper: + new_name = self.rename_mapper[val_mean.name] + _rename_or_remove_weight(self.weights, new_name, op_name + '._mean', + False) + else: + _rename_or_remove_weight(self.weights, val_mean.name, + op_name + '._mean') + self.rename_mapper[val_mean.name] = op_name + '._mean' # Attribute: spatial is used in BatchNormalization-1,6,7 spatial = bool(node.get_attr('spatial')) From 53f8175d013bbfafdbaa10930a40fa094e6aa409 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Wed, 12 Jan 2022 16:57:56 +0800 Subject: [PATCH 03/11] Add unique_name support --- x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py b/x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py index 289ce190f..842eeedbe 100644 --- a/x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py +++ b/x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py @@ -15,7 +15,7 @@ import math from functools import reduce import paddle -from paddle.fluid import framework +from paddle.fluid import framework, unique_name from paddle.fluid.core import VarDesc from paddle.fluid.initializer import XavierInitializer, MSRAInitializer from paddle.fluid.data_feeder import check_variable_and_dtype From acbb1cac795ae6112785ee77fbc1f03b8f528ce6 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Mon, 17 Jan 2022 17:08:05 +0800 Subject: [PATCH 04/11] Simplified code --- .../op_mapper/onnx2paddle/opset9/opset.py | 88 ++++++++++--------- 1 file changed, 47 insertions(+), 41 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index bcde0a103..9c02f344a 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -45,7 +45,8 @@ def _const_weight_or_none(node, necessary=False): def _rename_or_remove_weight(weights, origin_name, target_name=None, - is_remove=True): + is_remove=True, + rename_mapper=None): ''' Rename parameters by Paddle's naming rule of parameters. @@ -56,9 +57,13 @@ def _rename_or_remove_weight(weights, {target_name:weights[origin_name]} to weights, and target_name must follow paddle's naming rule of parameters. Default: None. is_remove: if is_remove is True, remove origin key-value pair. Default: True. + rename_mapper: Solved the same data is used for multiple OPs, key is old_name, value is new_name. Returns: None ''' + if origin_name in rename_mapper: + origin_name = rename_mapper[origin_name] + is_remove = False if origin_name not in weights: raise KeyError('{} not a key in {}'.format(origin_name, weights.keys())) if is_remove: @@ -69,6 +74,7 @@ def _rename_or_remove_weight(weights, if target_name is not None: # rename weight weights[target_name] = data + rename_mapper[origin_name] = target_name def _is_static_shape(shape): @@ -1682,38 +1688,26 @@ def BatchNormalization(self, node): c = val_x.out_shapes[0][1] # solved the same data is used as an argument to multiple OPs. - if val_scale.name in self.rename_mapper: - new_name = self.rename_mapper[val_scale.name] - _rename_or_remove_weight(self.weights, new_name, - op_name + '.weight', False) - else: - _rename_or_remove_weight(self.weights, val_scale.name, - op_name + '.weight') - self.rename_mapper[val_scale.name] = op_name + '.weight' - if val_b.name in self.rename_mapper: - new_name = self.rename_mapper[val_b.name] - _rename_or_remove_weight(self.weights, new_name, op_name + '.bias', - False) - else: - _rename_or_remove_weight(self.weights, val_b.name, - op_name + '.bias') - self.rename_mapper[val_b.name] = op_name + '.bias' - if val_var.name in self.rename_mapper: - new_name = self.rename_mapper[val_var.name] - _rename_or_remove_weight(self.weights, new_name, - op_name + '._variance', False) - else: - _rename_or_remove_weight(self.weights, val_var.name, - op_name + '._variance') - self.rename_mapper[val_var.name] = op_name + '._variance' - if val_mean.name in self.rename_mapper: - new_name = self.rename_mapper[val_mean.name] - _rename_or_remove_weight(self.weights, new_name, op_name + '._mean', - False) - else: - _rename_or_remove_weight(self.weights, val_mean.name, - op_name + '._mean') - self.rename_mapper[val_mean.name] = op_name + '._mean' + _rename_or_remove_weight( + self.weights, + val_scale.name, + op_name + '.weight', + rename_mapper=self.rename_mapper) + _rename_or_remove_weight( + self.weights, + val_b.name, + op_name + '.bias', + rename_mapper=self.rename_mapper) + _rename_or_remove_weight( + self.weights, + val_var.name, + op_name + '._variance', + rename_mapper=self.rename_mapper) + _rename_or_remove_weight( + self.weights, + val_mean.name, + op_name + '._mean', + rename_mapper=self.rename_mapper) # Attribute: spatial is used in BatchNormalization-1,6,7 spatial = bool(node.get_attr('spatial')) @@ -2255,14 +2249,22 @@ def Conv(self, node): remove_weight = True if val_w.name in self.done_weight_list else False if remove_weight: self.done_weight_list.append(val_w.name) - _rename_or_remove_weight(self.weights, val_w.name, op_name + '.weight', - remove_weight) + _rename_or_remove_weight( + self.weights, + val_w.name, + op_name + '.weight', + remove_weight, + rename_mapper=self.rename_mapper) if has_bias: remove_bias = True if val_b.name in self.done_weight_list else False if remove_bias: - self.done_weight_list.append(val_b_name) - _rename_or_remove_weight(self.weights, val_b.name, - op_name + '.bias', remove_bias) + self.done_weight_list.append(val_b.name) + _rename_or_remove_weight( + self.weights, + val_b.name, + op_name + '.bias', + remove_bias, + rename_mapper=self.rename_mapper) else: layer_attrs["bias_attr"] = False if reduce(lambda x, y: x * y, @@ -2382,10 +2384,14 @@ def ConvTranspose(self, node): _rename_or_remove_weight( self.weights, val_w.name, - op_name + '.weight', ) + op_name + '.weight', + rename_mapper=self.rename_mapper) if val_b is not None: - _rename_or_remove_weight(self.weights, val_b.name, - op_name + '.bias') + _rename_or_remove_weight( + self.weights, + val_b.name, + op_name + '.bias', + rename_mapper=self.rename_mapper) else: layer_attrs["bias_attr"] = False self.paddle_graph.add_layer( From 0acebeefb812ceeba22a36a3a3f65e3574249005 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Mon, 17 Jan 2022 17:23:28 +0800 Subject: [PATCH 05/11] Add PR link --- x2paddle/op_mapper/onnx2paddle/opset9/opset.py | 1 + 1 file changed, 1 insertion(+) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 9c02f344a..e7c2b8572 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -188,6 +188,7 @@ def __init__(self, decoder, paddle_graph): self.nn_name2id = dict() self.done_weight_list = list() # solve for same data is used as an argument to multiple OPs. + # PR link(wangjunjie06): https://github.com/PaddlePaddle/X2Paddle/pull/728 self.rename_mapper = dict() @print_mapping_info From 2a4708eb49febd7ea6adbfb4e5b86297232e4ad5 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Mon, 17 Jan 2022 18:15:30 +0800 Subject: [PATCH 06/11] fixed bug for CI --- x2paddle/op_mapper/onnx2paddle/opset9/opset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index e7c2b8572..3dec5b5e8 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -61,7 +61,7 @@ def _rename_or_remove_weight(weights, Returns: None ''' - if origin_name in rename_mapper: + if rename_mapper is not None and origin_name in rename_mapper: origin_name = rename_mapper[origin_name] is_remove = False if origin_name not in weights: From b793e5b3ec4a901e89316c957129f0698b470d61 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Mon, 17 Jan 2022 21:15:25 +0800 Subject: [PATCH 07/11] Add dynamic shape --- .../pytorch2paddle.md | 54 +++++++++++++++++-- 1 file changed, 51 insertions(+), 3 deletions(-) diff --git a/docs/inference_model_convertor/pytorch2paddle.md b/docs/inference_model_convertor/pytorch2paddle.md index 74976b591..f4840fbbb 100644 --- a/docs/inference_model_convertor/pytorch2paddle.md +++ b/docs/inference_model_convertor/pytorch2paddle.md @@ -14,7 +14,7 @@ treelib ## 使用方式 -``` python +```python from x2paddle.convert import pytorch2paddle pytorch2paddle(module=torch_module, save_dir="./pd_model", @@ -27,11 +27,14 @@ pytorch2paddle(module=torch_module, ``` **注意:** 当jit_type为"trace"时,input_examples不可为None,转换后自动进行动转静; - 当jit_type为"script"时",input_examples不为None时,才可以进行动转静。 + + 当jit_type为"script"时",当input_examples为None时,只生成动态图代码;当input_examples不为None时,才能自动动转静。 ## 使用示例 -``` python +### Trace 模式 + +```python import torch import numpy as np from torchvision.models import AlexNet @@ -51,3 +54,48 @@ pytorch2paddle(torch_module, jit_type="trace", input_examples=[torch.tensor(input_data)]) ``` + +### Script 模式动态 shape 导出 + +```python +import torch +import numpy as np +from torchvision.models import AlexNet +from torchvision.models.utils import load_state_dict_from_url + +# 获取PyTorch Module +torch_module = AlexNet() +torch_state_dict = load_state_dict_from_url('https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth') +torch_module.load_state_dict(torch_state_dict) +# 设置为eval模式 +torch_module.eval() +# 进行转换 +from x2paddle.convert import pytorch2paddle +pytorch2paddle(torch_module, + save_dir="pd_model_script", + jit_type="script", + input_examples=None) +``` + +在自动生成的x2paddle_code.py中添加如下代码: + +```python +def main(x0): + # There are 0 inputs. + paddle.disable_static() + params = paddle.load('model.pdparams') + model = AlexNet() + model.set_dict(params) + model.eval() + ## convert to jit + sepc_list = list() + sepc_list.append( + paddle.static.InputSpec( + shape=[-1, 3, -1, -1], name="x0", dtype="float32")) + static_model = paddle.jit.to_static(model, input_spec=sepc_list) + paddle.jit.save(static_model, "pd_model_script/inference_model/model") + out = model(x0) + return out +``` + +运行main函数导出动态shape的静态图模型,若导出失败,可尝试动态shape导出onnx,再从onnx转到paddle,[相关文档](https://pytorch.org/docs/stable/onnx.html?highlight=onnx%20export#torch.onnx.export) From d8ecf555e05c04a84dfc4a6201d7ea3230d545a0 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Mon, 17 Jan 2022 21:19:11 +0800 Subject: [PATCH 08/11] Fixed readme --- docs/inference_model_convertor/pytorch2paddle.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/inference_model_convertor/pytorch2paddle.md b/docs/inference_model_convertor/pytorch2paddle.md index f4840fbbb..f1c34be2a 100644 --- a/docs/inference_model_convertor/pytorch2paddle.md +++ b/docs/inference_model_convertor/pytorch2paddle.md @@ -27,7 +27,6 @@ pytorch2paddle(module=torch_module, ``` **注意:** 当jit_type为"trace"时,input_examples不可为None,转换后自动进行动转静; - 当jit_type为"script"时",当input_examples为None时,只生成动态图代码;当input_examples不为None时,才能自动动转静。 ## 使用示例 From 0e7312fe007af9674a45ab1d3b082b46c0de9bfc Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Mon, 17 Jan 2022 21:20:41 +0800 Subject: [PATCH 09/11] Fixed readme --- docs/inference_model_convertor/pytorch2paddle.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/inference_model_convertor/pytorch2paddle.md b/docs/inference_model_convertor/pytorch2paddle.md index f1c34be2a..76aac394a 100644 --- a/docs/inference_model_convertor/pytorch2paddle.md +++ b/docs/inference_model_convertor/pytorch2paddle.md @@ -26,8 +26,9 @@ pytorch2paddle(module=torch_module, # input_examples (list[torch.tensor]): torch.nn.Module的输入示例,list的长度必须与输入的长度一致。默认为None。 ``` -**注意:** 当jit_type为"trace"时,input_examples不可为None,转换后自动进行动转静; - 当jit_type为"script"时",当input_examples为None时,只生成动态图代码;当input_examples不为None时,才能自动动转静。 +**注意:** +- jit_type为"trace"时,input_examples不可为None,转换后自动进行动转静; +- jit_type为"script"时",当input_examples为None时,只生成动态图代码;当input_examples不为None时,才能自动动转静。 ## 使用示例 From e15ca80a44ffb8e2dcc63b824f78a60860e5a67c Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Tue, 18 Jan 2022 11:21:23 +0800 Subject: [PATCH 10/11] Update readme --- .../pytorch2paddle.md | 39 ++++++++++++------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/docs/inference_model_convertor/pytorch2paddle.md b/docs/inference_model_convertor/pytorch2paddle.md index 76aac394a..2430886bb 100644 --- a/docs/inference_model_convertor/pytorch2paddle.md +++ b/docs/inference_model_convertor/pytorch2paddle.md @@ -26,7 +26,7 @@ pytorch2paddle(module=torch_module, # input_examples (list[torch.tensor]): torch.nn.Module的输入示例,list的长度必须与输入的长度一致。默认为None。 ``` -**注意:** +**注意:** - jit_type为"trace"时,input_examples不可为None,转换后自动进行动转静; - jit_type为"script"时",当input_examples为None时,只生成动态图代码;当input_examples不为None时,才能自动动转静。 @@ -55,11 +55,12 @@ pytorch2paddle(torch_module, input_examples=[torch.tensor(input_data)]) ``` -### Script 模式动态 shape 导出 +### 动态 shape 导出 + +#### 方式一:PyTorch->ONNX->Paddle ```python import torch -import numpy as np from torchvision.models import AlexNet from torchvision.models.utils import load_state_dict_from_url @@ -69,15 +70,27 @@ torch_state_dict = load_state_dict_from_url('https://download.pytorch.org/models torch_module.load_state_dict(torch_state_dict) # 设置为eval模式 torch_module.eval() -# 进行转换 -from x2paddle.convert import pytorch2paddle -pytorch2paddle(torch_module, - save_dir="pd_model_script", - jit_type="script", - input_examples=None) +input_names = ["input_0"] +output_names = ["output_0"] + +x = torch.randn((1, 3, 224, 224)) +y = torch.randn((1, 1000)) + +torch.onnx.export(torch_module, x, 'model.onnx', opset_version=11, input_names=input_names, + output_names=output_names, dynamic_axes={'input_0': [0], 'output_0': [0]}) +``` + +导出 ONNX 动态 shape 模型,更多细节参考[相关文档](https://pytorch.org/docs/stable/onnx.html?highlight=onnx%20export#torch.onnx.export) + +然后通过 X2Paddle 命令导出 Paddle 模型 + +```shell +x2paddle --framework=onnx --model=model.onnx --save_dir=pd_model_dynamic ``` -在自动生成的x2paddle_code.py中添加如下代码: +#### 方式二:手动动转静 + +在自动生成的 x2paddle_code.py 中添加如下代码: ```python def main(x0): @@ -91,11 +104,11 @@ def main(x0): sepc_list = list() sepc_list.append( paddle.static.InputSpec( - shape=[-1, 3, -1, -1], name="x0", dtype="float32")) + shape=[-1, 3, 224, 224], name="x0", dtype="float32")) static_model = paddle.jit.to_static(model, input_spec=sepc_list) - paddle.jit.save(static_model, "pd_model_script/inference_model/model") + paddle.jit.save(static_model, "pd_model_trace/inference_model/model") out = model(x0) return out ``` -运行main函数导出动态shape的静态图模型,若导出失败,可尝试动态shape导出onnx,再从onnx转到paddle,[相关文档](https://pytorch.org/docs/stable/onnx.html?highlight=onnx%20export#torch.onnx.export) +然后运行 main 函数导出动态 shape 模型 From 86127ec2f5c23b84088aebd9b67692b373830be6 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Tue, 18 Jan 2022 11:23:13 +0800 Subject: [PATCH 11/11] Update readme --- docs/inference_model_convertor/pytorch2paddle.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/inference_model_convertor/pytorch2paddle.md b/docs/inference_model_convertor/pytorch2paddle.md index 2430886bb..3ac21eaa4 100644 --- a/docs/inference_model_convertor/pytorch2paddle.md +++ b/docs/inference_model_convertor/pytorch2paddle.md @@ -28,7 +28,7 @@ pytorch2paddle(module=torch_module, **注意:** - jit_type为"trace"时,input_examples不可为None,转换后自动进行动转静; -- jit_type为"script"时",当input_examples为None时,只生成动态图代码;当input_examples不为None时,才能自动动转静。 +- jit_type为"script"时,当input_examples为None时,只生成动态图代码;当input_examples不为None时,才能自动动转静。 ## 使用示例