From a26c439979a2b0921eb2d01a5ddf9a3dcece4d27 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 6 Dec 2022 18:16:44 +0800 Subject: [PATCH] Clear extra input (Bias, ResidualData) in OpMaker of conv2d (#47579) * delete Bias and ResidualData in OpMaker of conv2d * delete extra input of conv3d * refactor pass of conv_bias_fusion * fix mkldnn dependency * fix mkldnn compile * fix test_conv_bias_mkldnn_fuse_pass * police some code * remove useless log * fix analyzer_vit_ocr_tester * fix conv_activation_mkldnn_fuse_pass * fix test_analyzer_ocr * add fused_conv_sig * fix performence regression * fix performance regression --- paddle/fluid/framework/ir/graph_helper.cc | 2 +- .../framework/ir/graph_pattern_detector.cc | 13 +- .../framework/ir/graph_pattern_detector.h | 2 +- .../compute_propagate_scales_mkldnn_pass.cc | 4 +- .../conv_activation_mkldnn_fuse_pass.cc | 41 +- .../ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc | 87 +++- .../ir/mkldnn/conv_bias_mkldnn_fuse_pass.h | 8 + .../conv_bias_mkldnn_fuse_pass_tester.cc | 3 +- .../framework/ir/mkldnn/cpu_quantize_pass.cc | 11 +- .../framework/ir/mkldnn/cpu_quantize_pass.h | 4 +- .../ir/mkldnn/cpu_quantize_placement_pass.cc | 29 +- .../mkldnn/params_quantization_mkldnn_pass.cc | 2 +- paddle/fluid/framework/ir/pass.h | 2 +- paddle/fluid/framework/op_desc.cc | 2 +- .../inference/api/mkldnn_quantizer_config.cc | 6 + .../inference/api/paddle_analysis_config.h | 1 + .../tests/api/analyzer_vit_ocr_tester.cc | 2 +- .../fluid/operators/compat/fused_conv2d.pbtxt | 54 +++ .../fluid/operators/compat/fused_conv3d.pbtxt | 54 +++ paddle/fluid/operators/conv_op.cc | 6 - .../fluid/operators/fused/fused_conv2d_op.cc | 98 +++++ paddle/fluid/pybind/pybind.cc | 1 - paddle/phi/kernels/CMakeLists.txt | 1 + .../fusion/onednn/fused_conv_kernel.cc | 147 +++++++ paddle/phi/kernels/onednn/conv_function.h | 408 ++++++++++++++++++ paddle/phi/kernels/onednn/conv_kernel.cc | 381 +--------------- paddle/phi/ops/compat/fused_conv_sig.cc | 56 +++ .../test_conv_bias_mkldnn_fuse_pass.py | 6 +- 28 files changed, 1013 insertions(+), 418 deletions(-) create mode 100644 paddle/fluid/operators/compat/fused_conv2d.pbtxt create mode 100644 paddle/fluid/operators/compat/fused_conv3d.pbtxt create mode 100644 paddle/fluid/operators/fused/fused_conv2d_op.cc create mode 100644 paddle/phi/kernels/fusion/onednn/fused_conv_kernel.cc create mode 100644 paddle/phi/kernels/onednn/conv_function.h create mode 100644 paddle/phi/ops/compat/fused_conv_sig.cc diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index fe2c9adf68fec..7e6ef668fb398 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -713,7 +713,7 @@ static void GetGraphOpDesc(const std::vector &nodes, UpdateControlOpSkipEagerDeletionVars(*n, graph, graph_idx, n->Name()); } ops->emplace_back(*n->Op()); - VLOG(4) << n->ToString(); + VLOG(5) << n->ToString(); } // delete no OpDesc op } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 9912cee3838db..475792c5564b3 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2068,8 +2068,9 @@ PDNode *patterns::Flatten2Matmul::operator()() { return matmul_out; } -PDNode *patterns::ConvResidual::operator()(bool with_residual_data) { - auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); +PDNode *patterns::ConvResidual::operator()(const std::string &conv_type, + bool with_residual_data) { + auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op(conv_type); if (!with_residual_data) { conv_op->assert_more([&](Node *x) { @@ -2082,22 +2083,22 @@ PDNode *patterns::ConvResidual::operator()(bool with_residual_data) { auto input_var = pattern->NewNode(conv_input_repr()) ->AsInput() - ->assert_is_op_input("conv2d", "Input"); + ->assert_is_op_input(conv_type, "Input"); auto filter_var = pattern->NewNode(conv_filter_repr()) ->AsInput() - ->assert_is_op_input("conv2d", "Filter"); + ->assert_is_op_input(conv_type, "Filter"); auto output_var = pattern->NewNode(conv_output_repr()) ->AsOutput() - ->assert_is_op_output("conv2d", "Output"); + ->assert_is_op_output(conv_type, "Output"); std::vector links_from{input_var, filter_var}; if (with_residual_data) { auto res_conn_var = pattern->NewNode(conv_residual_data_repr()) ->AsInput() - ->assert_is_op_input("conv2d", "ResidualData"); + ->assert_is_op_input(conv_type, "ResidualData"); links_from.push_back(res_conn_var); } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 8263a19756b1d..1674cac012150 100755 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1057,7 +1057,7 @@ struct ConvResidual : public PatternBase { ConvResidual(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "conv_residual") {} - PDNode* operator()(bool with_residual_data); + PDNode* operator()(const std::string& conv_type, bool with_residual_data); PATTERN_DECL_NODE(conv_op); PATTERN_DECL_NODE(conv_input); diff --git a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc index f1686e445fc1b..3fb7636f06fd0 100644 --- a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc @@ -319,7 +319,7 @@ void ComputePropagateScalesMkldnnPass::ComputeWeightScales( ir::Graph* graph, Scope* scope, StringPairMap* var_quant_scales) const { ComputeVarScales(graph, scope, - {"conv2d", "depthwise_conv2d"}, + {"conv2d", "depthwise_conv2d", "fused_conv2d"}, "Filter", 1, var_quant_scales); @@ -446,7 +446,7 @@ void ComputePropagateScalesMkldnnPass::UpdateReluOutputScales( if (op->Type() == "relu") { is_unsigned = true; } else { - if (op->Type() == "conv2d") { + if (op->Type() == "conv2d" || op->Type() == "fused_conv2d") { act_name = "fuse_activation"; output_name = "Output"; } else if (op->Type() == "fc") { diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index 3cd3cc8f7b054..a673aafadccfc 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -26,7 +26,7 @@ using string::PrettyLogDetail; void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const { auto act_types = phi::funcs::GetSupportedActivations(); - std::vector conv_types = {"conv2d"}; + std::vector conv_types = {"conv2d", "fused_conv2d"}; for (auto& act_type : act_types) { FuseConvConcatAct(graph, act_type); @@ -218,6 +218,45 @@ ConvActivationMkldnnFusePass::ConvActivationMkldnnFusePass() { .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .End(); + AddOpCompat(OpCompat("fused_conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsOptional() + .IsTensor() + .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsOptional() + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + AddOpCompat(OpCompat("concat")) .AddInput("X") .End() diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index bf88b82fc30a1..13cd875431603 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -61,6 +61,40 @@ ConvBiasFusePass::ConvBiasFusePass() { .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .End(); + AddOpCompat(OpCompat("fused_conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + AddOpCompat(OpCompat("elementwise_add")) .AddInput("X") .IsTensor() @@ -165,6 +199,40 @@ Conv3DBiasFusePass::Conv3DBiasFusePass() { .IsStringIn({"NDHWC", "NCDHW"}) .End(); + AddOpCompat(OpCompat("fused_conv3d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + AddOpCompat(OpCompat("elementwise_add")) .AddInput("X") .IsTensor() @@ -203,6 +271,16 @@ phi::DenseTensor tensor_apply_eltwise(const phi::DenseTensor& vec_a, } void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { + FuseConvBias(graph, type(), fused_type()); + if (type() != fused_type()) { + // Is the second pass useful? + FuseConvBias(graph, fused_type(), fused_type()); + } +} + +void ConvBiasFusePass::FuseConvBias(ir::Graph* graph, + const std::string& conv_type, + const std::string& fused_conv) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); FusePassBase::Init(name_scope_, graph); @@ -216,9 +294,9 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { gpd.mutable_pattern() ->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) ->AsInput() - ->assert_is_op_input(type(), "Input"); + ->assert_is_op_input(conv_type, "Input"); patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_); - conv_bias_pattern(conv_input, type()); + conv_bias_pattern(conv_input, conv_type); int found_conv_bias_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { @@ -249,7 +327,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { // check if fuse can be done and if MKL-DNN should be used FuseOptions fuse_option = FindFuseOption(*conv, *eltwise); if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) { - VLOG(3) << "do not perform " + type() + "+bias fuse"; + VLOG(3) << "do not perform " + conv_type + "+bias fuse"; return; } @@ -294,7 +372,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetInput("Filter", std::vector({conv_weight->Name()})); desc.SetInput("Bias", std::vector({eltwise_bias->Name()})); desc.SetOutput("Output", std::vector({eltwise_out->Name()})); - desc.SetType(type()); + desc.SetType(fused_conv); for (auto& attr : conv->Op()->GetAttrMap()) { desc.SetAttr(attr.first, attr.second); @@ -323,6 +401,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { type()); } } + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h index 18e09173491da..d4fb89f091c87 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h @@ -32,11 +32,17 @@ class ConvBiasFusePass : public FusePassBase { ConvBiasFusePass(); virtual ~ConvBiasFusePass() {} virtual std::string type() const { return "conv2d"; } + virtual std::string fused_type() const { return "fused_conv2d"; } protected: void ApplyImpl(ir::Graph* graph) const override; + void FuseConvBias(ir::Graph* graph, + const std::string& conv_type, + const std::string& fused_conv) const; + const std::string name_scope_{"conv_bias_mkldnn_fuse"}; }; + /* * Fuse the Conv3D and Elementwise_add to a Conv3DBiasOp. */ @@ -44,12 +50,14 @@ class Conv2DTransposeBiasFusePass : public ConvBiasFusePass { public: Conv2DTransposeBiasFusePass(); std::string type() const override { return "conv2d_transpose"; } + std::string fused_type() const override { return "conv2d_transpose"; } }; class Conv3DBiasFusePass : public ConvBiasFusePass { public: Conv3DBiasFusePass(); std::string type() const override { return "conv3d"; } + std::string fused_type() const override { return "fused_conv3d"; } }; } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc index 41aea6218b209..c5ee20b4b0162 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc @@ -139,7 +139,8 @@ void MainTest(bool convWithExistingBias) { int conv_bias_count = 0; for (auto* node : graph->Nodes()) { - if (node->IsOp() && node->Op()->Type() == "conv2d") { + if (node->IsOp() && (node->Op()->Type() == "conv2d" || + node->Op()->Type() == "fused_conv2d")) { auto* op = node->Op(); ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn"))); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index ac509aa604bd6..6995412d055c6 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -388,11 +388,12 @@ void CPUQuantizePass::GetQuantInfo(Graph* graph) const { } void CPUQuantizePass::QuantizeConv(Graph* graph, + const std::string& conv_type, bool with_residual_data) const { GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); patterns::ConvResidual conv_pattern{pattern, name_scope_}; - conv_pattern(with_residual_data); + conv_pattern(conv_type, with_residual_data); int quantize_conv_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -510,7 +511,7 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, AddStatis(quantize_conv_count); LogQuantizedOpsCounter( - "conv2d", + conv_type, quantize_conv_count, ((with_residual_data) ? "with residual connection" : "")); } @@ -1247,8 +1248,10 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { platform::errors::InvalidArgument("Scope cannot be nullptr.")); GetQuantInfo(graph); - QuantizeConv(graph, false /* with_residual_data */); - QuantizeConv(graph, true /* with_residual_data */); + QuantizeConv(graph, "conv2d", false /* with_residual_data */); + QuantizeConv(graph, "conv2d", true /* with_residual_data */); + QuantizeConv(graph, "fused_conv2d", false /* with_residual_data */); + QuantizeConv(graph, "fused_conv2d", true /* with_residual_data */); QuantizePool(graph); QuantizeConcat(graph); QuantizePriorBox(graph); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h index b3c5312197baf..59bf2ab2d4fd0 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h @@ -49,7 +49,9 @@ class CPUQuantizePass : public FusePassBase { protected: void ApplyImpl(ir::Graph* graph) const override; - void QuantizeConv(Graph* graph, bool with_residual_data) const; + void QuantizeConv(Graph* graph, + const std::string& conv_type, + bool with_residual_data) const; void QuantizeFc(Graph* graph, bool with_residual_data) const; void QuantizePool(Graph* graph) const; void QuantizeConcat(Graph* graph) const; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc index 70433772ce0c7..08b66d8f2f56e 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_placement_pass.cc @@ -25,25 +25,14 @@ class Graph; void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { VLOG(3) << "Marks operators which are to be quantized."; std::unordered_set supported_op_types = - std::unordered_set({"concat", - "conv2d", - "depthwise_conv2d", - "elementwise_add", - "elementwise_mul", - "elementwise_sub", - "fc", - "matmul", - "nearest_interp", - "nearest_interp_v2", - "pool2d", - "prior_box", - "reshape2", - "transpose2", - "fusion_gru", - "fusion_lstm", - "multi_gru", - "slice", - "split"}); + std::unordered_set( + {"concat", "conv2d", "depthwise_conv2d", + "fused_conv2d", "fused_conv3d", "elementwise_add", + "elementwise_mul", "elementwise_sub", "fc", + "matmul", "nearest_interp", "nearest_interp_v2", + "pool2d", "prior_box", "reshape2", + "transpose2", "fusion_gru", "fusion_lstm", + "multi_gru", "slice", "split"}); const auto& excluded_ids_list = Get>("quantize_excluded_op_ids"); const auto& op_types_list = @@ -71,7 +60,6 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { GET_IR_NODE_FROM_SUBGRAPH(op, op, quantize_placement_pattern); - if (std::find(excluded_ids_list.begin(), excluded_ids_list.end(), op->id()) != excluded_ids_list.end()) { @@ -81,7 +69,6 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { if (op->Op()->GetAttrIfExists("skip_quant") == 1) { return; } - op->Op()->SetAttr("mkldnn_data_type", std::string("int8")); }; gpd(graph, handler); diff --git a/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.cc index e2000708b8a84..5ad1e95cd79c0 100644 --- a/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass.cc @@ -120,7 +120,7 @@ void ParamsQuantizationMkldnnPass::QuantizeConv(ir::Graph* graph, bool with_residual_data) const { GraphPatternDetector gpd; patterns::ConvResidual conv_pattern(gpd.mutable_pattern(), name_scope_); - conv_pattern(with_residual_data); + conv_pattern("conv2d", with_residual_data); int params_to_int8_conv_found = 0; diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index e0315f0b5b741..50d7434c7d97a 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -146,7 +146,7 @@ class Pass { } attrs_[attr_name] = attr; attr_dels_[attr_name] = [attr, attr_name]() { - VLOG(3) << "deleting " << attr_name; + VLOG(8) << "deleting " << attr_name; delete attr; }; } diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index dcc47058b6414..2b84fed6846a8 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -979,7 +979,7 @@ struct SetAttrDescVisitor { }; void OpDesc::Flush() { - VLOG(4) << "Flush " + VLOG(8) << "Flush " << " " << Type() << " " << need_update_; if (need_update_) { this->desc_.mutable_inputs()->Clear(); diff --git a/paddle/fluid/inference/api/mkldnn_quantizer_config.cc b/paddle/fluid/inference/api/mkldnn_quantizer_config.cc index 0beac10903886..646e72fe2885a 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer_config.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer_config.cc @@ -26,6 +26,12 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() { rules_["conv2d"]["ResidualData"] = ScaleAlgo::KL; rules_["conv2d"]["Output"] = ScaleAlgo::KL; + rules_["fused_conv2d"]["Input"] = ScaleAlgo::KL; + rules_["fused_conv2d"]["Filter"] = ScaleAlgo::MAX_CH; + rules_["fused_conv2d"]["Bias"] = ScaleAlgo::NONE; // do not compute scale + rules_["fused_conv2d"]["ResidualData"] = ScaleAlgo::KL; + rules_["fused_conv2d"]["Output"] = ScaleAlgo::KL; + rules_["pool2d"]["X"] = ScaleAlgo::KL; rules_["pool2d"]["Out"] = ScaleAlgo::KL; diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index f8ddcbdaa8f39..5521caee9f430 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -1172,6 +1172,7 @@ struct PD_INFER_DECL AnalysisConfig { "concat", "conv2d", "depthwise_conv2d", + "fused_conv2d", "elementwise_add", "elementwise_mul", "fc", diff --git a/paddle/fluid/inference/tests/api/analyzer_vit_ocr_tester.cc b/paddle/fluid/inference/tests/api/analyzer_vit_ocr_tester.cc index 3870fde8b533a..8180d951050ce 100644 --- a/paddle/fluid/inference/tests/api/analyzer_vit_ocr_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_vit_ocr_tester.cc @@ -103,7 +103,7 @@ TEST(Analyzer_vit_ocr, fuse_status) { static_cast(predictor.get()), &num_ops); CHECK_EQ(fuse_statis.at("fc_mkldnn_pass"), 33); - CHECK_EQ(fuse_statis.at("conv2d_gelu_mkldnn_fuse_pass"), 2); + CHECK_EQ(fuse_statis.at("fused_conv2d_gelu_mkldnn_fuse_pass"), 2); CHECK_EQ(fuse_statis.at("fc_elementwise_add_mkldnn_fuse"), 16); } #endif diff --git a/paddle/fluid/operators/compat/fused_conv2d.pbtxt b/paddle/fluid/operators/compat/fused_conv2d.pbtxt new file mode 100644 index 0000000000000..c6bdc08f4649b --- /dev/null +++ b/paddle/fluid/operators/compat/fused_conv2d.pbtxt @@ -0,0 +1,54 @@ +type: "fused_conv2d" +def { + inputs { + name: "Input" + } + inputs { + name: "Filter" + } + inputs { + name: "Bias" + } + inputs { + name: "ResidualData" + } + outputs { + name: "Output" + } + attrs { + name: "strides" + type: INTS + } + attrs { + name: "paddings" + type: INTS + } + attrs { + name: "padding_algorithm" + type: STRING + } + attrs { + name: "groups" + type: INT + } + attrs { + name: "dilations" + type: INTS + } + attrs { + name: "data_format" + type: STRING + } + attrs { + name: "fuse_activation" + type: STRING + } + attrs { + name: "fuse_residual_connection" + type: BOOLEAN + } + attrs { + name: "force_fp32_output" + type: BOOLEAN + } +} diff --git a/paddle/fluid/operators/compat/fused_conv3d.pbtxt b/paddle/fluid/operators/compat/fused_conv3d.pbtxt new file mode 100644 index 0000000000000..038cabf5140de --- /dev/null +++ b/paddle/fluid/operators/compat/fused_conv3d.pbtxt @@ -0,0 +1,54 @@ +type: "fused_conv3d" +def { + inputs { + name: "Input" + } + inputs { + name: "Filter" + } + inputs { + name: "Bias" + } + inputs { + name: "ResidualData" + } + outputs { + name: "Output" + } + attrs { + name: "strides" + type: INTS + } + attrs { + name: "paddings" + type: INTS + } + attrs { + name: "padding_algorithm" + type: STRING + } + attrs { + name: "groups" + type: INT + } + attrs { + name: "dilations" + type: INTS + } + attrs { + name: "data_format" + type: STRING + } + attrs { + name: "fuse_activation" + type: STRING + } + attrs { + name: "fuse_residual_connection" + type: BOOLEAN + } + attrs { + name: "force_fp32_output" + type: BOOLEAN + } +} diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 50b90e56c03e0..107e3b5a3de49 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -364,12 +364,6 @@ void Conv3DOpMaker::Make() { "is the width of the filter." "If the groups attribute is greater than 1, C equals the number of " "input image channels divided by the groups."); - AddInput("ResidualData", - "(Tensor) Tensor with residual data " - "to which convolution output will be added." - "Used with fuse_residual_connection fusion.") - .AsDispensable() - .AsExtra(); AddOutput("Output", "(Tensor) The output tensor of convolution operator." "It has same data fromat and data type as the Input."); diff --git a/paddle/fluid/operators/fused/fused_conv2d_op.cc b/paddle/fluid/operators/fused/fused_conv2d_op.cc new file mode 100644 index 0000000000000..178c2a963e28f --- /dev/null +++ b/paddle/fluid/operators/fused/fused_conv2d_op.cc @@ -0,0 +1,98 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/operators/conv_op.h" + +namespace paddle { +namespace operators { + +class FusedConvOpMaker : public Conv2DOpMaker { + protected: + void Apply() override { + AddAttr( + "mkldnn_data_type", + "(string, default \"float32\"). Data type of mkldnn kernel") + .SetDefault("float32") + .InEnum({"float32", "int8", "bfloat16"}); + AddAttr("fuse_activation", + "(string, default \"\") Only used in mkldnn kernel") + .SetDefault(""); + AddAttr("fuse_residual_connection", + "(bool, default false) Only used in mkldnn kernel. Used " + "whenever convolution output is as an input to residual " + "connection.") + .SetDefault(false); + AddAttr("force_fp32_output", + "(bool, default false) Force INT8 kernel output FP32, only " + "used in MKL-DNN INT8") + .SetDefault(false); + AddAttr("use_mkldnn", "(bool, default false) Used in mkldnn kernel") + .SetDefault(true); + AddComment(R"DOC( +Convolution Operator. + +The convolution operation calculates the output based on the input, filter +and strides, paddings, dilations, groups parameters. The size of each dimension of the +parameters is checked in the infer-shape. +Input(Input) and Output(Output) are in NCHW or NHWC format. Where N is batch +size, C is the number of channels, H is the height of the feature, and W is +the width of the feature. +Filters(Input) is MCHW format format. Where M is the number of output image channels, C is +the number of input image channels, H is the height of the filter, and W +is the width of the filter. +Parameters(strides, paddings, dilations) are two elements. These two elements represent +height and width, respectively. +The input(X) size and output(Out) size may be different. + +Example: + Input: + Input shape: $(N, C_{in}, H_{in}, W_{in})$ + Filter shape: $(C_{out}, C_{in}, H_f, W_f)$ + Output: + ? + Output shape: $(N, C_{out}, H_{out}, W_{out})$ + Where +$$ + H_{out}= \frac{(H_{in} + pad_height_top + pad_height_bottom - (dilations[0] * (H_f - 1) + 1))}{strides[0]}+ 1 \\ + W_{out}= \frac{(W_{in} + pad_width_left + pad_width_right - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1 +$$ +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +// fused_conv2d is only used for onednn inference. +REGISTER_OPERATOR( + fused_conv2d, + ops::ConvOp, + ops::FusedConvOpMaker, + ops::ConvOpInferVarType, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +// fused_conv3d is only used for onednn inference. +REGISTER_OPERATOR( + fused_conv3d, + ops::ConvOp, + ops::FusedConvOpMaker, + ops::ConvOpInferVarType, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index d85824a9237b2..be8355d023d25 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1197,7 +1197,6 @@ All parameter, weight, gradient are variables in Paddle. -> const paddle::framework::AttributeMap & { return operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(op_type); }); - m.def( "get_attrtibute_type", [](const std::string &op_type, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index ef2231c059ad9..808b18bb02d45 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -118,6 +118,7 @@ if(WITH_MKLDNN) "strings/cpu/*.cc" "onednn/*.cc" "fusion/*.cc" + "fusion/onednn/*.cc" "fusion/cpu/*.cc") else() file( diff --git a/paddle/phi/kernels/fusion/onednn/fused_conv_kernel.cc b/paddle/phi/kernels/fusion/onednn/fused_conv_kernel.cc new file mode 100644 index 0000000000000..a49bf03eee6d4 --- /dev/null +++ b/paddle/phi/kernels/fusion/onednn/fused_conv_kernel.cc @@ -0,0 +1,147 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/onednn/conv_function.h" + +namespace phi { + +template +void FusedConv2DKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& filter, + const paddle::optional& bias, + const paddle::optional& residual_param, + const std::vector& strides, + const std::vector& paddings, + const std::string& padding_algorithm, + const std::vector& dilations, + int groups, + const std::string& data_format, + const std::string& mkldnn_data_type, + const std::string& fuse_activation, + bool fuse_residual_conn, + bool force_fp32_output, + DenseTensor* out) { + bool is_BFLOAT16 = mkldnn_data_type == "bfloat16"; + + ConvOnednn(dev_ctx, + &input, + &filter, + bias.get_ptr(), + residual_param.get_ptr(), + strides, + paddings, + padding_algorithm, + dilations, + groups, + data_format, + true, + is_BFLOAT16, + fuse_activation, + fuse_residual_conn, + force_fp32_output, + out); +} + +template +void FusedDepthwiseConv2DKernel( + const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& filter, + const paddle::optional& bias, + const paddle::optional& residual_param, + const std::vector& strides, + const std::vector& paddings, + const std::string& padding_algorithm, + const std::vector& dilations, + int groups, + const std::string& data_format, + const std::string& mkldnn_data_type, + const std::string& fuse_activation, + bool fuse_residual_conn, + bool force_fp32_output, + DenseTensor* out) { + bool is_BFLOAT16 = mkldnn_data_type == "bfloat16"; + + ConvOnednn(dev_ctx, + &input, + &filter, + bias.get_ptr(), + residual_param.get_ptr(), + strides, + paddings, + padding_algorithm, + dilations, + groups, + data_format, + true, + is_BFLOAT16, + fuse_activation, + fuse_residual_conn, + force_fp32_output, + out); +} + +template +void FusedConv3DKernel(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& filter, + const paddle::optional& bias, + const paddle::optional& residual_param, + const std::vector& strides, + const std::vector& paddings, + const std::string& padding_algorithm, + const std::vector& dilations, + int groups, + const std::string& data_format, + const std::string& mkldnn_data_type, + const std::string& fuse_activation, + bool fuse_residual_conn, + bool force_fp32_output, + DenseTensor* out) { + bool is_BFLOAT16 = mkldnn_data_type == "bfloat16"; + + ConvOnednn(dev_ctx, + &input, + &filter, + bias.get_ptr(), + residual_param.get_ptr(), + strides, + paddings, + padding_algorithm, + dilations, + groups, + data_format, + true, + is_BFLOAT16, + fuse_activation, + fuse_residual_conn, + force_fp32_output, + out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(fused_conv2d, + OneDNN, + ONEDNN, + phi::FusedConv2DKernel, + float, + phi::dtype::bfloat16, + uint8_t, + int8_t) {} + +PD_REGISTER_KERNEL( + fused_conv3d, OneDNN, ONEDNN, phi::FusedConv3DKernel, float) {} diff --git a/paddle/phi/kernels/onednn/conv_function.h b/paddle/phi/kernels/onednn/conv_function.h new file mode 100644 index 0000000000000..4b3c4d58895cc --- /dev/null +++ b/paddle/phi/kernels/onednn/conv_function.h @@ -0,0 +1,408 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/funcs/data_layout_transform.h" +#include "paddle/phi/kernels/onednn/conv_handler.h" + +namespace phi { + +static dnnl::memory::data_type GetDstType( + bool is_int8, + bool is_bfloat16, + bool force_fp32_output, + std::string fuse_activation, + bool fuse_residual_conn, + const phi::DenseTensor* residual_param) { + auto dst_dt = dnnl::memory::data_type::f32; + if (is_int8) { + dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6") + ? dnnl::memory::data_type::u8 + : dnnl::memory::data_type::s8; + if (force_fp32_output) { + dst_dt = dnnl::memory::data_type::f32; + } + if (fuse_residual_conn && residual_param) { + auto residual_dt = funcs::ToOneDNNDataType(residual_param->dtype()); + if (dst_dt != residual_dt) dst_dt = residual_dt; + } + } else { + if (!force_fp32_output && is_bfloat16) { + dst_dt = dnnl::memory::data_type::bf16; + if (fuse_residual_conn && residual_param) { + dst_dt = funcs::ToOneDNNDataType(residual_param->dtype()); + } + } + } + return dst_dt; +} + +#define PD_VISIT_FLOAT_AND_INT8_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + +template +void ComputeFP32(const OneDNNContext& dev_ctx, + const DenseTensor* input, + const DenseTensor* filter, + const DenseTensor* bias, + const DenseTensor* residual_param, + const std::vector& strides, + const std::vector& paddings, + const std::string& padding_algorithm, + const std::vector& dilations, + int groups, + const std::string& data_format, + bool is_test, + bool is_BFLOAT16, + const std::string& fuse_activation, + bool fuse_residual_conn, + bool force_fp32_output, + DenseTensor* output) { + const auto& onednn_engine = dev_ctx.GetEngine(); + const bool is_conv3d = strides.size() == 3U; + const std::string& unique_name = + dev_ctx.GetInputsName("Input")[0] + dev_ctx.GetInputsName("Filter")[0]; + PD_VISIT_FLOAT_AND_INT8_TYPES( + filter->dtype(), "ConvOneDNNHandlerT", ([&] { + onednn::ConvOneDNNHandlerT handler(dev_ctx, + onednn_engine, + dev_ctx.GetPlace(), + input, + filter, + bias, + strides, + paddings, + padding_algorithm, + dilations, + groups, + data_format, + is_test, + is_BFLOAT16, + fuse_activation, + fuse_residual_conn, + force_fp32_output, + output, + unique_name); + auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); + auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( + filter, groups, is_conv3d, is_test); + std::shared_ptr dst_memory_p; + if (fuse_residual_conn) { + dst_memory_p = + handler.AcquireDstMemoryWithResidual(output, residual_param); + } else { + dst_memory_p = handler.template AcquireDstMemory(output); + } + + auto conv_p = handler.AcquireForwardPrimitive(); + std::unordered_map args = { + {DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + if (bias) { + auto bias_memory_p = + handler.AcquireBiasMemoryWithReorder(bias, is_test); + args.insert({DNNL_ARG_BIAS, *bias_memory_p}); + } + + auto& astream = OneDNNContext::tls().get_stream(); + conv_p->execute(astream, args); + astream.wait(); + output->set_mem_desc(dst_memory_p->get_desc()); + })); +} + +template +void ComputeINT8(const OneDNNContext& dev_ctx, + const DenseTensor* input, + const DenseTensor* filter, + const DenseTensor* bias, + const DenseTensor* residual_param, + const std::vector& strides, + const std::vector& paddings, + const std::string& padding_algorithm, + const std::vector& dilations, + int groups, + const std::string& data_format, + bool is_test, + bool is_BFLOAT16, + const std::string& fuse_activation, + bool fuse_residual_conn, + bool force_fp32_output, + DenseTensor* output) { + const auto& onednn_engine = dev_ctx.GetEngine(); + const bool is_conv3d = strides.size() == 3U; + + bool unsigned_output = + (fuse_activation == "relu" || fuse_activation == "relu6"); + bool need_s8_to_u8 = false; + + PADDLE_ENFORCE_NE( + is_conv3d, + true, + phi::errors::Unimplemented( + "OneDNN int8 convolution does not support 3D inputs currently")); + PADDLE_ENFORCE_EQ( + fuse_residual_conn && force_fp32_output, + false, + phi::errors::Unimplemented( + "residual fusion does not support force output with fp32")); + const std::string& unique_name = + dev_ctx.GetInputsName("Input")[0] + dev_ctx.GetInputsName("Filter")[0]; + PD_VISIT_FLOAT_AND_INT8_TYPES( + filter->dtype(), "ConvMKLDNNHandlerT", ([&] { + onednn::ConvOneDNNHandlerT handler(dev_ctx, + onednn_engine, + dev_ctx.GetPlace(), + input, + filter, + bias, + strides, + paddings, + padding_algorithm, + dilations, + groups, + data_format, + is_test, + is_BFLOAT16, + fuse_activation, + fuse_residual_conn, + force_fp32_output, + output, + unique_name); + + auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); + + const auto& scale_weights_data = + dev_ctx.HasDnnAttr("Scale_weights") + ? PADDLE_GET_CONST(std::vector, + dev_ctx.GetDnnAttr("Scale_weights")) + : std::vector{1.0f}; + const bool is_multi_channel = scale_weights_data.size() > 1; + int mask_reorder = is_multi_channel + ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0) + : 0; + auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( + filter, groups, false, true, scale_weights_data, mask_reorder); + + std::shared_ptr dst_memory_p; + if (fuse_residual_conn) { + PADDLE_ENFORCE_EQ( + output->dims(), + residual_param->dims(), + phi::errors::InvalidArgument( + "Output and elementwise parameter need to have the " + "same dimension sizes, but got output's dimension = %d" + " and residual param's dimension =%d .", + output->dims().size(), + residual_param->dims().size())); + dst_memory_p = + handler.AcquireDstMemoryWithResidual(output, residual_param); + need_s8_to_u8 = (funcs::OneDNNGetDataType() == + dnnl::memory::data_type::s8) && + unsigned_output; + } else { + dst_memory_p = handler.template AcquireDstMemory(output); + } + + auto conv_p = handler.AcquireForwardPrimitive(); + + std::unordered_map args = { + {DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_WEIGHTS, *weights_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}; + + if (bias) { + std::vector bias_scales; + auto p_scales_tuple = + std::make_shared>>( + std::make_tuple(static_cast(mask_reorder), + bias_scales)); + if (dev_ctx.HasDnnAttr("Bias_scales")) { + bias_scales = PADDLE_GET_CONST(std::vector, + dev_ctx.GetDnnAttr("Bias_scales")); + p_scales_tuple = + std::make_shared>>( + std::make_tuple(static_cast(mask_reorder), + bias_scales)); + } else { + p_scales_tuple = handler.get_int8_bias_scales( + filter, groups, scale_weights_data); + } + auto bias_memory_p = handler.AcquireBiasMemoryWithReorder( + bias, + true, + std::get<1>(*p_scales_tuple), + std::get<0>(*p_scales_tuple)); + args.insert({DNNL_ARG_BIAS, *bias_memory_p}); + } + + auto& astream = OneDNNContext::tls().get_stream(); + conv_p->execute(astream, args); + astream.wait(); + + if (need_s8_to_u8) { + dev_ctx.Alloc(output); + } + + output->set_mem_desc(dst_memory_p->get_desc()); + })); +} + +template +void ConvOnednn(const Context& dev_ctx, + const DenseTensor* input, + const DenseTensor* filter, + const DenseTensor* bias, + const DenseTensor* residual_param, + const std::vector& strides, + const std::vector& paddings, + const std::string& padding_algorithm, + const std::vector& dilations, + int groups, + const std::string& data_format, + bool is_test, + bool is_bfloat16, + const std::string& fuse_activation, + bool fuse_residual_connection, + bool force_fp32_output, + DenseTensor* out) { + PADDLE_ENFORCE_EQ( + dev_ctx.GetPlace().GetType(), + AllocationType::CPU, + phi::errors::PreconditionNotMet("Operator DNNL Conv must use CPUPlace")); + + bool is_INT8 = funcs::is_int8(); + + auto dst_dt = GetDstType(is_INT8, + is_bfloat16, + force_fp32_output, + fuse_activation, + fuse_residual_connection, + residual_param); + if (!is_INT8) { + if (dst_dt == dnnl::memory::data_type::f32) { + ComputeFP32(dev_ctx, + input, + filter, + bias, + residual_param, + strides, + paddings, + padding_algorithm, + dilations, + groups, + data_format, + is_test, + is_bfloat16, + fuse_activation, + fuse_residual_connection, + force_fp32_output, + out); + } else if (dst_dt == dnnl::memory::data_type::bf16) { + ComputeFP32(dev_ctx, + input, + filter, + bias, + residual_param, + strides, + paddings, + padding_algorithm, + dilations, + groups, + data_format, + is_test, + is_bfloat16, + fuse_activation, + fuse_residual_connection, + force_fp32_output, + out); + } + } else { + if (dst_dt == dnnl::memory::data_type::f32) { + ComputeINT8(dev_ctx, + input, + filter, + bias, + residual_param, + strides, + paddings, + padding_algorithm, + dilations, + groups, + data_format, + is_test, + is_bfloat16, + fuse_activation, + fuse_residual_connection, + force_fp32_output, + out); + } else if (dst_dt == dnnl::memory::data_type::u8) { + ComputeINT8(dev_ctx, + input, + filter, + bias, + residual_param, + strides, + paddings, + padding_algorithm, + dilations, + groups, + data_format, + is_test, + is_bfloat16, + fuse_activation, + fuse_residual_connection, + force_fp32_output, + out); + } else if (dst_dt == dnnl::memory::data_type::s8) { + ComputeINT8(dev_ctx, + input, + filter, + bias, + residual_param, + strides, + paddings, + padding_algorithm, + dilations, + groups, + data_format, + is_test, + is_bfloat16, + fuse_activation, + fuse_residual_connection, + force_fp32_output, + out); + } + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/onednn/conv_kernel.cc b/paddle/phi/kernels/onednn/conv_kernel.cc index e2faaea6b023a..1e54ba0337e22 100644 --- a/paddle/phi/kernels/onednn/conv_kernel.cc +++ b/paddle/phi/kernels/onednn/conv_kernel.cc @@ -17,265 +17,10 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/data_layout_transform.h" -#include "paddle/phi/kernels/onednn/conv_handler.h" +#include "paddle/phi/kernels/onednn/conv_function.h" namespace phi { -static dnnl::memory::data_type GetDstType( - bool is_int8, - bool is_bfloat16, - bool force_fp32_output, - std::string fuse_activation, - bool fuse_residual_conn, - const phi::DenseTensor* residual_param) { - auto dst_dt = dnnl::memory::data_type::f32; - if (is_int8) { - dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6") - ? dnnl::memory::data_type::u8 - : dnnl::memory::data_type::s8; - if (force_fp32_output) { - dst_dt = dnnl::memory::data_type::f32; - } - if (fuse_residual_conn && residual_param) { - auto residual_dt = funcs::ToOneDNNDataType(residual_param->dtype()); - if (dst_dt != residual_dt) dst_dt = residual_dt; - } - } else { - if (!force_fp32_output && is_bfloat16) { - dst_dt = dnnl::memory::data_type::bf16; - if (fuse_residual_conn && residual_param) { - dst_dt = funcs::ToOneDNNDataType(residual_param->dtype()); - } - } - } - return dst_dt; -} - -#define PD_VISIT_FLOAT_AND_INT8_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \ - default: \ - PD_THROW("function " #NAME " is not implemented for data type `", \ - __dtype__, \ - "`"); \ - } \ - }() - -template -void ComputeFP32(const OneDNNContext& dev_ctx, - const DenseTensor* input, - const DenseTensor* filter, - const DenseTensor* bias, - const DenseTensor* residual_param, - const std::vector& strides, - const std::vector& paddings, - const std::string& padding_algorithm, - const std::vector& dilations, - int groups, - const std::string& data_format, - bool is_test, - bool is_BFLOAT16, - const std::string& fuse_activation, - bool fuse_residual_conn, - bool force_fp32_output, - DenseTensor* output) { - const auto& onednn_engine = dev_ctx.GetEngine(); - const bool is_conv3d = strides.size() == 3U; - const std::string& unique_name = - dev_ctx.GetInputsName("Input")[0] + dev_ctx.GetInputsName("Filter")[0]; - PD_VISIT_FLOAT_AND_INT8_TYPES( - filter->dtype(), "ConvOneDNNHandlerT", ([&] { - onednn::ConvOneDNNHandlerT handler(dev_ctx, - onednn_engine, - dev_ctx.GetPlace(), - input, - filter, - bias, - strides, - paddings, - padding_algorithm, - dilations, - groups, - data_format, - is_test, - is_BFLOAT16, - fuse_activation, - fuse_residual_conn, - force_fp32_output, - output, - unique_name); - auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); - auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( - filter, groups, is_conv3d, is_test); - std::shared_ptr dst_memory_p; - if (fuse_residual_conn) { - dst_memory_p = - handler.AcquireDstMemoryWithResidual(output, residual_param); - } else { - dst_memory_p = handler.template AcquireDstMemory(output); - } - - auto conv_p = handler.AcquireForwardPrimitive(); - std::unordered_map args = { - {DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_WEIGHTS, *weights_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}}; - - if (bias) { - auto bias_memory_p = - handler.AcquireBiasMemoryWithReorder(bias, is_test); - args.insert({DNNL_ARG_BIAS, *bias_memory_p}); - } - - auto& astream = OneDNNContext::tls().get_stream(); - conv_p->execute(astream, args); - astream.wait(); - output->set_mem_desc(dst_memory_p->get_desc()); - })); -} - -template -void ComputeINT8(const OneDNNContext& dev_ctx, - const DenseTensor* input, - const DenseTensor* filter, - const DenseTensor* bias, - const DenseTensor* residual_param, - const std::vector& strides, - const std::vector& paddings, - const std::string& padding_algorithm, - const std::vector& dilations, - int groups, - const std::string& data_format, - bool is_test, - bool is_BFLOAT16, - const std::string& fuse_activation, - bool fuse_residual_conn, - bool force_fp32_output, - DenseTensor* output) { - const auto& onednn_engine = dev_ctx.GetEngine(); - const bool is_conv3d = strides.size() == 3U; - - bool unsigned_output = - (fuse_activation == "relu" || fuse_activation == "relu6"); - bool need_s8_to_u8 = false; - - PADDLE_ENFORCE_NE( - is_conv3d, - true, - phi::errors::Unimplemented( - "OneDNN int8 convolution does not support 3D inputs currently")); - PADDLE_ENFORCE_EQ( - fuse_residual_conn && force_fp32_output, - false, - phi::errors::Unimplemented( - "residual fusion does not support force output with fp32")); - const std::string& unique_name = - dev_ctx.GetInputsName("Input")[0] + dev_ctx.GetInputsName("Filter")[0]; - PD_VISIT_FLOAT_AND_INT8_TYPES( - filter->dtype(), "ConvMKLDNNHandlerT", ([&] { - onednn::ConvOneDNNHandlerT handler(dev_ctx, - onednn_engine, - dev_ctx.GetPlace(), - input, - filter, - bias, - strides, - paddings, - padding_algorithm, - dilations, - groups, - data_format, - is_test, - is_BFLOAT16, - fuse_activation, - fuse_residual_conn, - force_fp32_output, - output, - unique_name); - - auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); - - const auto& scale_weights_data = - dev_ctx.HasDnnAttr("Scale_weights") - ? PADDLE_GET_CONST(std::vector, - dev_ctx.GetDnnAttr("Scale_weights")) - : std::vector{1.0f}; - const bool is_multi_channel = scale_weights_data.size() > 1; - int mask_reorder = is_multi_channel - ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0) - : 0; - auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( - filter, groups, false, true, scale_weights_data, mask_reorder); - - std::shared_ptr dst_memory_p; - if (fuse_residual_conn) { - PADDLE_ENFORCE_EQ( - output->dims(), - residual_param->dims(), - phi::errors::InvalidArgument( - "Output and elementwise parameter need to have the " - "same dimension sizes, but got output's dimension = %d" - " and residual param's dimension =%d .", - output->dims().size(), - residual_param->dims().size())); - dst_memory_p = - handler.AcquireDstMemoryWithResidual(output, residual_param); - need_s8_to_u8 = (funcs::OneDNNGetDataType() == - dnnl::memory::data_type::s8) && - unsigned_output; - } else { - dst_memory_p = handler.template AcquireDstMemory(output); - } - - auto conv_p = handler.AcquireForwardPrimitive(); - - std::unordered_map args = { - {DNNL_ARG_SRC, *src_memory_p}, - {DNNL_ARG_WEIGHTS, *weights_memory_p}, - {DNNL_ARG_DST, *dst_memory_p}}; - - if (bias) { - std::vector bias_scales; - auto p_scales_tuple = - std::make_shared>>( - std::make_tuple(static_cast(mask_reorder), - bias_scales)); - if (dev_ctx.HasDnnAttr("Bias_scales")) { - bias_scales = PADDLE_GET_CONST(std::vector, - dev_ctx.GetDnnAttr("Bias_scales")); - p_scales_tuple = - std::make_shared>>( - std::make_tuple(static_cast(mask_reorder), - bias_scales)); - } else { - p_scales_tuple = handler.get_int8_bias_scales( - filter, groups, scale_weights_data); - } - auto bias_memory_p = handler.AcquireBiasMemoryWithReorder( - bias, - true, - std::get<1>(*p_scales_tuple), - std::get<0>(*p_scales_tuple)); - args.insert({DNNL_ARG_BIAS, *bias_memory_p}); - } - - auto& astream = OneDNNContext::tls().get_stream(); - conv_p->execute(astream, args); - astream.wait(); - - if (need_s8_to_u8) { - dev_ctx.Alloc(output); - } - - output->set_mem_desc(dst_memory_p->get_desc()); - })); -} - template void ConvKernel(const Context& dev_ctx, const DenseTensor& input, @@ -287,12 +32,6 @@ void ConvKernel(const Context& dev_ctx, int groups, const std::string& data_format, DenseTensor* out) { - PADDLE_ENFORCE_EQ( - dev_ctx.GetPlace().GetType(), - AllocationType::CPU, - phi::errors::PreconditionNotMet("Operator DNNL Conv must use CPUPlace")); - bool is_INT8 = funcs::is_int8(); - bool is_test = dev_ctx.HasDnnAttr("is_test") ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("is_test")) : false; @@ -320,107 +59,23 @@ void ConvKernel(const Context& dev_ctx, dev_ctx.HasDnnAttr("force_fp32_output") ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output")) : false; - auto dst_dt = GetDstType(is_INT8, - is_BFLOAT16, - force_fp32_output, - fuse_activation, - fuse_residual_conn, - residual_param); - if (!is_INT8) { - if (dst_dt == dnnl::memory::data_type::f32) { - ComputeFP32(dev_ctx, - &input, - &filter, - bias, - residual_param, - strides, - paddings, - padding_algorithm, - dilations, - groups, - data_format, - is_test, - is_BFLOAT16, - fuse_activation, - fuse_residual_conn, - force_fp32_output, - out); - } else if (dst_dt == dnnl::memory::data_type::bf16) { - ComputeFP32(dev_ctx, - &input, - &filter, - bias, - residual_param, - strides, - paddings, - padding_algorithm, - dilations, - groups, - data_format, - is_test, - is_BFLOAT16, - fuse_activation, - fuse_residual_conn, - force_fp32_output, - out); - } - } else { - if (dst_dt == dnnl::memory::data_type::f32) { - ComputeINT8(dev_ctx, - &input, - &filter, - bias, - residual_param, - strides, - paddings, - padding_algorithm, - dilations, - groups, - data_format, - is_test, - is_BFLOAT16, - fuse_activation, - fuse_residual_conn, - force_fp32_output, - out); - } else if (dst_dt == dnnl::memory::data_type::u8) { - ComputeINT8(dev_ctx, - &input, - &filter, - bias, - residual_param, - strides, - paddings, - padding_algorithm, - dilations, - groups, - data_format, - is_test, - is_BFLOAT16, - fuse_activation, - fuse_residual_conn, - force_fp32_output, - out); - } else if (dst_dt == dnnl::memory::data_type::s8) { - ComputeINT8(dev_ctx, - &input, - &filter, - bias, - residual_param, - strides, - paddings, - padding_algorithm, - dilations, - groups, - data_format, - is_test, - is_BFLOAT16, - fuse_activation, - fuse_residual_conn, - force_fp32_output, - out); - } - } + ConvOnednn(dev_ctx, + &input, + &filter, + bias, + residual_param, + strides, + paddings, + padding_algorithm, + dilations, + groups, + data_format, + is_test, + is_BFLOAT16, + fuse_activation, + fuse_residual_conn, + force_fp32_output, + out); } template diff --git a/paddle/phi/ops/compat/fused_conv_sig.cc b/paddle/phi/ops/compat/fused_conv_sig.cc new file mode 100644 index 0000000000000..0e0f4325232dc --- /dev/null +++ b/paddle/phi/ops/compat/fused_conv_sig.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature FusedConv2dOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("fused_conv2d", + {"Input", "Filter", "Bias", "ResidualData"}, + {"strides", + "paddings", + "padding_algorithm", + "dilations", + "groups", + "data_format", + "mkldnn_data_type", + "fuse_activation", + "fuse_residual_connection", + "force_fp32_output"}, + {"Output"}); +} + +KernelSignature FusedConv3dOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("fused_conv3d", + {"Input", "Filter", "Bias", "ResidualData"}, + {"strides", + "paddings", + "padding_algorithm", + "dilations", + "groups", + "data_format", + "mkldnn_data_type", + "fuse_activation", + "fuse_residual_connection", + "force_fp32_output"}, + {"Output"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(fused_conv2d, phi::FusedConv2dOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(fused_conv3d, phi::FusedConv3dOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py index 6e16970dcd2fe..7efea770bfa2a 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bias_mkldnn_fuse_pass.py @@ -36,7 +36,7 @@ def sample_predictor_configs(self, program_config): # MKLDNN config = self.create_inference_config(use_gpu=False) config.enable_mkldnn() - yield config, ["conv2d"], (1e-4, 1e-5) + yield config, ["fused_conv2d"], (1e-4, 1e-5) def is_program_valid(self, prog_config): paddings = prog_config.ops[0].attrs["paddings"] @@ -156,8 +156,10 @@ def sample_program_config(self, draw): inputs = dict() weights = dict() use_mkldnn = None + conv_type = "conv2d" if draw(st.booleans()): conv_bias_shape = [f_shape[0]] + conv_type = "fused_conv2d" inputs = { "Input": ["input_x"], "Filter": ["filter"], @@ -181,7 +183,7 @@ def sample_program_config(self, draw): use_mkldnn = False conv2d_op = OpConfig( - "conv2d", + conv_type, inputs=inputs, outputs={"Output": ["conv2d_out"]}, strides=strides,