From 9f1b2b1f26575e8812ee20ecddbab44138ff7c5c Mon Sep 17 00:00:00 2001 From: RichardWooSJTU <37864677+RichardWooSJTU@users.noreply.github.com> Date: Thu, 8 Dec 2022 17:20:33 +0800 Subject: [PATCH] rewrite delete_weight_dequant_linear_op_encoder/decoder pass (#48650) * rewrite delete_weight_deqquant_linear_op_encoder/decoder pass --- paddle/fluid/framework/ir/CMakeLists.txt | 7 +- ...e_weight_dequant_linear_op_decoder_pass.cc | 373 ------------ ...te_weight_dequant_linear_op_encoder_pass.h | 34 -- .../delete_weight_dequant_linear_op_pass.cc | 550 +++++------------- .../ir/delete_weight_dequant_linear_op_pass.h | 35 +- ...te_weight_dequant_linear_op_pass_tester.cc | 141 +++++ paddle/fluid/framework/ir/pass.cc | 4 +- .../fluid/framework/ir/pass_tester_helper.h | 17 + ...t_delete_weight_dequant_linear_op_pass.cc} | 141 +++-- ...rt_delete_weight_dequant_linear_op_pass.h} | 7 +- .../inference/api/paddle_pass_builder.cc | 26 +- 11 files changed, 442 insertions(+), 893 deletions(-) delete mode 100644 paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.cc delete mode 100644 paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.h create mode 100644 paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass_tester.cc rename paddle/fluid/framework/ir/{delete_weight_dequant_linear_op_encoder_pass.cc => trt_delete_weight_dequant_linear_op_pass.cc} (63%) rename paddle/fluid/framework/ir/{delete_weight_dequant_linear_op_decoder_pass.h => trt_delete_weight_dequant_linear_op_pass.h} (82%) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 51f111314c1cec..f65db53893038c 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -95,9 +95,8 @@ pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(shuffle_channel_detect_pass inference) pass_library(delete_quant_dequant_op_pass inference) pass_library(delete_quant_dequant_filter_op_pass inference) +pass_library(trt_delete_weight_dequant_linear_op_pass inference) pass_library(delete_weight_dequant_linear_op_pass inference) -pass_library(delete_weight_dequant_linear_op_encoder_pass inference) -pass_library(delete_weight_dequant_linear_op_decoder_pass inference) pass_library(delete_quant_dequant_linear_op_pass inference) pass_library(delete_dropout_op_pass inference) pass_library(delete_c_identity_op_pass inference) @@ -359,6 +358,10 @@ cc_test( test_delete_dropout_pass_cc SRCS delete_dropout_op_pass_test.cc DEPS delete_dropout_op_pass) +cc_test( + test_delete_dequant_weight_linear_op_pass + SRCS delete_weight_dequant_linear_op_pass_tester.cc + DEPS delete_weight_dequant_linear_op_pass) if(WITH_GPU OR WITH_ROCM) cc_test( test_embedding_eltwise_layernorm_fuse_pass diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.cc b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.cc deleted file mode 100644 index fe692d01928f72..00000000000000 --- a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.cc +++ /dev/null @@ -1,373 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.h" - -#include -#include -#include -#include -#include - -namespace paddle { -namespace framework { -namespace ir { - -#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); -#define GET_NODES \ - GET_IR_NODE(weight_dequantize_linear_op_x); \ - GET_IR_NODE(weight_dequantize_linear_op_scale); \ - GET_IR_NODE(weight_dequantize_linear_op); \ - GET_IR_NODE(weight_dequantize_linear_op_out); \ - GET_IR_NODE(any_op2); - -DeleteWeightDequantLinearOpDecoderPass:: - DeleteWeightDequantLinearOpDecoderPass() { - AddOpCompat(OpCompat("quantize_linear")) - .AddInput("X") - .IsTensor() - .End() - .AddInput("Scale") - .IsTensor() - .End() - .AddInput("ZeroPoint") - .IsTensor() - .IsOptional() - .End() - .AddOutput("Y") - .IsTensor() - .End() - .AddAttr("bit_length") - .IsType() - .End() - .AddAttr("quant_axis") - .IsType() - .End() - .AddAttr("round_type") - .IsOptional() - .IsType() - .End(); - AddOpCompat(OpCompat("dequantize_linear")) - .AddInput("X") - .IsTensor() - .End() - .AddInput("Scale") - .IsTensor() - .End() - .AddInput("ZeroPoint") - .IsTensor() - .IsOptional() - .End() - .AddOutput("Y") - .IsTensor() - .End() - .AddAttr("bit_length") - .IsType() - .End() - .AddAttr("quant_axis") - .IsType() - .End() - .AddAttr("round_type") - .IsOptional() - .IsType() - .End(); - AddOpCompat(OpCompat("conv2d")) - .AddInput("Input") - .IsTensor() - .End() - .AddInput("Filter") - .IsTensor() - .End() - .AddInput("Bias") - .IsTensor() - .IsOptional() - .End() - .AddInput("ResidualData") - .IsTensor() - .IsOptional() - .End() - .AddOutput("Output") - .IsTensor() - .End() - .AddAttr("strides") - .IsType>() - .End() - .AddAttr("paddings") - .IsType>() - .End() - .AddAttr("padding_algorithm") - .IsOptional() - .IsStringIn({"EXPLICIT", "SAME", "VALID"}) - .End() - .AddAttr("groups") - .IsNumGE(1) - .End() - .AddAttr("dilations") - .IsType>() - .End() - .AddAttr("data_format") - .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) - .End(); - AddOpCompat(OpCompat("depthwise_conv2d")) - .AddInput("Input") - .IsTensor() - .End() - .AddInput("Filter") - .IsTensor() - .End() - .AddInput("Bias") - .IsTensor() - .IsOptional() - .End() - .AddInput("ResidualData") - .IsTensor() - .IsOptional() - .End() - .AddOutput("Output") - .IsTensor() - .End() - .AddAttr("strides") - .IsType>() - .End() - .AddAttr("paddings") - .IsType>() - .End() - .AddAttr("padding_algorithm") - .IsOptional() - .IsStringIn({"EXPLICIT", "SAME", "VALID"}) - .End() - .AddAttr("groups") - .IsNumGE(1) - .End() - .AddAttr("dilations") - .IsType>() - .End() - .AddAttr("data_format") - .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) - .End(); - AddOpCompat(OpCompat("mul")) - .AddInput("X") - .IsTensor() - .End() - .AddInput("Y") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddAttr("x_num_col_dims") - .IsNumGE(1) - .End() - .AddAttr("y_num_col_dims") - .IsNumEQ(1) - .End(); - AddOpCompat(OpCompat("matmul_v2")) - .AddInput("X") - .IsTensor() - .End() - .AddInput("Y") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddAttr("trans_x") - .IsBoolEQ(false) - .End() - .AddAttr("trans_y") - .IsBoolEQ(false) - .End(); - AddOpCompat(OpCompat("matmul")) - .AddInput("X") - .IsTensor() - .End() - .AddInput("Y") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddAttr("alpha") - .IsNumGE(0.99f) - .IsNumLE(1.01f) - .End() - .AddAttr("transpose_X") - .IsBoolEQ(false) - .End() - .AddAttr("transpose_Y") - .IsBoolEQ(false) - .End(); - AddOpCompat(OpCompat("fc")) - .AddInput("Input") - .IsTensor() - .End() - .AddInput("W") - .IsTensor() - .End() - .AddInput("Bias") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddAttr("in_num_col_dims") - .IsNumGE(1) - .End() - .AddAttr("activation_type") - .IsStringIn({"relu", ""}) - .End(); - AddOpCompat(OpCompat("conv2d_transpose")) - .AddInput("Input") - .IsTensor() - .End() - .AddInput("Filter") - .IsTensor() - .End() - .AddInput("Bias") - .IsTensor() - .IsOptional() - .End() - .AddOutput("Output") - .IsTensor() - .End() - .AddAttr("output_padding") - .IsType>() - .IsOptional() - .End() - .AddAttr("output_size") - .IsType>() - .IsOptional() - .End() - .AddAttr("groups") - .IsNumGE(1) - .End() - .AddAttr("dilations") - .IsType>() - .End() - .AddAttr("strides") - .IsType>() - .End() - .AddAttr("paddings") - .IsType>() - .End() - .AddAttr("padding_algorithm") - .IsOptional() - .IsStringIn({"EXPLICIT", "SAME", "VALID"}) - .End() - .AddAttr("data_format") - .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) - .End(); -} -// Delete dequantize_linear_op, then dequantize weight -void DeleteWeightDequantLinearOpDecoderPass::ApplyImpl(ir::Graph* graph) const { - const std::string pattern_name = - "delete_weight_dequant_linear_op_decoder_pattern"; - FusePassBase::Init(pattern_name, graph); - - GraphPatternDetector gpd; - auto* scope = param_scope(); - PADDLE_ENFORCE_NOT_NULL(scope, - platform::errors::InvalidArgument( - "Scope in DeleteWeightDequantLinearOpDecoderPass " - "should not be null.")); - // Create pattern - patterns::DeleteWeightDequantLinearOpDecoderPattern pattern( - gpd.mutable_pattern(), pattern_name); - pattern(); - int found_count = 0; - bool is_int8 = false; - - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - GET_NODES; - /* - if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "delete_weight_dequant_linear_op_pass " - "compat check failed."; - return; - } - */ - is_int8 = true; - std::unordered_set nodes2rm = {}; - - auto* any_op2_desc = any_op2->Op(); - - // Get weight scale - std::vector weight_scale; - auto* weight_scale_tensor = - scope->GetVar(weight_dequantize_linear_op_scale->Name()) - ->GetMutable(); - auto weight_scale_nums = weight_scale_tensor->numel(); - - if (weight_scale_tensor->dtype() == - paddle::experimental::DataType::FLOAT32) { - float* weight_scale_data = weight_scale_tensor->data(); - for (int i = 0; i < weight_scale_nums; i++) { - weight_scale.push_back(weight_scale_data[i]); - } - } else if (weight_scale_tensor->dtype() == - paddle::experimental::DataType::FLOAT16) { - phi::dtype::float16* weight_scale_data = - weight_scale_tensor->data(); - for (int i = 0; i < weight_scale_nums; i++) { - weight_scale.push_back(static_cast(weight_scale_data[i])); - } - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "%d is not supported.", weight_scale_tensor->dtype())); - } - - int quant_axis = PADDLE_GET_CONST( - int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis")); - if (quant_axis == -1) { // per_layer quant_dequant: all OP - PADDLE_ENFORCE_EQ(weight_scale_nums, - 1, - platform::errors::InvalidArgument( - "When quant_axis == -1 means use per_layer " - "quant_dequant, weight_scale'number should be 1.")); - - // Add attr to anyop 2 - any_op2_desc->SetAttr("weight_scale", weight_scale[0]); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Delete Weight Dequant Linear Op Encoder Pass is not supported for " - "per-channel quantization")); - } - - nodes2rm.insert(weight_dequantize_linear_op_scale); - nodes2rm.insert(weight_dequantize_linear_op); - nodes2rm.insert(weight_dequantize_linear_op_out); - - // relink weight to any_op2 - any_op2_desc->RenameInput(weight_dequantize_linear_op_out->Var()->Name(), - weight_dequantize_linear_op_x->Var()->Name()); - any_op2_desc->Flush(); - IR_NODE_LINK_TO(weight_dequantize_linear_op_x, any_op2); - GraphSafeRemoveNodes(graph, nodes2rm); - found_count++; - }; - gpd(graph, handler); - if (is_int8) { - auto& enable_int8 = graph->Get("enable_int8"); - enable_int8 = true; - } - AddStatis(found_count); -} - -} // namespace ir -} // namespace framework -} // namespace paddle - -REGISTER_PASS(delete_weight_dequant_linear_op_decoder_pass, - paddle::framework::ir::DeleteWeightDequantLinearOpDecoderPass); diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.h b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.h deleted file mode 100644 index 8aead6bd5cc583..00000000000000 --- a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.h +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/ir/fuse_pass_base.h" - -namespace paddle { -namespace framework { -namespace ir { - -class DeleteWeightDequantLinearOpEncoderPass : public FusePassBase { - public: - DeleteWeightDequantLinearOpEncoderPass(); - virtual ~DeleteWeightDequantLinearOpEncoderPass() {} - - protected: - void ApplyImpl(ir::Graph* graph) const override; -}; - -} // namespace ir -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc index cc9ce6d0b2419c..5a6796550619b9 100644 --- a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.cc @@ -1,429 +1,163 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h" +#include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include -#include -#include -#include -#include +#include "glog/logging.h" namespace paddle { namespace framework { namespace ir { -#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); -#define GET_NODES \ - GET_IR_NODE(weight_dequantize_linear_op_x); \ - GET_IR_NODE(weight_dequantize_linear_op_scale); \ - GET_IR_NODE(weight_dequantize_linear_op); \ - GET_IR_NODE(weight_dequantize_linear_op_out); \ - GET_IR_NODE(any_op2); - -DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() { - AddOpCompat(OpCompat("quantize_linear")) - .AddInput("X") - .IsTensor() - .End() - .AddInput("Scale") - .IsTensor() - .End() - .AddInput("ZeroPoint") - .IsTensor() - .IsOptional() - .End() - .AddOutput("Y") - .IsTensor() - .End() - .AddAttr("bit_length") - .IsType() - .End() - .AddAttr("quant_axis") - .IsType() - .End() - .AddAttr("round_type") - .IsOptional() - .IsType() - .End(); - AddOpCompat(OpCompat("dequantize_linear")) - .AddInput("X") - .IsTensor() - .End() - .AddInput("Scale") - .IsTensor() - .End() - .AddInput("ZeroPoint") - .IsTensor() - .IsOptional() - .End() - .AddOutput("Y") - .IsTensor() - .End() - .AddAttr("bit_length") - .IsType() - .End() - .AddAttr("quant_axis") - .IsType() - .End() - .AddAttr("round_type") - .IsOptional() - .IsType() - .End(); - AddOpCompat(OpCompat("conv2d")) - .AddInput("Input") - .IsTensor() - .End() - .AddInput("Filter") - .IsTensor() - .End() - .AddInput("Bias") - .IsTensor() - .IsOptional() - .End() - .AddInput("ResidualData") - .IsTensor() - .IsOptional() - .End() - .AddOutput("Output") - .IsTensor() - .End() - .AddAttr("strides") - .IsType>() - .End() - .AddAttr("paddings") - .IsType>() - .End() - .AddAttr("padding_algorithm") - .IsOptional() - .IsStringIn({"EXPLICIT", "SAME", "VALID"}) - .End() - .AddAttr("groups") - .IsNumGE(1) - .End() - .AddAttr("dilations") - .IsType>() - .End() - .AddAttr("data_format") - .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) - .End(); - AddOpCompat(OpCompat("depthwise_conv2d")) - .AddInput("Input") - .IsTensor() - .End() - .AddInput("Filter") - .IsTensor() - .End() - .AddInput("Bias") - .IsTensor() - .IsOptional() - .End() - .AddInput("ResidualData") - .IsTensor() - .IsOptional() - .End() - .AddOutput("Output") - .IsTensor() - .End() - .AddAttr("strides") - .IsType>() - .End() - .AddAttr("paddings") - .IsType>() - .End() - .AddAttr("padding_algorithm") - .IsOptional() - .IsStringIn({"EXPLICIT", "SAME", "VALID"}) - .End() - .AddAttr("groups") - .IsNumGE(1) - .End() - .AddAttr("dilations") - .IsType>() - .End() - .AddAttr("data_format") - .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) - .End(); - AddOpCompat(OpCompat("mul")) - .AddInput("X") - .IsTensor() - .End() - .AddInput("Y") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddAttr("x_num_col_dims") - .IsNumGE(1) - .End() - .AddAttr("y_num_col_dims") - .IsNumEQ(1) - .End(); - AddOpCompat(OpCompat("matmul_v2")) - .AddInput("X") - .IsTensor() - .End() - .AddInput("Y") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddAttr("trans_x") - .IsBoolEQ(false) - .End() - .AddAttr("trans_y") - .IsBoolEQ(false) - .End(); - AddOpCompat(OpCompat("matmul")) - .AddInput("X") - .IsTensor() - .End() - .AddInput("Y") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddAttr("alpha") - .IsNumGE(0.99f) - .IsNumLE(1.01f) - .End() - .AddAttr("transpose_X") - .IsBoolEQ(false) - .End() - .AddAttr("transpose_Y") - .IsBoolEQ(false) - .End(); - AddOpCompat(OpCompat("fc")) - .AddInput("Input") - .IsTensor() - .End() - .AddInput("W") - .IsTensor() - .End() - .AddInput("Bias") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddAttr("in_num_col_dims") - .IsNumGE(1) - .End() - .AddAttr("activation_type") - .IsStringIn({"relu", ""}) - .End(); - AddOpCompat(OpCompat("conv2d_transpose")) - .AddInput("Input") - .IsTensor() - .End() - .AddInput("Filter") - .IsTensor() - .End() - .AddInput("Bias") - .IsTensor() - .IsOptional() - .End() - .AddOutput("Output") - .IsTensor() - .End() - .AddAttr("output_padding") - .IsType>() - .IsOptional() - .End() - .AddAttr("output_size") - .IsType>() - .IsOptional() - .End() - .AddAttr("groups") - .IsNumGE(1) - .End() - .AddAttr("dilations") - .IsType>() - .End() - .AddAttr("strides") - .IsType>() - .End() - .AddAttr("paddings") - .IsType>() - .End() - .AddAttr("padding_algorithm") - .IsOptional() - .IsStringIn({"EXPLICIT", "SAME", "VALID"}) - .End() - .AddAttr("data_format") - .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) - .End(); -} -// Delete dequantize_linear_op, then dequantize weight -void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { - const std::string pattern_name = - "delete_weight_quantdequant_linear_op_pattern"; - FusePassBase::Init(pattern_name, graph); - - GraphPatternDetector gpd; - auto* scope = param_scope(); - PADDLE_ENFORCE_NOT_NULL( - scope, - platform::errors::InvalidArgument( - "Scope in DeleteWeightQuantDequantLinearOpPass should not be null.")); - // Create pattern - patterns::DeleteWeightQuantDequantLinearOpPattern pattern( - gpd.mutable_pattern(), pattern_name); - pattern(); - int found_count = 0; - - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - GET_NODES; - /* - if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "delete_weight_dequant_linear_op_pass " - "compat check failed."; - return; - } - */ - std::unordered_set nodes2rm = {}; - int bit_length = PADDLE_GET_CONST( - int, weight_dequantize_linear_op->Op()->GetAttr("bit_length")); - int range = ((1 << (bit_length - 1)) - 1); - - auto* any_op2_desc = any_op2->Op(); - - // get weight tensor - auto* weight_tensor = scope->GetVar(weight_dequantize_linear_op_x->Name()) - ->GetMutable(); - int8_t* quantized_weight_data = - weight_tensor->mutable_data(platform::CPUPlace()); - auto w_dims = weight_tensor->dims(); - - // Get weight scale - std::vector weight_scale; - auto* weight_scale_tensor = - scope->GetVar(weight_dequantize_linear_op_scale->Name()) - ->GetMutable(); - float* weight_scale_data = - weight_scale_tensor->mutable_data(platform::CPUPlace()); - - auto weight_scale_nums = weight_scale_tensor->numel(); - for (int i = 0; i < weight_scale_nums; i++) { - weight_scale.push_back(weight_scale_data[i] / range); - } - - // dequant weight - std::vector weight_data_tmp; - weight_data_tmp.reserve(weight_tensor->numel()); - - int quant_axis = PADDLE_GET_CONST( - int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis")); - if (quant_axis == -1) { // per_layer quant_dequant: all OP - PADDLE_ENFORCE_EQ(weight_scale_nums, +class Graph; + +void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { + std::unordered_set op_list = {"matmul_v2", + "matmul", + "mul", + "fc", + "depthwise_conv2d", + "conv2d", + "conv2d_transpose"}; + PADDLE_ENFORCE_EQ(graph->Has(kParamScopeAttr), + true, + platform::errors::InvalidArgument( + "Graph must have kParamScopeAttr attribute.")); + + auto& scope = graph->Get(kParamScopeAttr); + bool is_int8 = false; + + std::unordered_set nodes2rm; + + for (const Node* n : graph->Nodes()) { + if (n->IsOp()) { + auto* op = n->Op(); + if (op->Type() == "dequantize_linear") { + Node *weight_var_node, *dequantized_weight_var_node, *scale_var_node, + *calcu_op_node, *while_op_node; + // 1. Judge whether for dequant weight and find + // weight_var_node/scale_var_node + for (auto* input_node : n->inputs) { + if (input_node->IsVar() && input_node->Var()->Persistable()) { + is_int8 = true; + if (input_node->Var()->Name() == op->Input("X")[0]) { + weight_var_node = input_node; + } else if (input_node->Var()->Name() == op->Input("Scale")[0]) { + scale_var_node = input_node; + } + } else { + return; + } + } + // 2. Find next_op_node + // For while op: delete its input which is related to dequantized + // For calculation op: set weight scale as their attributes + for (auto* output_node : n->outputs) { + if (output_node->IsVar() && + output_node->Var()->Name() == op->Output("Y")[0]) { + dequantized_weight_var_node = output_node; + for (auto* next_op_node : output_node->outputs) { + if (next_op_node->IsOp()) { + if (next_op_node->Op()->Type() == "while") { + while_op_node = next_op_node; + auto while_op_desc = while_op_node->Op(); + auto while_Xs = while_op_desc->Input("X"); + while_Xs.erase(std::remove(std::begin(while_Xs), + std::end(while_Xs), + output_node->Var()->Name()), + std::end(while_Xs)); + while_op_node->Op()->SetInput("X", while_Xs); + } else if (op_list.count(next_op_node->Op()->Type()) != 0) { + calcu_op_node = next_op_node; + auto* calcu_op_desc = calcu_op_node->Op(); + + std::vector weight_scale; + auto* weight_scale_tensor = + scope.GetVar(scale_var_node->Name()) + ->GetMutable(); + auto weight_scale_nums = weight_scale_tensor->numel(); + + if (weight_scale_tensor->dtype() == + paddle::experimental::DataType::FLOAT32) { + float* weight_scale_data = + weight_scale_tensor->data(); + for (int i = 0; i < weight_scale_nums; i++) { + weight_scale.push_back(weight_scale_data[i]); + } + } else if (weight_scale_tensor->dtype() == + paddle::experimental::DataType::FLOAT16) { + phi::dtype::float16* weight_scale_data = + weight_scale_tensor->data(); + for (int i = 0; i < weight_scale_nums; i++) { + weight_scale.push_back( + static_cast(weight_scale_data[i])); + } + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "The dtype of quantization scale must be FP32/16, " + "but received %d, which is not supported.", + weight_scale_tensor->dtype())); + } + + int quant_axis = + PADDLE_GET_CONST(int, op->GetAttr("quant_axis")); + if (quant_axis == -1) { // per_layer quant_dequant: all OP + PADDLE_ENFORCE_EQ( + weight_scale_nums, 1, platform::errors::InvalidArgument( - "When quant_axis == -1 means use per_layer " - "quant_dequant, weight_scale'number should be 1.")); - - // float(weight) * scale - for (int i = 0; i < weight_tensor->numel(); i++) { - weight_data_tmp[i] = - static_cast(quantized_weight_data[i]) * weight_scale[0]; - } - } else if (quant_axis == 0) { // per_channel quant_dequant: conv2d, - // depthwise_conv2d, conv2d_fusion - PADDLE_ENFORCE_EQ( - weight_scale_nums, - w_dims[quant_axis], - platform::errors::InvalidArgument( - "When quant_axis == 0 means use per_channel quant_dequant, " - "weight_scale'numbers should be equal channels.")); - PADDLE_ENFORCE_EQ(w_dims.size(), - 4, - platform::errors::InvalidArgument( - "When quant_axis == 0 means use per_channel " - "quant_dequant, (conv2d, depthwise_conv2d, " - "conv2d_fusion)'s weight dims should be 4.")); - - for (int i = 0; i < weight_tensor->numel(); i++) { - int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; - weight_data_tmp[i] = static_cast(quantized_weight_data[i]) * - weight_scale[i / inner_size]; - } - } else if (quant_axis == 1) { - PADDLE_ENFORCE_EQ( - weight_scale_nums, - w_dims[quant_axis], - platform::errors::InvalidArgument( - "When quant_axis == 1 means use per_channel quant_dequant, " - "weight_scale'numbers should be equal channels.")); - - if (w_dims.size() == 4) { // conv2d_transpose - std::string quantized_op_type = any_op2->Op()->Type(); - PADDLE_ENFORCE_EQ( - quantized_op_type, - "conv2d_transpose", - platform::errors::InvalidArgument( - "When quant_axis == 1 means use per_channel quant_dequant, " - "only conv2d_transpose weight dims equal 4.")); - for (int i = 0; i < weight_tensor->numel(); i++) { - int inner_size = w_dims[2] * w_dims[3]; - weight_data_tmp[i] = static_cast(quantized_weight_data[i]) * - weight_scale[(i / inner_size) % w_dims[1]]; + "When quant_axis == -1, it means using per_layer " + "dequantization. In this situation, the number of " + "weight_scale should be 1, but received %d.", + weight_scale_nums)); + + calcu_op_desc->SetAttr("weight_scale", weight_scale[0]); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Delete Weight Dequant Linear Op Pass is not supported " + "for " + "per-channel quantization")); + } + calcu_op_desc->RenameInput( + dequantized_weight_var_node->Var()->Name(), + weight_var_node->Var()->Name()); + } + } + } + } } - } else if (w_dims.size() == 2) { - for (int i = 0; i < weight_tensor->numel(); i++) { - weight_data_tmp[i] = static_cast(quantized_weight_data[i]) * - weight_scale[i % w_dims[1]]; + + // 3. Delete dequant op + IR_NODE_LINK_TO(weight_var_node, calcu_op_node); + std::vector nodes2rm_local{ + dequantized_weight_var_node, scale_var_node, n}; + for (auto* node2rm : nodes2rm_local) { + if (node2rm) { + nodes2rm.insert(node2rm); + } } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "When quant_axis == 1 , weight dims should be 2 or 4, please check " - "your model ")); } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "quant_axis should be -1 or 0 or 1, please check your model " - "OP'attribute ")); } - weight_tensor->clear(); // clear int weight - weight_tensor->Resize(phi::make_ddim(phi::vectorize(w_dims))); - float* new_quantized_weight_data = - weight_tensor->mutable_data(platform::CPUPlace()); - memcpy(new_quantized_weight_data, - weight_data_tmp.data(), - weight_tensor->numel() * sizeof(float)); - - nodes2rm.insert(weight_dequantize_linear_op_scale); - nodes2rm.insert(weight_dequantize_linear_op); - nodes2rm.insert(weight_dequantize_linear_op_out); + } - // relink weight to any_op2 - any_op2_desc->RenameInput(weight_dequantize_linear_op_out->Var()->Name(), - weight_dequantize_linear_op_x->Var()->Name()); - any_op2_desc->Flush(); - IR_NODE_LINK_TO(weight_dequantize_linear_op_x, any_op2); - GraphSafeRemoveNodes(graph, nodes2rm); - found_count++; - }; - gpd(graph, handler); - AddStatis(found_count); + GraphSafeRemoveNodes(graph, nodes2rm); + graph->Set("enable_int8", new bool(is_int8)); } - } // namespace ir } // namespace framework } // namespace paddle REGISTER_PASS(delete_weight_dequant_linear_op_pass, - paddle::framework::ir::DeleteWeightQuantDequantLinearOpPass); + paddle::framework::ir::DeleteWeightDequantLinearOpPass); diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h index e240b6212b84a3..dd7187946d053a 100644 --- a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h +++ b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h @@ -1,31 +1,28 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once -#include -#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass.h" namespace paddle { namespace framework { namespace ir { -class DeleteWeightQuantDequantLinearOpPass : public FusePassBase { - public: - DeleteWeightQuantDequantLinearOpPass(); - virtual ~DeleteWeightQuantDequantLinearOpPass() {} +class Graph; +class DeleteWeightDequantLinearOpPass : public Pass { protected: void ApplyImpl(ir::Graph* graph) const override; }; diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass_tester.cc b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass_tester.cc new file mode 100644 index 00000000000000..603bbf0872d52d --- /dev/null +++ b/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass_tester.cc @@ -0,0 +1,141 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +template +void AddVarToScope(Scope* param_scope, + const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CPUPlace())); + dev_ctx->HostAlloc(tensor, tensor->numel() * sizeof(T)); +} + +template +Scope* CreateParamScope() { + auto param_scope = new Scope(); + AddVarToScope(param_scope, "scale", {1}); + + return param_scope; +} + +TEST(DeleteWeightDequantLinearOpPass, basic) { + // inputs operator output + // -------------------------------------------------------------------- + // (weight, scale) dequantize_linear -> dequantized_weight + // (x, dequantized_weight) matmul/fc/conv -> matmul_out + // (dequantized_weight) while -> [optional] + + Layers layers; + + auto* x = layers.data("x", {1, 128, 768}); + auto* weight = layers.data("weight", {768, 768}, true); + auto* scale = layers.data("scale", {1}, true); + auto* zero_point = layers.data("zero_point", {1}, true); + auto* dequantized_weight = + layers.dequantize_linear(weight, scale, zero_point); + layers.matmul_v2(x, dequantized_weight); + layers.while_loop({dequantized_weight}); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + + graph->Set("__param_scope__", CreateParamScope()); + auto pass = + PassRegistry::Instance().Get("delete_weight_dequant_linear_op_pass"); + int num_nodes_before = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + int num_dequant_nodes_after = GetNumOpNodes(graph, "dequantize_linear"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ( + num_nodes_before, + num_nodes_after + 3, + platform::errors::InvalidArgument( + "After pass, the number of nodes should be reduced by 3, but the " + "number before pass is %d, after pass is %d.", + num_nodes_before, + num_nodes_after)); + PADDLE_ENFORCE_EQ(num_dequant_nodes_after, + 0, + platform::errors::InvalidArgument( + "After pass, the number of nodes of type " + "'dequantize_linear' should be 1, not %d.", + num_dequant_nodes_after)); +} + +TEST(DeleteWeightDequantLinearOpPass, basic_fp16) { + // inputs operator output + // -------------------------------------------------------------------- + // (weight, scale) dequantize_linear -> dequantized_weight + // (x, dequantized_weight) matmul/fc/conv -> matmul_out + // (dequantized_weight) while -> [optional] + + Layers layers; + + auto* x = layers.data("x", {1, 128, 768}); + auto* weight = layers.data("weight", {768, 768}, true); + auto* scale = layers.data("scale", {1}, true); + auto* zero_point = layers.data("zero_point", {1}, true); + auto* dequantized_weight = + layers.dequantize_linear(weight, scale, zero_point); + layers.matmul_v2(x, dequantized_weight); + layers.while_loop({dequantized_weight}); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + + graph->Set("__param_scope__", CreateParamScope()); + auto pass = + PassRegistry::Instance().Get("delete_weight_dequant_linear_op_pass"); + int num_nodes_before = graph->Nodes().size(); + VLOG(3) << DebugString(graph); + + graph.reset(pass->Apply(graph.release())); + int num_nodes_after = graph->Nodes().size(); + int num_dequant_nodes_after = GetNumOpNodes(graph, "dequantize_linear"); + VLOG(3) << DebugString(graph); + + PADDLE_ENFORCE_EQ( + num_nodes_before, + num_nodes_after + 3, + platform::errors::InvalidArgument( + "After pass, the number of nodes should be reduced by 3, but the " + "number before pass is %d, after pass is %d.", + num_nodes_before, + num_nodes_after)); + PADDLE_ENFORCE_EQ(num_dequant_nodes_after, + 0, + platform::errors::InvalidArgument( + "After pass, the number of nodes of type " + "'dequantize_linear' should be 1, not %d.", + num_dequant_nodes_after)); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(delete_weight_dequant_linear_op_pass); diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index 5c05a9f27c1fc1..fbe2b3e748d40c 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -48,8 +48,8 @@ static const std::vector support_subgraph_passes = { "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", "fuse_multi_transformer_layer_pass", "delete_quant_dequant_linear_op_pass", - "delete_weight_dequant_linear_op_encoder_pass", - "delete_weight_dequant_linear_op_decoder_pass"}; + "delete_weight_dequant_linear_op_pass", +}; Graph *Pass::Apply(Graph *graph) const { VLOG(10) << "start to apply pass " << Type() << " to graph"; diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 48f8cb37d60a98..589cf4d0d192d9 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -641,6 +641,23 @@ struct Layers { return out; } + VarDesc* dequantize_linear(VarDesc* x, + VarDesc* scale, + VarDesc* zero_point, + int bit_length = 8, + int quant_axis = -1) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("dequantize_linear"); + op->SetInput("X", {x->Name()}); + op->SetInput("Scale", {scale->Name()}); + op->SetInput("ZeroPoint", {zero_point->Name()}); + op->SetAttr("bit_length", bit_length); + op->SetAttr("quant_axis", quant_axis); + op->SetOutput("Y", {out->Name()}); + return out; + } + void backward(std::vector targets) { // This function is designed to simulate the structure of training program, // but is constructed differently as the actual program. diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.cc b/paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.cc similarity index 63% rename from paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.cc rename to paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.cc index 0cffcd38b3466a..7614a9eda8e53c 100644 --- a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.cc +++ b/paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.h" +#include "paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.h" #include #include @@ -32,8 +32,8 @@ namespace ir { GET_IR_NODE(weight_dequantize_linear_op_out); \ GET_IR_NODE(any_op2); -DeleteWeightDequantLinearOpEncoderPass:: - DeleteWeightDequantLinearOpEncoderPass() { +TrtDeleteWeightQuantDequantLinearOpPass:: + TrtDeleteWeightQuantDequantLinearOpPass() { AddOpCompat(OpCompat("quantize_linear")) .AddInput("X") .IsTensor() @@ -270,64 +270,69 @@ DeleteWeightDequantLinearOpEncoderPass:: .End(); } // Delete dequantize_linear_op, then dequantize weight -void DeleteWeightDequantLinearOpEncoderPass::ApplyImpl(ir::Graph* graph) const { +void TrtDeleteWeightQuantDequantLinearOpPass::ApplyImpl( + ir::Graph* graph) const { const std::string pattern_name = - "delete_weight_dequant_linear_op_encoder_pattern"; + "delete_weight_quantdequant_linear_op_pattern"; FusePassBase::Init(pattern_name, graph); GraphPatternDetector gpd; auto* scope = param_scope(); - PADDLE_ENFORCE_NOT_NULL(scope, - platform::errors::InvalidArgument( - "Scope in DeleteWeightDequantLinearOpEncoderPass " - "should not be null.")); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::InvalidArgument( + "Scope in TrtDeleteWeightQuantDequantLinearOpPass should not be " + "null.")); // Create pattern - patterns::DeleteWeightDequantLinearOpEncoderPattern pattern( + patterns::DeleteWeightQuantDequantLinearOpPattern pattern( gpd.mutable_pattern(), pattern_name); pattern(); int found_count = 0; - bool is_int8 = false; + + // Device context + auto* dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(platform::CPUPlace())); auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { GET_NODES; /* if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "delete_weight_dequant_linear_op_pass " + LOG(WARNING) << "trt_delete_weight_dequant_linear_op_pass " "compat check failed."; return; } */ - is_int8 = true; std::unordered_set nodes2rm = {}; + int bit_length = PADDLE_GET_CONST( + int, weight_dequantize_linear_op->Op()->GetAttr("bit_length")); + int range = ((1 << (bit_length - 1)) - 1); auto* any_op2_desc = any_op2->Op(); + // get weight tensor + auto* weight_tensor = scope->GetVar(weight_dequantize_linear_op_x->Name()) + ->GetMutable(); + int8_t* quantized_weight_data = weight_tensor->data(); + + auto w_dims = weight_tensor->dims(); + // Get weight scale std::vector weight_scale; auto* weight_scale_tensor = scope->GetVar(weight_dequantize_linear_op_scale->Name()) ->GetMutable(); - auto weight_scale_nums = weight_scale_tensor->numel(); + float* weight_scale_data = weight_scale_tensor->data(); - if (weight_scale_tensor->dtype() == - paddle::experimental::DataType::FLOAT32) { - float* weight_scale_data = weight_scale_tensor->data(); - for (int i = 0; i < weight_scale_nums; i++) { - weight_scale.push_back(weight_scale_data[i]); - } - } else if (weight_scale_tensor->dtype() == - paddle::experimental::DataType::FLOAT16) { - phi::dtype::float16* weight_scale_data = - weight_scale_tensor->data(); - for (int i = 0; i < weight_scale_nums; i++) { - weight_scale.push_back(static_cast(weight_scale_data[i])); - } - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "%d is not supported.", weight_scale_tensor->dtype())); + auto weight_scale_nums = weight_scale_tensor->numel(); + for (int i = 0; i < weight_scale_nums; i++) { + weight_scale.push_back(weight_scale_data[i] / range); } + // dequant weight + std::vector weight_data_tmp; + weight_data_tmp.reserve(weight_tensor->numel()); + int quant_axis = PADDLE_GET_CONST( int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis")); if (quant_axis == -1) { // per_layer quant_dequant: all OP @@ -337,13 +342,74 @@ void DeleteWeightDequantLinearOpEncoderPass::ApplyImpl(ir::Graph* graph) const { "When quant_axis == -1 means use per_layer " "quant_dequant, weight_scale'number should be 1.")); - // Add attr to anyop 2 - any_op2_desc->SetAttr("weight_scale", weight_scale[0]); + // float(weight) * scale + for (int i = 0; i < weight_tensor->numel(); i++) { + weight_data_tmp[i] = + static_cast(quantized_weight_data[i]) * weight_scale[0]; + } + } else if (quant_axis == 0) { // per_channel quant_dequant: conv2d, + // depthwise_conv2d, conv2d_fusion + PADDLE_ENFORCE_EQ( + weight_scale_nums, + w_dims[quant_axis], + platform::errors::InvalidArgument( + "When quant_axis == 0 means use per_channel quant_dequant, " + "weight_scale'numbers should be equal channels.")); + PADDLE_ENFORCE_EQ(w_dims.size(), + 4, + platform::errors::InvalidArgument( + "When quant_axis == 0 means use per_channel " + "quant_dequant, (conv2d, depthwise_conv2d, " + "conv2d_fusion)'s weight dims should be 4.")); + + for (int i = 0; i < weight_tensor->numel(); i++) { + int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; + weight_data_tmp[i] = static_cast(quantized_weight_data[i]) * + weight_scale[i / inner_size]; + } + } else if (quant_axis == 1) { + PADDLE_ENFORCE_EQ( + weight_scale_nums, + w_dims[quant_axis], + platform::errors::InvalidArgument( + "When quant_axis == 1 means use per_channel quant_dequant, " + "weight_scale'numbers should be equal channels.")); + + if (w_dims.size() == 4) { // conv2d_transpose + std::string quantized_op_type = any_op2->Op()->Type(); + PADDLE_ENFORCE_EQ( + quantized_op_type, + "conv2d_transpose", + platform::errors::InvalidArgument( + "When quant_axis == 1 means use per_channel quant_dequant, " + "only conv2d_transpose weight dims equal 4.")); + for (int i = 0; i < weight_tensor->numel(); i++) { + int inner_size = w_dims[2] * w_dims[3]; + weight_data_tmp[i] = static_cast(quantized_weight_data[i]) * + weight_scale[(i / inner_size) % w_dims[1]]; + } + } else if (w_dims.size() == 2) { + for (int i = 0; i < weight_tensor->numel(); i++) { + weight_data_tmp[i] = static_cast(quantized_weight_data[i]) * + weight_scale[i % w_dims[1]]; + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "When quant_axis == 1 , weight dims should be 2 or 4, please check " + "your model ")); + } } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Delete Weight Dequant Linear Op Encoder Pass is not supported for " - "per-channel quantization")); + PADDLE_THROW(platform::errors::InvalidArgument( + "quant_axis should be -1 or 0 or 1, please check your model " + "OP'attribute ")); } + weight_tensor->clear(); // clear int weight + weight_tensor->Resize(phi::make_ddim(phi::vectorize(w_dims))); + float* new_quantized_weight_data = dev_ctx->HostAlloc( + weight_tensor, weight_tensor->numel() * sizeof(float)); + memcpy(new_quantized_weight_data, + weight_data_tmp.data(), + weight_tensor->numel() * sizeof(float)); nodes2rm.insert(weight_dequantize_linear_op_scale); nodes2rm.insert(weight_dequantize_linear_op); @@ -358,7 +424,6 @@ void DeleteWeightDequantLinearOpEncoderPass::ApplyImpl(ir::Graph* graph) const { found_count++; }; gpd(graph, handler); - graph->Set("enable_int8", new bool(is_int8)); AddStatis(found_count); } @@ -366,5 +431,5 @@ void DeleteWeightDequantLinearOpEncoderPass::ApplyImpl(ir::Graph* graph) const { } // namespace framework } // namespace paddle -REGISTER_PASS(delete_weight_dequant_linear_op_encoder_pass, - paddle::framework::ir::DeleteWeightDequantLinearOpEncoderPass); +REGISTER_PASS(trt_delete_weight_dequant_linear_op_pass, + paddle::framework::ir::TrtDeleteWeightQuantDequantLinearOpPass); diff --git a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.h b/paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.h similarity index 82% rename from paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.h rename to paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.h index 866bfb7b736543..da5f4ffb75bd24 100644 --- a/paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.h +++ b/paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include "paddle/fluid/framework/ir/fuse_pass_base.h" @@ -20,10 +21,10 @@ namespace paddle { namespace framework { namespace ir { -class DeleteWeightDequantLinearOpDecoderPass : public FusePassBase { +class TrtDeleteWeightQuantDequantLinearOpPass : public FusePassBase { public: - DeleteWeightDequantLinearOpDecoderPass(); - virtual ~DeleteWeightDequantLinearOpDecoderPass() {} + TrtDeleteWeightQuantDequantLinearOpPass(); + virtual ~TrtDeleteWeightQuantDequantLinearOpPass() {} protected: void ApplyImpl(ir::Graph* graph) const override; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index be96d754acb995..f7ce5b39ed9015 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -84,16 +84,16 @@ void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) { void PaddlePassBuilder::ClearPasses() { passes_.clear(); } const std::vector kTRTSubgraphPasses({ - "adaptive_pool2d_convert_global_pass", // - "shuffle_channel_detect_pass", // - "quant_conv2d_dequant_fuse_pass", // - "delete_fill_constant_op_pass", // - "delete_quant_dequant_op_pass", // - "delete_quant_dequant_filter_op_pass", // - "delete_weight_dequant_linear_op_pass", // - "delete_quant_dequant_linear_op_pass", // - "identity_scale_op_clean_pass", // - "add_support_int8_pass", // + "adaptive_pool2d_convert_global_pass", // + "shuffle_channel_detect_pass", // + "quant_conv2d_dequant_fuse_pass", // + "delete_fill_constant_op_pass", // + "delete_quant_dequant_op_pass", // + "delete_quant_dequant_filter_op_pass", // + "trt_delete_weight_dequant_linear_op_pass", // + "delete_quant_dequant_linear_op_pass", // + "identity_scale_op_clean_pass", // + "add_support_int8_pass", // // "fc_fuse_pass", // "simplify_with_basic_ops_pass", // "trt_embedding_eltwise_layernorm_fuse_pass", // @@ -161,8 +161,7 @@ const std::vector kLiteSubgraphPasses({ const std::vector kGpuLowerPrecisionPasses{ "simplify_with_basic_ops_pass", "delete_quant_dequant_linear_op_pass", - "delete_weight_dequant_linear_op_encoder_pass", - "delete_weight_dequant_linear_op_decoder_pass", + "delete_weight_dequant_linear_op_pass", "map_depthwise_conv_to_conv_pass", "conv_bn_fuse_pass", "conv_eltwiseadd_bn_fuse_pass", @@ -210,8 +209,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "is_test_pass", // "simplify_with_basic_ops_pass", // "delete_quant_dequant_linear_op_pass", // - "delete_weight_dequant_linear_op_encoder_pass", // - "delete_weight_dequant_linear_op_decoder_pass", // + "delete_weight_dequant_linear_op_pass", // "map_depthwise_conv_to_conv_pass", // "conv_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", //