Skip to content

Commit

Permalink
[cherry-pick]Verify the correctness of graph rewrited by GeneratePass (
Browse files Browse the repository at this point in the history
…#36453)

* [WIP]Verify the correctness of graph rewrited by GeneratePass, test=develop

* add delete subgraph and unittest, test=develop

* check simple pass, test=develop

* fix coverage, test=develop

* limit with input_spec via Paddle API, test=develop
  • Loading branch information
Avin0323 authored Oct 15, 2021
1 parent fc429fe commit cc44965
Show file tree
Hide file tree
Showing 3 changed files with 275 additions and 81 deletions.
117 changes: 106 additions & 11 deletions paddle/fluid/framework/ir/generate_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ namespace ir {

void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
const proto::BlockDesc& block = pass_desc.pattern().blocks(0);
for (const proto::VarDesc& var : block.vars()) {
PDNode* var_pdnode = pattern->NewNode(var.name())->AsInput();
var_pdnode->assert_is_var();
var_pdnode->assert_more([&](Node* x) {
if (VarDesc(var).GetShape() == x->Var()->GetShape()) {
return true;
}
return false;
});
}
// Traverse all operators to create subgraph.
for (int index = 0; index < block.ops_size(); ++index) {
const proto::OpDesc& op = block.ops(index);
Expand All @@ -31,15 +41,32 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
pattern->NewNode(std::to_string(index))->assert_is_op(op.type());
// Create PDNodes for inputs of current operator.
for (const proto::OpDesc::Var& var : op.inputs()) {
for (const std::string& argument : var.arguments()) {
for (int n = 0; n < var.arguments_size(); ++n) {
const std::string& argument = var.arguments(n);
// The input may be the output of other operator.
PDNode* var_pdnode = pattern->RetrieveNode(argument);
if (nullptr == var_pdnode) {
var_pdnode = pattern->NewNode(argument)->AsInput();
var_pdnode->assert_is_var();
} else if (var_pdnode->IsOutput()) {
var_pdnode->AsIntermediate();
}
var_pdnode->assert_is_op_input(op.type());
var_pdnode->assert_more([&](Node* x) {
for (auto* out : x->outputs) {
if (out->IsOp() && out->Op()->Type() == op.type()) {
const auto& inputs = out->Op()->Inputs();
const auto& iter = inputs.find(var.parameter());
if (inputs.end() != iter) {
if (iter->second.end() != std::find(iter->second.begin(),
iter->second.end(),
x->Name())) {
return true;
}
}
}
}
return false;
});
pattern->AddEdge(var_pdnode, op_pdnode);
}
}
Expand All @@ -50,6 +77,24 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
PDNode* var_pdnode = pattern->RetrieveNode(argument);
if (nullptr == var_pdnode) {
var_pdnode = pattern->NewNode(argument)->AsOutput();
var_pdnode->assert_is_var();
var_pdnode->assert_more([&](Node* x) {
for (Node* input : x->inputs) {
if (input && input->IsOp() && input->Op() &&
input->Op()->Type() == op.type()) {
const auto& outputs = input->Op()->Outputs();
const auto& iter = outputs.find(var.parameter());
if (outputs.end() != iter) {
if (iter->second.end() != std::find(iter->second.begin(),
iter->second.end(),
x->Name())) {
return true;
}
}
}
}
return false;
});
} else if (var_pdnode->IsInput()) {
var_pdnode->AsIntermediate();
}
Expand All @@ -73,18 +118,64 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
}
}

GraphPatternDetector::handle_t GetGenerateRewrite(
// There are some duplicate patterns.
bool IsDuplicatePattern(const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
for (auto iter : subgraph) {
if (nullptr == graph->RetrieveNode(iter.second->id())) {
VLOG(3) << "Node [" << iter.second->Name()
<< "] of subgraph has been removed. So skip this optimize.";
return true;
}
}
return false;
}

GraphPatternDetector::handle_t GetGenerateDelete(
const PDPattern& pattern, const proto::PassDesc& pass_desc) {
GraphPatternDetector::handle_t handler = [&](
const GraphPatternDetector::subgraph_t subgraph, Graph* graph) {
// There are some duplicate patterns.
for (auto iter : subgraph) {
if (nullptr == graph->RetrieveNode(iter.second->id())) {
VLOG(3) << "Node [" << iter.second->Name()
<< "] of subgraph has been removed. So skip this optimize.";
return;
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
if (IsDuplicatePattern(subgraph, graph)) {
return;
}
// `var_node_maps` record the mapping of variable to the pattern subgraph.
std::map<std::string, Node*> var_node_maps;
for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) {
Node* node = subgraph.at(pattern.RetrieveNode(var_map.pattern_var()));
const auto& iter = var_node_maps.find(var_map.replace_var());
if (var_node_maps.end() == iter) {
// first node is input
var_node_maps.insert({var_map.replace_var(), node});
} else {
// output node
for (Node* s_node : node->outputs) {
iter->second->outputs.push_back(s_node);
std::replace(s_node->inputs.begin(), s_node->inputs.end(), node,
iter->second);
s_node->Op()->RenameInput(node->Name(), iter->second->Name());
}
}
}
// Remove nodes that are intermediate.
std::unordered_set<const Node*> remove_nodes;
for (const std::unique_ptr<PDNode>& pdnode : pattern.nodes()) {
remove_nodes.emplace(subgraph.at(pdnode.get()));
}
for (auto iter : var_node_maps) {
remove_nodes.erase(iter.second);
}
GraphSafeRemoveNodes(graph, remove_nodes);
};
return handler;
}

GraphPatternDetector::handle_t GetGenerateRewrite(
const PDPattern& pattern, const proto::PassDesc& pass_desc) {
GraphPatternDetector::handle_t handler = [&](
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
if (IsDuplicatePattern(subgraph, graph)) {
return;
}
const proto::BlockDesc& block = pass_desc.replace().blocks(0);
// `var_node_maps` record the mapping of variable to the pattern subgraph.
std::map<std::string, Node*> var_node_maps;
Expand Down Expand Up @@ -175,7 +266,11 @@ void GeneratePass::ApplyImpl(Graph* graph) const {
for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) {
GraphPatternDetector detector;
InitGeneratePattern(pass_desc, detector.mutable_pattern());
detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc));
if (pass_desc.replace().blocks(0).ops_size() == 0) {
detector(graph, GetGenerateDelete(detector.pattern(), pass_desc));
} else {
detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc));
}
// The rewrited graph needs to be verified. Current Pass should be skipped
// if validation failed. Rewrite based on the original graph cannot
// implement rollback operation.
Expand Down
43 changes: 38 additions & 5 deletions python/paddle/fluid/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,13 @@ def apply_pass(name):


class RegisterPassHelper(object):
_register_helpers = list()

def __init__(self, pass_pairs, pass_type=str(), input_specs=dict()):
self._pass_type = pass_type
self._pass_pairs = pass_pairs
if isinstance(input_specs, dict):
self._input_specs = input_specs
self._input_specs = input_specs
RegisterPassHelper._register_helpers.append(self)

def _get_args_from_func(self, func):
args = list()
Expand All @@ -148,6 +150,35 @@ def _get_args_from_func(self, func):
args.append(paddle.static.data(arg_name, [-1]))
return args

def _prune_program_desc(self, program_desc):
block_desc = program_desc.blocks[0]
# block_desc.ClearField("vars")
for var in [
var for var in block_desc.vars
if var.name not in self._input_specs
]:
block_desc.vars.remove(var)
for op_desc in block_desc.ops:
default_attrs = core.get_op_attrs_default_value(
paddle.compat.to_bytes(op_desc.type))
remove_attrs = list()
for attr in op_desc.attrs:
# attr must not in
if attr.name not in [
"op_namescope", "op_callstack", "op_device"
]:
attr_list_fields = attr.ListFields()
# attr format must be: name, type, value
if len(attr_list_fields) == 3:
attr_value = attr.ListFields()[-1][-1]
default_attr_value = default_attrs.get(attr.name)
# value must not default
if default_attr_value != attr_value:
continue
remove_attrs.append(attr)
for attr in remove_attrs:
op_desc.attrs.remove(attr)

def _func_to_program_desc(self, func, program_desc, is_replace=False):
vars = list()
program = paddle.static.Program()
Expand All @@ -166,6 +197,7 @@ def _func_to_program_desc(self, func, program_desc, is_replace=False):
elif isinstance(out, paddle.fluid.framework.Variable):
vars.append(out.name)
program_desc.ParseFromString(program.desc.serialize_to_string())
self._prune_program_desc(program_desc)
if is_replace:
attrs = list()
for op in program.current_block().ops:
Expand Down Expand Up @@ -296,7 +328,7 @@ def Outputs(self):
OP = OpHelper()


def RegisterPass(function=None, input_specs=None):
def RegisterPass(function=None, input_specs=dict()):
"""
The function decorator of Register Pass. Decorator @RegisterPass handles
the function and register it into a core.Pass instance. Use name of function
Expand All @@ -305,11 +337,11 @@ def RegisterPass(function=None, input_specs=None):
Args:
function (callable): The function with return of callable pair(s) that
represents the pattern subgraph and the replace subgraph.
input_specs (dict[str, InputSpec]|None): Dict of InputSpec to specific the shape/dtype
input_specs (dict[str, InputSpec]): Dict of InputSpec to specific the shape/dtype
information of Tensor. Some operators limit the shape and dtype of datas when
create subgraph with Paddle APIs. So user need specify InputSpec of data to
ensure create a correctly subgraph. Of course, this argument is not limited to
matching subgraph. The default is None.
matching subgraph. The default is dict().
Returns:
callables: Callable pair(s).
Expand Down Expand Up @@ -351,6 +383,7 @@ def decorated(python_func):
"Return value of Pass function must be (callable, callable)."
)
helper = RegisterPassHelper(pass_pairs, pass_type, input_specs)
core.register_pass(pass_type, helper.SerializeMultiPassDesc)
return python_func

if inspect.isfunction(function):
Expand Down
Loading

0 comments on commit cc44965

Please sign in to comment.