Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick]Verify the correctness of graph rewrited by GeneratePass #36453

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 106 additions & 11 deletions paddle/fluid/framework/ir/generate_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ namespace ir {

void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
const proto::BlockDesc& block = pass_desc.pattern().blocks(0);
for (const proto::VarDesc& var : block.vars()) {
PDNode* var_pdnode = pattern->NewNode(var.name())->AsInput();
var_pdnode->assert_is_var();
var_pdnode->assert_more([&](Node* x) {
if (VarDesc(var).GetShape() == x->Var()->GetShape()) {
return true;
}
return false;
});
}
// Traverse all operators to create subgraph.
for (int index = 0; index < block.ops_size(); ++index) {
const proto::OpDesc& op = block.ops(index);
Expand All @@ -31,15 +41,32 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
pattern->NewNode(std::to_string(index))->assert_is_op(op.type());
// Create PDNodes for inputs of current operator.
for (const proto::OpDesc::Var& var : op.inputs()) {
for (const std::string& argument : var.arguments()) {
for (int n = 0; n < var.arguments_size(); ++n) {
const std::string& argument = var.arguments(n);
// The input may be the output of other operator.
PDNode* var_pdnode = pattern->RetrieveNode(argument);
if (nullptr == var_pdnode) {
var_pdnode = pattern->NewNode(argument)->AsInput();
var_pdnode->assert_is_var();
} else if (var_pdnode->IsOutput()) {
var_pdnode->AsIntermediate();
}
var_pdnode->assert_is_op_input(op.type());
var_pdnode->assert_more([&](Node* x) {
for (auto* out : x->outputs) {
if (out->IsOp() && out->Op()->Type() == op.type()) {
const auto& inputs = out->Op()->Inputs();
const auto& iter = inputs.find(var.parameter());
if (inputs.end() != iter) {
if (iter->second.end() != std::find(iter->second.begin(),
iter->second.end(),
x->Name())) {
return true;
}
}
}
}
return false;
});
pattern->AddEdge(var_pdnode, op_pdnode);
}
}
Expand All @@ -50,6 +77,24 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
PDNode* var_pdnode = pattern->RetrieveNode(argument);
if (nullptr == var_pdnode) {
var_pdnode = pattern->NewNode(argument)->AsOutput();
var_pdnode->assert_is_var();
var_pdnode->assert_more([&](Node* x) {
for (Node* input : x->inputs) {
if (input && input->IsOp() && input->Op() &&
input->Op()->Type() == op.type()) {
const auto& outputs = input->Op()->Outputs();
const auto& iter = outputs.find(var.parameter());
if (outputs.end() != iter) {
if (iter->second.end() != std::find(iter->second.begin(),
iter->second.end(),
x->Name())) {
return true;
}
}
}
}
return false;
});
} else if (var_pdnode->IsInput()) {
var_pdnode->AsIntermediate();
}
Expand All @@ -73,18 +118,64 @@ void InitGeneratePattern(const proto::PassDesc& pass_desc, PDPattern* pattern) {
}
}

GraphPatternDetector::handle_t GetGenerateRewrite(
// There are some duplicate patterns.
bool IsDuplicatePattern(const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
for (auto iter : subgraph) {
if (nullptr == graph->RetrieveNode(iter.second->id())) {
VLOG(3) << "Node [" << iter.second->Name()
<< "] of subgraph has been removed. So skip this optimize.";
return true;
}
}
return false;
}

GraphPatternDetector::handle_t GetGenerateDelete(
const PDPattern& pattern, const proto::PassDesc& pass_desc) {
GraphPatternDetector::handle_t handler = [&](
const GraphPatternDetector::subgraph_t subgraph, Graph* graph) {
// There are some duplicate patterns.
for (auto iter : subgraph) {
if (nullptr == graph->RetrieveNode(iter.second->id())) {
VLOG(3) << "Node [" << iter.second->Name()
<< "] of subgraph has been removed. So skip this optimize.";
return;
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
if (IsDuplicatePattern(subgraph, graph)) {
return;
}
// `var_node_maps` record the mapping of variable to the pattern subgraph.
std::map<std::string, Node*> var_node_maps;
for (const proto::PassDesc::VarMap& var_map : pass_desc.var_maps()) {
Node* node = subgraph.at(pattern.RetrieveNode(var_map.pattern_var()));
const auto& iter = var_node_maps.find(var_map.replace_var());
if (var_node_maps.end() == iter) {
// first node is input
var_node_maps.insert({var_map.replace_var(), node});
} else {
// output node
for (Node* s_node : node->outputs) {
iter->second->outputs.push_back(s_node);
std::replace(s_node->inputs.begin(), s_node->inputs.end(), node,
iter->second);
s_node->Op()->RenameInput(node->Name(), iter->second->Name());
}
}
}
// Remove nodes that are intermediate.
std::unordered_set<const Node*> remove_nodes;
for (const std::unique_ptr<PDNode>& pdnode : pattern.nodes()) {
remove_nodes.emplace(subgraph.at(pdnode.get()));
}
for (auto iter : var_node_maps) {
remove_nodes.erase(iter.second);
}
GraphSafeRemoveNodes(graph, remove_nodes);
};
return handler;
}

GraphPatternDetector::handle_t GetGenerateRewrite(
const PDPattern& pattern, const proto::PassDesc& pass_desc) {
GraphPatternDetector::handle_t handler = [&](
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
if (IsDuplicatePattern(subgraph, graph)) {
return;
}
const proto::BlockDesc& block = pass_desc.replace().blocks(0);
// `var_node_maps` record the mapping of variable to the pattern subgraph.
std::map<std::string, Node*> var_node_maps;
Expand Down Expand Up @@ -175,7 +266,11 @@ void GeneratePass::ApplyImpl(Graph* graph) const {
for (const proto::PassDesc& pass_desc : multi_pass_desc_.pass_descs()) {
GraphPatternDetector detector;
InitGeneratePattern(pass_desc, detector.mutable_pattern());
detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc));
if (pass_desc.replace().blocks(0).ops_size() == 0) {
detector(graph, GetGenerateDelete(detector.pattern(), pass_desc));
} else {
detector(graph, GetGenerateRewrite(detector.pattern(), pass_desc));
}
// The rewrited graph needs to be verified. Current Pass should be skipped
// if validation failed. Rewrite based on the original graph cannot
// implement rollback operation.
Expand Down
43 changes: 38 additions & 5 deletions python/paddle/fluid/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,13 @@ def apply_pass(name):


class RegisterPassHelper(object):
_register_helpers = list()

def __init__(self, pass_pairs, pass_type=str(), input_specs=dict()):
self._pass_type = pass_type
self._pass_pairs = pass_pairs
if isinstance(input_specs, dict):
self._input_specs = input_specs
self._input_specs = input_specs
RegisterPassHelper._register_helpers.append(self)

def _get_args_from_func(self, func):
args = list()
Expand All @@ -148,6 +150,35 @@ def _get_args_from_func(self, func):
args.append(paddle.static.data(arg_name, [-1]))
return args

def _prune_program_desc(self, program_desc):
block_desc = program_desc.blocks[0]
# block_desc.ClearField("vars")
for var in [
var for var in block_desc.vars
if var.name not in self._input_specs
]:
block_desc.vars.remove(var)
for op_desc in block_desc.ops:
default_attrs = core.get_op_attrs_default_value(
paddle.compat.to_bytes(op_desc.type))
remove_attrs = list()
for attr in op_desc.attrs:
# attr must not in
if attr.name not in [
"op_namescope", "op_callstack", "op_device"
]:
attr_list_fields = attr.ListFields()
# attr format must be: name, type, value
if len(attr_list_fields) == 3:
attr_value = attr.ListFields()[-1][-1]
default_attr_value = default_attrs.get(attr.name)
# value must not default
if default_attr_value != attr_value:
continue
remove_attrs.append(attr)
for attr in remove_attrs:
op_desc.attrs.remove(attr)

def _func_to_program_desc(self, func, program_desc, is_replace=False):
vars = list()
program = paddle.static.Program()
Expand All @@ -166,6 +197,7 @@ def _func_to_program_desc(self, func, program_desc, is_replace=False):
elif isinstance(out, paddle.fluid.framework.Variable):
vars.append(out.name)
program_desc.ParseFromString(program.desc.serialize_to_string())
self._prune_program_desc(program_desc)
if is_replace:
attrs = list()
for op in program.current_block().ops:
Expand Down Expand Up @@ -296,7 +328,7 @@ def Outputs(self):
OP = OpHelper()


def RegisterPass(function=None, input_specs=None):
def RegisterPass(function=None, input_specs=dict()):
"""
The function decorator of Register Pass. Decorator @RegisterPass handles
the function and register it into a core.Pass instance. Use name of function
Expand All @@ -305,11 +337,11 @@ def RegisterPass(function=None, input_specs=None):
Args:
function (callable): The function with return of callable pair(s) that
represents the pattern subgraph and the replace subgraph.
input_specs (dict[str, InputSpec]|None): Dict of InputSpec to specific the shape/dtype
input_specs (dict[str, InputSpec]): Dict of InputSpec to specific the shape/dtype
information of Tensor. Some operators limit the shape and dtype of datas when
create subgraph with Paddle APIs. So user need specify InputSpec of data to
ensure create a correctly subgraph. Of course, this argument is not limited to
matching subgraph. The default is None.
matching subgraph. The default is dict().

Returns:
callables: Callable pair(s).
Expand Down Expand Up @@ -351,6 +383,7 @@ def decorated(python_func):
"Return value of Pass function must be (callable, callable)."
)
helper = RegisterPassHelper(pass_pairs, pass_type, input_specs)
core.register_pass(pass_type, helper.SerializeMultiPassDesc)
return python_func

if inspect.isfunction(function):
Expand Down
Loading