Skip to content

Commit

Permalink
add batchnorm tests, commutivity extensions, realize I need to break …
Browse files Browse the repository at this point in the history
…the diamond for more complicated graphs
  • Loading branch information
mbrookhart committed Apr 6, 2020
1 parent 1eefb0d commit ea83939
Show file tree
Hide file tree
Showing 4 changed files with 322 additions and 68 deletions.
14 changes: 14 additions & 0 deletions python/tvm/relay/df_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ def __call__(self, *args):
def __or__(self, other):
return AltPattern(self, other)

def __add__(self, other):
return is_op("add")(self, other)

def __sub__(self, other):
return is_op("subtract")(self, other)

def __mul__(self, other):
return is_op("multiply")(self, other)

def __truediv__(self, other):
return is_op("divide")(self, other)

def has_attr(self, attr_name, attr_value):
attrs = make_node("DictAttrs", **{attr_name:attr_value})
return AttrPattern(self, attrs)
Expand Down Expand Up @@ -241,4 +253,6 @@ def match(pattern, expr):
return ffi.match(pattern, expr)

def rewrite(callbacks, expr):
if isinstance(callbacks, DFPatternCallback):
callbacks = [callbacks]
return ffi.rewrite(callbacks, expr)
213 changes: 167 additions & 46 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ bool DFPatternMatcher::Match(const DFPattern& pattern, const Expr& expr) {
}

bool DFPatternMatcher::VisitDFPattern(const DFPattern& pattern, const Expr& expr) {
if (memo_.count(pattern)) {
return expr.same_as(memo_[pattern]);
} else {
auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
if (out) {
memo_[pattern] = expr;
}
return out;
}
// if (memo_.count(pattern)) {
// return expr.same_as(memo_[pattern]);
// } else {
auto out = DFPatternFunctor::VisitDFPattern(pattern, expr);
// if (out) {
// memo_[pattern] = expr;
// }
return out;
// }
}

bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& expr) {
Expand Down Expand Up @@ -105,19 +105,165 @@ bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, cons
}
return matches;
}

Array<DFPattern> reverse(const Array<DFPattern> args) {
Array<DFPattern> new_args;
for (auto it = args.rbegin(); it != args.rend(); ++it) {
new_args.push_back(*it);
}
return new_args;
}

bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& expr) {
bool matches = false;
if (const auto* call_node = expr.as<CallNode>()) {
if (op->args.size() == call_node->args.size()) {
matches = VisitDFPattern(op->op, call_node->op);
size_t i = 0;
while (matches && i < op->args.size()) {
matches &= VisitDFPattern(op->args[i], call_node->args[i]);
auto match_args = [this](const Array<DFPattern> pattern_args, const Array<Expr> expr_args) {
bool matches = true;
size_t i = 0;
if (pattern_args.size() == expr_args.size()) {
while (matches && i < pattern_args.size()) {
matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
++i;
}
} else {
matches = false;
}
return matches;
};

auto get_op_node = [](const CallPatternNode* op) -> const tvm::OpNode* {
if (op) {
if (auto* expr_pattern = op->op.as<ExprPatternNode>()) {
return expr_pattern->expr.as<OpNode>();
}
}
return nullptr;
};

if (const auto* call_node = expr.as<CallNode>()) {
auto matches_op = VisitDFPattern(op->op, call_node->op);
if (matches_op) {
if (match_args(op->args, call_node->args)) {
return true;
}
if (auto* op_node = get_op_node(op)) {
if ((op_node->name == "add")) {
if (match_args(reverse(op->args), call_node->args)) {
return true;
}
} else if ((op_node->name == "multiply")) {
if (match_args(reverse(op->args), call_node->args)) {
return true;
}
}
}
} else {
if (const OpNode* op_node = get_op_node(op)) {
if (op_node->name == "divide") {
if (auto* arg_node = op->args[0].as<CallPatternNode>()) {
if (const OpNode* arg_op = get_op_node(arg_node)) {
if (arg_op->name == "multiply") {
auto associate_div_mul = [this, &op, &arg_node, &expr]() {
auto div1 = CallPatternNode::make(op->op, {arg_node->args[1], op->args[1]},
op->attrs, op->type_args);
auto mul1 = CallPatternNode::make(arg_node->op, {arg_node->args[0], div1},
arg_node->attrs, arg_node->type_args);
auto div2 = CallPatternNode::make(op->op, {arg_node->args[0], op->args[1]},
op->attrs, op->type_args);
auto mul2 = CallPatternNode::make(arg_node->op, {arg_node->args[1], div2},
arg_node->attrs, arg_node->type_args);
return VisitDFPattern(mul1, expr)|| VisitDFPattern(mul2, expr);
};

if (const OpNode* expr_op_node = call_node->op.as<OpNode>()) {
if (expr_op_node->name == "multiply") {
if (auto* input_call_node = call_node->args[0].as<CallNode>()) {
if (const OpNode* input_op_node = input_call_node->op.as<OpNode>()) {
if (input_op_node->name == "divide") {
return associate_div_mul();
}
}
}
if (auto* input_call_node = call_node->args[1].as<CallNode>()) {
if (const OpNode* input_op_node = input_call_node->op.as<OpNode>()) {
if (input_op_node->name == "divide") {
return associate_div_mul();
}
}
}
}
}
}
}
}
} else if (op_node->name == "multiply") {
if (auto* arg_node = op->args[0].as<CallPatternNode>()) {
if (const OpNode* arg_op = get_op_node(arg_node)) {
if (arg_op->name == "divide") {
auto associate_mul_div = [this, &op, &arg_node, &expr]() {
auto mul1 = CallPatternNode::make(op->op, {arg_node->args[0], op->args[1]},
op->attrs, op->type_args);
auto div1 = CallPatternNode::make(arg_node->op, {mul1, arg_node->args[1]},
arg_node->attrs, arg_node->type_args);
return VisitDFPattern(div1, expr);
};

if (const OpNode* expr_op_node = call_node->op.as<OpNode>()) {
if (expr_op_node->name == "divide") {
if (auto* input_call_node = call_node->args[0].as<CallNode>()) {
if (const OpNode* input_op_node = input_call_node->op.as<OpNode>()) {
if (input_op_node->name == "multiply") {
return associate_mul_div();
}
}
}
if (auto* input_call_node = call_node->args[1].as<CallNode>()) {
if (const OpNode* input_op_node = input_call_node->op.as<OpNode>()) {
if (input_op_node->name == "multiply") {
return associate_mul_div();
}
}
}
}
}
}
}
}
if (auto* arg_node = op->args[1].as<CallPatternNode>()) {
if (const OpNode* arg_op = get_op_node(arg_node)) {
if (arg_op->name == "divide") {
auto associate_mul_div = [this, &op, &arg_node, &expr]() {
auto mul1 = CallPatternNode::make(op->op, {arg_node->args[0], op->args[0]},
op->attrs, op->type_args);
auto div1 = CallPatternNode::make(arg_node->op, {mul1, arg_node->args[1]},
arg_node->attrs, arg_node->type_args);
return VisitDFPattern(div1, expr);
};

if (const OpNode* expr_op_node = call_node->op.as<OpNode>()) {
if (expr_op_node->name == "divide") {
if (auto* input_call_node = call_node->args[0].as<CallNode>()) {
if (const OpNode* input_op_node = input_call_node->op.as<OpNode>()) {
if (input_op_node->name == "multiply") {
return associate_mul_div();
}
}
}
if (auto* input_call_node = call_node->args[1].as<CallNode>()) {
if (const OpNode* input_op_node = input_call_node->op.as<OpNode>()) {
if (input_op_node->name == "multiply") {
return associate_mul_div();
}
}
}
}
}
}
}
}
}
}
}
}
return matches;
return false;
}
bool DFPatternMatcher::VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) {
return op->expr == expr;
Expand Down Expand Up @@ -267,33 +413,8 @@ DFPattern DFPatternMutator::VisitDFPattern_(const WildcardPatternNode* op) {
return GetRef<DFPattern>(op);
}

// Prepare

class DFPatternPrepare : protected DFPatternMutator {
public:
DFPattern Prepare(const DFPattern& pattern) { return Mutate(pattern); }
DFPattern VisitDFPattern_(const CallPatternNode* op) {
auto post = DFPatternMutator::VisitDFPattern_(op);
auto* post_node = post.as<CallPatternNode>();
if (auto* expr_pattern = post_node->op.as<ExprPatternNode>()) {
if (auto* op_node = expr_pattern->expr.as<OpNode>()) {
if ((op_node->name == "add") || (op_node->name == "multiply")) {
tvm::Array<DFPattern> call_args;
for (auto it = post_node->args.rbegin(); it != post_node->args.rend(); ++it) {
call_args.push_back(*it);
}
return AltPatternNode::make(
post, CallPatternNode::make(post_node->op, call_args, post_node->attrs,
post_node->type_args));
}
}
}
return post;
}
};

TVM_REGISTER_GLOBAL("relay.df_pattern.match").set_body_typed([](DFPattern pattern, Expr expr) {
return DFPatternMatcher().Match(DFPatternPrepare().Prepare(pattern), expr);
return DFPatternMatcher().Match(pattern, expr);
});

// Rewrite
Expand All @@ -310,16 +431,16 @@ TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode);
TVM_REGISTER_GLOBAL("relay.df_pattern.DFPatternCallback")
.set_body_typed(DFPatternCallbackNode::make);

class PatternRewriter : public ExprMutator {
class PatternRewriter : protected MixedModeMutator {
public:
PatternRewriter(const Array<DFPatternCallback>& callbacks) : callbacks_(callbacks) {}
Expr Rewrite(const Expr& pre) {
return this->VisitExpr(pre);
}

protected:
Expr VisitExpr(const Expr& pre) override {
auto post = ExprMutator::VisitExpr(pre);
Expr DispatchVisitExpr(const Expr& pre) override {
auto post = MixedModeMutator::DispatchVisitExpr(pre);
Expr out = post;
for (auto& callback : callbacks_) {
if (auto* callback_node = callback.as<DFPatternCallbackNode>()) {
Expand Down
7 changes: 3 additions & 4 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ void MixedModeVisitor::VisitExpr_(const TupleGetItemNode* op) {}

void MixedModeMutator::VisitLeaf(const Expr& expr) {
if (!memo_.count(expr)) {
this->DispatchVisitExpr(expr);
Expr ret = this->DispatchVisitExpr(expr);
memo_[expr] = ret;
}
}

Expand All @@ -165,9 +166,7 @@ Expr MixedModeMutator::VisitExpr(const Expr& expr) {
return memo_[expr];
} else {
ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
Expr ret = this->DispatchVisitExpr(expr);
memo_[expr] = ret;
return ret;
return memo_[expr];
}
}

Expand Down
Loading

0 comments on commit ea83939

Please sign in to comment.