Skip to content

Commit

Permalink
[RELAY][PASS] CombineParallelConv2D (apache#2089)
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 authored and tqchen committed Nov 22, 2018
1 parent ef02bec commit de004c5
Show file tree
Hide file tree
Showing 7 changed files with 558 additions and 0 deletions.
5 changes: 5 additions & 0 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# List of optimization pass and level when switch on
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"CombineParallelConv2D": 1,
"OpFusion": 1,
"FoldConstant": 2,
"FoldScaleAxis": 3,
Expand Down Expand Up @@ -144,6 +145,10 @@ def optimize(func, params=None):
func = ir_pass.infer_type(func)
func = ir_pass.simplify_inference(func)

if cfg.pass_enabled("CombineParallelConv2D"):
func = ir_pass.infer_type(func)
func = ir_pass.combine_parallel_conv2d(func)

if cfg.pass_enabled("FoldScaleAxis"):
func = ir_pass.infer_type(func)
func = ir_pass.backward_fold_scale_axis(func)
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,19 @@ def fuse_ops(expr, opt_level=1):
Transformed expression, containing fused result.
"""
return _ir_pass.FuseOps(expr, opt_level)


def combine_parallel_conv2d(expr):
"""Fold multiple conv2d into one.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
transformed_expr : tvm.relay.Expr
Transformed expression
"""
return _ir_pass.CombineParallelConv2D(expr)
328 changes: 328 additions & 0 deletions src/relay/pass/combine_parallel_conv2d.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,328 @@
/*!
* Copyright (c) 2018 by Contributors
*
* \file combine_parallel_conv2d.cc
* \brief Combine parallel 2d convolutions into a single convolution.
*
* This pass replaces convolutions that share the same input node and the same
* arguments (except that the number of output channels can be different) with a
* single convolution. The weight of the new 2d convolution is the concatenation
* of the original weights. Elemwise and broadcast ops following conv2d are also
* combined if possible.
*
* This prevents launching multiple kernels in networks with multiple
* convolution branches, such as Inception block.
*/

#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op_attr_types.h>
#include <unordered_map>
#include <unordered_set>
#include "./expr_subst.h"
#include "./pattern_util.h"


namespace tvm {
namespace relay {

using Branch = std::vector<const CallNode*>;
using Group = std::vector<Branch>;

/*
Find parallel branches starting with conv2d as shown below and then group branches by kernel
shape and attributes of conv2d. Conv2d can be followed by zero or more elemwise or broadcast ops.
Intermediate nodes have exactly one successor. It is possible that branches meet at a point,
which should be handled in ParallelConv2DCombiner.
data
/ \
conv2d conv2d
| |
op op
| |
*/
class BranchGroupFinder : private ExprVisitor {
public:
std::vector<Group> Find(const Expr& expr) {
this->VisitExpr(expr);

std::vector<Group> groups;
for (const auto& root : conv_roots_) {
const auto& convs = children_map_.at(root);
for (const CallNode* conv : convs) {
auto&& branch = CreateBranch(conv);
// add the branch to a group, or create a new group
auto it = std::find_if(groups.begin(), groups.end(), [&](const Group& group) {
CHECK(!group.empty() && !group[0].empty());
return IsCompatibleConv2D(conv, group[0][0]);
});
if (it != groups.end()) {
it->push_back(branch);
} else {
groups.emplace_back();
// each group has at least one branch
groups.back().push_back(branch);
}
}
}
return groups;
}

private:
std::unordered_set<Expr, NodeHash, NodeEqual> conv_roots_;
std::unordered_map<Expr, std::vector<const CallNode*>, NodeHash, NodeEqual> children_map_;

// Two 2d convolutions can be combined if they have the same attributes or
// only have different output channels.
bool IsCompatibleConv2D(const CallNode* a, const CallNode* b) {
AttrsEqual eq;
static const Layout kOIHW("OIHW");
const auto* attrs_a = a->attrs.as<Conv2DAttrs>();
const auto* attrs_b = b->attrs.as<Conv2DAttrs>();
CHECK(attrs_a);
CHECK(attrs_b);
const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>();
const auto shape_a = ConvertLayout(tweight_a->shape, attrs_a->weight_layout, kOIHW);
const auto shape_b = ConvertLayout(tweight_b->shape, attrs_b->weight_layout, kOIHW);

return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) &&
eq(attrs_a->dilation, attrs_b->dilation) && eq(attrs_a->groups, attrs_b->groups) &&
eq(attrs_a->data_layout, attrs_b->data_layout) &&
eq(attrs_a->weight_layout, attrs_b->weight_layout) &&
eq(attrs_a->out_dtype, attrs_b->out_dtype) &&
eq(attrs_a->out_layout, attrs_b->out_layout) && eq(shape_a[2], shape_b[2]) &&
eq(shape_a[3], shape_b[3]);
}

// Create a branch starting from conv2d.
Branch CreateBranch(const CallNode* conv) {
static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
// each branch has at least one element, the first element is always conv2d
Branch branch{conv};
auto it = children_map_.find(GetRef<Expr>(branch.back()));
while (it != children_map_.end() && it->second.size() == 1) {
const CallNode* call = it->second[0];
auto pattern = fpattern[Downcast<Op>(call->op)];
if (pattern <= kBroadcast) {
branch.push_back(it->second[0]);
it = children_map_.find(GetRef<Expr>(branch.back()));
} else {
break;
}
}
return branch;
}

void VisitExpr_(const CallNode* n) final {
static const Op& conv2d = Op::Get("nn.conv2d");
ExprVisitor::VisitExpr_(n);
if (n->op.same_as(conv2d) && n->attrs.as<Conv2DAttrs>()->groups == 1) {
conv_roots_.insert(n->args[0]);
children_map_[n->args[0]].push_back(n);
} else {
for (size_t i = 0; i < n->args.size(); i++) {
children_map_[n->args[i]].push_back(n);
}
}
}
};

class ParallelConv2DCombiner {
public:
Expr Combine(const Expr& expr) {
auto groups = BranchGroupFinder().Find(expr);
for (const Group& group : groups) {
if (group.size() < 2) continue;
CombineBranches(group);
}
return ExprSubst(expr, std::move(subst_map_));
}

private:
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> subst_map_;

std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
int64_t num_filters = 0; // number of filters of the transformed weight
Array<Expr> weights;
for (const auto& branch : branches) {
auto conv2d = branch[0];
weights.push_back(conv2d->args[1]);
auto channels = GetConv2DSuperChannelsDim(conv2d);
num_filters += channels;
}
auto index = branches[0][0]->attrs.as<Conv2DAttrs>()->weight_layout.find('O');
CHECK_NE(index, std::string::npos);
return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index),
MakeConstScalar(Int(32), num_filters));
}

Call MakeCombinedConv2D(const Group& branches) {
static const Op& conv2d = Op::Get("nn.conv2d");
Expr data = branches[0][0]->args[0];
Expr new_weight;
IndexExpr new_channels;
std::tie(new_weight, new_channels) = TransformWeight(branches);

const CallNode* group_root = branches[0][0];
const auto* attrs = group_root->attrs.as<Conv2DAttrs>();
CHECK(attrs);
const auto new_attrs = make_node<Conv2DAttrs>();
new_attrs->strides = attrs->strides;
new_attrs->padding = attrs->padding;
new_attrs->dilation = attrs->dilation;
new_attrs->groups = attrs->groups;
new_attrs->kernel_size = attrs->kernel_size;
new_attrs->data_layout = attrs->data_layout;
new_attrs->weight_layout = attrs->weight_layout;
new_attrs->out_layout = attrs->out_layout;
new_attrs->out_dtype = attrs->out_dtype;
new_attrs->channels = new_channels;

return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {});
}

bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index, size_t channel_pos) {
AttrsEqual eq;
auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>();
auto toutput_a = a->type_as<TensorTypeNode>();
auto toutput_b = b->type_as<TensorTypeNode>();

if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size())
return false;

// Position of the 'C' dimension in the argument
size_t arg_channel_pos = channel_pos - toutput_a->shape.size() + ta->shape.size();

// Channel super-dimension shoule be present and not broadcasted
if ((arg_channel_pos > channel_pos) || // size_t overflow
!eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos]) ||
!eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos]))
return false;

for (size_t i = 0; i < ta->shape.size(); i++) {
if (i == arg_channel_pos) continue;
if (!eq(ta->shape[i], tb->shape[i]))
return false;
}
return true;
}

// Check if ops in depth-th level can be combined
bool CheckLevel(const Group& branches, size_t depth, size_t channel_pos, size_t parent_index) {
const CallNode* call = branches[0][depth];
AttrsEqual attrs_equal;
// check if all branches in current depth can be combined
for (auto it = branches.begin() + 1; it != branches.end(); it++) {
const Branch& branch = *it;
if (!branch[depth]->op.same_as(call->op) ||
!attrs_equal(branch[depth]->attrs, call->attrs) ||
branch[depth]->args.size() != call->args.size()) {
return false;
}

if (branch[depth]->args[parent_index].get() != branch[depth - 1])
return false;

// Check args
for (size_t i = 0; i < call->args.size(); i++) {
if (i == parent_index) continue;

if (!IsArgCompatible(call, branch[depth], i, channel_pos) ||
!attrs_equal(call->attrs, branch[depth]->attrs)) {
return false;
}
}
}
return true;
}

// Combine args and make the combined CallNode
Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t channel_pos,
size_t parent_index) {
Array<Expr> new_args;
const CallNode* call = branches[0][depth];
size_t ndim = call->type_as<TensorTypeNode>()->shape.size();

for (size_t i = 0; i < call->args.size(); i++) {
if (i == parent_index) {
new_args.push_back(data);
continue;
}
size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size();
size_t arg_channel_pos = channel_pos - ndim + arg_ndim;
Array<Expr> tuple;
for (const auto& branch : branches) {
tuple.push_back(branch[depth]->args[i]);
}
auto concat = MakeConcatenate(TupleNode::make(tuple), arg_channel_pos);
new_args.push_back(std::move(concat));
}
return CallNode::make(call->op, new_args, call->attrs, {});
}

// Replace output of each branch with slices of the combined output
void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
size_t channel_pos) {
int64_t index = 0;
for (const auto& branch : branches) {
const CallNode* conv2d = branch[0];
int64_t channels = GetConv2DSuperChannelsDim(conv2d);
Array<Integer> begin;
Array<Integer> end;
for (size_t i = 0; i < channel_pos; i++) {
begin.push_back(0);
end.push_back(NullValue<Integer>());
}
begin.push_back(index);
index += channels;
end.push_back(index);
auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array<Integer>{});
subst_map_[GetRef<Expr>(branch[depth])] = slice;
}
}

// Combine branches in a group. Conv2d in different branches in the same group are safe to
// combine. Subsequent ops may or may not be combined. We start from conv2d and try to
// combine ops from all branches in the same depth.
void CombineBranches(const Group& branches) {
Call combined = MakeCombinedConv2D(branches);
auto conv_param = combined->attrs.as<Conv2DAttrs>();
const std::string& layout =
conv_param->out_layout == "" ? conv_param->data_layout : conv_param->out_layout;
size_t channel_pos = layout.find('C');
CHECK_NE(channel_pos, std::string::npos);
auto it = std::min_element(branches.begin(), branches.end(),
[](const Branch& branch_a,
const Branch& branch_b) {
return branch_a.size() < branch_b.size();
});
size_t depth = it->size();
size_t i;
// starting from 1 to skip the conv2d
for (i = 1; i < depth; i++) {
size_t parent_index;
for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) {
if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break;
}
CHECK_NE(parent_index, branches[0][i]->args.size());
if (!CheckLevel(branches, i, channel_pos, parent_index)) break;
combined = MakeCombinedCall(combined, branches, i, channel_pos, parent_index);
}
UpdateGroupOutput(combined, branches, i - 1, channel_pos);
}
};

Expr CombineParallelConv2D(const Expr& expr) { return ParallelConv2DCombiner().Combine(expr); }

TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = CombineParallelConv2D(args[0]);
});

} // namespace relay
} // namespace tvm
Loading

0 comments on commit de004c5

Please sign in to comment.