Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#8 from ckl117/ADFM-PNC
Browse files Browse the repository at this point in the history
trt支持p_norm和scatter_nd_add
  • Loading branch information
ming1753 authored Jul 22, 2024
2 parents 69a1b06 + 58f08b6 commit 688f628
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 5 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3367,6 +3367,8 @@ USE_TRT_CONVERTER(unbind)
USE_TRT_CONVERTER(argsort)
USE_TRT_CONVERTER(atan2)
USE_TRT_CONVERTER(index_put)
USE_TRT_CONVERTER(scatter_nd_add)
USE_TRT_CONVERTER(p_norm)
USE_TRT_CONVERTER(flip)
USE_TRT_CONVERTER(share_data)
#if IS_TRT_VERSION_GE(8522)
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ list(
quantize_linear_op.cc
dequantize_linear_op.cc
index_put_op.cc
scatter_nd_add_op.cc
p_norm_op.cc
share_data_op.cc)

if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7)
Expand Down
56 changes: 56 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/p_norm_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"

namespace paddle::inference::tensorrt {

/*
* p_norm Op
*/
class PNormOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(3) << "convert a p_norm op to tensorrt layer";

framework::OpDesc op_desc(op, nullptr);
std::string input_name = op_desc.Input("X").front();
std::string output_name = op_desc.Output("Out").front();
auto* input_tensor = engine_->GetITensor(input_name);
int rank = input_tensor->getDimensions().nbDims;
int axis = PADDLE_GET_CONST(int, op_desc.GetAttr("axis"));
bool keepdim = PADDLE_GET_CONST(bool, op_desc.GetAttr("keepdim"));
if (axis < 0) {
axis += rank;
}
uint32_t axisMask = 1 << axis;
auto* prod_tensor = Prod(input_tensor, input_tensor);
auto* prod_layer = TRT_ENGINE_ADD_LAYER(engine_,
Reduce,
*prod_tensor,
nvinfer1::ReduceOperation::kSUM,
axisMask,
keepdim);
auto* reduce_tensor = prod_layer->getOutput(0);
auto* sqrt_layer = TRT_ENGINE_ADD_LAYER(
engine_, Unary, *reduce_tensor, nvinfer1::UnaryOperation::kSQRT);
ReplenishLayerAndOutput(sqrt_layer, "p_norm", {output_name}, test_mode);
}
};

} // namespace paddle::inference::tensorrt

REGISTER_TRT_OP_CONVERTER(p_norm, PNormOpConverter);
61 changes: 61 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/scatter_nd_add_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"

namespace paddle::inference::tensorrt {

/*
* scatter_nd_add Op
*/
class ScatterNdAddOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(3) << "convert a scatter_nd_add op to tensorrt layer";

framework::OpDesc op_desc(op, nullptr);
std::string input_name = op_desc.Input("X").front();
std::string index_name = op_desc.Input("Index").front();
std::string update_name = op_desc.Input("Updates").front();
std::string output_name = op_desc.Output("Out").front();

auto* input_tensor = engine_->GetITensor(input_name);
auto* index_tensor = engine_->GetITensor(index_name);
auto* update_tensor = engine_->GetITensor(update_name);
auto* input_shape_tensor = Shape(input_tensor);
nvinfer1::Dims input_dims = input_tensor->getDimensions();
auto rank = input_dims.nbDims;
auto* zero_tensor = FillConstantLayer(input_shape_tensor, rank, 0.f);
auto* value_layer = TRT_ENGINE_ADD_LAYER(engine_,
Scatter,
*zero_tensor,
*index_tensor,
*update_tensor,
nvinfer1::ScatterMode::kND);
value_layer->setAxis(0);
auto* value_tensor = value_layer->getOutput(0);
auto* layer = TRT_ENGINE_ADD_LAYER(engine_,
ElementWise,
*input_tensor,
*value_tensor,
nvinfer1::ElementWiseOperation::kSUM);
ReplenishLayerAndOutput(layer, "scatter_nd_add", {output_name}, test_mode);
}
};

} // namespace paddle::inference::tensorrt

REGISTER_TRT_OP_CONVERTER(scatter_nd_add, ScatterNdAddOpConverter);
52 changes: 47 additions & 5 deletions paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2824,11 +2824,6 @@ struct SimpleOpTypeSetTeller : public Teller {
auto* x_var_desc = block->FindVarRecursive(x_var_name);
std::vector<int64_t> shape = x_var_desc->GetShape();
int axis = PADDLE_GET_CONST(int, desc.GetAttr("axis"));
// std::cout << "x shape = ";
// for (long unsigned int i = 0; i < shape.size(); ++i) {
// std::cout << shape[i] << ",";
// }
// std::cout << std::endl;
if (shape.size() <= 1) {
VLOG(3) << op_type << " op shape size <= 1.";
return false;
Expand Down Expand Up @@ -2879,6 +2874,49 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}

if (op_type == "scatter_nd_add") {
if (!with_dynamic_shape) {
VLOG(3) << "the scatter_nd_add does not support "
"static shape yet";
return false;
}
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
}

if (op_type == "p_norm") {
if (!with_dynamic_shape) {
VLOG(3) << "the p_norm does not support "
"static shape yet";
return false;
}
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}
if (!(desc.HasAttr("asvector") && desc.HasAttr("axis") &&
desc.HasAttr("porder") && desc.HasAttr("keepdim"))) {
VLOG(3) << op_type << " op need attrs asvector, porder, axis, keepdim.";
return false;
}
bool asvector = PADDLE_GET_CONST(bool, desc.GetAttr("asvector"));
int axis = PADDLE_GET_CONST(int, desc.GetAttr("axis"));
float porder = PADDLE_GET_CONST(float, desc.GetAttr("porder"));
if (asvector || porder != 2.0f || axis != -1) {
VLOG(3) << op_type
<< " op only support asvector=False, porder=2.0, axis = -1.";
return false;
}
}

if (op_type == "temporal_shift") {
#if !IS_TRT_VERSION_GE(8200)
VLOG(3) << "temporal_shift is not supported when TensorRT < 8.2";
Expand Down Expand Up @@ -3143,6 +3181,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"argsort",
"atan2",
"index_put",
"scatter_nd_add",
"p_norm",
"assign",
"flip",
"quantize_linear",
Expand Down Expand Up @@ -3318,6 +3358,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"argsort",
"atan2",
"index_put",
"scatter_nd_add",
"p_norm",
"assign",
"flip",
"quantize_linear",
Expand Down

0 comments on commit 688f628

Please sign in to comment.