Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 6th No.58】support model convert from fp32 to fp16 #1268

Merged
merged 22 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
8f5d879
support model convert from fp32 to fp16
xiaoyewww Jun 2, 2024
53f9286
support model convert from fp32 to fp16
xiaoyewww Jun 4, 2024
ed4f7c2
support model convert from fp32 to fp16
xiaoyewww Jun 6, 2024
f3c6a26
Merge branch 'develop' into hackathon/half-support
Zheng-Bicheng Jun 11, 2024
17e6980
support model convert from fp32 to fp16
xiaoyewww Jun 12, 2024
0b89e74
support model convert from fp32 to fp16
xiaoyewww Jun 12, 2024
6f0b350
Merge branch 'develop' into hackathon/half-support
Zheng-Bicheng Jun 17, 2024
96f6c02
support model convert from fp32 to fp16
xiaoyewww Jun 17, 2024
04e830e
support model convert from fp32 to fp16
xiaoyewww Jun 17, 2024
92e9a58
support model convert from fp32 to fp16
xiaoyewww Jun 17, 2024
9af7778
support model convert from fp32 to fp16
xiaoyewww Jun 19, 2024
cfe4a63
support model convert from fp32 to fp16
xiaoyewww Jun 19, 2024
1120032
support model convert from fp32 to fp16
xiaoyewww Jun 20, 2024
1b2e6f0
support model convert from fp32 to fp16
xiaoyewww Jun 20, 2024
710c56a
support model convert from fp32 to fp16
xiaoyewww Jun 20, 2024
ad627e1
support model convert from fp32 to fp16
xiaoyewww Jun 20, 2024
2cfe42f
support model convert from fp32 to fp16
xiaoyewww Jun 20, 2024
024143f
support model convert from fp32 to fp16
xiaoyewww Jun 20, 2024
3f14872
support model convert from fp32 to fp16
xiaoyewww Jun 20, 2024
9cacb7b
support model convert from fp32 to fp16
xiaoyewww Jun 24, 2024
f239519
support model convert from fp32 to fp16
xiaoyewww Jun 26, 2024
37911d4
Merge branch 'develop' into hackathon/half-support
Zheng-Bicheng Jul 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Paddle2ONNX 本身不依赖其他组件,但是我们建议您在以下环境

- PaddlePaddle == 2.6.0
- onnxruntime >= 1.10.0
- numpy < 2.0.0

# 3 安装 Paddle2ONNX

Expand Down
7 changes: 6 additions & 1 deletion paddle2onnx/mapper/mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,19 @@
// limitations under the License.
#pragma once
#include <vector>
#include <unordered_set>

#include "paddle2onnx/mapper/data_helper.h"
#include "paddle2onnx/mapper/onnx_helper.h"
#include "paddle2onnx/mapper/register_mapper.h"
#include "paddle2onnx/parser/parser.h"

