Skip to content

Commit

Permalink
Fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Nov 12, 2018
1 parent aeef954 commit 2accc32
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/relay/pass/expr_subst.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class ExprSubstituter : public ExprMutator {
};

Expr ExprSubst(const Expr& expr, tvm::Map<Expr, Expr> subst_map) {
return ExprSubstituter(std::move(subst_map)).Mutate(expr);
return ExprSubstituter(std::move(subst_map)).Mutate(expr);
}

} // namespace relay
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/expr_subst.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace tvm {
namespace relay {

Expr ExprSubst(const Expr& expr, tvm::Map<Expr, Expr> subst_map);
Expr ExprSubst(const Expr& expr, tvm::Map<Expr, Expr> subst_map);

} // namespace relay
} // namespace tvm
Expand Down
47 changes: 24 additions & 23 deletions src/relay/pass/fold_conv2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
* This pass replaces convolutions that share the same input node and the same arguments (except
* that the number of output channels can be different) with a single convolution. The weight of
* the new 2d convolution is the concatenation of the original weights.
*
* This prevents launching multiple kernels in networks with multiple convolution branches, such
* as Inception block.
*/

#include <tvm/relay/pass.h>
Expand All @@ -22,8 +25,8 @@ namespace relay {

class SiblingConv2DFinder : public ExprVisitor {
public:
std::unordered_map<Expr, std::vector<const CallNode*>, NodeHash, NodeEqual>
Find(const Expr& expr) {
std::unordered_map<Expr, std::vector<const CallNode*>, NodeHash, NodeEqual> Find(
const Expr& expr) {
this->VisitExpr(expr);
return std::move(children_map_);
}
Expand All @@ -49,23 +52,23 @@ std::tuple<Expr, IndexExpr> TransformWeight(std::vector<const CallNode*> convolu
CHECK(channels);
num_filters += *channels;
}
return std::tuple<Expr, IndexExpr>{ MakeConcatenate(TupleNode::make(weights), 0),
MakeConstScalar(Int(32), num_filters) };
return std::tuple<Expr, IndexExpr>{MakeConcatenate(TupleNode::make(weights), 0),
MakeConstScalar(Int(32), num_filters)};
}

// Two 2d convolutions can be combined if they have the same attributes or only have
// different output channels.
bool IsCompatibleConv2D(const Conv2DAttrs& a, const Conv2DAttrs& b) {
AttrsEqual eq;
return eq(a.strides, b.strides) &&
eq(a.padding, b.padding) &&
eq(a.dilation, b.dilation) &&
eq(a.groups, b.groups) &&
eq(a.kernel_size, b.kernel_size) &&
eq(a.data_layout, b.data_layout) &&
eq(a.weight_layout, b.weight_layout) &&
eq(a.out_dtype, b.out_dtype) &&
eq(a.out_layout, b.out_layout);
AttrsEqual eq;
return eq(a.strides, b.strides) &&
eq(a.padding, b.padding) &&
eq(a.dilation, b.dilation) &&
eq(a.groups, b.groups) &&
eq(a.kernel_size, b.kernel_size) &&
eq(a.data_layout, b.data_layout) &&
eq(a.weight_layout, b.weight_layout) &&
eq(a.out_dtype, b.out_dtype) &&
eq(a.out_layout, b.out_layout);
}

Expr MakeFoldedConv2D(const Expr& data, const std::vector<const CallNode*>& convolutions) {
Expand Down Expand Up @@ -101,8 +104,7 @@ Expr FoldConv2D(const Expr& expr) {
Expr data = pair.first;
std::vector<const CallNode*> children = pair.second;

if (children.size() < 2)
continue;
if (children.size() < 2) continue;

std::vector<size_t> group_ids(children.size());
std::vector<std::vector<const CallNode*>> groups;
Expand All @@ -112,13 +114,12 @@ Expr FoldConv2D(const Expr& expr) {
auto args = n->attrs.as<Conv2DAttrs>();

// assign a group id or create a new group for each conv2d
auto it =
std::find_if(groups.begin(), groups.end(),
[&](std::vector<const CallNode*> group) {
const CallNode* group_root = *(group.begin());
auto group_args = group_root->attrs.as<Conv2DAttrs>();
return IsCompatibleConv2D(*args, *group_args);
});
auto it = std::find_if(groups.begin(), groups.end(),
[&](const std::vector<const CallNode*>& group) {
const CallNode* group_root = *(group.begin());
auto group_args = group_root->attrs.as<Conv2DAttrs>();
return IsCompatibleConv2D(*args, *group_args);
});

if (it != groups.end()) {
auto group_id = std::distance(groups.begin(), it);
Expand Down

0 comments on commit 2accc32

Please sign in to comment.