Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] IndexedGraph improvements in preparation for Collage #11481

Merged
merged 5 commits into from
Jun 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 52 additions & 38 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace relay {

// Pattern Matcher
bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
VLOG(1) << "Match " << PrettyPrint(pattern) << " in:" << std::endl << PrettyPrint(expr);
memo_.clear();
matched_nodes_.clear();
return VisitDFPattern(pattern, expr);
Expand All @@ -58,6 +59,7 @@ bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr
if (out) {
memo_[pattern].push_back(expr);
matched_nodes_.push_back(pattern);
VLOG(1) << "Matched " << PrettyPrint(pattern) << " at:" << std::endl << PrettyPrint(expr);
} else {
ClearMap(watermark);
}
Expand Down Expand Up @@ -124,7 +126,6 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons
if (!matches) {
return matches;
}
VLOG(1) << "considering AttrPatternNode at:\n" << PrettyPrint(expr);
auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
if (const auto* op_node = expr.as<OpNode>()) {
Op op = GetRef<Op>(op_node);
Expand Down Expand Up @@ -299,14 +300,18 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
// Recursively find the Dominator parent along all inputs paths.
bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
auto call_node = expr.as<CallNode>();
for (auto node : expr_graph_.node_map_.at(expr)->inputs_) {
if (!(call_node && node->ref_ == call_node->op)) {
auto index_node = expr_to_node(expr);
for (auto node : index_node->inputs_) {
if (!(call_node && node->ref() == call_node->op)) {
memoize_ = true;
if (VisitDFPattern(op->parent, node->ref_)) {
if (VisitDFPattern(op->parent, node->ref())) {
return true;
} else {
memoize_ = false;
if (!VisitDFPattern(op->path, node->ref_) || !MatchesPath(op, node->ref_)) {
if (!VisitDFPattern(op->path, node->ref())) {
return false;
}
if (!MatchesPath(op, node->ref())) {
return false;
}
}
Expand All @@ -318,19 +323,19 @@ bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& e
// Iteratively ensure that the parent is dominated somewhere by the child or the path
bool DFPatternMatcher::DominatesParent(const DominatorPatternNode* op, const Expr& expr) {
std::stack<Expr> stack;
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> visited;
std::unordered_set<const ExprNode*> visited;
stack.push(expr);
while (!stack.empty()) {
Expr current = stack.top();
stack.pop();
for (auto node : expr_graph_.node_map_.at(current)->dominator_children_) {
if (visited.count(node->ref_) == 0) {
if (VisitDFPattern(op->parent, node->ref_)) {
for (auto node : expr_to_node(current)->dominator_children_) {
if (visited.count(node->node_ref_) == 0) {
if (VisitDFPattern(op->parent, node->ref())) {
return true;
} else {
stack.push(node->ref_);
stack.push(node->ref());
}
visited.insert(node->ref_);
visited.insert(node->node_ref_);
}
}
}
Expand Down Expand Up @@ -500,7 +505,8 @@ bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr
}

bool MatchPattern(DFPattern pattern, Expr expr) {
return DFPatternMatcher(expr).Match(pattern, expr);
std::unique_ptr<IndexedGraph<Expr>> expr_graph = CreateIndexedGraph(expr);
return DFPatternMatcher(expr_graph.get()).Match(pattern, expr);
}

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern);
Expand Down Expand Up @@ -575,17 +581,18 @@ const std::unordered_map<int, PatternGrouper::Group>& PatternGrouper::GroupMatch

pattern_ = pattern;
pattern_graph_ = CreateIndexedGraph(pattern_);
auto matcher = DFPatternMatcher(pre);
std::unique_ptr<IndexedGraph<Expr>> expr_graph = CreateIndexedGraph(pre);
DFPatternMatcher matcher(expr_graph.get());
matcher_ = &matcher;
this->VisitExprs();
return this->groups_;
}

void PatternGrouper::VisitExprs() {
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> pre_partitioned;
for (size_t i = matcher_->expr_graph_.topological_order_.size(); i != 0; --i) {
size_t index = i - 1;
Expr current = matcher_->expr_graph_.topological_order_.at(index)->ref_;
for (PostDfsIndex i = matcher_->size(); i != 0; --i) {
PostDfsIndex index = i - 1;
const auto current = matcher_->index_to_node(index)->ref();
if (gid_assignments_.count(current) == 0) { // Don't visit nodes we've already grouped
if (auto op = current.as<FunctionNode>()) {
if (op->attrs.defined() && op->attrs->dict.count(attr::kPartitionedFromPattern) != 0) {
Expand All @@ -607,22 +614,24 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
auto node_map = matcher_->GetMemo();
// Get fuzzy patterns
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> fuzzy_matches;
for (auto node : pattern_graph_.topological_order_) {
for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) {
auto node = pattern_graph_->index_to_node(index);
// Don't treat fuzzy Dominator patterns input variables for partition
if (auto op = node->ref_.as<DominatorPatternNode>()) {
if (auto op = node->ref().as<DominatorPatternNode>()) {
for (auto fuzzy_op : {op->parent, op->path}) {
for (auto match : node_map[fuzzy_op]) {
fuzzy_matches.insert(match);
}
}
}
// Don't treat Function params or body as input variables for partition
if (node->ref_.as<FunctionPatternNode>()) {
auto matches = node_map[node->ref_];
if (node->ref().as<FunctionPatternNode>()) {
auto matches = node_map[node->ref()];
for (auto match : matches) {
auto graph = CreateIndexedGraph(match.as<FunctionNode>()->body);
for (auto node : graph.topological_order_) {
fuzzy_matches.insert(node->ref_);
auto sub_graph = CreateIndexedGraph(match.as<FunctionNode>()->body);
for (PostDfsIndex sub_index = 0; sub_index < sub_graph->size(); ++sub_index) {
auto sub_node = sub_graph->index_to_node(sub_index);
fuzzy_matches.insert(sub_node->ref());
}
}
}
Expand All @@ -636,10 +645,11 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs;
Array<Var> params;

for (auto node : pattern_graph_.topological_order_) {
for (PostDfsIndex index = 0; index < pattern_graph_->size(); ++index) {
auto node = pattern_graph_->index_to_node(index);
auto make_input = [&](const Expr& input) {
if (fuzzy_matches.count(input) == 0 && input.as<OpNode>() == nullptr &&
input.as<FunctionNode>() == nullptr && !EmbedConst(input, node->ref_)) {
input.as<FunctionNode>() == nullptr && !EmbedConst(input, node->ref())) {
inputs[input] =
Var("FunctionVar_" + std::to_string(graph_number_) + "_" + std::to_string(var_number),
NullValue<Type>());
Expand All @@ -648,29 +658,29 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
var_number++;
}
};
auto tuple = node->ref_.as<TuplePatternNode>();
auto call = node->ref_.as<CallPatternNode>();
auto tuple = node->ref().as<TuplePatternNode>();
auto call = node->ref().as<CallPatternNode>();
if (tuple && !tuple->fields.defined()) {
if (node_map.count(node->ref_)) {
auto matches = node_map[node->ref_];
if (node_map.count(node->ref())) {
auto matches = node_map[node->ref()];
for (auto match : matches) {
for (auto input : match.as<TupleNode>()->fields) {
make_input(input);
}
}
}
} else if (call && !call->args.defined()) {
if (node_map.count(node->ref_)) {
auto matches = node_map[node->ref_];
if (node_map.count(node->ref())) {
auto matches = node_map[node->ref()];
for (auto match : matches) {
for (auto input : match.as<CallNode>()->args) {
make_input(input);
}
}
}
} else if (node->inputs_.size() == 0) {
if (node_map.count(node->ref_)) {
auto matches = node_map[node->ref_];
if (node_map.count(node->ref())) {
auto matches = node_map[node->ref()];
for (auto match : matches) {
make_input(match);
}
Expand Down Expand Up @@ -708,13 +718,17 @@ void PatternGrouper::CreateGroup(const Expr& expr) {
return;
} else if (kv.second != body) {
// if the node isn't the output of the group
auto node = matcher_->expr_graph_.node_map_.at(kv.first);
auto node = matcher_->expr_to_node(kv.first);
for (auto* output : node->outputs_) {
// and the node is used by nodes outside of the group
if (memo.count(output->ref_) == 0 &&
!matcher_->expr_graph_.node_map_.at(expr)->Dominates(output)) {
// Exit because nodes in this pattern's body are used outside the pattern
// fusing it would be invalid
if (memo.count(output->ref()) == 0) {
// TODO(mbs): This condition used to also include the following test, which since
// the dominators relation is used back-to-front was always vacuously true. So the
// code is just rejecting the match if a strictly internal node happened to connect
// to an outside node.
ICHECK(!matcher_->expr_to_node(expr)->Dominates(output));
// Exit because nodes in this pattern's body are used outside the pattern, fusing it
// would be invalid
return;
}
}
Expand Down
19 changes: 16 additions & 3 deletions src/relay/ir/dataflow_matcher_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
#include <tvm/relay/dataflow_matcher.h>
#include <tvm/relay/dataflow_pattern.h>
#include <tvm/relay/dataflow_pattern_functor.h>
#include <tvm/relay/expr_functor.h>

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
Expand All @@ -39,10 +41,20 @@ namespace relay {

class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
public:
explicit DFPatternMatcher(const Expr& root_expr) : expr_graph_(CreateIndexedGraph(root_expr)) {}
explicit DFPatternMatcher(const IndexedGraph<Expr>* expr_graph) : expr_graph_(expr_graph) {}
bool Match(const DFPattern& pattern, const Expr& expr);
Map<DFPattern, Array<Expr>> GetMemo() { return Map<DFPattern, Array<Expr>>(memo_); }
const IndexedGraph<Expr> expr_graph_;

const IndexedGraph<Expr>::Node* expr_to_node(const Expr& expr) const {
return expr_graph_->item_to_node(expr);
}
const IndexedGraph<Expr>::Node* index_to_node(size_t index) const {
return expr_graph_->index_to_node(index);
}
size_t size() const { return expr_graph_->size(); }
const std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual>& memo() const {
return memo_;
}

protected:
bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
Expand All @@ -67,6 +79,7 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
bool MatchesPath(const DominatorPatternNode* op, const Expr& expr);
bool DominatesParent(const DominatorPatternNode* op, const Expr& expr);

const IndexedGraph<Expr>* expr_graph_;
std::unordered_map<DFPattern, Array<Expr>, ObjectPtrHash, ObjectPtrEqual> memo_;
std::vector<DFPattern> matched_nodes_;
bool memoize_ = true;
Expand Down Expand Up @@ -131,7 +144,7 @@ class PatternGrouper {
std::unordered_map<int, Group> groups_;
std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> gid_assignments_;
DFPatternMatcher* matcher_ = nullptr;
IndexedGraph<DFPattern> pattern_graph_;
std::unique_ptr<IndexedGraph<DFPattern>> pattern_graph_;
int gid_ = 0;
int graph_number_ = 0;
};
Expand Down
Loading