From 33b940a4c20eda9f0a62dfbf360bffcf1e2c8a1a Mon Sep 17 00:00:00 2001 From: wozna Date: Wed, 10 Aug 2022 20:03:14 +0800 Subject: [PATCH 1/2] Add int8 support for matmul+elementwiae_add fuse --- .../framework/ir/graph_pattern_detector.cc | 26 ++++++++-- .../framework/ir/graph_pattern_detector.h | 3 +- .../framework/ir/mkldnn/cpu_quantize_pass.cc | 52 ++++++++++++++++--- .../framework/ir/mkldnn/cpu_quantize_pass.h | 2 +- .../ir/mkldnn/cpu_quantize_pass_tester.cc | 40 ++++++++++++-- .../inference/api/paddle_pass_builder.cc | 6 ++- paddle/fluid/platform/mkldnn_reuse.h | 8 ++- .../quantization/quant2_int8_mkldnn_pass.py | 15 ++++-- 8 files changed, 126 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 85b3bdb874d4f..d0a18f568af41 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1892,11 +1892,21 @@ PDNode *patterns::Reshape2Matmul::operator()() { return matmul_out; } -PDNode *patterns::MatmulWithInputOps::operator()() { +PDNode *patterns::MatmulWithInputOps::operator()(bool with_residual) { auto prev_op_x = pattern->NewNode(prev_op_x_repr())->assert_is_op(); auto prev_op_y = pattern->NewNode(prev_op_y_repr())->assert_is_op(); auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul"); + + if (!with_residual) { + matmul_op->assert_more([&](Node *x) { + if (!HasInput(x, "ResidualData") || + x->Op()->Input("ResidualData").size() == 0) + return true; + return false; + }); + } + auto matmul_in_x = pattern->NewNode(matmul_in_x_repr()) ->AsInput() ->assert_is_op_input("matmul", "X"); @@ -1905,11 +1915,21 @@ PDNode *patterns::MatmulWithInputOps::operator()() { ->assert_is_op_input("matmul", "Y"); auto matmul_out = pattern->NewNode(matmul_out_repr()) ->AsOutput() - ->assert_is_op_output("matmul", "Out"); + ->assert_is_op_output("matmul", "Out") + ->assert_is_only_output_of_op("matmul"); + std::vector links_from{matmul_in_x, matmul_in_y}; + + if (with_residual) { + auto matmul_residual_data = + pattern->NewNode(matmul_residual_data_repr()) + ->AsInput() + ->assert_is_op_input("matmul", "ResidualData"); + links_from.push_back(matmul_residual_data); + } prev_op_x->LinksTo({matmul_in_x}); prev_op_y->LinksTo({matmul_in_y}); - matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({matmul_out}); + matmul_op->LinksFrom(links_from).LinksTo({matmul_out}); return matmul_out; } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index f0f7282683b71..751b82713da0d 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1191,12 +1191,13 @@ struct MatmulWithInputOps : public PatternBase { MatmulWithInputOps(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "matmul_with_input_ops") {} - PDNode* operator()(); + PDNode* operator()(bool with_residual); PATTERN_DECL_NODE(prev_op_x); PATTERN_DECL_NODE(prev_op_y); PATTERN_DECL_NODE(matmul_in_x); PATTERN_DECL_NODE(matmul_in_y); PATTERN_DECL_NODE(matmul_op); + PATTERN_DECL_NODE(matmul_residual_data); PATTERN_DECL_NODE(matmul_out); }; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index 2c99322e565e9..a05f405206581 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -733,11 +733,11 @@ void CPUQuantizePass::QuantizeImmutable(Graph* graph, LogQuantizedOpsCounter(immutable_type, quantize_immutable_count); } -void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { +void CPUQuantizePass::QuantizeMatmul(Graph* graph, bool with_residual) const { GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); patterns::MatmulWithInputOps matmul_pattern{pattern, name_scope_}; - matmul_pattern(); + matmul_pattern(with_residual); int quantize_matmul_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -754,7 +754,7 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(prev_op_y, prev_op_y, matmul_pattern); // skip if prev ops are not quantized - if (!IsOpDequantized(prev_op_x) || !IsOpDequantized(prev_op_y)) { + if (!IsOpDequantized(prev_op_x) && !IsOpDequantized(prev_op_y)) { MarkAndLogCannotQuantizeOp(matmul_op, "No other quantizable operators nearby"); return; @@ -763,9 +763,12 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern); - if (!AreScalesPresentForNodes({matmul_in_x, matmul_in_y})) { - MarkAndLogCannotQuantizeOp(matmul_op, - "No scale available for the operator"); + auto has_output_scale = AreScalesPresentForNodes({matmul_out}); + if (with_residual && !has_output_scale) { + MarkAndLogCannotQuantizeOp( + matmul_op, + "Matmul op with ResidualData input cannot be quantized " + "without output scale."); return; } @@ -780,6 +783,36 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { "are different: x(%d), y(%d).", is_x_unsigned, is_y_unsigned)); + + if (with_residual) { + GET_IR_NODE_FROM_SUBGRAPH( + matmul_residual_data, matmul_residual_data, matmul_pattern); + if (!AreScalesPresentForNodes( + {matmul_in_x, matmul_in_y, matmul_residual_data})) { + MarkAndLogCannotQuantizeOp(matmul_op, + "No scale available for the operator"); + return; + } + bool is_residual_unsigned{false}; + auto residual_scale = + GetScaleValueForNode(matmul_residual_data, &is_residual_unsigned); + + QuantizeInput(g, + matmul_op, + matmul_residual_data, + "ResidualData", + residual_scale, + is_residual_unsigned, + "Scale_in_eltwise"); + + } else { + if (!AreScalesPresentForNodes({matmul_in_x, matmul_in_y})) { + MarkAndLogCannotQuantizeOp(matmul_op, + "No scale available for the operator"); + return; + } + } + QuantizeInput(g, matmul_op, matmul_in_x, @@ -814,7 +847,9 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph) const { }; gpd(graph, handler); AddStatis(quantize_matmul_count); - LogQuantizedOpsCounter("matmul", quantize_matmul_count); + LogQuantizedOpsCounter("matmul", + quantize_matmul_count, + (with_residual ? "with residual connection" : "")); } void CPUQuantizePass::QuantizeElementwise( @@ -1132,7 +1167,8 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { QuantizeConcat(graph); QuantizePriorBox(graph); QuantizeFc(graph); - QuantizeMatmul(graph); + QuantizeMatmul(graph, false /* with_residual_data */); + QuantizeMatmul(graph, true /* with_residual_data */); QuantizeImmutable(graph, "reshape2", "X"); QuantizeImmutable(graph, "transpose2", "X"); QuantizeImmutable(graph, "slice", "Input"); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h index 56909b7fe7fb5..f26d8bfc84c15 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h @@ -54,7 +54,7 @@ class CPUQuantizePass : public FusePassBase { void QuantizePool(Graph* graph) const; void QuantizeConcat(Graph* graph) const; void QuantizePriorBox(Graph* graph) const; - void QuantizeMatmul(Graph* graph) const; + void QuantizeMatmul(Graph* graph, bool with_residual) const; void QuantizeElementwise(Graph* graph, const std::string& elementwise_type) const; void QuantizeFusionGru(Graph* graph) const; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc index fdeaeccdf94ed..4dabdd6bed0bd 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc @@ -90,6 +90,7 @@ void SetOp(ProgramDesc* prog, } else if (type == "matmul") { op->SetInput("X", {inputs[0]}); if (inputs.size() > 1) op->SetInput("Y", {inputs[1]}); + if (inputs.size() > 2) op->SetInput("ResidualData", {inputs[2]}); op->SetOutput("Out", {outputs[0]}); op->SetAttr("Scale_x", 1.0f); op->SetAttr("Scale_y", 1.0f); @@ -180,6 +181,11 @@ void CheckScales(const OpDesc* op, float scale, float shift) { scale_names.push_back("Scale_x"); scale_names.push_back("Scale_y"); scale_names.push_back("Scale_out"); + if (type == "matmul") { + auto const& names = op->InputNames(); + if (std::find(names.begin(), names.end(), "ResidualData") != names.end()) + scale_names.push_back("Scale_in_eltwise"); + } } else if (type == "fusion_gru" || type == "fusion_lstm") { EXPECT_EQ(op->GetAttrIfExists("Shift_data"), shift); EXPECT_EQ(op->GetAttrIfExists>("Scale_weights")[0], @@ -579,7 +585,7 @@ INSTANTIATE_TEST_CASE_P( }); static const std::initializer_list variable_names_matmul = { - "a", "b", "c", "d", "e", "f"}; + "a", "b", "c", "d", "e", "f", "g", "h"}; ProgramDesc BuildProgramDescMatmul() { ProgramDesc prog; @@ -599,14 +605,28 @@ ProgramDesc BuildProgramDescMatmulNotQuantized() { for (auto& v : variable_names_matmul) { prog.MutableBlock(0)->Var(v); } - SetOp(&prog, "dropout", "Dropout", {"a"}, {"b"}, false); - SetOp(&prog, "dequantize", "Dequantize", {"c"}, {"d"}, true); + SetOp(&prog, "dropout", "Dropout1", {"a"}, {"b"}, false); + SetOp(&prog, "dropout", "Dropout2", {"c"}, {"d"}, false); SetOp(&prog, "matmul", "Matmul", {"b", "d"}, {"e"}, true, "int8"); SetOp(&prog, "dropout", "Dropout", {"e"}, {"f"}, true, "float32"); return prog; } +ProgramDesc BuildProgramDescMatmulResidual() { + ProgramDesc prog; + for (auto& v : variable_names_matmul) { + prog.MutableBlock(0)->Var(v); + } + SetOp(&prog, "dequantize", "Dequantize1", {"a"}, {"b"}, true); + SetOp(&prog, "dequantize", "Dequantize2", {"c"}, {"d"}, true); + SetOp(&prog, "dequantize", "Dequantize3", {"e"}, {"f"}, true); + SetOp(&prog, "matmul", "Matmul", {"b", "d", "f"}, {"g"}, true, "int8"); + SetOp(&prog, "dropout", "Dropout", {"g"}, {"h"}, true, "float32"); + + return prog; +} + TEST(CpuQuantizePass, matmul) { // 2 Quant + 2 IN + 1 DeQuant + 1 OUT int added_nodes = 6; @@ -623,7 +643,7 @@ TEST(CpuQuantizePass, matmul_not_quantized) { // nothing change int added_nodes = 0; std::unordered_map expected_operators = { - {"matmul", 1}, {"quantize", 0}, {"dequantize", 1}}; + {"matmul", 1}, {"quantize", 0}, {"dequantize", 0}}; MainTest(BuildProgramDescMatmulNotQuantized(), variable_names_matmul, expected_operators, @@ -631,6 +651,18 @@ TEST(CpuQuantizePass, matmul_not_quantized) { 1.0f); } +TEST(CpuQuantizePass, matmul_residual) { + // 3 Quant + 3 IN + 1 DeQuant + 1 OUT + int added_nodes = 8; + std::unordered_map expected_operators = { + {"matmul", 1}, {"quantize", 3}, {"dequantize", 4}}; + MainTest(BuildProgramDescMatmulResidual(), + variable_names_matmul, + expected_operators, + added_nodes, + SCALE * S8_MAX); +} + static const std::initializer_list variable_names_elementwise = { "a", "b", "c", "d", "e", "f"}; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 6119714c38cc1..92020f9f1d89e 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -361,6 +361,8 @@ void CpuPassStrategy::EnableMkldnnInt8() { if (!use_mkldnn_int8_) { passes_.clear(); passes_.push_back("quant_dequant_mkldnn_pass"); + passes_.push_back("mkldnn_placement_pass"); + passes_.push_back("simplify_with_basic_ops_pass"); passes_.push_back("layer_norm_fuse_pass"); passes_.push_back("attention_lstm_fuse_pass"); passes_.push_back("seqconv_eltadd_relu_fuse_pass"); @@ -382,10 +384,10 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("matmul_scale_fuse_pass"); passes_.push_back("gpu_cpu_map_matmul_to_mul_pass"); passes_.push_back("repeated_fc_relu_fuse_pass"); - passes_.push_back("mkldnn_placement_pass"); passes_.push_back("depthwise_conv_mkldnn_pass"); passes_.push_back("conv_bn_fuse_pass"); passes_.push_back("conv_eltwiseadd_bn_fuse_pass"); + passes_.push_back("conv_affine_channel_mkldnn_fuse_pass"); passes_.push_back("conv_transpose_bn_fuse_pass"); passes_.push_back("conv_transpose_eltwiseadd_bn_fuse_pass"); passes_.push_back("conv_bias_mkldnn_fuse_pass"); @@ -403,10 +405,10 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("compute_propagate_scales_mkldnn_pass"); passes_.push_back("scale_matmul_fuse_pass"); passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass"); + passes_.push_back("matmul_elementwise_add_mkldnn_fuse_pass"); passes_.push_back("cpu_quantize_placement_pass"); passes_.push_back("cpu_quantize_pass"); passes_.push_back("cpu_quantize_squash_pass"); - passes_.push_back("simplify_with_basic_ops_pass"); passes_.push_back("mkldnn_inplace_pass"); passes_.push_back("runtime_context_cache_pass"); } diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index a88dc2a24863b..7f6f4ff31f0b9 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -779,10 +779,14 @@ class MatMulV2MKLDNNHandler auto* residual_data = ctx.Input("ResidualData"); auto residual_data_tz = phi::vectorize(residual_data->dims()); auto residual_data_md = memory::desc(residual_data_tz, - dnnl::memory::data_type::f32, - dnnl::memory::format_tag::abcd); + MKLDNNGetDataType(), + dnnl::memory::format_tag::any); post_operations.append_binary(dnnl::algorithm::binary_add, residual_data_md); + if (ctx.HasAttr("Scale_in_eltwise")) { + float sum_scale = scale_out / ctx.Attr("Scale_in_eltwise"); + post_operations.append_sum(sum_scale); + } } AppendActivation(ctx, post_operations); diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 9fb14e4e72021..2d0dcc6ee6a6d 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -403,11 +403,15 @@ def _remove_ctrl_vars(self, graph): def _optimize_fp32_graph(self, graph): graph = self._update_activations(graph) graph = self._remove_ctrl_vars(graph) + graph = self._apply_pass(graph, 'mkldnn_placement_pass', + ['mkldnn_enabled_op_types'], [set()]) + # remove dropout ops + graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass') graph = self._apply_pass(graph, 'layer_norm_fuse_pass') graph = self._apply_pass(graph, 'attention_lstm_fuse_pass') graph = self._apply_pass(graph, 'seqconv_eltadd_relu_fuse_pass') # graph = self._apply_pass(graph, 'seqpool_concat_fuse_pass') - graph = self._apply_pass(graph, 'seqpool_cvm_concat_fuse_pass') + # graph = self._apply_pass(graph, 'seqpool_cvm_concat_fuse_pass') # graph = self._apply_pass(graph, 'embedding_fc_lstm_fuse_pass') graph = self._apply_pass(graph, 'fc_lstm_fuse_pass') graph = self._apply_pass(graph, 'mul_lstm_fuse_pass') @@ -427,8 +431,6 @@ def _optimize_fp32_graph(self, graph): graph = self._apply_pass(graph, 'matmul_scale_fuse_pass') graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_to_mul_pass') graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass') - graph = self._apply_pass(graph, 'mkldnn_placement_pass', - ['mkldnn_enabled_op_types'], [set()]) graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass') graph = self._apply_pass(graph, 'conv_bn_fuse_pass') graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass') @@ -452,6 +454,11 @@ def _optimize_fp32_graph(self, graph): 'matmul_transpose_reshape_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'batch_norm_act_fuse_pass') graph = self._apply_pass(graph, 'softplus_activation_mkldnn_fuse_pass') + graph = self._apply_pass(graph, 'scale_matmul_fuse_pass') + graph = self._apply_pass(graph, + 'reshape_transpose_matmul_mkldnn_fuse_pass') + graph = self._apply_pass(graph, + 'matmul_elementwise_add_mkldnn_fuse_pass') # the following pass should be the last one since it will work on all fused ops. graph = self._apply_pass(graph, 'runtime_context_cache_pass') return graph @@ -477,8 +484,6 @@ def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None): return graph def _final_optimizations(self, graph): - # remove dropout ops - graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass') # make some MKL-DNN ops working inplace graph = self._apply_pass(graph, 'mkldnn_inplace_pass') return graph From 268285d44bc4d12112f4898b993aa0d4352ed96b Mon Sep 17 00:00:00 2001 From: wozna Date: Mon, 22 Aug 2022 13:19:30 +0200 Subject: [PATCH 2/2] Corrections after review and ernie test fix --- .../fluid/framework/ir/graph_pattern_detector.cc | 6 ++---- .../framework/ir/mkldnn/cpu_quantize_pass.cc | 16 +++++++--------- .../slim/quantization/quant2_int8_mkldnn_pass.py | 3 --- 3 files changed, 9 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index d0a18f568af41..f4cb91a7edd9f 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1900,10 +1900,8 @@ PDNode *patterns::MatmulWithInputOps::operator()(bool with_residual) { if (!with_residual) { matmul_op->assert_more([&](Node *x) { - if (!HasInput(x, "ResidualData") || - x->Op()->Input("ResidualData").size() == 0) - return true; - return false; + return (!HasInput(x, "ResidualData") || + x->Op()->Input("ResidualData").size() == 0); }); } diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index a05f405206581..92351d5067f6b 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -772,6 +772,12 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph, bool with_residual) const { return; } + if (!AreScalesPresentForNodes({matmul_in_x, matmul_in_y})) { + MarkAndLogCannotQuantizeOp(matmul_op, + "No scale available for the operator"); + return; + } + bool is_x_unsigned{false}, is_y_unsigned{false}; auto input_x_scale = GetScaleValueForNode(matmul_in_x, &is_x_unsigned); auto input_y_scale = GetScaleValueForNode(matmul_in_y, &is_y_unsigned); @@ -787,8 +793,7 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph, bool with_residual) const { if (with_residual) { GET_IR_NODE_FROM_SUBGRAPH( matmul_residual_data, matmul_residual_data, matmul_pattern); - if (!AreScalesPresentForNodes( - {matmul_in_x, matmul_in_y, matmul_residual_data})) { + if (!AreScalesPresentForNodes({matmul_residual_data})) { MarkAndLogCannotQuantizeOp(matmul_op, "No scale available for the operator"); return; @@ -804,13 +809,6 @@ void CPUQuantizePass::QuantizeMatmul(Graph* graph, bool with_residual) const { residual_scale, is_residual_unsigned, "Scale_in_eltwise"); - - } else { - if (!AreScalesPresentForNodes({matmul_in_x, matmul_in_y})) { - MarkAndLogCannotQuantizeOp(matmul_op, - "No scale available for the operator"); - return; - } } QuantizeInput(g, diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 2d0dcc6ee6a6d..c79613293553a 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -410,9 +410,6 @@ def _optimize_fp32_graph(self, graph): graph = self._apply_pass(graph, 'layer_norm_fuse_pass') graph = self._apply_pass(graph, 'attention_lstm_fuse_pass') graph = self._apply_pass(graph, 'seqconv_eltadd_relu_fuse_pass') - # graph = self._apply_pass(graph, 'seqpool_concat_fuse_pass') - # graph = self._apply_pass(graph, 'seqpool_cvm_concat_fuse_pass') - # graph = self._apply_pass(graph, 'embedding_fc_lstm_fuse_pass') graph = self._apply_pass(graph, 'fc_lstm_fuse_pass') graph = self._apply_pass(graph, 'mul_lstm_fuse_pass') graph = self._apply_pass(graph, 'fc_gru_fuse_pass')