Skip to content

Commit

Permalink
limit with input_spec via Paddle API, test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Avin0323 committed Oct 15, 2021
1 parent e70abde commit f86bd96
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 33 deletions.
31 changes: 29 additions & 2 deletions paddle/fluid/framework/ir/generate_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ namespace ir {

void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
const proto::BlockDesc& block = pass_desc.pattern().blocks(0);
for (const proto::VarDesc& var : block.vars()) {
PDNode* var_pdnode = pattern->NewNode(var.name())->AsInput();
var_pdnode->assert_is_var();
var_pdnode->assert_more([&](Node* x) {
if (VarDesc(var).GetShape() == x->Var()->GetShape()) {
return true;
}
return false;
});
}
// Traverse all operators to create subgraph.
for (int index = 0; index < block.ops_size(); ++index) {
const proto::OpDesc& op = block.ops(index);
Expand All @@ -31,15 +41,32 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
pattern->NewNode(std::to_string(index))->assert_is_op(op.type());
// Create PDNodes for inputs of current operator.
for (const proto::OpDesc::Var& var : op.inputs()) {
for (const std::string& argument : var.arguments()) {
for (int n = 0; n < var.arguments_size(); ++n) {
const std::string& argument = var.arguments(n);
// The input may be the output of other operator.
PDNode* var_pdnode = pattern->RetrieveNode(argument);
if (nullptr == var_pdnode) {
var_pdnode = pattern->NewNode(argument)->AsInput();
var_pdnode->assert_is_var();
} else if (var_pdnode->IsOutput()) {
var_pdnode->AsIntermediate();
}
var_pdnode->assert_is_op_input(op.type());
var_pdnode->assert_more([&](Node* x) {
for (auto* out : x->outputs) {
if (out->IsOp() && out->Op()->Type() == op.type()) {
const auto& inputs = out->Op()->Inputs();
const auto& iter = inputs.find(var.parameter());
if (inputs.end() != iter) {
if (iter->second.end() != std::find(iter->second.begin(),
iter->second.end(),
x->Name())) {
return true;
}
}
}
}
return false;
});
pattern->AddEdge(var_pdnode, op_pdnode);
}
}
Expand Down
27 changes: 25 additions & 2 deletions python/paddle/fluid/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,32 @@ def _get_args_from_func(self, func):

def _prune_program_desc(self, program_desc):
block_desc = program_desc.blocks[0]
block_desc.ClearField("vars")
# block_desc.ClearField("vars")
for var in [
var for var in block_desc.vars
if var.name not in self._input_specs
]:
block_desc.vars.remove(var)
for op_desc in block_desc.ops:
op_desc.ClearField("attrs")
default_attrs = core.get_op_attrs_default_value(
paddle.compat.to_bytes(op_desc.type))
remove_attrs = list()
for attr in op_desc.attrs:
# attr must not in
if attr.name not in [
"op_namescope", "op_callstack", "op_device"
]:
attr_list_fields = attr.ListFields()
# attr format must be: name, type, value
if len(attr_list_fields) == 3:
attr_value = attr.ListFields()[-1][-1]
default_attr_value = default_attrs.get(attr.name)
# value must not default
if default_attr_value != attr_value:
continue
remove_attrs.append(attr)
for attr in remove_attrs:
op_desc.attrs.remove(attr)

def _func_to_program_desc(self, func, program_desc, is_replace=False):
vars = list()
Expand Down
64 changes: 35 additions & 29 deletions python/paddle/fluid/tests/unittests/ir/test_ir_generate_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def multi_add_to_sum_v3():

# mul(x, y1), mul(x, y2) => slice(mul(x, concat(y1, y2)))
@ir.RegisterPass(input_specs={
'x': InputSpec([1, 1]),
'y1': InputSpec([1, 1]),
'y2': InputSpec([1, 1])
'x': InputSpec([16, 32]),
'y1': InputSpec([32, 12]),
'y2': InputSpec([32, 48])
})
def generate_combine_mul_v1():
def pattern(x, y1, y2):
Expand All @@ -86,8 +86,8 @@ def pattern(x, y1, y2):
def replace(x, y1, y2):
concat_out = paddle.concat([y1, y2], axis=-1)
mul_out = paddle.matmul(x, concat_out)
out1 = paddle.slice(mul_out, axes=[1], starts=[0], ends=[1])
out2 = paddle.slice(mul_out, axes=[1], starts=[1], ends=[2])
out1 = paddle.slice(mul_out, axes=[1], starts=[0], ends=[12])
out2 = paddle.slice(mul_out, axes=[1], starts=[12], ends=[60])
return out1, out2

return pattern, replace
Expand All @@ -111,11 +111,11 @@ def replace(x, y1, y2):


# reshape(reshape(x)) => x
@ir.RegisterPass(input_specs={'x': InputSpec([-1, 16, 16, 16])})
@ir.RegisterPass(input_specs={'x': InputSpec([10, 16, 16])})
def generate_simplify_inference_v1():
def pattern(x):
transpose = paddle.transpose(x, [0, 3, 1, 2])
return paddle.transpose(transpose, [0, 3, 1, 2])
transpose = paddle.transpose(x, [0, 2, 1])
return paddle.transpose(transpose, [0, 2, 1])

return pattern, lambda x: x

Expand Down Expand Up @@ -217,28 +217,34 @@ def test_multi_add_to_sum(self):
self.check_multi_add_to_sum("multi_add_to_sum_v3")

def test_generate_combine_mul_v1(self):
input_specs = {
'x': InputSpec([1, 1]),
'y1': InputSpec([1, 1]),
'y2': InputSpec([1, 1])
paddle.enable_static()
program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(program, startup_program):
x = paddle.static.data("x", [16, 32])
y = paddle.static.data("y", [32, 12])
z = paddle.static.data("z", [32, 48])
out1 = paddle.matmul(x, y)
out2 = paddle.matmul(x, z)
graph = core.Graph(program.desc)
before_node_nums = len(graph.nodes())
core.get_pass("generate_combine_mul_v1").apply(graph)
after_node_nums = len(graph.nodes())
self.assertEqual(after_node_nums, before_node_nums + 4)
after_program = paddle.fluid.framework.IrGraph(graph).to_program()
executor = paddle.static.Executor(paddle.CPUPlace())
executor.run(startup_program)
feed = {
"x": np.random.random([16, 32]).astype("float32"),
"y": np.random.random([32, 12]).astype("float32"),
"z": np.random.random([32, 48]).astype("float32")
}
helper = ir.RegisterPassHelper(
[generate_combine_mul_v1()], input_specs=input_specs)
s = helper.SerializeMultiPassDesc()
multi_pass_desc = get_multi_pass_desc_from_str(s)
self.assertEqual(len(multi_pass_desc.pass_descs), 1)
pass_desc = multi_pass_desc.pass_descs[0]
self.assertEqual(len(pass_desc.var_maps), 5)
self.assertEqual(len(pass_desc.pattern.blocks[0].ops), 2)
self.assertEqual(len(pass_desc.replace.blocks[0].ops), 4)
pattern_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.pattern.blocks[0].ops)
replace_op_dicts = self.convert_ops_to_op_dicts(
pass_desc.replace.blocks[0].ops)
self.assertEqual(len(pattern_op_dicts.get("matmul_v2", [])), 2)
self.assertEqual(len(replace_op_dicts.get("concat", [])), 1)
self.assertEqual(len(replace_op_dicts.get("matmul_v2", [])), 1)
self.assertEqual(len(replace_op_dicts.get("slice", [])), 2)
before_out1, before_out2 = executor.run(
program, feed=feed, fetch_list=[out1.name, out2.name])
after_out1, after_out2 = executor.run(
after_program, feed=feed, fetch_list=[out1.name, out2.name])
self.assertTrue(np.allclose(before_out1, after_out1))
self.assertTrue(np.allclose(before_out2, after_out2))

def test_generate_combine_mul_v2(self):
helper = ir.RegisterPassHelper([generate_combine_mul_v2()])
Expand Down

1 comment on commit f86bd96

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.