-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Paddle Inference] Add add onehot trt converter #48655
Changes from 9 commits
681a370
f6c918d
12eb774
db6104c
dfbb9d7
ca2c8cb
48c0ae6
9059fec
baa89a7
be73de8
9eaa050
d96f615
7e41019
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
class Scope; | ||
|
||
namespace proto { | ||
class OpDesc; | ||
} // namespace proto | ||
} // namespace framework | ||
} // namespace paddle | ||
|
||
namespace paddle { | ||
namespace inference { | ||
namespace tensorrt { | ||
|
||
/* | ||
* OneHot Op | ||
*/ | ||
class OneHotOpConverter : public OpConverter { | ||
public: | ||
void operator()(const framework::proto::OpDesc& op, | ||
const framework::Scope& scope, | ||
bool test_mode) override { | ||
#if IS_TRT_VERSION_GE(8510) | ||
VLOG(3) << "convert a fluid one_hot op to tensorrt one_hot layer"; | ||
framework::OpDesc op_desc(op, nullptr); | ||
|
||
const auto indices_tensor = engine_->GetITensor(op_desc.Input("X").front()); | ||
const nvinfer1::ITensor* values_tensor; | ||
const nvinfer1::ITensor* depth_tensor; | ||
const int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); | ||
if (dtype == 2 || dtype == 3) { // int, int64 | ||
const std::vector<int> values_data = {0, 1}; | ||
values_tensor = Add1DConstantLayer<int>(values_data, "values_tensor"); | ||
if (dtype == 3) { // int64 | ||
VLOG(3) << "trt not support int64, so it is converted to int32."; | ||
} | ||
} else if (dtype == 5 || dtype == 6) { // float | ||
const std::vector<float> values_data = {0.0f, 1.0f}; | ||
values_tensor = Add1DConstantLayer<float>(values_data, "values_tensor"); | ||
if (dtype == 6) { // int64 | ||
VLOG(3) << "trt not support float64, so it is converted to float32."; | ||
} | ||
} | ||
|
||
auto depth_name = op_desc.Input("depth_tensor"); | ||
if (depth_name.size() == 0) { | ||
const int depth = PADDLE_GET_CONST(int, op_desc.GetAttr("depth")); | ||
depth_tensor = Add1DConstantLayer<int>(depth, "depth_tensor", true); | ||
} else { | ||
nvinfer1::Dims depth_dims; | ||
depth_dims.nbDims = 0; | ||
nvinfer1::ITensor* depth_tensor_paddle = | ||
engine_->GetITensor(depth_name.front()); | ||
auto shuffle_layer = | ||
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *depth_tensor_paddle); | ||
shuffle_layer->setReshapeDimensions(depth_dims); | ||
shuffle_layer->getOutput(0)->setName(depth_tensor_paddle->getName()); | ||
depth_tensor = shuffle_layer->getOutput(0); | ||
} | ||
auto layer = TRT_ENGINE_ADD_LAYER( | ||
engine_, OneHot, *indices_tensor, *values_tensor, *depth_tensor, -1); | ||
|
||
auto output_name = op_desc.Output("Out").front(); | ||
RreplenishLayerAndOutput(layer, "one_hot", {output_name}, test_mode); | ||
#else | ||
VLOG(3) << "one_hot is not supported when TensorRT < 8.5.1"; | ||
#endif | ||
Comment on lines
+81
to
+83
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OP teller 里面做过滤后,这里就可以不用提示 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我在op teller里过滤后,还是会报错addOneHot不存在,感觉像是ci环境里trt的版本不是最新的 xly.bce.baidu.com/paddlepaddle/paddle/newipipe/detail/7324522/job/20859717/realTimeLog/465 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
嗯嗯 这里你之前写法没问题,我忽略这点了 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,这部分已经改回去 |
||
} | ||
}; | ||
|
||
} // namespace tensorrt | ||
} // namespace inference | ||
} // namespace paddle | ||
|
||
REGISTER_TRT_OP_CONVERTER(one_hot, OneHotOpConverter); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one_hot与one_hot_v2等价,可以同时加上 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,已修改 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. op teller 里, 没有改完整 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,这部分已经修改 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1742,6 +1742,36 @@ struct SimpleOpTypeSetTeller : public Teller { | |
} | ||
} | ||
|
||
if (op_type == "one_hot") { | ||
#if IS_TRT_VERSION_LT(8510) | ||
VLOG(3) << "one_hot is not supported when TensorRT < 8.5.1"; | ||
return false; | ||
#endif | ||
if (!with_dynamic_shape) { | ||
VLOG(3) << "the one_hot op does not support static shape yet"; | ||
return false; | ||
} | ||
if (desc.HasAttr("allow_out_of_range")) { | ||
VLOG(3) << "allow_out_of_range one_hot op is not supported now."; | ||
if (PADDLE_GET_CONST(bool, desc.GetAttr("allow_out_of_range"))) | ||
return false; | ||
} | ||
if (desc.HasAttr("dtype")) { | ||
const int dtype = PADDLE_GET_CONST(int, desc.GetAttr("dtype")); | ||
if (dtype != 2 && dtype != 3 && dtype != 5) { | ||
VLOG(3) << "one_hot op only support int32, int64, float."; | ||
return false; | ||
} | ||
} | ||
if (desc.HasAttr("depth")) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 对于存在depth_tensor输入时,可以返回true There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,已修改 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这部分不用删除,这个if前可以加入 Input depth_tensor输入判断,存在就返回true There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的,这部分已经修复 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里还是报错了, 参考这个写法 另外 teller_set、int8_teller_set 也漏添加one_hot_v2. |
||
const int depth = PADDLE_GET_CONST(int, desc.GetAttr("depth")); | ||
if (depth <= 0) { | ||
VLOG(3) << "depth only support positive in one_hot op."; | ||
return false; | ||
} | ||
} | ||
} | ||
|
||
if (op_type == "skip_layernorm") { | ||
if (!with_dynamic_shape) { | ||
VLOG(3) << "the skip_layernorm does not support static shape yet"; | ||
|
@@ -2391,6 +2421,7 @@ struct SimpleOpTypeSetTeller : public Teller { | |
"fc", | ||
"shuffle_channel", | ||
"where", | ||
"one_hot", | ||
"swish", | ||
"silu", | ||
"celu", | ||
|
@@ -2523,6 +2554,7 @@ struct SimpleOpTypeSetTeller : public Teller { | |
"fc", | ||
"shuffle_channel", | ||
"where", | ||
"one_hot", | ||
"swish", | ||
"silu", | ||
"celu", | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
from functools import partial | ||
from typing import List | ||
|
||
import numpy as np | ||
from program_config import ProgramConfig, TensorConfig | ||
from trt_layer_auto_scan_test import TrtLayerAutoScanTest | ||
|
||
import paddle.inference as paddle_infer | ||
|
||
|
||
class TrtConvertOneHotTest(TrtLayerAutoScanTest): | ||
def is_program_valid(self, program_config: ProgramConfig) -> bool: | ||
ver = paddle_infer.get_trt_compile_version() | ||
if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8510: | ||
return False | ||
return True | ||
|
||
def sample_program_configs(self): | ||
self.trt_param.workspace_size = 1073741824 | ||
|
||
def generate_indices(dims, batch): | ||
if dims == 2: | ||
return np.random.randint(0, 10, (batch, 4), dtype=np.int32) | ||
elif dims == 3: | ||
return np.random.randint(0, 10, (batch, 4, 6), dtype=np.int32) | ||
else: | ||
return np.random.randint( | ||
0, 10, (batch, 4, 6, 8), dtype=np.int32 | ||
) | ||
|
||
def generate_depth(dims, batch): | ||
return np.ones((1,), dtype=np.int32) * 10 | ||
|
||
for dims in [2, 3, 4]: | ||
for batch in [1, 2]: | ||
self.dims = dims | ||
dics = [{"dtype": 5, "depth": 10}, {}] | ||
ops_config = [ | ||
{ | ||
"op_type": "one_hot", | ||
"op_inputs": { | ||
"X": ["input_x_data"], | ||
"depth_tensor": ["input_depth_data"], | ||
}, | ||
"op_outputs": {"Out": ["output_data"]}, | ||
"op_attrs": dics[0], | ||
"outputs_dtype": {"output_data": np.int}, | ||
}, | ||
] | ||
ops = self.generate_op_config(ops_config) | ||
|
||
program_config = ProgramConfig( | ||
ops=ops, | ||
weights={ | ||
"depth_tensor": TensorConfig( | ||
data_gen=partial(generate_depth, dims, batch) | ||
), | ||
}, | ||
inputs={ | ||
"indices_tensor": TensorConfig( | ||
data_gen=partial(generate_indices, dims, batch) | ||
), | ||
}, | ||
outputs=["output_data"], | ||
) | ||
|
||
yield program_config | ||
|
||
def sample_predictor_configs( | ||
self, program_config | ||
) -> (paddle_infer.Config, List[int], float): | ||
def generate_dynamic_shape(attrs): | ||
if self.dims == 1: | ||
self.dynamic_shape.min_input_shape = { | ||
"input_x_data": [1], | ||
} | ||
self.dynamic_shape.max_input_shape = { | ||
"input_x_data": [2], | ||
} | ||
self.dynamic_shape.opt_input_shape = { | ||
"input_x_data": [1], | ||
} | ||
elif self.dims == 2: | ||
self.dynamic_shape.min_input_shape = { | ||
"input_x_data": [1, 4], | ||
} | ||
self.dynamic_shape.max_input_shape = { | ||
"input_x_data": [2, 4], | ||
} | ||
self.dynamic_shape.opt_input_shape = { | ||
"input_x_data": [1, 4], | ||
} | ||
elif self.dims == 3: | ||
self.dynamic_shape.min_input_shape = { | ||
"input_x_data": [1, 4, 6], | ||
} | ||
self.dynamic_shape.max_input_shape = { | ||
"input_x_data": [2, 4, 6], | ||
} | ||
self.dynamic_shape.opt_input_shape = { | ||
"input_x_data": [1, 4, 6], | ||
} | ||
elif self.dims == 4: | ||
self.dynamic_shape.min_input_shape = { | ||
"input_x_data": [1, 4, 6, 8], | ||
} | ||
self.dynamic_shape.max_input_shape = { | ||
"input_x_data": [2, 4, 6, 8], | ||
} | ||
self.dynamic_shape.opt_input_shape = { | ||
"input_x_data": [1, 4, 6, 8], | ||
} | ||
|
||
def clear_dynamic_shape(): | ||
self.dynamic_shape.min_input_shape = {} | ||
self.dynamic_shape.max_input_shape = {} | ||
self.dynamic_shape.opt_input_shape = {} | ||
|
||
def generate_trt_nodes_num(attrs, dynamic_shape): | ||
if not dynamic_shape: | ||
return 0, 3 | ||
return 1, 2 | ||
|
||
attrs = [op.attrs for op in program_config.ops] | ||
|
||
# for static_shape | ||
clear_dynamic_shape() | ||
self.trt_param.precision = paddle_infer.PrecisionType.Float32 | ||
yield self.create_inference_config(), generate_trt_nodes_num( | ||
attrs, False | ||
), 1e-5 | ||
self.trt_param.precision = paddle_infer.PrecisionType.Half | ||
yield self.create_inference_config(), generate_trt_nodes_num( | ||
attrs, False | ||
), 1e-5 | ||
|
||
# for dynamic_shape | ||
generate_dynamic_shape(attrs) | ||
self.trt_param.precision = paddle_infer.PrecisionType.Float32 | ||
yield self.create_inference_config(), generate_trt_nodes_num( | ||
attrs, True | ||
), 1e-5 | ||
self.trt_param.precision = paddle_infer.PrecisionType.Half | ||
yield self.create_inference_config(), generate_trt_nodes_num( | ||
attrs, True | ||
), 1e-5 | ||
|
||
def test(self): | ||
self.run_test() | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里去掉const,否则会报错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,已修改