From 762791c60e479b570ef280e05f4f0902891df8f3 Mon Sep 17 00:00:00 2001 From: wwbitejotunn Date: Fri, 13 Jan 2023 02:57:46 +0000 Subject: [PATCH 1/9] preln_residual 2 fused_bias_residual --- .../ir/preln_residual_bias_fuse_pass.cc | 30 ++++++++++++------- .../fluid/inference/api/analysis_predictor.cc | 2 +- .../tensorrt/convert/preln_residual_bias.cc | 25 +++++++++------- paddle/fluid/inference/tensorrt/op_teller.cc | 4 +-- ...sed_bias_dropout_residual_layer_norm_op.cc | 14 ++++----- ...sed_bias_dropout_residual_layer_norm_op.cu | 6 ++-- .../fused_layernorm_residual_dropout_bias.h | 7 +++-- 7 files changed, 51 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc index 13b7b4ac72f96b..bc6001a934e77b 100644 --- a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc @@ -207,7 +207,7 @@ int PrelnResidualBiasFusePass::ApplyPattern(ir::Graph *graph, // on each other, so we make below check to ensure only one // PrelnResidualBias pattern is delalted with. for (auto op : elementwise1_out->inputs) { - if (op->Name() == "preln_residual_bias") return; + if (op->Name() == "fused_bias_dropout_residual_layer_norm") return; } if (!IsCompat(subgraph, graph)) { @@ -218,20 +218,24 @@ int PrelnResidualBiasFusePass::ApplyPattern(ir::Graph *graph, std::unordered_set del_node_set; // Create an PrelnResidualBias op node OpDesc new_desc; - new_desc.SetType("preln_residual_bias"); + new_desc.SetType("fused_bias_dropout_residual_layer_norm"); // inputs new_desc.SetInput("X", {subgraph.at(x)->Name()}); - new_desc.SetInput("Y", {subgraph.at(y)->Name()}); - new_desc.SetInput("Scale", {layer_norm_scale->Name()}); - new_desc.SetInput("Bias", {layer_norm_bias->Name()}); + new_desc.SetInput("Residual", {subgraph.at(y)->Name()}); + new_desc.SetInput("LnScale", {layer_norm_scale->Name()}); + new_desc.SetInput("LnBias", {layer_norm_bias->Name()}); if (with_bias) { - new_desc.SetInput("EleBias", {elementwise_bias->Name()}); + new_desc.SetInput("Bias", {elementwise_bias->Name()}); } // outputs - new_desc.SetOutput("Out_0", {layer_norm_out->Name()}); - new_desc.SetOutput("Out_1", {elementwise1_out->Name()}); + new_desc.SetOutput("Y", {layer_norm_out->Name()}); + new_desc.SetOutput("BiasDropoutResidualOut", {elementwise1_out->Name()}); + new_desc.SetOutput("LnMean",{layer_norm_mean->Name()}); + new_desc.SetOutput("LnVariance",{layer_norm_variance->Name()}); // attrs - new_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon")); + new_desc.SetAttr("ln_epsilon", layer_norm->Op()->GetAttr("epsilon")); + new_desc.SetAttr("dropout_rate", 0.0f); + new_desc.SetAttr("is_test", true); new_desc.SetAttr("begin_norm_axis", layer_norm->Op()->GetAttr("begin_norm_axis")); auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied. @@ -241,8 +245,8 @@ int PrelnResidualBiasFusePass::ApplyPattern(ir::Graph *graph, } del_node_set.insert(elementwise1); del_node_set.insert(layer_norm); - del_node_set.insert(layer_norm_mean); - del_node_set.insert(layer_norm_variance); + // del_node_set.insert(layer_norm_mean); + // del_node_set.insert(layer_norm_variance); GraphSafeRemoveNodes(graph, del_node_set); IR_NODE_LINK_TO(subgraph.at(x), fused_node); IR_NODE_LINK_TO(subgraph.at(y), fused_node); @@ -253,6 +257,9 @@ int PrelnResidualBiasFusePass::ApplyPattern(ir::Graph *graph, IR_NODE_LINK_TO(layer_norm_bias, fused_node); IR_NODE_LINK_TO(fused_node, layer_norm_out); IR_NODE_LINK_TO(fused_node, elementwise1_out); + IR_NODE_LINK_TO(fused_node, layer_norm_mean); + IR_NODE_LINK_TO(fused_node, layer_norm_variance); + found_subgraph_count++; }; @@ -261,6 +268,7 @@ int PrelnResidualBiasFusePass::ApplyPattern(ir::Graph *graph, } void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { + VLOG(1)<<"Fuse PrelnResidualBias into fused_bias_dropout_residual_layer_norm op with dropout rate = 0"; PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); FusePassBase::Init("preln_residual_bias_fuse", graph); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 0fb11279ebdf9c..742bfcc2b0f72f 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2386,7 +2386,7 @@ USE_TRT_CONVERTER(rsqrt); USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm) USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm); USE_TRT_CONVERTER(preln_skip_layernorm) -USE_TRT_CONVERTER(preln_residual_bias) +USE_TRT_CONVERTER(fused_bias_dropout_residual_layer_norm) USE_TRT_CONVERTER(c_allreduce_sum) USE_TRT_CONVERTER(roll) USE_TRT_CONVERTER(strided_slice) diff --git a/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc b/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc index 28847aa5b7a307..265fda2c5455c9 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc @@ -26,7 +26,7 @@ class PrelnResidualBiasOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { - VLOG(4) << "convert fused preln_residual_bias op to tensorrt layer"; + VLOG(4) << "convert fused_bias_dropout_residual_layer_norm op with drop_rate = 0 to preln_residual_bias tensorrt layer"; if (!engine_->with_dynamic_shape()) { PADDLE_THROW( platform::errors::Fatal("Unsupported static graph mode. Please set " @@ -35,7 +35,7 @@ class PrelnResidualBiasOpConverter : public OpConverter { framework::OpDesc op_desc(op, nullptr); // Declare inputs auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); - auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]); + auto* input2 = engine_->GetITensor(op_desc.Input("Residual")[0]); std::vector inputs; inputs.push_back(input1); inputs.push_back(input2); @@ -50,18 +50,23 @@ class PrelnResidualBiasOpConverter : public OpConverter { return temp_data; }; framework::DDim bias_dims, scale_dims, ele_bias_dims; - auto* bias = get_persistable_data("Bias", &bias_dims); - auto* scale = get_persistable_data("Scale", &scale_dims); + auto* bias = get_persistable_data("LnBias", &bias_dims); + auto* scale = get_persistable_data("LnScale", &scale_dims); auto const& vars = op_desc.Inputs(false); - bool has_bias = vars.find("EleBias") != vars.end(); + bool has_bias = vars.find("Bias") != vars.end(); float* ele_bias = - has_bias ? get_persistable_data("EleBias", &ele_bias_dims) : nullptr; + has_bias ? get_persistable_data("Bias", &ele_bias_dims) : nullptr; int bias_size = phi::product(bias_dims); int scale_size = phi::product(scale_dims); int ele_bias_size = has_bias ? phi::product(ele_bias_dims) : 0; - float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("epsilon")); + float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("ln_epsilon")); + float dropout_rate = PADDLE_GET_CONST(float, op_desc.GetAttr("dropout_rate")); + if (dropout_rate != 0.0f){ + VLOG(4)<<"preln_residual_bias trt layer can not work with fused_bias_dropout_residual_layer_norm op in which the dropout_rate != 0, stop convert"; + return; + } bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); if (engine_->precision() == AnalysisConfig::Precision::kInt8) { with_fp16 = true; @@ -102,8 +107,8 @@ class PrelnResidualBiasOpConverter : public OpConverter { plugin_inputs.emplace_back(input2); layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin); std::vector output_names; - output_names.push_back(op_desc.Output("Out_0")[0]); - output_names.push_back(op_desc.Output("Out_1")[0]); + output_names.push_back(op_desc.Output("Y")[0]); + output_names.push_back(op_desc.Output("BiasDropoutResidualOut")[0]); RreplenishLayerAndOutput( layer, "preln_residual_bias", output_names, test_mode); } @@ -113,4 +118,4 @@ class PrelnResidualBiasOpConverter : public OpConverter { } // namespace inference } // namespace paddle -REGISTER_TRT_OP_CONVERTER(preln_residual_bias, PrelnResidualBiasOpConverter); +REGISTER_TRT_OP_CONVERTER(fused_bias_dropout_residual_layer_norm, PrelnResidualBiasOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index b49ebbff55d80d..8caa429d0a8665 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -2535,7 +2535,7 @@ struct SimpleOpTypeSetTeller : public Teller { "slice", "strided_slice", "fused_preln_embedding_eltwise_layernorm", - "preln_residual_bias", + "fused_bias_dropout_residual_layer_norm", "c_allreduce_sum", "c_allreduce_min", "c_allreduce_max", @@ -2683,7 +2683,7 @@ struct SimpleOpTypeSetTeller : public Teller { "strided_slice", "fused_preln_embedding_eltwise_layernorm", "preln_skip_layernorm", - "preln_residual_bias", + "fused_bias_dropout_residual_layer_norm", "c_allreduce_sum", "c_allreduce_min", "c_allreduce_max", diff --git a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc index a6fa80a4939728..04eb576cb4b680 100644 --- a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc +++ b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc @@ -39,10 +39,10 @@ class FusedBiasDropoutResidualLnOp : public framework::OperatorWithKernel { "Output", "BiasDropoutResidualOut", "FusedBiasDropoutResidualLnOp"); - OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), - "Output", - "DropoutMaskOut", - "FusedBiasDropoutResidualLnOp"); + // OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), + // "Output", + // "DropoutMaskOut", + // "FusedBiasDropoutResidualLnOp"); OP_INOUT_CHECK( ctx->HasOutput("Y"), "Output", "Y", "FusedBiasDropoutResidualLnOp"); auto x_dim = ctx->GetInputDim("X"); @@ -86,10 +86,10 @@ class FusedBiasDropoutResidualLnOpMaker AddOutput("BiasDropoutResidualOut", "Output of bias + dropout + residual.") .AsIntermediate(); AddOutput("DropoutMaskOut", "The random sampled dropout mask.") - .AsIntermediate(); - AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate(); + .AsDispensable(); + AddOutput("LnMean", "Mean of the current mini batch.").AsDispensable(); AddOutput("LnVariance", "Variance of the current mini batch.") - .AsIntermediate(); + .AsDispensable(); AddOutput("Y", "Result."); AddAttr("dropout_rate", "Probability of setting units to zero.") .SetDefault(.5f) diff --git a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cu b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cu index 2562c2cc225756..4f5246df34f64b 100644 --- a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cu +++ b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cu @@ -48,14 +48,14 @@ class FusedBiasDropoutResidualLnOpKernel : public framework::OpKernel { auto *residual_data = (residual == nullptr) ? nullptr : residual->data(); auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data()); - auto *bias_dropout_residual_out_data = + auto *bias_dropout_residual_out_data = dev_ctx.Alloc(bias_dropout_residual_out, bias_dropout_residual_out->numel() * sizeof(T)); auto *ln_mean_data = dev_ctx.Alloc(ln_mean, ln_mean->numel() * sizeof(U)); auto *ln_var_data = dev_ctx.Alloc(ln_var, ln_var->numel() * sizeof(U)); - auto *dropout_mask_out_data = dev_ctx.Alloc( - dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t)); + auto *dropout_mask_out_data = (dropout_mask_out == nullptr) ? nullptr : + dev_ctx.Alloc(dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t)); auto *y_data = dev_ctx.Alloc(y, y->numel() * sizeof(T)); const auto input_x_dims = input_x->dims(); diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index c65364d2818d1a..bbf1205495f5a7 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -854,9 +854,10 @@ void LaunchLayernormResidualDropoutBias( residual, rows * cols * sizeof(T), ctx.stream()); - PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( - mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream())); - + if(mask_data!=nullptr){ + PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( + mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream())); + } // call layernorm forward switch (GetDesiredBlockDim(cols)) { FIXED_BLOCK_DIM_CASE( From d8655ecb9fbc6ce8a995404bc24ff18ddfa5bb91 Mon Sep 17 00:00:00 2001 From: wwbitejotunn Date: Fri, 13 Jan 2023 06:37:36 +0000 Subject: [PATCH 2/9] skip layernorm fix and ut --- .../framework/ir/trt_skip_layernorm_fuse_pass.cc | 2 +- .../test_trt_convert_preln_residual_bias.py | 10 ++++++++++ .../test_trt_convert_preln_residual_no_bias.py | 12 ++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc index db023746ac4c79..18ea8850dc5bfb 100644 --- a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc @@ -170,7 +170,7 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { // attrs new_desc.SetAttr("epsilon", layer_norm->Op()->GetAttr("epsilon")); - if (new_desc.HasAttr("begin_norm_axis")) { + if (layer_norm->Op()->HasAttr("begin_norm_axis")) { int32_t begin_norm_axis = PADDLE_GET_CONST( int32_t, layer_norm->Op()->GetAttr("begin_norm_axis")); int32_t input_rank = diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_bias.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_bias.py index 9e17b83ab9c1ef..9202fa1fcc1f06 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_bias.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_bias.py @@ -163,6 +163,16 @@ def generate_trt_nodes_num(attrs, dynamic_shape): attrs = [ program_config.ops[i].attrs for i in range(len(program_config.ops)) ] + # for static_shape, fall back to fluid fused op + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-2 # atol=1e-2 while rtol is 1e-8 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-2 # atol=1e-2 while rtol is 1e-8 # just support dynamic_shape generate_dynamic_shape(attrs) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py index aef2142bf3e8ea..8ae65f5d478d6e 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py @@ -152,6 +152,18 @@ def generate_trt_nodes_num(attrs, dynamic_shape): program_config.ops[i].attrs for i in range(len(program_config.ops)) ] + # for static_shape, fall back to fluid fused op + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-2 # atol=1e-2 while rtol is 1e-8 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-2 # atol=1e-2 while rtol is 1e-8 + + # just support dynamic_shape generate_dynamic_shape(attrs) self.trt_param.precision = paddle_infer.PrecisionType.Float32 From aa491604e0008e21d79cc6ca471b566d48779457 Mon Sep 17 00:00:00 2001 From: wwbitejotunn Date: Fri, 13 Jan 2023 06:43:09 +0000 Subject: [PATCH 3/9] code refine --- .../fluid/framework/ir/preln_residual_bias_fuse_pass.cc | 9 ++++----- .../fused/fused_bias_dropout_residual_layer_norm_op.cc | 4 ---- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc index bc6001a934e77b..bd0d27caa65f82 100644 --- a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc @@ -230,8 +230,8 @@ int PrelnResidualBiasFusePass::ApplyPattern(ir::Graph *graph, // outputs new_desc.SetOutput("Y", {layer_norm_out->Name()}); new_desc.SetOutput("BiasDropoutResidualOut", {elementwise1_out->Name()}); - new_desc.SetOutput("LnMean",{layer_norm_mean->Name()}); - new_desc.SetOutput("LnVariance",{layer_norm_variance->Name()}); + new_desc.SetOutput("LnMean", {layer_norm_mean->Name()}); + new_desc.SetOutput("LnVariance", {layer_norm_variance->Name()}); // attrs new_desc.SetAttr("ln_epsilon", layer_norm->Op()->GetAttr("epsilon")); new_desc.SetAttr("dropout_rate", 0.0f); @@ -245,8 +245,6 @@ int PrelnResidualBiasFusePass::ApplyPattern(ir::Graph *graph, } del_node_set.insert(elementwise1); del_node_set.insert(layer_norm); - // del_node_set.insert(layer_norm_mean); - // del_node_set.insert(layer_norm_variance); GraphSafeRemoveNodes(graph, del_node_set); IR_NODE_LINK_TO(subgraph.at(x), fused_node); IR_NODE_LINK_TO(subgraph.at(y), fused_node); @@ -268,7 +266,8 @@ int PrelnResidualBiasFusePass::ApplyPattern(ir::Graph *graph, } void PrelnResidualBiasFusePass::ApplyImpl(ir::Graph *graph) const { - VLOG(1)<<"Fuse PrelnResidualBias into fused_bias_dropout_residual_layer_norm op with dropout rate = 0"; + VLOG(1) << "Fuse PrelnResidualBias into " + "fused_bias_dropout_residual_layer_norm op with dropout rate = 0"; PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); FusePassBase::Init("preln_residual_bias_fuse", graph); diff --git a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc index 04eb576cb4b680..6d69488ece0927 100644 --- a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc +++ b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc @@ -39,10 +39,6 @@ class FusedBiasDropoutResidualLnOp : public framework::OperatorWithKernel { "Output", "BiasDropoutResidualOut", "FusedBiasDropoutResidualLnOp"); - // OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), - // "Output", - // "DropoutMaskOut", - // "FusedBiasDropoutResidualLnOp"); OP_INOUT_CHECK( ctx->HasOutput("Y"), "Output", "Y", "FusedBiasDropoutResidualLnOp"); auto x_dim = ctx->GetInputDim("X"); From 551d121e261f16967483174081f38d854a4d69c3 Mon Sep 17 00:00:00 2001 From: wwbitejotunn Date: Fri, 13 Jan 2023 06:58:11 +0000 Subject: [PATCH 4/9] code style refine --- .../tensorrt/convert/preln_residual_bias.cc | 15 ++++++++++----- .../fused_bias_dropout_residual_layer_norm_op.cu | 10 +++++++--- .../fused/fused_layernorm_residual_dropout_bias.h | 2 +- .../test_trt_convert_preln_residual_no_bias.py | 1 - 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc b/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc index 265fda2c5455c9..b98421c518365c 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc @@ -26,7 +26,8 @@ class PrelnResidualBiasOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { - VLOG(4) << "convert fused_bias_dropout_residual_layer_norm op with drop_rate = 0 to preln_residual_bias tensorrt layer"; + VLOG(4) << "convert fused_bias_dropout_residual_layer_norm op with " + "drop_rate = 0 to preln_residual_bias tensorrt layer"; if (!engine_->with_dynamic_shape()) { PADDLE_THROW( platform::errors::Fatal("Unsupported static graph mode. Please set " @@ -62,9 +63,12 @@ class PrelnResidualBiasOpConverter : public OpConverter { int scale_size = phi::product(scale_dims); int ele_bias_size = has_bias ? phi::product(ele_bias_dims) : 0; float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("ln_epsilon")); - float dropout_rate = PADDLE_GET_CONST(float, op_desc.GetAttr("dropout_rate")); - if (dropout_rate != 0.0f){ - VLOG(4)<<"preln_residual_bias trt layer can not work with fused_bias_dropout_residual_layer_norm op in which the dropout_rate != 0, stop convert"; + float dropout_rate = + PADDLE_GET_CONST(float, op_desc.GetAttr("dropout_rate")); + if (dropout_rate != 0.0f) { + VLOG(4) << "preln_residual_bias trt layer can not work with " + "fused_bias_dropout_residual_layer_norm op in which the " + "dropout_rate != 0, stop convert"; return; } bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); @@ -118,4 +122,5 @@ class PrelnResidualBiasOpConverter : public OpConverter { } // namespace inference } // namespace paddle -REGISTER_TRT_OP_CONVERTER(fused_bias_dropout_residual_layer_norm, PrelnResidualBiasOpConverter); +REGISTER_TRT_OP_CONVERTER(fused_bias_dropout_residual_layer_norm, + PrelnResidualBiasOpConverter); diff --git a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cu b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cu index 4f5246df34f64b..01a233950b2793 100644 --- a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cu +++ b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cu @@ -48,14 +48,18 @@ class FusedBiasDropoutResidualLnOpKernel : public framework::OpKernel { auto *residual_data = (residual == nullptr) ? nullptr : residual->data(); auto *ln_scale_data = (ln_scale == nullptr ? nullptr : ln_scale->data()); auto *ln_bias_data = (ln_bias == nullptr ? nullptr : ln_bias->data()); - auto *bias_dropout_residual_out_data = + auto *bias_dropout_residual_out_data = dev_ctx.Alloc(bias_dropout_residual_out, bias_dropout_residual_out->numel() * sizeof(T)); auto *ln_mean_data = dev_ctx.Alloc(ln_mean, ln_mean->numel() * sizeof(U)); auto *ln_var_data = dev_ctx.Alloc(ln_var, ln_var->numel() * sizeof(U)); - auto *dropout_mask_out_data = (dropout_mask_out == nullptr) ? nullptr : - dev_ctx.Alloc(dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t)); + auto *dropout_mask_out_data = + (dropout_mask_out == nullptr) + ? nullptr + : dev_ctx.Alloc( + dropout_mask_out, + dropout_mask_out->numel() * sizeof(uint8_t)); auto *y_data = dev_ctx.Alloc(y, y->numel() * sizeof(T)); const auto input_x_dims = input_x->dims(); diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h index bbf1205495f5a7..0c4e10fa156f9f 100644 --- a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -854,7 +854,7 @@ void LaunchLayernormResidualDropoutBias( residual, rows * cols * sizeof(T), ctx.stream()); - if(mask_data!=nullptr){ + if (mask_data != nullptr) { PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync( mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream())); } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py index 8ae65f5d478d6e..deffd92d0dbe35 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py @@ -163,7 +163,6 @@ def generate_trt_nodes_num(attrs, dynamic_shape): attrs, True ), 1e-2 # atol=1e-2 while rtol is 1e-8 - # just support dynamic_shape generate_dynamic_shape(attrs) self.trt_param.precision = paddle_infer.PrecisionType.Float32 From afaafa8751eb94eb9aa412aff7e9ce265f7d2ff5 Mon Sep 17 00:00:00 2001 From: wwbitejotunn Date: Fri, 13 Jan 2023 10:33:14 +0000 Subject: [PATCH 5/9] fix ut --- .../unittests/ir/test_ir_preln_residual_bias_fuse_pass.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py index 8f74ceebb65867..c66ee864532883 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_preln_residual_bias_fuse_pass.py @@ -38,7 +38,7 @@ def setUp(self): self.fetch_list = [out, elementwise_out] self.pass_names = "preln_residual_bias_fuse_pass" - self.fused_op_type = "preln_residual_bias" + self.fused_op_type = "fused_bias_dropout_residual_layer_norm" self.num_fused_ops = 1 # self.graph_attrs = { # "embedding_eltwise_layernorm_fuse_pass_flag": True, @@ -72,7 +72,7 @@ def setUp(self): self.fetch_list = [out, elementwise_out] self.pass_names = "preln_residual_bias_fuse_pass" - self.fused_op_type = "preln_residual_bias" + self.fused_op_type = "fused_bias_dropout_residual_layer_norm" self.num_fused_ops = 1 def test_check_program(self): From 7086f107f5cd76c550e4a5782c643970cf516015 Mon Sep 17 00:00:00 2001 From: wwbitejotunn Date: Mon, 16 Jan 2023 08:40:38 +0000 Subject: [PATCH 6/9] fix output --- .../framework/ir/preln_residual_bias_fuse_pass.cc | 14 ++++++++++++++ .../fused_bias_dropout_residual_layer_norm_op.cc | 6 +++--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc index bd0d27caa65f82..88a21b0fd362b5 100644 --- a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc @@ -129,6 +129,17 @@ void PrelnResidualBias::operator()(PDNode *x, PDNode *y) { } // namespace patterns +void addIntermediateOut(Node *op_node, + const std::string &out_name, + const std::string &scope_name, + Graph *graph) { + std::string new_name = scope_name + "/at." + out_name + ".new"; + VarDesc out_var(new_name); + out_var.SetPersistable(false); + auto *node_var = graph->CreateVarNode(&out_var); + IR_NODE_LINK_TO(op_node, node_var); +} + int PrelnResidualBiasFusePass::ApplyPattern(ir::Graph *graph, bool with_bias) const { PADDLE_ENFORCE_NOT_NULL( @@ -239,6 +250,9 @@ int PrelnResidualBiasFusePass::ApplyPattern(ir::Graph *graph, new_desc.SetAttr("begin_norm_axis", layer_norm->Op()->GetAttr("begin_norm_axis")); auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied. + addIntermediateOut( + fused_node, "DropoutMaskOut", "preln_residual_bias_fuse", graph); + if (with_bias) { del_node_set.insert(elementwise0); del_node_set.insert(elementwise0_out); diff --git a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc index 6d69488ece0927..4a4133649234a5 100644 --- a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc +++ b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc @@ -82,10 +82,10 @@ class FusedBiasDropoutResidualLnOpMaker AddOutput("BiasDropoutResidualOut", "Output of bias + dropout + residual.") .AsIntermediate(); AddOutput("DropoutMaskOut", "The random sampled dropout mask.") - .AsDispensable(); - AddOutput("LnMean", "Mean of the current mini batch.").AsDispensable(); + .AsIntermediate(); + AddOutput("LnMean", "Mean of the current mini batch.").AsIntermediate(); AddOutput("LnVariance", "Variance of the current mini batch.") - .AsDispensable(); + .AsIntermediate(); AddOutput("Y", "Result."); AddAttr("dropout_rate", "Probability of setting units to zero.") .SetDefault(.5f) From 17dc79b2faf3056314e5af6cf41b7c5a60282264 Mon Sep 17 00:00:00 2001 From: wwbitejotunn Date: Tue, 17 Jan 2023 07:14:33 +0000 Subject: [PATCH 7/9] add trt layer fall back info --- .../inference/tensorrt/convert/preln_residual_bias.cc | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc b/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc index b98421c518365c..c3a21b9c271ebe 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc @@ -29,9 +29,12 @@ class PrelnResidualBiasOpConverter : public OpConverter { VLOG(4) << "convert fused_bias_dropout_residual_layer_norm op with " "drop_rate = 0 to preln_residual_bias tensorrt layer"; if (!engine_->with_dynamic_shape()) { - PADDLE_THROW( - platform::errors::Fatal("Unsupported static graph mode. Please set " - "dynamic shape of inputs.")); + VLOG(0) << "preln_residual_bias trt layer can not work with static graph " + "mode. " + "preln_residual_bias trt layer will fall back to phi " + "fused_bias_dropout_residual_layer_norm op and cut the trt " + "subgraph" + "It is recommended to use dynamic shape trt mode instead."; } framework::OpDesc op_desc(op, nullptr); // Declare inputs From ffa9bef7b329233b583c16e0bc59374683d644fe Mon Sep 17 00:00:00 2001 From: wwbitejotunn Date: Wed, 18 Jan 2023 10:58:23 +0000 Subject: [PATCH 8/9] refine op teller and ut --- .../tensorrt/convert/preln_residual_bias.cc | 16 ---------------- paddle/fluid/inference/tensorrt/op_teller.cc | 16 +++++++++++++++- .../tests/unittests/ir/inference/CMakeLists.txt | 9 --------- .../test_trt_convert_preln_residual_bias.py | 9 ++++++--- .../test_trt_convert_preln_residual_no_bias.py | 9 ++++++--- 5 files changed, 27 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc b/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc index c3a21b9c271ebe..85f9106b011488 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_residual_bias.cc @@ -28,14 +28,6 @@ class PrelnResidualBiasOpConverter : public OpConverter { bool test_mode) override { VLOG(4) << "convert fused_bias_dropout_residual_layer_norm op with " "drop_rate = 0 to preln_residual_bias tensorrt layer"; - if (!engine_->with_dynamic_shape()) { - VLOG(0) << "preln_residual_bias trt layer can not work with static graph " - "mode. " - "preln_residual_bias trt layer will fall back to phi " - "fused_bias_dropout_residual_layer_norm op and cut the trt " - "subgraph" - "It is recommended to use dynamic shape trt mode instead."; - } framework::OpDesc op_desc(op, nullptr); // Declare inputs auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); @@ -66,14 +58,6 @@ class PrelnResidualBiasOpConverter : public OpConverter { int scale_size = phi::product(scale_dims); int ele_bias_size = has_bias ? phi::product(ele_bias_dims) : 0; float epsilon = PADDLE_GET_CONST(float, op_desc.GetAttr("ln_epsilon")); - float dropout_rate = - PADDLE_GET_CONST(float, op_desc.GetAttr("dropout_rate")); - if (dropout_rate != 0.0f) { - VLOG(4) << "preln_residual_bias trt layer can not work with " - "fused_bias_dropout_residual_layer_norm op in which the " - "dropout_rate != 0, stop convert"; - return; - } bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); if (engine_->precision() == AnalysisConfig::Precision::kInt8) { with_fp16 = true; diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 8caa429d0a8665..6e362f32829080 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -1451,7 +1451,21 @@ struct SimpleOpTypeSetTeller : public Teller { return false; } } - + if (op_type == "fused_bias_dropout_residual_layer_norm") { + if (!with_dynamic_shape) { + VLOG(3) << "fused_bias_dropout_residual_layer_norm should run on " + "dynamic shape mode."; + return false; + } + float dropout_rate = + PADDLE_GET_CONST(float, desc.GetAttr("dropout_rate")); + if (dropout_rate != 0.0f) { + VLOG(4) << "preln_residual_bias trt layer can not work with " + "fused_bias_dropout_residual_layer_norm op in which the " + "dropout_rate != 0, stop convert"; + return false; + } + } if (op_type == "fused_preln_embedding_eltwise_layernorm") { if (!with_dynamic_shape) { VLOG(3) << "fused_preln_embedding_eltwise_layernorm should run on " diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 754627a426cd2f..3bc04aa780b033 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -18,15 +18,6 @@ string(REPLACE ".py" "" TEST_TRT_CONVERTER "${TEST_TRT_CONVERTER}") if(NOT WITH_DISTRIBUTE) list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_delete_c_identity_op_pass") - list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES - "test_trt_convert_preln_residual_bias") - list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_preln_residual_bias") - list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_preln_residual_bias") - list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES - "test_trt_convert_preln_residual_no_bias") - list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_preln_residual_no_bias") - list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_preln_residual_no_bias") - list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES "test_trt_convert_c_allreduce") list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_c_allreduce") list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_c_allreduce") diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_bias.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_bias.py index 9202fa1fcc1f06..a45ddfcae189e2 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_bias.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_bias.py @@ -158,7 +158,10 @@ def clear_dynamic_shape(): self.dynamic_shape.opt_input_shape = {} def generate_trt_nodes_num(attrs, dynamic_shape): - return 1, 4 + if dynamic_shape: + return 1, 4 + else: + return 0, 5 attrs = [ program_config.ops[i].attrs for i in range(len(program_config.ops)) @@ -167,11 +170,11 @@ def generate_trt_nodes_num(attrs, dynamic_shape): clear_dynamic_shape() self.trt_param.precision = paddle_infer.PrecisionType.Float32 yield self.create_inference_config(), generate_trt_nodes_num( - attrs, True + attrs, False ), 1e-2 # atol=1e-2 while rtol is 1e-8 self.trt_param.precision = paddle_infer.PrecisionType.Half yield self.create_inference_config(), generate_trt_nodes_num( - attrs, True + attrs, False ), 1e-2 # atol=1e-2 while rtol is 1e-8 # just support dynamic_shape diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py index deffd92d0dbe35..fd3bdb64c7ede2 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_preln_residual_no_bias.py @@ -146,7 +146,10 @@ def clear_dynamic_shape(): self.dynamic_shape.opt_input_shape = {} def generate_trt_nodes_num(attrs, dynamic_shape): - return 1, 4 + if dynamic_shape: + return 1, 4 + else: + return 0, 5 attrs = [ program_config.ops[i].attrs for i in range(len(program_config.ops)) @@ -156,11 +159,11 @@ def generate_trt_nodes_num(attrs, dynamic_shape): clear_dynamic_shape() self.trt_param.precision = paddle_infer.PrecisionType.Float32 yield self.create_inference_config(), generate_trt_nodes_num( - attrs, True + attrs, False ), 1e-2 # atol=1e-2 while rtol is 1e-8 self.trt_param.precision = paddle_infer.PrecisionType.Half yield self.create_inference_config(), generate_trt_nodes_num( - attrs, True + attrs, False ), 1e-2 # atol=1e-2 while rtol is 1e-8 # just support dynamic_shape From e0634667a189d184f4c86c01689999e4ec6ceb0d Mon Sep 17 00:00:00 2001 From: wwbitejotunn Date: Mon, 30 Jan 2023 10:33:53 +0000 Subject: [PATCH 9/9] DropoutMaskOut output fix --- .../fluid/framework/ir/preln_residual_bias_fuse_pass.cc | 8 ++++++++ .../fused/fused_bias_dropout_residual_layer_norm_op.cc | 5 +++++ 2 files changed, 13 insertions(+) diff --git a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc index 88a21b0fd362b5..48baf1f4b102fc 100644 --- a/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_residual_bias_fuse_pass.cc @@ -129,6 +129,13 @@ void PrelnResidualBias::operator()(PDNode *x, PDNode *y) { } // namespace patterns +void setIntermediateOut(OpDesc *desc, + const std::string &out_name, + const std::string &scope_name) { + std::string new_name = scope_name + "/at." + out_name + ".new"; + desc->SetOutput(out_name, {new_name}); +} + void addIntermediateOut(Node *op_node, const std::string &out_name, const std::string &scope_name, @@ -243,6 +250,7 @@ int PrelnResidualBiasFusePass::ApplyPattern(ir::Graph *graph, new_desc.SetOutput("BiasDropoutResidualOut", {elementwise1_out->Name()}); new_desc.SetOutput("LnMean", {layer_norm_mean->Name()}); new_desc.SetOutput("LnVariance", {layer_norm_variance->Name()}); + setIntermediateOut(&new_desc, "DropoutMaskOut", "preln_residual_bias_fuse"); // attrs new_desc.SetAttr("ln_epsilon", layer_norm->Op()->GetAttr("epsilon")); new_desc.SetAttr("dropout_rate", 0.0f); diff --git a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc index 4a4133649234a5..7f877867050ed4 100644 --- a/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc +++ b/paddle/fluid/operators/fused/fused_bias_dropout_residual_layer_norm_op.cc @@ -35,12 +35,17 @@ class FusedBiasDropoutResidualLnOp : public framework::OperatorWithKernel { "Output", "LnVariance", "FusedBiasDropoutResidualLnOp"); + OP_INOUT_CHECK(ctx->HasOutput("DropoutMaskOut"), + "Output", + "DropoutMaskOut", + "FusedBiasDropoutResidualLnOp"); OP_INOUT_CHECK(ctx->HasOutput("BiasDropoutResidualOut"), "Output", "BiasDropoutResidualOut", "FusedBiasDropoutResidualLnOp"); OP_INOUT_CHECK( ctx->HasOutput("Y"), "Output", "Y", "FusedBiasDropoutResidualLnOp"); + auto x_dim = ctx->GetInputDim("X"); int left = 1; for (int i = 0; i < x_dim.size() - 1; i++) {