Skip to content

Commit

Permalink
[BYOC] Pattern Language MergeComposite (#5656)
Browse files Browse the repository at this point in the history
* Pattern Language MergeComposite

* fix DNNL pattern

* Use builtin binary operator syntax for demo

* Improve unit test
  • Loading branch information
comaniac authored May 26, 2020
1 parent 6100112 commit 81ad18e
Show file tree
Hide file tree
Showing 5 changed files with 276 additions and 328 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""The Relay Pattern Language and tooling."""
from tvm.relay import Expr
from tvm.relay.expr import RelayExpr as Expr
import tvm._ffi
from ...ir.base import Node
from ...ir import make_node
Expand Down
14 changes: 7 additions & 7 deletions python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
- The other way is to implement the function by themselves to
check the attributes of the op and decide if it should be offloaded to DNNL.
"""
from ... import expr as _expr
from ... import op as _op
from ...dataflow_pattern import wildcard, is_op
from .register import register_pattern_table


Expand Down Expand Up @@ -68,15 +68,15 @@ def _func_wrapper(attrs, args):


def make_pattern(with_bias=True):
data = _expr.var("data")
weight = _expr.var("weight")
bias = _expr.var("bias")
conv = _op.nn.conv2d(data, weight)
data = wildcard()
weight = wildcard()
bias = wildcard()
conv = is_op('nn.conv2d')(data, weight)
if with_bias:
conv_out = _op.add(conv, bias)
conv_out = is_op('add')(conv, bias)
else:
conv_out = conv
return _op.nn.relu(conv_out)
return is_op('nn.relu')(conv_out)


@register_pattern_table("dnnl")
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def MergeComposite(pattern_table):
Parameters
----------
pattern_table : list(tuple)
pattern_table : List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, Function]]
A list of (pattern_name, pattern, check) tuples.
The order of the patterns in the list will determine the order
of priority in which they are matched.
Expand Down
193 changes: 10 additions & 183 deletions src/relay/transforms/merge_composite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
*/

