Skip to content

Commit

Permalink
Pattern Rewriter
Browse files Browse the repository at this point in the history
  • Loading branch information
mbrookhart committed Apr 6, 2020
1 parent a115887 commit 1eefb0d
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 12 deletions.
27 changes: 27 additions & 0 deletions include/tvm/relay/dataflow_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,33 @@ class DFPatternMutator : public DFPatternFunctor<DFPattern(const DFPattern&)> {
std::unordered_map<DFPattern, DFPattern, ObjectHash, ObjectEqual> memo_;
};

class DFPatternCallback;
/*!
* \brief Base type of all dataflow pattern callbacks.
* \sa DFPatternCallback
*/
class DFPatternCallbackNode : public Object {
public:
DFPattern pattern_;
PackedFunc function_;

void VisitAttrs(tvm::AttrVisitor* v) {}

TVM_DLL static DFPatternCallback make(DFPattern pattern, PackedFunc callback);

static constexpr const char* _type_key = "DFPatternCallbackNode";
TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object);
};

/*!
* \brief Managed reference to dataflow pattern callbacks.
* \sa DFPatternCallbackNode
*/
class DFPatternCallback : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode);
};

} // namespace relay
} // namespace tvm

Expand Down
15 changes: 12 additions & 3 deletions python/tvm/relay/df_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
"""The Relay Pattern Language and tooling."""
from ...ir.base import Node
from ...ir import make_node
from ...runtime import Object
from ... import _ffi as tvm_ffi
from ..op import get
from . import _ffi as ffi

def match(pattern, expr):
return ffi.match(pattern, expr)

def register_df_node(type_key=None):
"""Register a Relay node type.
Expand Down Expand Up @@ -214,6 +212,11 @@ def __init__(self, pattern, attrs):
self.__init_handle_by_constructor__(
ffi.AttrPattern, pattern, attrs)

class DFPatternCallback(Object):
def __init__(self, pattern, callback):
self.__init_handle_by_constructor__(
ffi.DFPatternCallback, pattern, callback)

def is_input(name="") -> DFPattern:
return VarPattern(name)

Expand All @@ -233,3 +236,9 @@ def has_attr(attr_name, attr_value, pattern=None):
if pattern is None:
pattern = wildcard()
return patter.has_attr(attr_name, attr_value)

def match(pattern, expr):
return ffi.match(pattern, expr)

def rewrite(callbacks, expr):
return ffi.rewrite(callbacks, expr)
71 changes: 62 additions & 9 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,21 @@
*/

#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/dataflow_matcher.h>
#include <tvm/relay/transform.h>

namespace tvm {
namespace relay {

// Pattern Matcher

class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Expr&)> {
public:
bool Match(const DFPattern& pattern, const Expr& expr);

protected:
bool VisitDFPattern(const DFPattern& pattern, const Expr& expr) override;
bool VisitDFPattern_(const AltPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const AttrPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const CallPatternNode* op, const Expr& expr) override;
Expand All @@ -42,15 +48,19 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const DFPattern&, const Ex
bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) override;

protected:
std::unordered_map<DFPattern, Expr, ObjectHash, ObjectEqual> memo_;
};

bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
memo_.clear();
return VisitDFPattern(pattern, expr);
}

bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
if (memo_.count(pattern)) {
return expr.same_as(memo_[pattern]);
} else {
auto out = VisitDFPattern(pattern, expr);
auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
if (out) {
memo_[pattern] = expr;
}
Expand All @@ -59,7 +69,7 @@ bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
}

bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
return Match(op->left, expr) || Match(op->right, expr);
return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
}
bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
bool matches = false;
Expand Down Expand Up @@ -99,10 +109,10 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex
bool matches = false;
if (const auto* call_node = expr.as<CallNode>()) {
if (op->args.size() == call_node->args.size()) {
matches = Match(op->op, call_node->op);
matches = VisitDFPattern(op->op, call_node->op);
size_t i = 0;
while (matches && i < op->args.size()) {
matches &= Match(op->args[i], call_node->args[i]);
matches &= VisitDFPattern(op->args[i], call_node->args[i]);
++i;
}
}
Expand All @@ -115,8 +125,8 @@ bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& ex
bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) {
bool matches = false;
if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
matches =
(op->index == tuple_get_item_node->index) && Match(op->tuple, tuple_get_item_node->tuple);
matches = (op->index == tuple_get_item_node->index) &&
VisitDFPattern(op->tuple, tuple_get_item_node->tuple);
}
return matches;
}
Expand All @@ -127,7 +137,7 @@ bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr& e
matches = true;
size_t i = 0;
while (matches && i < op->fields.size()) {
matches &= Match(op->fields[i], tuple_node->fields[i]);
matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
++i;
}
}
Expand All @@ -145,7 +155,7 @@ Expr InferType(const Expr& expr) {
}
bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) {
auto expr_type = InferType(expr).as<ExprNode>()->checked_type();
return (StructuralEqual()(op->type, expr_type)) && Match(op->pattern, expr);
return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr);
}
bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) {
bool matches = false;
Expand Down Expand Up @@ -286,5 +296,48 @@ TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern patter
return DFPatternMatcher().Match(DFPatternPrepare().Prepare(pattern), expr);
});

// Rewrite

DFPatternCallback DFPatternCallbackNode::make(DFPattern pattern, PackedFunc function) {
ObjectPtr<DFPatternCallbackNode> n = make_object<DFPatternCallbackNode>();
n->pattern_ = std::move(pattern);
n->function_ = std::move(function);
return DFPatternCallback(n);
}

TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode);

TVM_REGISTER_GLOBAL("relay.df_pattern.DFPatternCallback")
.set_body_typed(DFPatternCallbackNode::make);

class PatternRewriter : public ExprMutator {
public:
PatternRewriter(const Array<DFPatternCallback>& callbacks) : callbacks_(callbacks) {}
Expr Rewrite(const Expr& pre) {
return this->VisitExpr(pre);
}

protected:
Expr VisitExpr(const Expr& pre) override {
auto post = ExprMutator::VisitExpr(pre);
Expr out = post;
for (auto& callback : callbacks_) {
if (auto* callback_node = callback.as<DFPatternCallbackNode>()) {
if (matcher_.Match(callback_node->pattern_, out)) {
out = callback_node->function_(pre, out);
}
}
}
return out;
}
DFPatternMatcher matcher_;
Array<DFPatternCallback> callbacks_;
};

TVM_REGISTER_GLOBAL("relay.df_pattern.rewrite")
.set_body_typed([](Array<DFPatternCallback> callbacks, Expr expr) {
return PatternRewriter(callbacks).Rewrite(expr);
});

} // namespace relay
} // namespace tvm
12 changes: 12 additions & 0 deletions tests/python/relay/test_df_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,17 @@ def test_match_fake_diamond():
# Check
assert not diamond.match(out)

def test_match_rewrite():
x = relay.var('x')
y = relay.var('y')
add_pattern = is_op('add')(wildcard(), wildcard())
sub_pattern = is_op('subtract')(wildcard(), wildcard())
def add_to_sub(pre, post):
return post.args[0] - post.args[1]
out = rewrite([DFPatternCallback(add_pattern, add_to_sub)], x + y)
assert sub_pattern.match(out)


if __name__ == "__main__":
test_match_op()
test_no_match_op()
Expand All @@ -242,3 +253,4 @@ def test_match_fake_diamond():
test_match_diamond()
test_no_match_diamond()
test_match_fake_diamond()
test_match_rewrite()

0 comments on commit 1eefb0d

Please sign in to comment.