Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY][BYOC] Add support for composite functions in BYOC #5261

Merged
merged 3 commits into from
Apr 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,12 @@ def MergeComposite(pattern_table):
Parameters
----------
pattern_table : list(tuple)
A list of (pattern_name, pattern) tuples.
A list of (pattern_name, pattern, check) tuples.
The order of the patterns in the list will determine the order
of priority in which they are matched.
'check' is a function to check whether an extracted pattern matches.
It can be implemented by pattern writer but if not specified it will
always return True.

Returns
-------
Expand All @@ -390,11 +393,19 @@ def MergeComposite(pattern_table):
"""
pattern_names = []
patterns = []
for pattern_name, pattern in pattern_table:
checks = []
for tup in pattern_table:
if len(tup) == 2:
pattern_name, pattern = tup
check = lambda extract: True
elif len(tup) == 3:
pattern_name, pattern, check = tup

pattern_names.append(pattern_name)
patterns.append(pattern)
checks.append(check)

return _ffi_api.MergeComposite(pattern_names, patterns)
return _ffi_api.MergeComposite(pattern_names, patterns, *checks)


def MergeCompilerRegions():
Expand Down
40 changes: 31 additions & 9 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,24 @@ class AnnotateTargetWrapper : public ExprMutator {
if (expr->IsInstance<CallNode>()) {
Call call = Downcast<Call>(expr);
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
Op op = Downcast<Op>(call->op);
CHECK(op.defined());
if (fannotate.count(op)) {
return fannotate[op](call->attrs, call->args);
if (call->op->IsInstance<OpNode>()) {
Op op = Downcast<Op>(call->op);
CHECK(op.defined());
if (fannotate.count(op)) {
return fannotate[op](call->attrs, call->args);
}
} else if (call->op->IsInstance<FunctionNode>()) {
// handle composite functions
Function func = Downcast<Function>(call->op);
CHECK(func.defined());
auto comp_name = func->GetAttr<tir::StringImm>(attr::kComposite);
if (comp_name.defined()) {
size_t i = comp_name->value.find('.');
if (i != std::string::npos) {
std::string target = comp_name->value.substr(0, i);
if (target == target_) return true;
}
}
}
}
if (expr->IsInstance<TupleGetItemNode>()) {
Expand All @@ -77,7 +91,6 @@ class AnnotateTargetWrapper : public ExprMutator {
}

Expr VisitExpr_(const CallNode* cn) {
// TODO(@zhiics, @comaniac) Handle composite functions.
auto new_e = ExprMutator::VisitExpr_(cn);

Call call = Downcast<Call>(new_e);
Expand Down Expand Up @@ -130,13 +143,22 @@ class AnnotateTargetWrapper : public ExprMutator {
}
}

Expr VisitExpr_(const FunctionNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
Expr VisitExpr_(const FunctionNode* fn) {
Function func;
Expr new_body;
// don't step into composite functions
if (fn->GetAttr<tir::StringImm>(attr::kComposite).defined()) {
func = GetRef<Function>(fn);
new_body = func->body;
} else {
auto new_e = ExprMutator::VisitExpr_(fn);
func = Downcast<Function>(new_e);
new_body = InsertEnd(func->body);
}

auto func = Downcast<Function>(new_e);
return Function(
func->params,
InsertEnd(func->body),
new_body,
func->ret_type,
func->type_params,
func->attrs);
Expand Down
90 changes: 44 additions & 46 deletions src/relay/transforms/merge_composite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,23 +25,24 @@
* Relay operators map to a single external operator.
*/

#include <tvm/te/operation.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <tvm/te/operation.h>

namespace tvm {
namespace relay {
namespace merge_composite {

class MergeCompositeWrapper : public ExprMutator {
public:
explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern)
: pattern_name_(pattern_name), pattern_(pattern) {}
explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern,
const PackedFunc& check)
: pattern_name_(pattern_name), pattern_(pattern), check_(check) {}

Expr ExtractPattern(const Var& pattern, const Expr& root,
Map<std::string, Array<Expr>>* var_map) {
Map<std::string, Array<Expr>>* var_map) {
if (var_map->find(pattern->name_hint()) == var_map->end()) {
// if we haven't encountered this var yet, make a new free var and associate
// it with the value at 'root'
Expand All @@ -62,27 +63,25 @@ class MergeCompositeWrapper : public ExprMutator {
}

Expr ExtractPattern(const Constant& pattern, const Expr& root,
Map<std::string, Array<Expr>>* var_map) {
Map<std::string, Array<Expr>>* var_map) {
return root;
}

Expr ExtractPattern(const TupleGetItem& pattern, const Expr& root,
Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) {
Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) {
if (!root->IsInstance<TupleGetItemNode>()) {
return Expr();
}
auto root_node = Downcast<TupleGetItem>(root);
if (pattern->index != root_node->index) {
return Expr();
}
if (pattern->tuple->IsInstance<CallNode>() &&
root_node->tuple->IsInstance<CallNode>()) {
if (pattern->tuple->IsInstance<CallNode>() && root_node->tuple->IsInstance<CallNode>()) {
Expr new_arg;
if (call_map->find(pattern->tuple) != call_map->end()) {
new_arg = (*call_map)[pattern->tuple];
} else {
new_arg = ExtractPattern(Downcast<Call>(pattern->tuple),
Downcast<Call>(root_node->tuple),
new_arg = ExtractPattern(Downcast<Call>(pattern->tuple), Downcast<Call>(root_node->tuple),
var_map, call_map);
call_map->Set(pattern->tuple, new_arg);
}
Expand All @@ -104,20 +103,18 @@ class MergeCompositeWrapper : public ExprMutator {
* and free variables. The free variables indicate where the pattern can 'attach' in your
* graph. This function takes the final call node of the pattern and the call node currently
* being traversed in the Relay graph. It traverses through the pattern in lockstep with call node
* from the graph (referred to as the 'root' node here) to check they're identical. If at any point
* they differ, an empty expression is returned to signify the extract failed. If a free var is
* reached in the pattern, the corresponding value in the root is associated with the name of the
* free var (via the var_map) so that when we construct the composite function, the inputs match
* up correctly with the rest of the graph. The return value of this function when successful is
* a new Relay expression ready to be wrapped into a composite function.
* from the graph (referred to as the 'root' node here) to check they're identical. If at any
* point they differ, an empty expression is returned to signify the extract failed. If a free var
* is reached in the pattern, the corresponding value in the root is associated with the name of
* the free var (via the var_map) so that when we construct the composite function, the inputs
* match up correctly with the rest of the graph. The return value of this function when
* successful is a new Relay expression ready to be wrapped into a composite function.
*/
Expr ExtractPattern(const Call& pattern, const Call& root,
Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) {
Expr ExtractPattern(const Call& pattern, const Call& root, Map<std::string, Array<Expr>>* var_map,
Map<Expr, Expr>* call_map) {
// check to make sure both calls are to operators (not functions)
if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>())
return Expr();
if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name)
return Expr();
if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>()) return Expr();
if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name) return Expr();

unsigned int i = 0;
Array<Expr> new_args;
Expand All @@ -133,27 +130,20 @@ class MergeCompositeWrapper : public ExprMutator {
return Expr();
}
// if it's a call node, recursively call this function
new_arg = ExtractPattern(Downcast<Call>(arg),
Downcast<Call>(root->args[i]),
var_map, call_map);
new_arg =
ExtractPattern(Downcast<Call>(arg), Downcast<Call>(root->args[i]), var_map, call_map);
call_map->Set(arg, new_arg);
}
} else if (arg->IsInstance<VarNode>()) {
// if there's a var in the pattern, it must be a free var
// so call the function to update the var_map
new_arg = ExtractPattern(Downcast<Var>(arg),
root->args[i],
var_map);
new_arg = ExtractPattern(Downcast<Var>(arg), root->args[i], var_map);
} else if (arg->IsInstance<ConstantNode>()) {
// if there's a constant, simply get the corresponding
// value of the constant from the root
new_arg = ExtractPattern(Downcast<Constant>(arg),
root->args[i],
var_map);
new_arg = ExtractPattern(Downcast<Constant>(arg), root->args[i], var_map);
} else if (arg->IsInstance<TupleGetItemNode>()) {
new_arg = ExtractPattern(Downcast<TupleGetItem>(arg),
root->args[i],
var_map, call_map);
new_arg = ExtractPattern(Downcast<TupleGetItem>(arg), root->args[i], var_map, call_map);
}
if (!new_arg.defined()) {
return Expr();
Expand All @@ -169,8 +159,7 @@ class MergeCompositeWrapper : public ExprMutator {
if (call->op->IsInstance<FunctionNode>()) {
Function func = Downcast<Function>(call->op);
CHECK(func.defined());
const auto name_node =
func->GetAttr<tir::StringImm>(attr::kComposite);
const auto name_node = func->GetAttr<tir::StringImm>(attr::kComposite);
// don't step into existing composite functions
if (name_node.defined() && name_node->value != "") {
tvm::Array<tvm::relay::Expr> new_args;
Expand All @@ -184,16 +173,15 @@ class MergeCompositeWrapper : public ExprMutator {

Expr expr = ExprMutator::VisitExpr_(cn);
call = Downcast<Call>(expr);
if (!call->op->IsInstance<OpNode>())
return std::move(call);
if (!call->op->IsInstance<OpNode>()) return std::move(call);

// only call patterns are supported
Call pattern = Downcast<Call>(pattern_);
CHECK(pattern.defined());
Map<std::string, Array<Expr>> args_map;
Map<Expr, Expr> call_map;
auto extract = ExtractPattern(pattern, call, &args_map, &call_map);
if (extract.defined()) {
if (extract.defined() && static_cast<bool>(check_(extract))) {
auto free_vars = FreeVars(extract);
// make the composite function
auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs());
Expand All @@ -215,17 +203,20 @@ class MergeCompositeWrapper : public ExprMutator {
std::string pattern_name_;
/*! \brief The pattern to match */
Expr pattern_;
/*! \brief The function to check whether an extract is supported */
PackedFunc check_;
};

Expr MergeComposite(const Expr& expr,
const Array<tir::StringImm>& pattern_names, const Array<Expr>& patterns) {
Expr MergeComposite(const Expr& expr, const Array<tir::StringImm>& pattern_names,
const Array<Expr>& patterns, const std::vector<PackedFunc>& checks) {
CHECK_EQ(pattern_names.size(), patterns.size());
Expr merged_expr = expr;
// merge the patterns one-by-one in order
for (size_t i = 0; i < patterns.size(); i++) {
std::string pattern_name = pattern_names[i]->value;
Expr pattern = patterns[i];
merged_expr = MergeCompositeWrapper(pattern_name, pattern).Mutate(merged_expr);
PackedFunc check = checks[i];
merged_expr = MergeCompositeWrapper(pattern_name, pattern, check).Mutate(merged_expr);
}
return merged_expr;
}
Expand All @@ -235,18 +226,25 @@ Expr MergeComposite(const Expr& expr,
namespace transform {

Pass MergeComposite(const tvm::Array<tir::StringImm>& pattern_names,
const tvm::Array<Expr>& patterns) {
const tvm::Array<Expr>& patterns, const std::vector<PackedFunc>& checks) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(
relay::merge_composite::MergeComposite(f, pattern_names, patterns));
relay::merge_composite::MergeComposite(f, pattern_names, patterns, checks));
};
auto func_pass = CreateFunctionPass(pass_func, 0, "MergeComposite", {});
return func_pass;
}

TVM_REGISTER_GLOBAL("relay._transform.MergeComposite")
.set_body_typed(MergeComposite);
TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) {
tvm::Array<tir::StringImm> pattern_names = args[0];
tvm::Array<Expr> patterns = args[1];
std::vector<PackedFunc> checks;
for (int i = 2; i < args.size(); i++) {
checks.push_back(args[i]);
}
*rv = MergeComposite(pattern_names, patterns, checks);
});

} // namespace transform

Expand Down
46 changes: 46 additions & 0 deletions tests/python/relay/test_annotate_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,53 @@ def after():
assert tvm.ir.structural_equal(expected, result)


def test_composite_function():
def before():
a = relay.var('a', shape=(10, 10))
b = relay.var('b', shape=(10, 10))

# add_relu function
in_1 = relay.var('in_1', shape=(10, 10))
in_2 = relay.var('in_2', shape=(10, 10))
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))

# merged function
r = relay.Call(add_relu, [a, b])
f = relay.Function([a, b], r)
mod = tvm.IRModule.from_expr(f)
return mod

def after():
a = relay.var('a', shape=(10, 10))
b = relay.var('b', shape=(10, 10))

# add_relu function
in_1 = relay.var('in_1', shape=(10, 10))
in_2 = relay.var('in_2', shape=(10, 10))
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))

# merged function
cb_1 = relay.annotation.compiler_begin(a, "test")
cb_2 = relay.annotation.compiler_begin(b, "test")
r = relay.Call(add_relu, [cb_1, cb_2])
ce_1 = relay.annotation.compiler_end(r, "test")
f = relay.Function([a, b], ce_1)
mod = tvm.IRModule.from_expr(f)
return mod

result = transform.AnnotateTarget("test")(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)


if __name__ == "__main__":
test_multiple_ends()
test_extern_dnnl()
test_extern_dnnl_mobilenet()
test_composite_function()
Loading