Skip to content

Commit

Permalink
use watermark-based memoization resets to fix diamond matching
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrookhart committed Apr 6, 2020
1 parent ea83939 commit 3af18d4
Showing 1 changed file with 53 additions and 27 deletions.
80 changes: 53 additions & 27 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,36 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;

void ClearMap(size_t watermark);
std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
std::vector<DFPattern> matched_nodes_;
};

bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
memo_.clear();
return VisitDFPattern(pattern, expr);
}

void DFPatternMatcher::ClearMap(size_t watermark) {
for (size_t i = watermark; i < matched_nodes_.size(); ++i) {
memo_.erase(matched_nodes_[i]);
}
matched_nodes_.erase(matched_nodes_.begin() + watermark, matched_nodes_.end());
}
bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
// if (memo_.count(pattern)) {
// return expr.same_as(memo_[pattern]);
// } else {
auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
// if (out) {
// memo_[pattern] = expr;
// }
return out;
// }
if (memo_.count(pattern)) {
return expr.same_as(memo_[pattern]);
} else {
auto watermark = matched_nodes_.size();
auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
if (out) {
memo_[pattern] = expr;
matched_nodes_.push_back(pattern);
} else {
ClearMap(watermark);
}
return out;
}
}

bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
Expand Down Expand Up @@ -115,19 +127,7 @@ Array<DFPattern> reverse(const Array<DFPattern> args) {
}

bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
auto match_args = [this](const Array<DFPattern> pattern_args, const Array<Expr> expr_args) {
bool matches = true;
size_t i = 0;
if (pattern_args.size() == expr_args.size()) {
while (matches && i < pattern_args.size()) {
matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
++i;
}
} else {
matches = false;
}
return matches;
};
auto watermark = matched_nodes_.size();

auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
if (op) {
Expand All @@ -141,6 +141,24 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
if (const auto* call_node = expr.as<CallNode>()) {
auto matches_op = VisitDFPattern(op->op, call_node->op);
if (matches_op) {
auto watermark2 = matched_nodes_.size();
auto match_args = [this, &watermark2](const Array<DFPattern> pattern_args, const Array<Expr> expr_args) {
bool matches = true;
size_t i = 0;
if (pattern_args.size() == expr_args.size()) {
while (matches && i < pattern_args.size()) {
matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
++i;
}
} else {
matches = false;
}
if (!matches) {
ClearMap(watermark2);
}
return matches;
};

if (match_args(op->args, call_node->args)) {
return true;
}
Expand All @@ -156,12 +174,13 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
}
}
} else {
ClearMap(watermark);
if (const OpNode* op_node = get_op_node(op)) {
if (op_node->name == "divide") {
if (auto* arg_node = op->args[0].as<CallPatternNode>()) {
if (const OpNode* arg_op = get_op_node(arg_node)) {
if (arg_op->name == "multiply") {
auto associate_div_mul = [this, &op, &arg_node, &expr]() {
auto associate_div_mul = [this, &op, &arg_node, &expr, &watermark]() {
auto div1 = CallPatternNode::make(op->op, {arg_node->args[1], op->args[1]},
op->attrs, op->type_args);
auto mul1 = CallPatternNode::make(arg_node->op, {arg_node->args[0], div1},
Expand All @@ -170,7 +189,12 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
op->attrs, op->type_args);
auto mul2 = CallPatternNode::make(arg_node->op, {arg_node->args[1], div2},
arg_node->attrs, arg_node->type_args);
return VisitDFPattern(mul1, expr)|| VisitDFPattern(mul2, expr);
auto out = VisitDFPattern(mul1, expr);
if (!out) {
ClearMap(watermark);
out = VisitDFPattern(mul2, expr);
}
return out;
};

if (const OpNode* expr_op_node = call_node->op.as<OpNode>()) {
Expand Down Expand Up @@ -455,10 +479,12 @@ class PatternRewriter : protected MixedModeMutator {
Array<DFPatternCallback> callbacks_;
};

TVM_REGISTER_GLOBAL("relay.df_pattern.rewrite")
.set_body_typed([](Array<DFPatternCallback> callbacks, Expr expr) {
Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr) {
return PatternRewriter(callbacks).Rewrite(expr);
});
}

TVM_REGISTER_GLOBAL("relay.df_pattern.rewrite")
.set_body_typed(RewritePatterns);

} // namespace relay
} // namespace tvm

0 comments on commit 3af18d4

Please sign in to comment.