Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add int8 support for matmul+elementwise_add fuse pass #45077

Merged
merged 2 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1892,11 +1892,19 @@ PDNode *patterns::Reshape2Matmul::operator()() {
return matmul_out;
}

PDNode *patterns::MatmulWithInputOps::operator()() {
PDNode *patterns::MatmulWithInputOps::operator()(bool with_residual) {
auto prev_op_x = pattern->NewNode(prev_op_x_repr())->assert_is_op();
auto prev_op_y = pattern->NewNode(prev_op_y_repr())->assert_is_op();

auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");

if (!with_residual) {
matmul_op->assert_more([&](Node *x) {
return (!HasInput(x, "ResidualData") ||
x->Op()->Input("ResidualData").size() == 0);
});
}

auto matmul_in_x = pattern->NewNode(matmul_in_x_repr())
->AsInput()
->assert_is_op_input("matmul", "X");
Expand All @@ -1905,11 +1913,21 @@ PDNode *patterns::MatmulWithInputOps::operator()() {
->assert_is_op_input("matmul", "Y");
auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput()
->assert_is_op_output("matmul", "Out");
->assert_is_op_output("matmul", "Out")
->assert_is_only_output_of_op("matmul");
std::vector<PDNode *> links_from{matmul_in_x, matmul_in_y};

if (with_residual) {
auto matmul_residual_data =
pattern->NewNode(matmul_residual_data_repr())
->AsInput()
->assert_is_op_input("matmul", "ResidualData");
links_from.push_back(matmul_residual_data);
}

prev_op_x->LinksTo({matmul_in_x});
prev_op_y->LinksTo({matmul_in_y});
matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({matmul_out});
matmul_op->LinksFrom(links_from).LinksTo({matmul_out});
return matmul_out;
}

Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -1191,12 +1191,13 @@ struct MatmulWithInputOps : public PatternBase {
MatmulWithInputOps(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_with_input_ops") {}

PDNode* operator()();
PDNode* operator()(bool with_residual);
PATTERN_DECL_NODE(prev_op_x);
PATTERN_DECL_NODE(prev_op_y);
PATTERN_DECL_NODE(matmul_in_x);
PATTERN_DECL_NODE(matmul_in_y);
PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_residual_data);
PATTERN_DECL_NODE(matmul_out);
};

Expand Down
44 changes: 39 additions & 5 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -733,11 +733,11 @@ void CPUQuantizePass::QuantizeImmutable(Graph* graph,
LogQuantizedOpsCounter(immutable_type, quantize_immutable_count);
}

void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
void CPUQuantizePass::QuantizeMatmul(Graph* graph, bool with_residual) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::MatmulWithInputOps matmul_pattern{pattern, name_scope_};
matmul_pattern();
matmul_pattern(with_residual);

int quantize_matmul_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Expand All @@ -754,7 +754,7 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(prev_op_y, prev_op_y, matmul_pattern);

// skip if prev ops are not quantized
if (!IsOpDequantized(prev_op_x) || !IsOpDequantized(prev_op_y)) {
if (!IsOpDequantized(prev_op_x) && !IsOpDequantized(prev_op_y)) {
MarkAndLogCannotQuantizeOp(matmul_op,
"No other quantizable operators nearby");
return;
Expand All @@ -763,6 +763,15 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern);

auto has_output_scale = AreScalesPresentForNodes({matmul_out});
if (with_residual && !has_output_scale) {
MarkAndLogCannotQuantizeOp(
matmul_op,
"Matmul op with ResidualData input cannot be quantized "
"without output scale.");
return;
}

if (!AreScalesPresentForNodes({matmul_in_x, matmul_in_y})) {
MarkAndLogCannotQuantizeOp(matmul_op,
"No scale available for the operator");
Expand All @@ -780,6 +789,28 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
"are different: x(%d), y(%d).",
is_x_unsigned,
is_y_unsigned));

if (with_residual) {
GET_IR_NODE_FROM_SUBGRAPH(
matmul_residual_data, matmul_residual_data, matmul_pattern);
if (!AreScalesPresentForNodes({matmul_residual_data})) {
MarkAndLogCannotQuantizeOp(matmul_op,
"No scale available for the operator");
return;
}
bool is_residual_unsigned{false};
auto residual_scale =
GetScaleValueForNode(matmul_residual_data, &is_residual_unsigned);

QuantizeInput(g,
matmul_op,
matmul_residual_data,
"ResidualData",
residual_scale,
is_residual_unsigned,
"Scale_in_eltwise");
}

QuantizeInput(g,
matmul_op,
matmul_in_x,
Expand Down Expand Up @@ -814,7 +845,9 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const {
};
gpd(graph, handler);
AddStatis(quantize_matmul_count);
LogQuantizedOpsCounter("matmul", quantize_matmul_count);
LogQuantizedOpsCounter("matmul",
quantize_matmul_count,
(with_residual ? "with residual connection" : ""));
}

void CPUQuantizePass::QuantizeElementwise(
Expand Down Expand Up @@ -1132,7 +1165,8 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeConcat(graph);
QuantizePriorBox(graph);
QuantizeFc(graph);
QuantizeMatmul(graph);
QuantizeMatmul(graph, false /* with_residual_data */);
QuantizeMatmul(graph, true /* with_residual_data */);
QuantizeImmutable(graph, "reshape2", "X");
QuantizeImmutable(graph, "transpose2", "X");
QuantizeImmutable(graph, "slice", "Input");
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class CPUQuantizePass : public FusePassBase {
void QuantizePool(Graph* graph) const;
void QuantizeConcat(Graph* graph) const;
void QuantizePriorBox(Graph* graph) const;
void QuantizeMatmul(Graph* graph) const;
void QuantizeMatmul(Graph* graph, bool with_residual) const;
void QuantizeElementwise(Graph* graph,
const std::string& elementwise_type) const;
void QuantizeFusionGru(Graph* graph) const;
Expand Down
40 changes: 36 additions & 4 deletions paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ void SetOp(ProgramDesc* prog,
} else if (type == "matmul") {
op->SetInput("X", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Y", {inputs[1]});
if (inputs.size() > 2) op->SetInput("ResidualData", {inputs[2]});
op->SetOutput("Out", {outputs[0]});
op->SetAttr("Scale_x", 1.0f);
op->SetAttr("Scale_y", 1.0f);
Expand Down Expand Up @@ -180,6 +181,11 @@ void CheckScales(const OpDesc* op, float scale, float shift) {
scale_names.push_back("Scale_x");
scale_names.push_back("Scale_y");
scale_names.push_back("Scale_out");
if (type == "matmul") {
auto const& names = op->InputNames();
if (std::find(names.begin(), names.end(), "ResidualData") != names.end())
scale_names.push_back("Scale_in_eltwise");
}
} else if (type == "fusion_gru" || type == "fusion_lstm") {
EXPECT_EQ(op->GetAttrIfExists<float>("Shift_data"), shift);
EXPECT_EQ(op->GetAttrIfExists<std::vector<float>>("Scale_weights")[0],
Expand Down Expand Up @@ -579,7 +585,7 @@ INSTANTIATE_TEST_CASE_P(
});

static const std::initializer_list<std::string> variable_names_matmul = {
"a", "b", "c", "d", "e", "f"};
"a", "b", "c", "d", "e", "f", "g", "h"};

ProgramDesc BuildProgramDescMatmul() {
ProgramDesc prog;
Expand All @@ -599,14 +605,28 @@ ProgramDesc BuildProgramDescMatmulNotQuantized() {
for (auto& v : variable_names_matmul) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "dropout", "Dropout", {"a"}, {"b"}, false);
SetOp(&prog, "dequantize", "Dequantize", {"c"}, {"d"}, true);
SetOp(&prog, "dropout", "Dropout1", {"a"}, {"b"}, false);
SetOp(&prog, "dropout", "Dropout2", {"c"}, {"d"}, false);
SetOp(&prog, "matmul", "Matmul", {"b", "d"}, {"e"}, true, "int8");
SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, "float32");

return prog;
}

ProgramDesc BuildProgramDescMatmulResidual() {
ProgramDesc prog;
for (auto& v : variable_names_matmul) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true);
SetOp(&prog, "dequantize", "Dequantize2", {"c"}, {"d"}, true);
SetOp(&prog, "dequantize", "Dequantize3", {"e"}, {"f"}, true);
SetOp(&prog, "matmul", "Matmul", {"b", "d", "f"}, {"g"}, true, "int8");
SetOp(&prog, "dropout", "Dropout", {"g"}, {"h"}, true, "float32");

return prog;
}

TEST(CpuQuantizePass, matmul) {
// 2 Quant + 2 IN + 1 DeQuant + 1 OUT
int added_nodes = 6;
Expand All @@ -623,14 +643,26 @@ TEST(CpuQuantizePass, matmul_not_quantized) {
// nothing change
int added_nodes = 0;
std::unordered_map<std::string, int> expected_operators = {
{"matmul", 1}, {"quantize", 0}, {"dequantize", 1}};
{"matmul", 1}, {"quantize", 0}, {"dequantize", 0}};
MainTest(BuildProgramDescMatmulNotQuantized(),
variable_names_matmul,
expected_operators,
added_nodes,
1.0f);
}

TEST(CpuQuantizePass, matmul_residual) {
// 3 Quant + 3 IN + 1 DeQuant + 1 OUT
int added_nodes = 8;
std::unordered_map<std::string, int> expected_operators = {
{"matmul", 1}, {"quantize", 3}, {"dequantize", 4}};
MainTest(BuildProgramDescMatmulResidual(),
variable_names_matmul,
expected_operators,
added_nodes,
SCALE * S8_MAX);
}

static const std::initializer_list<std::string> variable_names_elementwise = {
"a", "b", "c", "d", "e", "f"};

Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ void CpuPassStrategy::EnableMkldnnInt8() {
if (!use_mkldnn_int8_) {
passes_.clear();
passes_.push_back("quant_dequant_mkldnn_pass");
passes_.push_back("mkldnn_placement_pass");
passes_.push_back("simplify_with_basic_ops_pass");
passes_.push_back("layer_norm_fuse_pass");
passes_.push_back("attention_lstm_fuse_pass");
passes_.push_back("seqconv_eltadd_relu_fuse_pass");
Expand All @@ -382,10 +384,10 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("matmul_scale_fuse_pass");
passes_.push_back("gpu_cpu_map_matmul_to_mul_pass");
passes_.push_back("repeated_fc_relu_fuse_pass");
passes_.push_back("mkldnn_placement_pass");
passes_.push_back("depthwise_conv_mkldnn_pass");
passes_.push_back("conv_bn_fuse_pass");
passes_.push_back("conv_eltwiseadd_bn_fuse_pass");
passes_.push_back("conv_affine_channel_mkldnn_fuse_pass");
passes_.push_back("conv_transpose_bn_fuse_pass");
passes_.push_back("conv_transpose_eltwiseadd_bn_fuse_pass");
passes_.push_back("conv_bias_mkldnn_fuse_pass");
Expand All @@ -403,10 +405,10 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("compute_propagate_scales_mkldnn_pass");
passes_.push_back("scale_matmul_fuse_pass");
passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass");
passes_.push_back("matmul_elementwise_add_mkldnn_fuse_pass");
passes_.push_back("cpu_quantize_placement_pass");
passes_.push_back("cpu_quantize_pass");
passes_.push_back("cpu_quantize_squash_pass");
passes_.push_back("simplify_with_basic_ops_pass");
passes_.push_back("mkldnn_inplace_pass");
passes_.push_back("runtime_context_cache_pass");
}
Expand Down
8 changes: 6 additions & 2 deletions paddle/fluid/platform/mkldnn_reuse.h
Original file line number Diff line number Diff line change
Expand Up @@ -779,10 +779,14 @@ class MatMulV2MKLDNNHandler
auto* residual_data = ctx.Input<Tensor>("ResidualData");
auto residual_data_tz = phi::vectorize(residual_data->dims());
auto residual_data_md = memory::desc(residual_data_tz,
dnnl::memory::data_type::f32,
dnnl::memory::format_tag::abcd);
MKLDNNGetDataType<OT>(),
dnnl::memory::format_tag::any);
post_operations.append_binary(dnnl::algorithm::binary_add,
residual_data_md);
if (ctx.HasAttr("Scale_in_eltwise")) {
float sum_scale = scale_out / ctx.Attr<float>("Scale_in_eltwise");
post_operations.append_sum(sum_scale);
}
}

AppendActivation(ctx, post_operations);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,12 +403,13 @@ def _remove_ctrl_vars(self, graph):
def _optimize_fp32_graph(self, graph):
graph = self._update_activations(graph)
graph = self._remove_ctrl_vars(graph)
graph = self._apply_pass(graph, 'mkldnn_placement_pass',
['mkldnn_enabled_op_types'], [set()])
# remove dropout ops
graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass')
graph = self._apply_pass(graph, 'layer_norm_fuse_pass')
graph = self._apply_pass(graph, 'attention_lstm_fuse_pass')
graph = self._apply_pass(graph, 'seqconv_eltadd_relu_fuse_pass')
# graph = self._apply_pass(graph, 'seqpool_concat_fuse_pass')
graph = self._apply_pass(graph, 'seqpool_cvm_concat_fuse_pass')
# graph = self._apply_pass(graph, 'embedding_fc_lstm_fuse_pass')
graph = self._apply_pass(graph, 'fc_lstm_fuse_pass')
graph = self._apply_pass(graph, 'mul_lstm_fuse_pass')
graph = self._apply_pass(graph, 'fc_gru_fuse_pass')
Expand All @@ -427,8 +428,6 @@ def _optimize_fp32_graph(self, graph):
graph = self._apply_pass(graph, 'matmul_scale_fuse_pass')
graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_to_mul_pass')
graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass')
graph = self._apply_pass(graph, 'mkldnn_placement_pass',
['mkldnn_enabled_op_types'], [set()])
graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass')
graph = self._apply_pass(graph, 'conv_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass')
Expand All @@ -452,6 +451,11 @@ def _optimize_fp32_graph(self, graph):
'matmul_transpose_reshape_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'batch_norm_act_fuse_pass')
graph = self._apply_pass(graph, 'softplus_activation_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'scale_matmul_fuse_pass')
graph = self._apply_pass(graph,
'reshape_transpose_matmul_mkldnn_fuse_pass')
graph = self._apply_pass(graph,
'matmul_elementwise_add_mkldnn_fuse_pass')
# the following pass should be the last one since it will work on all fused ops.
graph = self._apply_pass(graph, 'runtime_context_cache_pass')
return graph
Expand All @@ -477,8 +481,6 @@ def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None):
return graph

def _final_optimizations(self, graph):
# remove dropout ops
graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass')
# make some MKL-DNN ops working inplace
graph = self._apply_pass(graph, 'mkldnn_inplace_pass')
return graph
Expand Down