diff --git a/src/relay/pass/expr_subst.cc b/src/relay/pass/expr_subst.cc index 3e342dee50618..bac66bc0acf1c 100644 --- a/src/relay/pass/expr_subst.cc +++ b/src/relay/pass/expr_subst.cc @@ -27,7 +27,7 @@ class ExprSubstituter : public ExprMutator { }; Expr ExprSubst(const Expr& expr, tvm::Map subst_map) { - return ExprSubstituter(std::move(subst_map)).Mutate(expr); + return ExprSubstituter(std::move(subst_map)).Mutate(expr); } } // namespace relay diff --git a/src/relay/pass/expr_subst.h b/src/relay/pass/expr_subst.h index 7656baba1fa62..02f4179dae66e 100644 --- a/src/relay/pass/expr_subst.h +++ b/src/relay/pass/expr_subst.h @@ -10,7 +10,7 @@ namespace tvm { namespace relay { - Expr ExprSubst(const Expr& expr, tvm::Map subst_map); +Expr ExprSubst(const Expr& expr, tvm::Map subst_map); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fold_conv2d.cc b/src/relay/pass/fold_conv2d.cc index e4ee45d3dbe05..e49ba6706f01e 100644 --- a/src/relay/pass/fold_conv2d.cc +++ b/src/relay/pass/fold_conv2d.cc @@ -8,6 +8,9 @@ * This pass replaces convolutions that share the same input node and the same arguments (except * that the number of output channels can be different) with a single convolution. The weight of * the new 2d convolution is the concatenation of the original weights. + * + * This prevents launching multiple kernels in networks with multiple convolution branches, such + * as Inception block. */ #include @@ -22,8 +25,8 @@ namespace relay { class SiblingConv2DFinder : public ExprVisitor { public: - std::unordered_map, NodeHash, NodeEqual> - Find(const Expr& expr) { + std::unordered_map, NodeHash, NodeEqual> Find( + const Expr& expr) { this->VisitExpr(expr); return std::move(children_map_); } @@ -49,23 +52,23 @@ std::tuple TransformWeight(std::vector convolu CHECK(channels); num_filters += *channels; } - return std::tuple{ MakeConcatenate(TupleNode::make(weights), 0), - MakeConstScalar(Int(32), num_filters) }; + return std::tuple{MakeConcatenate(TupleNode::make(weights), 0), + MakeConstScalar(Int(32), num_filters)}; } // Two 2d convolutions can be combined if they have the same attributes or only have // different output channels. bool IsCompatibleConv2D(const Conv2DAttrs& a, const Conv2DAttrs& b) { - AttrsEqual eq; - return eq(a.strides, b.strides) && - eq(a.padding, b.padding) && - eq(a.dilation, b.dilation) && - eq(a.groups, b.groups) && - eq(a.kernel_size, b.kernel_size) && - eq(a.data_layout, b.data_layout) && - eq(a.weight_layout, b.weight_layout) && - eq(a.out_dtype, b.out_dtype) && - eq(a.out_layout, b.out_layout); + AttrsEqual eq; + return eq(a.strides, b.strides) && + eq(a.padding, b.padding) && + eq(a.dilation, b.dilation) && + eq(a.groups, b.groups) && + eq(a.kernel_size, b.kernel_size) && + eq(a.data_layout, b.data_layout) && + eq(a.weight_layout, b.weight_layout) && + eq(a.out_dtype, b.out_dtype) && + eq(a.out_layout, b.out_layout); } Expr MakeFoldedConv2D(const Expr& data, const std::vector& convolutions) { @@ -101,8 +104,7 @@ Expr FoldConv2D(const Expr& expr) { Expr data = pair.first; std::vector children = pair.second; - if (children.size() < 2) - continue; + if (children.size() < 2) continue; std::vector group_ids(children.size()); std::vector> groups; @@ -112,13 +114,12 @@ Expr FoldConv2D(const Expr& expr) { auto args = n->attrs.as(); // assign a group id or create a new group for each conv2d - auto it = - std::find_if(groups.begin(), groups.end(), - [&](std::vector group) { - const CallNode* group_root = *(group.begin()); - auto group_args = group_root->attrs.as(); - return IsCompatibleConv2D(*args, *group_args); - }); + auto it = std::find_if(groups.begin(), groups.end(), + [&](const std::vector& group) { + const CallNode* group_root = *(group.begin()); + auto group_args = group_root->attrs.as(); + return IsCompatibleConv2D(*args, *group_args); + }); if (it != groups.end()) { auto group_id = std::distance(groups.begin(), it);