#include <tvm/relay/analysis.h>
#include <tvm/relay/dataflow_matcher.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
Expand All @@ -35,199 +36,25 @@ namespace tvm {
namespace relay {
namespace merge_composite {

class MergeCompositeWrapper : public ExprMutator {
public:
explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern,
const PackedFunc& check)
: pattern_name_(pattern_name), pattern_(pattern), check_(check) {}

Expr ExtractPattern(const Var& pattern, const Expr& root,
Map<std::string, Array<Expr>>* var_map) {
if (var_map->find(pattern->name_hint()) == var_map->end()) {
// if we haven't encountered this var yet, make a new free var and associate
// it with the value at 'root'
auto free_var = Var(pattern->name_hint(), root->checked_type());
free_var->checked_type_ = root->checked_type();
var_map->Set(pattern->name_hint(), Array<Expr>({free_var, root}));
return std::move(free_var);
} else {
// if we have encountered this var already, return the free var that was created
auto vars = (*var_map)[pattern->name_hint()];
auto free_var = vars[0];
auto graph_expr = vars[1];
// make sure to first check they both map to the same node in the graph
if (graph_expr != root) {
return Expr();
}
return (*var_map)[pattern->name_hint()][0];
}
}

Expr ExtractPattern(const Constant& pattern, const Expr& root,
Map<std::string, Array<Expr>>* var_map) {
return root;
}

Expr ExtractPattern(const TupleGetItem& pattern, const Expr& root,
Map<std::string, Array<Expr>>* var_map, Map<Expr, Expr>* call_map) {
if (!root->IsInstance<TupleGetItemNode>()) {
return Expr();
}
auto root_node = Downcast<TupleGetItem>(root);
if (pattern->index != root_node->index) {
return Expr();
}
if (pattern->tuple->IsInstance<CallNode>() && root_node->tuple->IsInstance<CallNode>()) {
Expr new_arg;
if (call_map->find(pattern->tuple) != call_map->end()) {
new_arg = (*call_map)[pattern->tuple];
} else {
new_arg = ExtractPattern(Downcast<Call>(pattern->tuple), Downcast<Call>(root_node->tuple),
var_map, call_map);
call_map->Set(pattern->tuple, new_arg);
}
return TupleGetItem(new_arg, root_node->index);
}
return Expr();
}

/*!
* \brief Try and extract a given pattern from a graph as a subgraph.
* \param pattern The pattern to extract.
* \param root The graph to extract from.
* \param var_map A map between free vars in the subgraph and nodes in the graph.
* \return The extracted subgraph.
*
* \note How does this work?
*
* A pattern consists of Relay expression containing only operator call nodes, constants
* and free variables. The free variables indicate where the pattern can 'attach' in your
* graph. This function takes the final call node of the pattern and the call node currently
* being traversed in the Relay graph. It traverses through the pattern in lockstep with call node
* from the graph (referred to as the 'root' node here) to check they're identical. If at any
* point they differ, an empty expression is returned to signify the extract failed. If a free var
* is reached in the pattern, the corresponding value in the root is associated with the name of
* the free var (via the var_map) so that when we construct the composite function, the inputs
* match up correctly with the rest of the graph. The return value of this function when
* successful is a new Relay expression ready to be wrapped into a composite function.
*/
Expr ExtractPattern(const Call& pattern, const Call& root, Map<std::string, Array<Expr>>* var_map,
Map<Expr, Expr>* call_map) {
// check to make sure both calls are to operators (not functions)
if (!root.defined()) return Expr();
if (!pattern->op->IsInstance<OpNode>() || !root->op->IsInstance<OpNode>()) return Expr();
if (pattern->op.as<OpNode>()->name != root->op.as<OpNode>()->name) return Expr();

unsigned int i = 0;
Array<Expr> new_args;
for (const auto& arg : pattern->args) {
Expr new_arg;
if (arg->IsInstance<CallNode>() && root->args[i]->IsInstance<CallNode>()) {
new_arg =
ExtractPattern(Downcast<Call>(arg), Downcast<Call>(root->args[i]), var_map, call_map);
// if we've already processed this call node, return the previous result
if (call_map->find(arg) != call_map->end() && new_arg.defined()) {
new_arg = (*call_map)[arg];
} else {
call_map->Set(arg, new_arg);
}
} else if (arg->IsInstance<VarNode>()) {
// if there's a var in the pattern, it must be a free var
// so call the function to update the var_map
new_arg = ExtractPattern(Downcast<Var>(arg), root->args[i], var_map);
} else if (arg->IsInstance<ConstantNode>()) {
// if there's a constant, simply get the corresponding
// value of the constant from the root
new_arg = ExtractPattern(Downcast<Constant>(arg), root->args[i], var_map);
} else if (arg->IsInstance<TupleGetItemNode>()) {
new_arg = ExtractPattern(Downcast<TupleGetItem>(arg), root->args[i], var_map, call_map);
}
if (!new_arg.defined()) {
return Expr();
}
new_args.push_back(new_arg);
i++;
}
Call new_call = Call(root->op, new_args, root->attrs);
new_call->checked_type_ = root->checked_type();
return std::move(new_call);
}

Expr VisitExpr_(const CallNode* cn) {
Call call = GetRef<Call>(cn);
if (call->op->IsInstance<FunctionNode>()) {
Function func = Downcast<Function>(call->op);
CHECK(func.defined());
auto name_node = func->GetAttr<String>(attr::kComposite);
// don't step into existing composite functions
if (name_node.defined() && name_node != "") {
tvm::Array<tvm::relay::Expr> new_args;
for (const auto& arg : call->args) {
auto new_e = this->Mutate(arg);
new_args.push_back(new_e);
}
Call new_call = Call(call->op, new_args, call->attrs);
new_call->checked_type_ = call->checked_type();
return std::move(new_call);
}
}

Expr expr = ExprMutator::VisitExpr_(cn);
call = Downcast<Call>(expr);
call->checked_type_ = cn->checked_type();
if (!call->op->IsInstance<OpNode>()) return std::move(call);

// only call patterns are supported
Call pattern = Downcast<Call>(pattern_);
CHECK(pattern.defined());
Map<std::string, Array<Expr>> args_map;
Map<Expr, Expr> call_map;
auto extract = ExtractPattern(pattern, call, &args_map, &call_map);
if (extract.defined() && static_cast<bool>(check_(extract))) {
auto free_vars = FreeVars(extract);
// make the composite function
auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs());
f = WithAttr(std::move(f), attr::kComposite, runtime::String(pattern_name_));
// find the expressions associated with the free vars using the args_map
// this tells us which expressions should be given as inputs to the composite function
Array<Expr> args;
for (const auto& free_var : free_vars) {
args.push_back(args_map[free_var->name_hint()][1]);
}
auto new_call = Call(f, args);
new_call->checked_type_ = call->checked_type();
return std::move(new_call);
}
return std::move(call);
}

private:
/*! \brief The name of the pattern to match */
std::string pattern_name_;
/*! \brief The pattern to match */
Expr pattern_;
/*! \brief The function to check whether an extract is supported */
PackedFunc check_;
};

Expr MergeComposite(const Expr& expr, const Array<runtime::String>& pattern_names,
const Array<Expr>& patterns, const std::vector<PackedFunc>& checks) {
Expr MergeComposite(const Function& func, const Array<runtime::String>& pattern_names,
const Array<DFPattern>& patterns, const std::vector<PackedFunc>& checks) {
CHECK_EQ(pattern_names.size(), patterns.size());
Expr merged_expr = expr;
Expr merged_expr = func->body;
// merge the patterns one-by-one in order
for (size_t i = 0; i < patterns.size(); i++) {
merged_expr =
MergeCompositeWrapper(pattern_names[i], patterns[i], checks[i]).Mutate(merged_expr);
Map<std::string, ObjectRef> attrs;
attrs.Set("Composite", pattern_names[i]);
merged_expr = PartitionPattern(patterns[i], merged_expr, attrs, checks[i]);
}
return merged_expr;
return Function(func->params, merged_expr, func->ret_type, func->type_params, func->attrs);
}

} // namespace merge_composite

namespace transform {

Pass MergeComposite(const tvm::Array<runtime::String>& pattern_names,
const tvm::Array<Expr>& patterns, const std::vector<PackedFunc>& checks) {
const tvm::Array<DFPattern>& patterns, const std::vector<PackedFunc>& checks) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(
Expand All @@ -239,7 +66,7 @@ Pass MergeComposite(const tvm::Array<runtime::String>& pattern_names,

TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) {
tvm::Array<runtime::String> pattern_names = args[0];
tvm::Array<Expr> patterns = args[1];
tvm::Array<DFPattern> patterns = args[1];
std::vector<PackedFunc> checks;
for (int i = 2; i < args.size(); i++) {
checks.push_back(args[i]);
Expand Down
Loading

0 comments on commit 81ad18e

Please sign in to comment.