Skip to content

Commit

Permalink
add the quant_linear_fuse_pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Wanglongzhi2001 committed Nov 7, 2023
1 parent 77e4ada commit 82c1f0d
Show file tree
Hide file tree
Showing 5 changed files with 510 additions and 0 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ pass_library(delete_quant_dequant_filter_op_pass inference)
pass_library(trt_delete_weight_dequant_linear_op_pass inference)
pass_library(delete_op_device_pass inference)
pass_library(delete_weight_dequant_linear_op_pass inference)
pass_library(quant_linear_fuse_pass inference)
pass_library(delete_quant_dequant_linear_op_pass inference)
pass_library(delete_assign_op_pass inference)
pass_library(delete_dropout_op_pass inference)
Expand Down
40 changes: 40 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3331,6 +3331,46 @@ void patterns::DeleteWeightDequantLinearOpEncoderPattern::operator()() {
any_op2->LinksFrom({weight_dequantize_linear_op_out});
}

PDNode *patterns::QuantLinearFusePattern::operator()(paddle::framework::ir::PDNode *x,
bool with_bias) {
// Create shared nodes.
x->assert_is_op_input("matmul_v2", "X");
auto *mul = pattern->NewNode(mul_repr())->assert_is_op("matmul_v2");

auto *mul_w_var = pattern->NewNode(w_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("matmul_v2", "Y");

auto *mul_out_var =
pattern->NewNode(mul_out_repr())->assert_is_op_output("matmul_v2");

// Add links.
mul->LinksFrom({x, mul_w_var}).LinksTo({mul_out_var});
if (!with_bias) { // not with bias
return mul_out_var;
} else { // with bias
mul_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
// Create operators.
auto *elementwise_add = pattern->NewNode(elementwise_add_repr())
->assert_is_op("elementwise_add");
// Create variables.
auto *bias = pattern->NewNode(bias_repr())
->assert_is_op_input("elementwise_add")
->assert_is_persistable_var()
->AsInput();

auto *elementwise_add_out_var =
pattern->NewNode(elementwise_add_out_repr())
->AsOutput()
->assert_is_op_output("elementwise_add");

elementwise_add->LinksFrom({mul_out_var, bias})
.LinksTo({elementwise_add_out_var});
return elementwise_add_out_var;
}
}

void patterns::DeleteWeightDequantLinearOpDecoderPattern::operator()() {
auto weight_dequantize_linear_op_x =
pattern->NewNode(weight_dequantize_linear_op_x_repr())
Expand Down
18 changes: 18 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1841,6 +1841,24 @@ struct DeleteWeightDequantLinearOpEncoderPattern : public PatternBase {
PATTERN_DECL_NODE(any_op2);
};

struct QuantLinearFusePattern : public PatternBase {
QuantLinearFusePattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fc") {}

PDNode* operator()(PDNode* x, bool with_bias);

// declare operator node's name
PATTERN_DECL_NODE(quant_linear);
PATTERN_DECL_NODE(mul);
PATTERN_DECL_NODE(elementwise_add);
PATTERN_DECL_NODE(relu);
// declare variable node's name
PATTERN_DECL_NODE(w);
PATTERN_DECL_NODE(mul_out); // (x,w) -> mul_out
PATTERN_DECL_NODE(bias);
PATTERN_DECL_NODE(elementwise_add_out);
};

struct DeleteWeightDequantLinearOpDecoderPattern : public PatternBase {
DeleteWeightDequantLinearOpDecoderPattern(PDPattern* pattern,
const std::string& name_scope)
Expand Down
Loading

0 comments on commit 82c1f0d

Please sign in to comment.