namespace paddle2onnx {
class Mapper {

static const std::unordered_set<int32_t> kNoNeedCastTypes{P2ODataType::INT8, P2ODataType::FP16, P2ODataType::FP32};
Zheng-Bicheng marked this conversation as resolved.
Show resolved Hide resolved

class Mapper
{
public:
Mapper() {
}
Expand Down
59 changes: 43 additions & 16 deletions paddle2onnx/mapper/nn/pool2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,20 @@ void Pool2dMapper::AdaptivePool(const std::vector<TensorInfo>& input_info,
onnx_pool_type = iter->second[0];
}

std::shared_ptr<ONNX_NAMESPACE::NodeProto>* node_ptr;
auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
auto node = helper_->MakeNode(onnx_pool_type, {input});
helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);
std::shared_ptr<ONNX_NAMESPACE::NodeProto> node(nullptr);
if (kNoNeedCastTypes.find(input_info[0].dtype) != kNoNeedCastTypes.end())
{
node = helper_->MakeNode(onnx_pool_type, {input_info[0].name}, {output_info[0].name});
}
else
{
auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
node = helper_->MakeNode(onnx_pool_type, {input});
helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);
}

std::vector<int64_t> kernel_size = {kernel_h, kernel_w};
AddAttribute(node, "kernel_shape", kernel_size);
std::vector<int64_t> strides = {stride_h, stride_w};
Expand Down Expand Up @@ -165,8 +173,12 @@ void Pool2dMapper::NoAdaptivePool(const std::vector<TensorInfo>& input_info,

int64_t max_ksize = *std::max_element(std::begin(k_size_), std::end(k_size_));
int64_t max_pads = *std::max_element(std::begin(pads_), std::end(pads_));
auto input_x = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
std::string input_x = input_info[0].name;
if (kNoNeedCastTypes.find(input_info[0].dtype) == kNoNeedCastTypes.end())
{
input_x = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
}
if (max_ksize <= max_pads) {
std::vector<int64_t> onnx_paddings = {0, 0, pads_[0], pads_[1],
0, 0, pads_[2], pads_[3]};
Expand Down Expand Up @@ -199,9 +211,17 @@ void Pool2dMapper::NoAdaptivePool(const std::vector<TensorInfo>& input_info,
auto iter = op_mapper_.find(pooling_type_);
onnx_pool_type = iter->second[0];
}
auto node = helper_->MakeNode(onnx_pool_type, {input_x});
helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);
std::shared_ptr<ONNX_NAMESPACE::NodeProto> node(nullptr);
if (kNoNeedCastTypes.find(input_info[0].dtype) != kNoNeedCastTypes.end())
{
node = helper_->MakeNode(onnx_pool_type, {input_x}, {output_info[0].name});
}
else
{
node = helper_->MakeNode(onnx_pool_type, {input_x});
helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);
}

AddAttribute(node, "kernel_shape", k_size_);
AddAttribute(node, "strides", strides_);
Expand Down Expand Up @@ -317,11 +337,18 @@ void Pool2dMapper::Opset7() {
auto iter = op_mapper_.find(pooling_type_);
onnx_pool_type = iter->second[1];
}
auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
auto output = helper_->MakeNode(onnx_pool_type, {input})->output(0);
helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);
if (kNoNeedCastTypes.find(input_info[0].dtype) != kNoNeedCastTypes.end())
{
auto output = helper_->MakeNode(onnx_pool_type, {input_info[0].name}, {output_info[0].name});
}
else
{
auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
auto output = helper_->MakeNode(onnx_pool_type, {input})->output(0);
helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);
}
} else if (adaptive_) {
AdaptivePool(input_info, output_info);
} else {
Expand Down
62 changes: 46 additions & 16 deletions paddle2onnx/mapper/nn/pool3d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,21 @@ void Pool3dMapper::AdaptivePool(const std::vector<TensorInfo>& input_info,
auto iter = op_mapper_.find(pooling_type_);
onnx_pool_type = iter->second[0];
}
std::shared_ptr<ONNX_NAMESPACE::NodeProto>* node_ptr;
auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
auto node = helper_->MakeNode(onnx_pool_type, {input});
helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);

std::shared_ptr<ONNX_NAMESPACE::NodeProto> node;
if (kNoNeedCastTypes.find(input_info[0].dtype) != kNoNeedCastTypes.end())
{
node = helper_->MakeNode(onnx_pool_type, {input_info[0].name}, {output_info[0].name});
}
else
{
auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
node = helper_->MakeNode(onnx_pool_type, {input});
helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);
}

std::vector<int64_t> kernel_size = {kernel_d, kernel_h, kernel_w};
AddAttribute(node, "kernel_shape", kernel_size);
std::vector<int64_t> strides = {stride_d, stride_h, stride_w};
Expand Down Expand Up @@ -109,8 +118,13 @@ void Pool3dMapper::NoAdaptivePool(const std::vector<TensorInfo>& input_info,

int64_t max_ksize = *std::max_element(std::begin(k_size_), std::end(k_size_));
int64_t max_pads = *std::max_element(std::begin(pads_), std::end(pads_));
auto input_x = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
auto input_x = input_info[0].name;
if (kNoNeedCastTypes.find(input_info[0].dtype) == kNoNeedCastTypes.end())
{
input_x = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
}

if (max_ksize <= max_pads) {
std::vector<int64_t> onnx_paddings = {0, 0, pads_[0], pads_[1], pads_[2],
0, 0, pads_[3], pads_[4], pads_[5]};
Expand Down Expand Up @@ -143,9 +157,17 @@ void Pool3dMapper::NoAdaptivePool(const std::vector<TensorInfo>& input_info,
auto iter = op_mapper_.find(pooling_type_);
onnx_pool_type = iter->second[0];
}
auto node = helper_->MakeNode(onnx_pool_type, {input_x});
helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);
std::shared_ptr<ONNX_NAMESPACE::NodeProto> node(nullptr);
if (kNoNeedCastTypes.find(input_info[0].dtype) != kNoNeedCastTypes.end())
{
node = helper_->MakeNode(onnx_pool_type, {input_x}, {output_info[0].name});
}
else
{
node = helper_->MakeNode(onnx_pool_type, {input_x});
helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);
}

