Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#32 from gongshaotian/drr_dfs_test
Browse files Browse the repository at this point in the history
Upgrade the subgraph matching algorithm to PatternGraphMatchV2
  • Loading branch information
yuanlehome authored Sep 18, 2023
2 parents 9f59cdd + 46a734a commit 5239e7d
Show file tree
Hide file tree
Showing 6 changed files with 310 additions and 36 deletions.
274 changes: 267 additions & 7 deletions paddle/fluid/pir/drr/drr_rewrite_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <memory>
#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <vector>

Expand Down Expand Up @@ -58,7 +59,7 @@ class DrrRewritePattern : public pir::RewritePattern {
PatternRewriter& rewriter) const override { // NOLINT
std::shared_ptr<MatchContextImpl> src_match_ctx =
std::make_shared<MatchContextImpl>();
if (PatternGraphMatch(op, src_match_ctx)) {
if (PatternGraphMatchV2(op, src_match_ctx.get())) {
VLOG(4) << "DRR pattern (" << pir::get_type_name<DrrPattern>()
<< ") is matched in program.";
PatternGraphRewrite(*src_match_ctx, rewriter);
Expand Down Expand Up @@ -123,7 +124,6 @@ class DrrRewritePattern : public pir::RewritePattern {
source_pattern_match_ctx->BindIrValue(
drr_input_tensors[i]->name(),
std::make_shared<IrValue>(ir_input_value));

// Input tensor is optional(or none)
if (!ir_input_value) {
if (drr_brother_ops.size() != 1) { // Only used by current op
Expand All @@ -134,7 +134,6 @@ class DrrRewritePattern : public pir::RewritePattern {
}
continue;
}

if (drr_brother_ops.size() != ir_input_value.use_count()) {
matched = false;
VLOG(6) << " --- match false: " << drr_brother_ops.size()
Expand Down Expand Up @@ -265,7 +264,7 @@ class DrrRewritePattern : public pir::RewritePattern {
VLOG(6) << "step: " << step
<< " CountOfOpCalls: " << source_pattern_graph_->CountOfOpCalls();
IR_ENFORCE(step == source_pattern_graph_->CountOfOpCalls(),
"step not equal to count of opcalls");
"step not equal to count of OpCalls");
} else {
VLOG(6) << " --- match false: " << op->name();
return matched;
Expand All @@ -285,6 +284,270 @@ class DrrRewritePattern : public pir::RewritePattern {
return matched;
}

std::unordered_map<const OpCall*, std::unordered_set<pir::Operation*>>
FindCandidateIrOutputOp(
pir::Operation* op,
const OpCall* anchor,
const SourcePatternGraph& source_pattern_graph) const {
// get source pattern output op
std::unordered_set<const OpCall*> drr_output_op_set =
source_pattern_graph.OutputNodes();
std::unordered_map<const OpCall*, std::unordered_set<pir::Operation*>>
output_op_bind_map{{anchor, {op}}};
if (drr_output_op_set.size() == 1) {
return output_op_bind_map;
}
std::unordered_set<const OpCall*> drr_visited_ops{anchor};
DfsVisitor(
anchor, op, drr_output_op_set, &drr_visited_ops, &output_op_bind_map);
if (output_op_bind_map.size() != drr_output_op_set.size()) {
return {};
}
return output_op_bind_map;
}

void DfsVisitor(
const OpCall* drr_op,
pir::Operation* ir_op,
const std::unordered_set<const OpCall*>& drr_output_op_set,
std::unordered_set<const OpCall*>* drr_visited_ops,
std::unordered_map<const OpCall*, std::unordered_set<pir::Operation*>>*
output_op_bind_map) const {
VLOG(6) << "DfsVisitor Start: drr op(" << drr_op->name() << ")"
<< "ir op(" << ir_op->name() << ")";
if (drr_op->name() != ir_op->name()) {
return;
}
// check input's size
const auto& drr_op_input_tensors = drr_op->inputs();
auto ir_op_input_value_size = ir_op->num_operands();
if (drr_op_input_tensors.size() != ir_op_input_value_size) {
return;
}
// check output's size
const auto& drr_op_output_tensors = drr_op->outputs();
auto ir_op_output_value_size = ir_op->num_results();
if (drr_op_output_tensors.size() != ir_op_output_value_size) {
return;
}
// check producer op
for (size_t i = 0; i < drr_op_input_tensors.size(); ++i) {
// case 1: drr_op_input_tensor is the input tensor of source pattern
if (drr_op_input_tensors[i]->producer() == nullptr) {
// dfs source pattern input tensor other child op
auto ir_input_tensor = ir_op->operand(i).source();
for (auto drr_bro_op : drr_op_input_tensors[i]->consumers()) {
if (drr_visited_ops->count(drr_bro_op)) {
continue;
}
for (auto it = ir_input_tensor.use_begin();
it != ir_input_tensor.use_end();
++it) {
auto* ir_bro_op = it.owner();
if (drr_bro_op->name() == ir_bro_op->name()) {
drr_visited_ops->insert(drr_bro_op);
DfsVisitor(drr_bro_op,
ir_bro_op,
drr_output_op_set,
drr_visited_ops,
output_op_bind_map);
drr_visited_ops->erase(drr_bro_op);
}
}
}
continue;
}
// case 2: have producer op
const auto& drr_producer_op = drr_op_input_tensors[i]->producer();
if (drr_visited_ops->count(drr_producer_op)) {
continue;
}
auto ir_operand_value = ir_op->operand(i).source();
if (drr_op_input_tensors[i]->consumers().size() !=
ir_operand_value.use_count()) {
return;
}
auto* ir_producer_op = ir_operand_value.GetDefiningOp();
drr_visited_ops->insert(drr_producer_op);
DfsVisitor(drr_producer_op,
ir_producer_op,
drr_output_op_set,
drr_visited_ops,
output_op_bind_map);

drr_visited_ops->erase(drr_producer_op);
}
if (drr_output_op_set.count(drr_op)) {
(*output_op_bind_map)[drr_op].insert(ir_op);
return;
}
// check child ops
for (size_t i = 0; i < drr_op_output_tensors.size(); ++i) {
const auto& drr_child_ops = drr_op_output_tensors[i]->consumers();
auto ir_output_value = ir_op->result(i);
if (drr_child_ops.size() != ir_output_value.use_count()) {
return;
}
for (auto* drr_child_op : drr_child_ops) {
for (auto it = ir_output_value.use_begin();
it != ir_output_value.use_end();
++it) {
auto* ir_child_op = it.owner();
if (drr_child_op->name() == ir_child_op->name()) {
if (drr_visited_ops->count(drr_child_op)) {
continue;
}
drr_visited_ops->insert(drr_child_op);
DfsVisitor(drr_child_op,
ir_child_op,
drr_output_op_set,
drr_visited_ops,
output_op_bind_map);
drr_visited_ops->erase(drr_child_op);
}
}
}
} // check child ops
return;
}

bool MatchFromOutputToInput(
std::unordered_map<const OpCall*, Operation*> output_op_map,
const SourcePatternGraph& source_pattern_graph,
const std::shared_ptr<MatchContextImpl>& source_pattern_match_ctx) const {
VLOG(6) << "MatchFromOutputToInput Start";
std::unordered_set<const OpCall*> drr_visited;
std::unordered_set<Operation*> ir_visited;
std::queue<const OpCall*> drr_q;
std::queue<pir::Operation*> ir_q;
bool matched = true;
size_t step = 0;
for (auto it = output_op_map.begin(); it != output_op_map.end(); ++it) {
drr_q.push(it->first);
drr_visited.insert(it->first);
ir_q.push(it->second);
ir_visited.insert(it->second);
}
while (!drr_q.empty()) {
if (!matched) break;
auto* drr_node = drr_q.front();
auto* ir_node = ir_q.front();
drr_q.pop();
ir_q.pop();
if (drr_node->name() != ir_node->name()) {
matched = false;
break;
}
const auto& drr_input_tensors = drr_node->inputs();
auto ir_input_value_size = ir_node->num_operands();
if (drr_input_tensors.size() != ir_input_value_size) {
matched = false;
break;
}
if (drr_node->outputs().size() != ir_node->num_results()) {
matched = false;
break;
}
source_pattern_match_ctx->BindIrOperation(
drr_node, std::make_shared<IrOperation>(ir_node));
// binding input_tensor of current_op
for (size_t i = 0; i < drr_input_tensors.size(); ++i) {
source_pattern_match_ctx->BindIrValue(
drr_input_tensors[i]->name(),
std::make_shared<IrValue>(ir_node->operand(i).source()));
auto* drr_producer_op = drr_input_tensors[i]->producer();
auto* ir_producer_op = ir_node->operand(i).source().GetDefiningOp();
if (drr_producer_op == nullptr) {
continue;
}
if (drr_input_tensors[i]->consumers().size() !=
ir_node->operand(i).source().use_count()) {
matched = false;
break;
}
// bfs producer_op of current_op
if (!drr_visited.count(drr_producer_op)) {
drr_q.push(drr_producer_op);
ir_q.push(ir_producer_op);
drr_visited.insert(drr_producer_op);
ir_visited.insert(ir_producer_op);
}
}
// binding output tensor of current_op
auto drr_op_output_tensor = drr_node->outputs();
for (size_t j = 0; j < drr_op_output_tensor.size(); j++) {
source_pattern_match_ctx->BindIrValue(
drr_op_output_tensor[j]->name(),
std::make_shared<IrValue>(ir_node->result(j)));
}
++step;
}

if (matched) {
IR_ENFORCE(step == source_pattern_graph.CountOfOpCalls());
} else {
return matched;
}

MatchContext match_context{source_pattern_match_ctx};
for (const auto& constraint : constraints_) {
matched = constraint(match_context);
if (!matched) break;
}

return matched;
}

bool PatternGraphMatchV2(pir::Operation* op,
MatchContextImpl* source_pattern_match_ctx) const {
VLOG(6) << "PatternGraphMatch Start: op(" << op->name() << ")";
const OpCall* anchor = source_pattern_graph_->AnchorNode();
std::unordered_map<const OpCall*, std::unordered_set<pir::Operation*>>
bind_map =
FindCandidateIrOutputOp(op, anchor, *(source_pattern_graph_.get()));
if (bind_map.empty()) {
return false;
}
std::vector<const OpCall*> drr_output_sequence;
std::vector<Operation*> ir_output_sequence;
std::unordered_map<const OpCall*, Operation*> output_op_map;
for (auto pair : bind_map) {
drr_output_sequence.push_back(pair.first);
}
// using dfs to obtain the arrangement of all candidate ir ops
auto permute = [&](auto&& permute, size_t index) -> bool {
if (index == drr_output_sequence.size()) {
// new match_ctx
std::shared_ptr<MatchContextImpl> match_ctx =
std::make_shared<MatchContextImpl>();
// create output op map
std::transform(drr_output_sequence.begin(),
drr_output_sequence.end(),
ir_output_sequence.begin(),
std::inserter(output_op_map, output_op_map.end()),
[](const OpCall* drr_op, Operation* ir_op) {
return std::make_pair(drr_op, ir_op);
});
if (MatchFromOutputToInput(
output_op_map, *(source_pattern_graph_.get()), match_ctx)) {
*source_pattern_match_ctx = *match_ctx;
return true;
}
return false;
}
for (auto* ir_op : bind_map[drr_output_sequence[index]]) {
ir_output_sequence.push_back(ir_op);
if (permute(permute, index + 1)) {
return true;
}
ir_output_sequence.pop_back();
}
return false;
};

return permute(permute, 0);
}

void PatternGraphRewrite(const MatchContextImpl& source_pattern_match_ctx,
pir::PatternRewriter& rewriter) const { // NOLINT
VLOG(6) << "Create Operations in result_pattern_graph";
Expand Down Expand Up @@ -322,7 +585,6 @@ class DrrRewritePattern : public pir::RewritePattern {
std::make_shared<IrValue>(src_match_ctx.GetIrValue(in_tensor)));
}
}

// set insert point
for (const auto& output : result_pattern_graph.output_tensors()) {
if (source_pattern_graph.id2owend_tensor().count(output)) {
Expand All @@ -333,7 +595,6 @@ class DrrRewritePattern : public pir::RewritePattern {
}
}
}

// topo order visit result_pattern_graph
GraphTopo graph_topo_visit(&result_pattern_graph);
graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) {
Expand Down Expand Up @@ -455,7 +716,6 @@ class DrrRewritePattern : public pir::RewritePattern {
rewriter.EraseOp(op);
}
}

const std::shared_ptr<SourcePatternGraph> source_pattern_graph_;
const std::vector<Constraint> constraints_;
const std::shared_ptr<ResultPatternGraph> result_pattern_graph_;
Expand Down
18 changes: 10 additions & 8 deletions paddle/fluid/pir/drr/ir_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ class IrValue : public TensorInterface {
public:
explicit IrValue(const pir::Value& value)
: value_(value),
shape_(value ? &value.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims()
: nullptr),
dtype_(value ? &value.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dtype()
: nullptr) {}
shape_((value && value.type())
? &value.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims()
: nullptr),
dtype_((value && value.type())
? &value.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dtype()
: nullptr) {}

ShapeInterface Shape() const override { return ShapeInterface(&shape_); }
DtypeInterface Dtype() const override { return DtypeInterface(&dtype_); }
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/pir/drr/match_context_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ class MatchContextImpl final {
return attr_map_;
}

const std::unordered_map<std::string, std::shared_ptr<IrValue>>& tensor_map()
const {
return tensor_map_;
}

void BindIrValue(const std::string& value_name,
const std::shared_ptr<IrValue>& value) {
tensor_map_.emplace(value_name, value);
Expand Down
15 changes: 11 additions & 4 deletions paddle/fluid/pir/drr/pattern_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,19 @@ OpCall *SourcePatternGraph::AnchorNode() const {
return id2owned_tensor_.at(*output_tensors_.begin())->producer();
}

std::vector<OpCall *> SourcePatternGraph::OutputNodes() const {
std::vector<OpCall *> output_nodes;
std::unordered_set<const OpCall *> SourcePatternGraph::OutputNodes() const {
std::unordered_set<const OpCall *> output_op_set;
for (const auto &output_tensor : output_tensors_) {
output_nodes.push_back(id2owned_tensor_.at(output_tensor)->producer());
OpCall *output_op_candidate =
id2owned_tensor_.at(output_tensor)->producer();
if (std::all_of(output_op_candidate->outputs().begin(),
output_op_candidate->outputs().end(),
[this](const Tensor *output) -> bool {
return this->output_tensors().count(output->name());
}))
output_op_set.insert(output_op_candidate);
}
return output_nodes;
return output_op_set;
}

void ResultPatternGraph::AssignTensor(const Tensor &from, const Tensor &to) {
Expand Down
Loading

0 comments on commit 5239e7d

Please sign in to comment.