Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick]mish trt plugin #38866

Merged
merged 9 commits into from
Jan 11, 2022
Merged
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/is_test_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void IsTestPass::ApplyImpl(ir::Graph* graph) const {
"hard_shrink", "hard_sigmoid", "relu6",
"soft_relu", "swish", "thresholded_relu",
"log", "square", "softplus",
"softsign", "silu"};
"softsign", "silu", "mish"};
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,7 @@ USE_TRT_CONVERTER(tile);
USE_TRT_CONVERTER(conv3d);
USE_TRT_CONVERTER(conv3d_transpose);
USE_TRT_CONVERTER(pool3d);
USE_TRT_CONVERTER(mish);
#endif

namespace paddle_infer {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/tensorrt/convert/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ nv_library(tensorrt_converter
tile_op.cc
conv3d_op.cc
pool3d_op.cc
mish_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)

nv_test(test_op_converter SRCS test_op_converter.cc DEPS
Expand Down
74 changes: 74 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/mish_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/mish_op_plugin.h"

namespace paddle {
namespace framework {
class Scope;

namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle

namespace paddle {
namespace inference {
namespace tensorrt {

/*
* Mish converter from fluid to tensorRT.
*/
class MishOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert fluid Mish op to tensorrt Mish plugin";

framework::OpDesc op_desc(op, nullptr);
// Declare inputs
int input_num = op_desc.Input("X").size();
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);

const float threshold =
op_desc.HasAttr("threshold")
? BOOST_GET_CONST(float, op_desc.GetAttr("threshold"))
: 20.0f;

nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::MishPluginDynamic* plugin =
new plugin::MishPluginDynamic(threshold, with_fp16);
layer = engine_->AddDynamicPlugin(&input, input_num, plugin);
} else {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::MishPlugin* plugin = new plugin::MishPlugin(threshold, with_fp16);
layer = engine_->AddPlugin(&input, input_num, plugin);
}

auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "mish", {output_name}, test_mode);
}
};

} // namespace tensorrt
} // namespace inference
} // namespace paddle

REGISTER_TRT_OP_CONVERTER(mish, MishOpConverter);
47 changes: 47 additions & 0 deletions paddle/fluid/inference/tensorrt/convert/test_mish_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"

namespace paddle {
namespace inference {
namespace tensorrt {

TEST(mish_op, test_mish) {
std::unordered_set<std::string> parameters;
framework::Scope scope;
TRTConvertValidation validator(10, parameters, scope, 1000);
validator.DeclInputVar("mish-X", nvinfer1::Dims3(3, 2, 2));
validator.DeclOutputVar("mish-Out", nvinfer1::Dims3(3, 2, 2));

// Prepare Op description
framework::OpDesc desc;
desc.SetType("mish");
desc.SetInput("X", {"mish-X"});
desc.SetOutput("Out", {"mish-Out"});

desc.SetAttr("threshold", 20.0f);

validator.SetOp(*desc.Proto());

validator.Execute(1);
}

} // namespace tensorrt
} // namespace inference
} // namespace paddle

USE_OP(mish);
41 changes: 40 additions & 1 deletion paddle/fluid/inference/tensorrt/op_teller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"reduce_mean",
"conv3d",
"conv3d_transpose",
"pool3d"};
"pool3d",
"mish"};
};

bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
Expand Down Expand Up @@ -1160,6 +1161,44 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
#endif
}

if (op_type == "mish") {
if (desc.Input("X").size() != 1) {
VLOG(3) << "Invalid input X's size of mish TRT converter. "
"Expected 1, received "
<< desc.Input("X").size() << ".";
return false;
}
if (desc.Output("Out").size() != 1) {
VLOG(3) << "Invalid output Out's size of mish TRT converter. "
"Expected 1, received "
<< desc.Output("Out").size() << ".";
return false;
}

auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
"Developers need to check whether block_desc is passed in "
"the pass.";
return false;
}

auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
if (x_shape.size() == 1) {
VLOG(3) << "mish op does not support input's dim is 1 in tensorrt.";
return false;
}

if (!with_dynamic_shape) {
if (x_shape.size() == 2) {
VLOG(3) << "mish op does not support input's dim is 2 in tensorrt.";
return false;
}
}
}

if (op_type == "roi_align") {
if (!with_dynamic_shape) {
VLOG(3) << "TRT roi align plugin only accept the dynamic shape, "
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ nv_library(tensorrt_plugin
roi_align_op_plugin.cu
gather_nd_op_plugin.cu
pool3d_op_plugin.cu
mish_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)

nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS
Expand Down
Loading