AddAttribute(node, "kernel_shape", k_size_);
AddAttribute(node, "strides", strides_);
Expand Down Expand Up @@ -247,11 +269,19 @@ void Pool3dMapper::Opset7() {
auto iter = op_mapper_.find(pooling_type_);
onnx_pool_type = iter->second[1];
}
auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
auto output = helper_->MakeNode(onnx_pool_type, {input})->output(0);
helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);

if (kNoNeedCastTypes.find(input_info[0].dtype) != kNoNeedCastTypes.end())
{
auto output = helper_->MakeNode(onnx_pool_type, {input_info[0].name}, {output_info[0].name});
}
else
{
auto input = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
auto output = helper_->MakeNode(onnx_pool_type, {input})->output(0);
helper_->AutoCast(output, output_info[0].name, P2ODataType::FP32,
output_info[0].dtype);
}
} else if (adaptive_) {
AdaptivePool(input_info, output_info);
} else {
Expand Down
5 changes: 5 additions & 0 deletions paddle2onnx/mapper/tensor/assign.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ REGISTER_MAPPER(share_data, AssignMapper)
void AssignMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");


if (block_idx_ != 0 && OpType() != "share_data") {
// Here's a trick for tensorrt
// Consider remove this trick
Expand All @@ -43,6 +45,9 @@ void AssignMapper::Opset7() {
} else {
helper_->MakeNode("Identity", {input_info[0].name}, {output_info[0].name});
}
std::cout << "use assign...\n";
std::cout << "use input_info dtype: " << input_info[0].dtype << std::endl;;
std::cout << "use output_info dtype: " << output_info[0].dtype << std::endl;
}

} // namespace paddle2onnx
24 changes: 19 additions & 5 deletions paddle2onnx/mapper/tensor/fill_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,16 @@ void FillConstantMapper::Opset7() {
float value = GetFillValue();
if (HasInput("ValueTensor")) {
auto value_info = GetInput("ValueTensor");
auto value_tensor = helper_->AutoCast(value_info[0].name, value_info[0].dtype, out_info[0].dtype);
auto out = helper_->Constant(shape, GetOnnxDtype(out_info[0].dtype), float(0.0));
helper_->MakeNode("Add", {out, value_tensor}, {out_info[0].name});
if (kNoNeedCastTypes.find(value_info[0].dtype) != kNoNeedCastTypes.end())
{
helper_->MakeNode("Add", {out, value_info[0].name}, {out_info[0].name});
}
else
{
auto value_tensor = helper_->AutoCast(value_info[0].name, value_info[0].dtype, out_info[0].dtype);
helper_->MakeNode("Add", {out, value_tensor}, {out_info[0].name});
}
} else {
helper_->Constant(out_info[0].name, shape, GetOnnxDtype(out_info[0].dtype), value);
}
Expand Down Expand Up @@ -149,9 +156,16 @@ void FillConstantMapper::Opset9() {
}
if (value_is_tensor) {
auto value_info = GetInput("ValueTensor");
std::string cast_value = helper_->AutoCast(
value_info[0].name, value_info[0].dtype, out_info[0].dtype);
helper_->MakeNode("Add", {out, cast_value}, {out_info[0].name});
if (kNoNeedCastTypes.find(value_info[0].dtype) != kNoNeedCastTypes.end())
{
helper_->MakeNode("Add", {out, value_info[0].name}, {out_info[0].name});
}
else
{
std::string cast_value = helper_->AutoCast(
value_info[0].name, value_info[0].dtype, out_info[0].dtype);
helper_->MakeNode("Add", {out, cast_value}, {out_info[0].name});
}
} else {
helper_->MakeNode("Identity", {out}, {out_info[0].name});
}
Expand Down
25 changes: 22 additions & 3 deletions paddle2onnx/mapper/tensor/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ REGISTER_MAPPER(matmul, MatmulMapper)

std::string MatmulMapper::GetTrans(std::vector<TensorInfo>& input_info) {
std::string castd_name = input_info[0].name;
if (input_info[0].dtype == P2ODataType::FP64) {
if (kNoNeedCastTypes.find(input_info[0].dtype) != kNoNeedCastTypes.end()) {
castd_name = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32);
}
Expand All @@ -43,11 +43,30 @@ void MatmulMapper::Opset7() {
if (transpose_Y_) {
input_y = GetTrans(input_y_info);
}
if (fabs(alpha_ - 1.0) < 1e-6) {

if (kNoNeedCastTypes.find(input_x_info[0].dtype) != kNoNeedCastTypes.end())
{
if (fabs(alpha_ - 1.0) < 1e-6)
{
auto node = helper_->MakeNode("MatMul", {input_x, input_y}, {output_info[0].name});
}
else
{
auto mutmul_node = helper_->MakeNode("MatMul", {input_x, input_y});
std::string scale_node =
helper_->Constant({1}, GetOnnxDtype(input_x_info[0].dtype), alpha_);
auto mul_node =
helper_->MakeNode("Mul", {mutmul_node->output(0), scale_node}, {output_info[0].name});
}
}
else if (fabs(alpha_ - 1.0) < 1e-6)
{
auto node = helper_->MakeNode("MatMul", {input_x, input_y});
helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32,
input_y_info[0].dtype);
} else {
}
else
{
auto mutmul_node = helper_->MakeNode("MatMul", {input_x, input_y});
std::string scale_node =
helper_->Constant({1}, GetOnnxDtype(input_x_info[0].dtype), alpha_);
Expand Down
22 changes: 17 additions & 5 deletions paddle2onnx/mapper/tensor/matmul_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@ namespace paddle2onnx {
REGISTER_MAPPER(matmul_v2, MatmulV2Mapper)

std::string MatmulV2Mapper::GetTrans(std::vector<TensorInfo>& input_info) {
std::string castd_name = helper_->AutoCast(
input_info[0].name, input_info[0].dtype, P2ODataType::FP32);
std::string castd_name = input_info[0].name;
if (kNoNeedCastTypes.find(input_info[0].dtype) == kNoNeedCastTypes.end())
{
castd_name = helper_->AutoCast(
input_info[0].name, input_info[0].dtype, P2ODataType::FP32);
}

std::vector<int64_t> perm = Arange(0, input_info[0].Rank());
std::swap(perm[perm.size() - 1], perm[perm.size() - 2]);
auto transpose_node = helper_->MakeNode("Transpose", {castd_name});
Expand All @@ -43,9 +48,16 @@ void MatmulV2Mapper::Opset7() {
if (trans_y_) {
input_y = GetTrans(input_y_info);
}
auto node = helper_->MakeNode("MatMul", {input_x, input_y});
helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32,
input_y_info[0].dtype);
if (kNoNeedCastTypes.find(input_y_info[0].dtype) != kNoNeedCastTypes.end())
{
auto node = helper_->MakeNode("MatMul", {input_x, input_y}, {output_info[0].name});
}
else
{
auto node = helper_->MakeNode("MatMul", {input_x, input_y});
helper_->AutoCast(node->output(0), output_info[0].name, P2ODataType::FP32,
input_y_info[0].dtype);
}
}

} // namespace paddle2onnx
3 changes: 2 additions & 1 deletion paddle2onnx/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,6 @@ void PaddleParser::GetGlobalBlockInputOutputInfo() {
}

int32_t PaddleDataTypeSize(int32_t paddle_dtype) {
Assert(paddle_dtype != FP16, "Float16 is not supported.");
if (paddle_dtype == P2ODataType::BOOL) {
return sizeof(bool);
} else if (paddle_dtype == P2ODataType::INT8) {
Expand All @@ -828,6 +827,8 @@ int32_t PaddleDataTypeSize(int32_t paddle_dtype) {
return sizeof(int64_t);
} else if (paddle_dtype == P2ODataType::FP32) {
return sizeof(float);
} else if (paddle_dtype == P2ODataType::FP16) {
return sizeof(int16_t);
} else if (paddle_dtype == P2ODataType::FP64) {
return sizeof(double);
} else if (paddle_dtype == P2ODataType::UINT8) {
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ license = {text = "Apache License v2.0"}
requires-python = ">=3.8"
dependencies = [
"onnxruntime>=1.10.0",
"numpy<2.0.0", # numpy 2.0.0 cannot support p2o at now
]

[project.scripts]
Expand Down
3 changes: 2 additions & 1 deletion tests/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ ignore="test_auto_scan_multiclass_nms.py
test_unsqueeze.py \
test_quantize_model.py \
test_quantize_model_minist.py \
test_quantize_model_speedup.py"
test_quantize_model_speedup.py \
test_resnet_fp16.py"
bug=0

# Install Python Packet
Expand Down
Loading