diff --git a/paddle/fluid/framework/ir/generate_pass.cc b/paddle/fluid/framework/ir/generate_pass.cc index 658b6a79a7448..b261cbeb08e3b 100644 --- a/paddle/fluid/framework/ir/generate_pass.cc +++ b/paddle/fluid/framework/ir/generate_pass.cc @@ -21,6 +21,16 @@ namespace ir { void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { const proto::BlockDesc& block = pass_desc.pattern().blocks(0); + for (const proto::VarDesc& var : block.vars()) { + PDNode* var_pdnode = pattern->NewNode(var.name())->AsInput(); + var_pdnode->assert_is_var(); + var_pdnode->assert_more([&](Node* x) { + if (VarDesc(var).GetShape() == x->Var()->GetShape()) { + return true; + } + return false; + }); + } // Traverse all operators to create subgraph. for (int index = 0; index < block.ops_size(); ++index) { const proto::OpDesc& op = block.ops(index); @@ -31,15 +41,32 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) { pattern->NewNode(std::to_string(index))->assert_is_op(op.type()); // Create PDNodes for inputs of current operator. for (const proto::OpDesc::Var& var : op.inputs()) { - for (const std::string& argument : var.arguments()) { + for (int n = 0; n < var.arguments_size(); ++n) { + const std::string& argument = var.arguments(n); // The input may be the output of other operator. PDNode* var_pdnode = pattern->RetrieveNode(argument); if (nullptr == var_pdnode) { var_pdnode = pattern->NewNode(argument)->AsInput(); + var_pdnode->assert_is_var(); } else if (var_pdnode->IsOutput()) { var_pdnode->AsIntermediate(); } - var_pdnode->assert_is_op_input(op.type()); + var_pdnode->assert_more([&](Node* x) { + for (auto* out : x->outputs) { + if (out->IsOp() && out->Op()->Type() == op.type()) { + const auto& inputs = out->Op()->Inputs(); + const auto& iter = inputs.find(var.parameter()); + if (inputs.end() != iter) { + if (iter->second.end() != std::find(iter->second.begin(), + iter->second.end(), + x->Name())) { + return true; + } + } + } + } + return false; + }); pattern->AddEdge(var_pdnode, op_pdnode); } } diff --git a/python/paddle/fluid/ir.py b/python/paddle/fluid/ir.py index 496fc9b2ff9ec..3c7c8879fd420 100644 --- a/python/paddle/fluid/ir.py +++ b/python/paddle/fluid/ir.py @@ -152,9 +152,32 @@ def _get_args_from_func(self, func): def _prune_program_desc(self, program_desc): block_desc = program_desc.blocks[0] - block_desc.ClearField("vars") + # block_desc.ClearField("vars") + for var in [ + var for var in block_desc.vars + if var.name not in self._input_specs + ]: + block_desc.vars.remove(var) for op_desc in block_desc.ops: - op_desc.ClearField("attrs") + default_attrs = core.get_op_attrs_default_value( + paddle.compat.to_bytes(op_desc.type)) + remove_attrs = list() + for attr in op_desc.attrs: + # attr must not in + if attr.name not in [ + "op_namescope", "op_callstack", "op_device" + ]: + attr_list_fields = attr.ListFields() + # attr format must be: name, type, value + if len(attr_list_fields) == 3: + attr_value = attr.ListFields()[-1][-1] + default_attr_value = default_attrs.get(attr.name) + # value must not default + if default_attr_value != attr_value: + continue + remove_attrs.append(attr) + for attr in remove_attrs: + op_desc.attrs.remove(attr) def _func_to_program_desc(self, func, program_desc, is_replace=False): vars = list() diff --git a/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py b/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py index c56fef9217ba2..61bd554ad2616 100644 --- a/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py @@ -73,9 +73,9 @@ def multi_add_to_sum_v3(): # mul(x, y1), mul(x, y2) => slice(mul(x, concat(y1, y2))) @ir.RegisterPass(input_specs={ - 'x': InputSpec([1, 1]), - 'y1': InputSpec([1, 1]), - 'y2': InputSpec([1, 1]) + 'x': InputSpec([16, 32]), + 'y1': InputSpec([32, 12]), + 'y2': InputSpec([32, 48]) }) def generate_combine_mul_v1(): def pattern(x, y1, y2): @@ -86,8 +86,8 @@ def pattern(x, y1, y2): def replace(x, y1, y2): concat_out = paddle.concat([y1, y2], axis=-1) mul_out = paddle.matmul(x, concat_out) - out1 = paddle.slice(mul_out, axes=[1], starts=[0], ends=[1]) - out2 = paddle.slice(mul_out, axes=[1], starts=[1], ends=[2]) + out1 = paddle.slice(mul_out, axes=[1], starts=[0], ends=[12]) + out2 = paddle.slice(mul_out, axes=[1], starts=[12], ends=[60]) return out1, out2 return pattern, replace @@ -111,11 +111,11 @@ def replace(x, y1, y2): # reshape(reshape(x)) => x -@ir.RegisterPass(input_specs={'x': InputSpec([-1, 16, 16, 16])}) +@ir.RegisterPass(input_specs={'x': InputSpec([10, 16, 16])}) def generate_simplify_inference_v1(): def pattern(x): - transpose = paddle.transpose(x, [0, 3, 1, 2]) - return paddle.transpose(transpose, [0, 3, 1, 2]) + transpose = paddle.transpose(x, [0, 2, 1]) + return paddle.transpose(transpose, [0, 2, 1]) return pattern, lambda x: x @@ -217,28 +217,34 @@ def test_multi_add_to_sum(self): self.check_multi_add_to_sum("multi_add_to_sum_v3") def test_generate_combine_mul_v1(self): - input_specs = { - 'x': InputSpec([1, 1]), - 'y1': InputSpec([1, 1]), - 'y2': InputSpec([1, 1]) + paddle.enable_static() + program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(program, startup_program): + x = paddle.static.data("x", [16, 32]) + y = paddle.static.data("y", [32, 12]) + z = paddle.static.data("z", [32, 48]) + out1 = paddle.matmul(x, y) + out2 = paddle.matmul(x, z) + graph = core.Graph(program.desc) + before_node_nums = len(graph.nodes()) + core.get_pass("generate_combine_mul_v1").apply(graph) + after_node_nums = len(graph.nodes()) + self.assertEqual(after_node_nums, before_node_nums + 4) + after_program = paddle.fluid.framework.IrGraph(graph).to_program() + executor = paddle.static.Executor(paddle.CPUPlace()) + executor.run(startup_program) + feed = { + "x": np.random.random([16, 32]).astype("float32"), + "y": np.random.random([32, 12]).astype("float32"), + "z": np.random.random([32, 48]).astype("float32") } - helper = ir.RegisterPassHelper( - [generate_combine_mul_v1()], input_specs=input_specs) - s = helper.SerializeMultiPassDesc() - multi_pass_desc = get_multi_pass_desc_from_str(s) - self.assertEqual(len(multi_pass_desc.pass_descs), 1) - pass_desc = multi_pass_desc.pass_descs[0] - self.assertEqual(len(pass_desc.var_maps), 5) - self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2) - self.assertEqual(len(pass_desc.replace.blocks[0].ops), 4) - pattern_op_dicts = self.convert_ops_to_op_dicts( - pass_desc.pattern.blocks[0].ops) - replace_op_dicts = self.convert_ops_to_op_dicts( - pass_desc.replace.blocks[0].ops) - self.assertEqual(len(pattern_op_dicts.get("matmul_v2", [])), 2) - self.assertEqual(len(replace_op_dicts.get("concat", [])), 1) - self.assertEqual(len(replace_op_dicts.get("matmul_v2", [])), 1) - self.assertEqual(len(replace_op_dicts.get("slice", [])), 2) + before_out1, before_out2 = executor.run( + program, feed=feed, fetch_list=[out1.name, out2.name]) + after_out1, after_out2 = executor.run( + after_program, feed=feed, fetch_list=[out1.name, out2.name]) + self.assertTrue(np.allclose(before_out1, after_out1)) + self.assertTrue(np.allclose(before_out2, after_out2)) def test_generate_combine_mul_v2(self): helper = ir.RegisterPassHelper([generate_combine_mul_v2()])