From 194879749026eff82dfdc07fcd799decafb965b4 Mon Sep 17 00:00:00 2001 From: baoachun <962571062@qq.com> Date: Tue, 28 Dec 2021 06:00:32 +0000 Subject: [PATCH] update config --- .../mkldnn/quant_dequant_mkldnn_fuse_pass.cc | 207 ++++++++++++++---- .../mkldnn/quant_dequant_mkldnn_fuse_pass.h | 8 +- .../inference/api/paddle_pass_builder.cc | 35 +++ 3 files changed, 199 insertions(+), 51 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_fuse_pass.cc index c6d19a4224a5f..d8b7da1504ecf 100644 --- a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_fuse_pass.cc @@ -16,7 +16,6 @@ #include #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace framework { @@ -55,7 +54,7 @@ void QuantDequantMkldnnFusePass::MarkSkipQuantizedOps( void QuantDequantMkldnnFusePass::GatherInfoFromFake( ir::Graph* graph, Scope* scope, std::unordered_set fake_dequantize_types, - std::unordered_map> weight_thresholds) + std::unordered_map>& weight_thresholds) const { for (auto* op_node : ir::TopologyVarientSort(*graph, static_cast(0))) { @@ -65,15 +64,19 @@ void QuantDequantMkldnnFusePass::GatherInfoFromFake( auto* op_desc = op_node->Op(); auto x_var_name = op_desc->Input("X")[0]; if (op_desc->HasAttr("max_range")) { - float max_range = BOOST_GET_CONST(float, op_desc->GetAttr("max_range")); + const float max_range = + BOOST_GET_CONST(float, op_desc->GetAttr("max_range")); weight_thresholds[x_var_name].push_back(127 * 127 / max_range); } else { auto scale_name = op_desc->Input("Scales")[0]; - // scope->FindVar(scale_name)判空? - const LoDTensor& scale_tensor = - scope->FindVar(scale_name)->Get(); - const float* scale_data = scale_tensor.data(); - for (int i = 0; i < scale_tensor.numel(); i++) { + auto* var = scope->FindVar(scale_name); + PADDLE_ENFORCE_NOT_NULL( + var, "The Scales variable of dequantize op is not found."); + + auto* scale_tensor = var->GetMutable(); + auto scale_data = + scale_tensor->mutable_data(platform::CPUPlace()); + for (int i = 0; i < scale_tensor->numel(); i++) { weight_thresholds[x_var_name].push_back(scale_data[i]); } } @@ -84,7 +87,7 @@ void QuantDequantMkldnnFusePass::GatherInfoFromFake( void QuantDequantMkldnnFusePass::GatherInputScalesFromFake( ir::Graph* graph, Scope* scope, std::unordered_set fake_quantize_types, - std::unordered_map>> + std::unordered_map>>& var_quant_scales) const { for (auto* op_node : ir::TopologyVarientSort(*graph, static_cast(0))) { @@ -93,7 +96,8 @@ void QuantDequantMkldnnFusePass::GatherInputScalesFromFake( if (op_node->Name() == "fake_quantize_dequantize_moving_average_abs_max" || fake_quantize_types.count(op_node->Name())) { auto* op_desc = op_node->Op(); - int bit_length = BOOST_GET_CONST(int, op_desc->GetAttr("bit_length")); + const int bit_length = + BOOST_GET_CONST(int, op_desc->GetAttr("bit_length")); PADDLE_ENFORCE_EQ(bit_length, 8, platform::errors::InvalidArgument( "Unsupported number quantization " "bits: %d, only 8 is supported now.", @@ -102,7 +106,11 @@ void QuantDequantMkldnnFusePass::GatherInputScalesFromFake( auto x_var_name = op_desc->Input("X")[0]; auto scale_name = op_desc->Input("InScale")[0]; auto out_var_name = op_desc->Output("Out")[0]; - auto* scale_tensor = scope->FindVar(scale_name)->GetMutable(); + auto* var = scope->FindVar(scale_name); + PADDLE_ENFORCE_NOT_NULL( + var, "The InScale variable of quantize op is not found."); + + auto* scale_tensor = var->GetMutable(); auto scale_data = scale_tensor->mutable_data(platform::CPUPlace()); float scale = 1.0 / scale_data[0]; @@ -123,7 +131,7 @@ void QuantDequantMkldnnFusePass::GatherInputScalesFromFake( void QuantDequantMkldnnFusePass::GatherOutputScalesFromAttr( ir::Graph* graph, - std::unordered_map>> + std::unordered_map>>& var_quant_scales) const { for (auto* op_node : ir::TopologyVarientSort(*graph, static_cast(0))) { @@ -131,7 +139,7 @@ void QuantDequantMkldnnFusePass::GatherOutputScalesFromAttr( auto* op_desc = op_node->Op(); if (op_desc->HasAttr("out_threshold")) { - float attr_scale = + const float attr_scale = BOOST_GET_CONST(float, op_desc->GetAttr("out_threshold")); if (attr_scale == 0.0) continue; float scale = 1.0 / attr_scale; @@ -180,6 +188,10 @@ void QuantDequantMkldnnFusePass::RemoveFakeOps( } } + PADDLE_ENFORCE_NOT_NULL(fake_quant_in, + "The input var of quantize op is not found."); + PADDLE_ENFORCE_NOT_NULL(fake_quant_out, + "The output var of quantize op is not found."); std::string input_act_name = fake_quant_in->Var()->Name(); std::string output_act_name = fake_quant_out->Var()->Name(); auto outlinks = fake_quant_out->outputs; @@ -215,6 +227,10 @@ void QuantDequantMkldnnFusePass::RemoveFakeOps( } } + PADDLE_ENFORCE_NOT_NULL(fake_dequant_in, + "The input var of dequantize op is not found."); + PADDLE_ENFORCE_NOT_NULL(fake_dequant_out, + "The output var of dequantize op is not found."); std::string input_act_name = fake_dequant_in->Var()->Name(); std::string output_act_name = fake_dequant_out->Var()->Name(); auto outlinks = fake_dequant_out->outputs; @@ -246,65 +262,160 @@ void QuantDequantMkldnnFusePass::RemoveFakeOps( void QuantDequantMkldnnFusePass::DequantizeWeights( ir::Graph* graph, Scope* scope, - std::unordered_map> weight_thresholds) + std::unordered_map>& weight_thresholds) const { auto is_int8_weights = [&](Node* op_node, Scope* scope, std::string weight_name) -> bool { auto* op_desc = op_node->Op(); auto var_name = op_desc->Input(weight_name)[0]; - std::cout << "var_name: " << var_name << std::endl; - if (scope->FindVar(var_name) == nullptr) { - std::cout << "eeeeeeeeeeeeeee" << std::endl; - return false; + auto* var = scope->FindVar(var_name); + PADDLE_ENFORCE_NOT_NULL(var, + "The input persistable var of %s op is not found.", + op_desc->Type()); + + auto* weight_tensor = var->GetMutable(); + auto weight_data = weight_tensor->mutable_data(platform::CPUPlace()); + bool is_int8 = true; + for (int i = 0; i < weight_tensor->numel(); i++) { + if (weight_data[i] - static_cast(weight_data[i]) != 0) { + is_int8 = false; + break; + } + } + return is_int8; + }; + + auto transpose_weight = [&](Tensor* input) { + const auto input_dims = input->dims(); + std::vector orders; + for (int i = input_dims.size() - 1; i >= 0; i--) { + orders.push_back(i); + } + + Tensor trans_tensor; + trans_tensor.Resize(input_dims); + float* trans_data = trans_tensor.mutable_data(platform::CPUPlace()); + float* in_data = input->mutable_data(platform::CPUPlace()); + + auto in_dims = input->dims(); + auto out_dims = trans_tensor.dims(); + int num_axes = in_dims.size(); + int count = 1; + for (int i = 0; i < num_axes; i++) { + count *= in_dims[i]; + } + + std::vector old_steps( + {static_cast(in_dims[1] * in_dims[2] * in_dims[3]), + static_cast(in_dims[2] * in_dims[3]), + static_cast(in_dims[3]), 1}); + std::vector new_steps( + {static_cast(out_dims[1] * out_dims[2] * out_dims[3]), + static_cast(out_dims[2] * out_dims[3]), + static_cast(out_dims[3]), 1}); + + for (int i = 0; i < count; ++i) { + int old_idx = 0; + int idx = i; + for (int j = 0; j < num_axes; ++j) { + int order = orders[j]; + old_idx += (idx / new_steps[j]) * old_steps[order]; + idx %= new_steps[j]; + } + trans_data[i] = in_data[old_idx]; + } + + for (int i = 0; i < input->numel(); i++) { + in_data[i] = trans_data[i]; } - auto* weight_tensor = scope->FindVar(var_name)->GetMutable(); - return weight_tensor->type() == framework::proto::VarType::INT8; }; - auto dequantize_op_weights = [&](Node* op_node, Scope* scope, - std::string weight_name, - std::string output_name) { + auto dequantize_op_weights = [&]( + Node* op_node, Scope* scope, std::string weight_name, + std::string output_name, + std::unordered_map>& weight_thresholds) { auto* op_desc = op_node->Op(); - auto weight_var_name = op_desc->Input(weight_name)[0]; - auto output_var_name = op_desc->Output(output_name)[0]; + std::string weight_var_name = op_desc->Input(weight_name)[0]; + std::string output_var_name = op_desc->Output(output_name)[0]; std::vector scales = weight_thresholds[output_var_name]; auto* weight_tensor = scope->FindVar(weight_var_name)->GetMutable(); + const auto weight_dims = weight_tensor->dims(); + + const int size = scales.size(); + if (size == 1 || size == weight_dims[0]) { + auto weight_data = + weight_tensor->mutable_data(platform::CPUPlace()); + for (int i = 0; i < weight_tensor->numel(); i++) { + weight_data[i] /= 127; + } + + transpose_weight(weight_tensor); - int size = scales.size(); - if (size == 1 || size == weight_tensor->dims()[0]) { + if (size == 1) { + for (int i = 0; i < weight_tensor->numel(); i++) { + weight_data[i] *= scales[0]; + } + } else { + int step = 1; + for (int i = 1; i < weight_dims.size(); i++) { + step *= weight_dims[i]; + } + + for (int i = 0; i < size; i++) { + int begin = i * step; + for (int j = begin; j < begin + step; j++) { + weight_data[j] *= scales[i]; + } + } + } + + transpose_weight(weight_tensor); + } else if (weight_dims.size() > 1 && size == weight_dims[1]) { auto weight_data = weight_tensor->mutable_data(platform::CPUPlace()); for (int i = 0; i < weight_tensor->numel(); i++) { weight_data[i] /= 127; } - // } else if (weight_tensor->dims().size() > 1 && scales.size() == - // weight_tensor->dims()[1]) { + int step_n = 1; + for (int i = 1; i < weight_dims.size(); i++) { + step_n *= weight_dims[i]; + } + int step_c = step_n / size; + for (int i = 0; i < weight_dims[0]; i++) { + int begin_n = i * step_n; + for (int j = begin_n; j < begin_n + step_n; j++) { + for (int k = 0; k < size; k++) { + int begin_c = k * step_c; + for (int m = begin_c; m < begin_c + step_c; m++) { + weight_data[m] *= scales[k]; + } + } + } + } } else { PADDLE_THROW(platform::errors::InvalidArgument( "The size of weight scales vector (%d) does not " "match the dimensions (%d) of the weights tensor %s.", size, weight_tensor->dims().size(), weight_var_name)); } + + weight_tensor->Resize(weight_dims); }; - std::cout << "scope11111: " << static_cast(scope) << std::endl; for (auto* op_node : ir::TopologyVarientSort(*graph, static_cast(0))) { if (!op_node->IsOp()) continue; - std::cout << "8888888888" << std::endl; if (op_node->Name() == "conv2d" || op_node->Name() == "depthwise_conv2d") { - std::cout << "9999999999" << std::endl; if (is_int8_weights(op_node, scope, "Filter")) { - std::cout << "95555555555" << std::endl; - dequantize_op_weights(op_node, scope, "Filter", "Output"); + dequantize_op_weights(op_node, scope, "Filter", "Output", + weight_thresholds); } } else if (op_node->Name() == "mul" || op_node->Name() == "matmul" || op_node->Name() == "matmul_v2") { - std::cout << "qqqqqqqqq" << std::endl; if (is_int8_weights(op_node, scope, "Y")) { - dequantize_op_weights(op_node, scope, "Y", "Out"); + dequantize_op_weights(op_node, scope, "Y", "Out", weight_thresholds); } } } @@ -320,10 +431,11 @@ void QuantDequantMkldnnFusePass::UpdateActivations(ir::Graph* graph) const { if (!op_desc->HasAttr("fuse_activation")) { std::string activation; if (op_desc->HasAttr("fuse_relu")) { - bool fuse_relu = BOOST_GET_CONST(bool, op_desc->GetAttr("fuse_relu")); + const bool fuse_relu = + BOOST_GET_CONST(bool, op_desc->GetAttr("fuse_relu")); if (fuse_relu) activation = "relu"; } else if (op_desc->HasAttr("fuse_brelu")) { - bool fuse_brelu = + const bool fuse_brelu = BOOST_GET_CONST(bool, op_desc->GetAttr("fuse_relu")); if (fuse_brelu) { activation = "relu6"; @@ -354,6 +466,7 @@ void QuantDequantMkldnnFusePass::RemoveCtrlVars(ir::Graph* graph) const { } void QuantDequantMkldnnFusePass::ApplyImpl(ir::Graph* graph) const { + VLOG(3) << "Convert paddle slim quantized model to mkldnn quantized model."; const std::string pattern_name = "quant_dequant_mkldnn_fuse_pass"; FusePassBase::Init(pattern_name, graph); @@ -373,24 +486,24 @@ void QuantDequantMkldnnFusePass::ApplyImpl(ir::Graph* graph) const { var_quant_scales; auto* scope = param_scope(); - std::cout << "scope: " << static_cast(scope) << std::endl; - MarkSkipQuantizedOps(graph, skip_ops); - std::cout << "11111111" << std::endl; GatherInfoFromFake(graph, scope, fake_dequantize_types, weight_thresholds); - std::cout << "2222222" << std::endl; + for (auto iter = weight_thresholds.begin(); iter != weight_thresholds.end(); + ++iter) { + std::cout << iter->first << std::endl; + } + std::cout << "111111" << std::endl; GatherInputScalesFromFake(graph, scope, fake_quantize_types, var_quant_scales); - std::cout << "333333333" << std::endl; GatherOutputScalesFromAttr(graph, var_quant_scales); - std::cout << "444444444" << std::endl; + for (auto iter = var_quant_scales.begin(); iter != var_quant_scales.end(); + ++iter) { + std::cout << iter->first << std::endl; + } RemoveFakeOps(graph, fake_quantize_types, fake_dequantize_types, fake_quantize_dequantize_types); - std::cout << "555555555" << std::endl; DequantizeWeights(graph, scope, weight_thresholds); - std::cout << "666666666" << std::endl; UpdateActivations(graph); - std::cout << "77777777" << std::endl; RemoveCtrlVars(graph); } diff --git a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_fuse_pass.h index 67deb731e086c..5a8366353d1ba 100644 --- a/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/quant_dequant_mkldnn_fuse_pass.h @@ -35,18 +35,18 @@ class QuantDequantMkldnnFusePass : public FusePassBase { void GatherInfoFromFake(ir::Graph* graph, Scope* scope, std::unordered_set fake_dequantize_types, - std::unordered_map> + std::unordered_map>& weight_thresholds) const; void GatherInputScalesFromFake( ir::Graph* graph, Scope* scope, std::unordered_set fake_quantize_types, - std::unordered_map>> + std::unordered_map>>& var_quant_scales) const; void GatherOutputScalesFromAttr( ir::Graph* graph, - std::unordered_map>> + std::unordered_map>>& var_quant_scales) const; void RemoveFakeOps( @@ -55,7 +55,7 @@ class QuantDequantMkldnnFusePass : public FusePassBase { std::unordered_set fake_quantize_dequantize_types) const; void DequantizeWeights(ir::Graph* graph, Scope* scope, - std::unordered_map> + std::unordered_map>& weight_thresholds) const; void UpdateActivations(ir::Graph* graph) const; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index af1a986bf2819..f21d7d45f7678 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -313,6 +313,41 @@ void CpuPassStrategy::EnableMkldnnInt8() { if (!use_mkldnn_int8_) { passes_.clear(); passes_.push_back("quant_dequant_mkldnn_fuse_pass"); + passes_.push_back("attention_lstm_fuse_pass"); + passes_.push_back("seqconv_eltadd_relu_fuse_pass"); + passes_.push_back("seqpool_cvm_concat_fuse_pass"); + passes_.push_back("fc_lstm_fuse_pass"); + passes_.push_back("mul_lstm_fuse_pass"); + passes_.push_back("fc_gru_fuse_pass"); + passes_.push_back("mul_gru_fuse_pass"); + passes_.push_back("multi_gru_fuse_pass"); + passes_.push_back("multi_gru_seq_fuse_pass"); + passes_.push_back("seq_concat_fc_fuse_pass"); + passes_.push_back("squared_mat_sub_fuse_pass"); + passes_.push_back("is_test_pass"); + passes_.push_back("map_matmul_v2_to_mul_pass"); + passes_.push_back("map_matmul_v2_to_matmul_pass"); + passes_.push_back("map_matmul_to_mul_pass"); + // 需要参数? + /// passes_.push_back("mkldnn_placement_pass"); + passes_.push_back("depthwise_conv_mkldnn_pass"); + passes_.push_back("conv_bn_fuse_pass"); + passes_.push_back("conv_eltwiseadd_bn_fuse_pass"); + passes_.push_back("conv_transpose_bn_fuse_pass"); + passes_.push_back("conv_transpose_eltwiseadd_bn_fuse_pass"); + passes_.push_back("conv_bias_mkldnn_fuse_pass"); + passes_.push_back("conv_elementwise_add_mkldnn_fuse_pass"); + passes_.push_back("conv_relu_mkldnn_fuse_pass"); + passes_.push_back("conv_relu6_mkldnn_fuse_pass"); + // 需要参数? + /// passes_.push_back("fc_fuse_pass"); + passes_.push_back("repeated_fc_relu_fuse_pass"); + // 要开吗? + // passes_.push_back("fc_mkldnn_pass"); + // passes_.push_back("fc_act_mkldnn_fuse_pass"); + passes_.push_back("matmul_transpose_reshape_fuse_pass"); + passes_.push_back("matmul_v2_transpose_reshape_fuse_pass"); + passes_.push_back("runtime_context_cache_pass"); } use_mkldnn_int8_ = true; #else