diff --git a/docs/api/python/relay/transform.rst b/docs/api/python/relay/transform.rst index 3c0a6dcf22f6..346152b9c769 100644 --- a/docs/api/python/relay/transform.rst +++ b/docs/api/python/relay/transform.rst @@ -46,6 +46,8 @@ tvm.relay.transform .. autofunction:: tvm.relay.transform.CombineParallelConv2D +.. autofunction:: tvm.relay.transform.CombineParallelDense + .. autofunction:: tvm.relay.transform.AlterOpLayout .. autofunction:: tvm.relay.transform.Legalize diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 4bd59302f0d8..14f25cf90726 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -482,6 +482,17 @@ TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr); */ TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3); +/*! + * \brief Combine parallel dense ops into a single batch_matmul if the + * number of branches of this dense operator is not less than + * `min_num_branch`. + * + * \param min_num_branches The minimun number of branches. + * + * \return The pass. + */ +TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3); + /*! * \brief Backward fold axis scaling into weights of conv/dense operators. * diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index ccdf00ed64e3..58bf17efd387 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -138,6 +138,7 @@ def build_config(opt_level=2, "CanonicalizeCast": 3, "EliminateCommonSubexpr": 3, "CombineParallelConv2D": 4, + "CombineParallelDense": 4 } fallback_device : int, str, or tvm.TVMContext, optional @@ -400,6 +401,35 @@ def CombineParallelConv2D(min_num_branches=3): return _transform.CombineParallelConv2D(min_num_branches) +def CombineParallelDense(min_num_branches=3): + """Combine multiple dense operators into one. For example: + + data + / \ + dense (2,2) dense (2,2) + | | + elemwise/bcast (2,2) elemwise/bcast (2,2) + + Would become: + + data + | + batch_matmul+elemwise/bcast (2,2,2) + + Parameters + ---------- + min_num_branches : int + The minimum number of required parallel branches for performing this + optimization. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass that combines parallel dense operators. + """ + return _transform.CombineParallelDense(min_num_branches) + + def AlterOpLayout(): """Alternate the layouts of operators or replace primitive operators with other expressions. diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index f757dad520ef..278ef43dd177 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -299,6 +299,7 @@ class RelayBuildModule : public runtime::ModuleNode { }); pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); pass_seqs.push_back(transform::CombineParallelConv2D(3)); + pass_seqs.push_back(transform::CombineParallelDense(3)); pass_seqs.push_back(transform::FoldConstant()); pass_seqs.push_back(transform::FoldScaleAxis()); pass_seqs.push_back(transform::CanonicalizeCast()); diff --git a/src/relay/pass/combine_parallel_conv2d.cc b/src/relay/pass/combine_parallel_conv2d.cc index d72705c8ce47..bc9685f815cb 100644 --- a/src/relay/pass/combine_parallel_conv2d.cc +++ b/src/relay/pass/combine_parallel_conv2d.cc @@ -18,7 +18,7 @@ */ /*! - * Copyright (c) 2018 by Contributors + * Copyright (c) 2019 by Contributors * * \file combine_parallel_conv2d.cc * \brief Combine parallel 2d convolutions into a single convolution. @@ -43,68 +43,25 @@ #include #include "./expr_subst.h" #include "./pattern_util.h" - +#include "./combine_parallel_op.h" namespace tvm { namespace relay { -using Branch = std::vector; -using Group = std::vector; - -/* - Find parallel branches starting with conv2d as shown below and then group branches by kernel - shape and attributes of conv2d. Conv2d can be followed by zero or more elemwise or broadcast ops. - Intermediate nodes have exactly one successor. It is possible that branches meet at a point, - which should be handled in ParallelConv2DCombiner. - - data - / \ - conv2d conv2d - | | - op op - | | -*/ -class BranchGroupFinder : private ExprVisitor { +class ParallelConv2DCombiner : public ParallelOpCombiner { public: - std::vector Find(const Expr& expr) { - static const Op& conv2d = Op::Get("nn.conv2d"); - - this->VisitExpr(expr); - - std::vector groups; - for (const auto& root : conv_roots_) { - const auto& children = children_map_.at(root); - size_t ngroups = groups.size(); - for (const CallNode* child : children) { - if (!child->op.same_as(conv2d)) continue; - - auto&& branch = CreateBranch(child); - // add the branch to a group, or create a new group - auto it = std::find_if(groups.begin() + ngroups, groups.end(), [&](const Group& group) { - CHECK(!group.empty() && !group[0].empty()); - return IsCompatibleConv2D(child, group[0][0]); - }); - if (it != groups.end()) { - it->push_back(branch); - } else { - groups.emplace_back(); - // each group has at least one branch - groups.back().push_back(branch); - } - } - } - return groups; + explicit ParallelConv2DCombiner(uint64_t min_num_branches) + : ParallelOpCombiner("nn.conv2d", min_num_branches) { } - private: - std::unordered_set conv_roots_; - std::unordered_map, NodeHash, NodeEqual> children_map_; + protected: + bool IsSupportedOp(const CallNode* n) { + return n->attrs.as()->groups == 1; + } - // Two 2d convolutions can be combined if they have the same attributes or - // only have different output channels. - bool IsCompatibleConv2D(const CallNode* a, const CallNode* b) { + bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { AttrsEqual eq; - static const Layout kOIHW("OIHW"); + const Layout kOIHW("OIHW"); const auto* attrs_a = a->attrs.as(); const auto* attrs_b = b->attrs.as(); CHECK(attrs_a); @@ -125,76 +82,8 @@ class BranchGroupFinder : private ExprVisitor { eq(shape_a[3], shape_b[3]); } - // Create a branch starting from conv2d. - Branch CreateBranch(const CallNode* conv) { - static auto fpattern = Op::GetAttr("TOpPattern"); - // each branch has at least one element, the first element is always conv2d - Branch branch{conv}; - auto it = children_map_.find(GetRef(branch.back())); - while (it != children_map_.end() && it->second.size() == 1) { - const CallNode* call = it->second[0]; - auto pattern = fpattern[Downcast(call->op)]; - if (pattern <= kBroadcast) { - branch.push_back(call); - it = children_map_.find(GetRef(branch.back())); - } else { - break; - } - } - return branch; - } - - void VisitExpr_(const CallNode* n) final { - static const Op& conv2d = Op::Get("nn.conv2d"); - ExprVisitor::VisitExpr_(n); - if (n->op.same_as(conv2d) && n->attrs.as()->groups == 1) { - conv_roots_.insert(n->args[0]); - children_map_[n->args[0]].push_back(n); - } else { - for (size_t i = 0; i < n->args.size(); i++) { - children_map_[n->args[i]].push_back(n); - } - } - } -}; - -class ParallelConv2DCombiner { - public: - explicit ParallelConv2DCombiner(uint64_t min_num_branches) : min_num_branches_(min_num_branches) { - } - - Expr Combine(const Expr& expr) { - auto groups = BranchGroupFinder().Find(expr); - for (const Group& group : groups) { - if (group.size() < min_num_branches_) { - continue; - } - CombineBranches(group); - } - return ExprSubst(expr, std::move(subst_map_)); - } - - private: - std::unordered_map subst_map_; - uint64_t min_num_branches_; - - std::tuple TransformWeight(const Group& branches) { - int64_t num_filters = 0; // number of filters of the transformed weight - Array weights; - for (const auto& branch : branches) { - auto conv2d = branch[0]; - weights.push_back(conv2d->args[1]); - auto channels = GetConv2DSuperChannelsDim(conv2d); - num_filters += channels; - } - auto index = branches[0][0]->attrs.as()->kernel_layout.find('O'); - CHECK_NE(index, std::string::npos); - return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index), - MakeConstScalar(Int(32), num_filters)); - } - - Call MakeCombinedConv2D(const Group& branches) { - static const Op& conv2d = Op::Get("nn.conv2d"); + Call MakeCombinedOp(const Group& branches) { + const Op& conv2d = Op::Get("nn.conv2d"); Expr data = branches[0][0]->args[0]; Expr new_weight; IndexExpr new_channels; @@ -215,10 +104,15 @@ class ParallelConv2DCombiner { new_attrs->out_dtype = attrs->out_dtype; new_attrs->channels = new_channels; + const std::string& layout = + new_attrs->out_layout == "" ? new_attrs->data_layout : new_attrs->out_layout; + channel_pos_ = layout.find('C'); + CHECK_NE(channel_pos_, std::string::npos); + return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {}); } - bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index, size_t channel_pos) { + bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { AttrsEqual eq; auto ta = a->args[index]->type_as(); auto tb = b->args[index]->type_as(); @@ -229,12 +123,12 @@ class ParallelConv2DCombiner { return false; // Position of the 'C' dimension in the argument - size_t arg_channel_pos = channel_pos - toutput_a->shape.size() + ta->shape.size(); + size_t arg_channel_pos = channel_pos_ - toutput_a->shape.size() + ta->shape.size(); // Channel super-dimension shoule be present and not broadcasted - if ((arg_channel_pos > channel_pos) || // size_t overflow - !eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos]) || - !eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos])) + if ((arg_channel_pos > channel_pos_) || // size_t overflow + !eq(ta->shape[arg_channel_pos], toutput_a->shape[channel_pos_]) || + !eq(tb->shape[arg_channel_pos], toutput_b->shape[channel_pos_])) return false; for (size_t i = 0; i < ta->shape.size(); i++) { @@ -245,38 +139,10 @@ class ParallelConv2DCombiner { return true; } - // Check if ops in depth-th level can be combined - bool CheckLevel(const Group& branches, size_t depth, size_t channel_pos, size_t parent_index) { - const CallNode* call = branches[0][depth]; - AttrsEqual attrs_equal; - // check if all branches in current depth can be combined - for (auto it = branches.begin() + 1; it != branches.end(); it++) { - const Branch& branch = *it; - if (!branch[depth]->op.same_as(call->op) || - !attrs_equal(branch[depth]->attrs, call->attrs) || - branch[depth]->args.size() != call->args.size()) { - return false; - } - - if (branch[depth]->args[parent_index].get() != branch[depth - 1]) - return false; - - // Check args - for (size_t i = 0; i < call->args.size(); i++) { - if (i == parent_index) continue; - - if (!IsArgCompatible(call, branch[depth], i, channel_pos) || - !attrs_equal(call->attrs, branch[depth]->attrs)) { - return false; - } - } - } - return true; - } - - // Combine args and make the combined CallNode - Call MakeCombinedCall(const Expr& data, const Group& branches, size_t depth, size_t channel_pos, - size_t parent_index) { + Call MakeCombinedCallFromFollowingOps(const Expr& data, + const Group& branches, + size_t depth, + size_t parent_index) { Array new_args; const CallNode* call = branches[0][depth]; size_t ndim = call->type_as()->shape.size(); @@ -286,28 +152,32 @@ class ParallelConv2DCombiner { new_args.push_back(data); continue; } + size_t arg_ndim = call->args[i]->type_as()->shape.size(); - size_t arg_channel_pos = channel_pos - ndim + arg_ndim; + size_t arg_channel_pos = channel_pos_ - ndim + arg_ndim; Array tuple; for (const auto& branch : branches) { tuple.push_back(branch[depth]->args[i]); } + auto concat = MakeConcatenate(TupleNode::make(tuple), arg_channel_pos); new_args.push_back(std::move(concat)); } + return CallNode::make(call->op, new_args, call->attrs, {}); } - // Replace output of each branch with slices of the combined output - void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, - size_t channel_pos) { + void UpdateGroupOutput(const Expr& data, + const Group& branches, + size_t depth, + ExprSubstMap* subst_map) { int64_t index = 0; for (const auto& branch : branches) { const CallNode* conv2d = branch[0]; int64_t channels = GetConv2DSuperChannelsDim(conv2d); Array begin; Array end; - for (size_t i = 0; i < channel_pos; i++) { + for (size_t i = 0; i < channel_pos_; i++) { begin.push_back(0); end.push_back(NullValue()); } @@ -315,38 +185,27 @@ class ParallelConv2DCombiner { index += channels; end.push_back(index); auto slice = MakeStridedSlice(data, std::move(begin), std::move(end), Array{}); - subst_map_[GetRef(branch[depth])] = slice; + subst_map->insert({GetRef(branch[depth]), slice}); } } - // Combine branches in a group. Conv2d in different branches in the same group are safe to - // combine. Subsequent ops may or may not be combined. We start from conv2d and try to - // combine ops from all branches in the same depth. - void CombineBranches(const Group& branches) { - Call combined = MakeCombinedConv2D(branches); - auto conv_param = combined->attrs.as(); - const std::string& layout = - conv_param->out_layout == "" ? conv_param->data_layout : conv_param->out_layout; - size_t channel_pos = layout.find('C'); - CHECK_NE(channel_pos, std::string::npos); - auto it = std::min_element(branches.begin(), branches.end(), - [](const Branch& branch_a, - const Branch& branch_b) { - return branch_a.size() < branch_b.size(); - }); - size_t depth = it->size(); - size_t i; - // starting from 1 to skip the conv2d - for (i = 1; i < depth; i++) { - size_t parent_index; - for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) { - if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break; - } - CHECK_NE(parent_index, branches[0][i]->args.size()); - if (!CheckLevel(branches, i, channel_pos, parent_index)) break; - combined = MakeCombinedCall(combined, branches, i, channel_pos, parent_index); + private: + /* \brief index of channel dimension */ + size_t channel_pos_; + + std::tuple TransformWeight(const Group& branches) { + int64_t num_filters = 0; // number of filters of the transformed weight + Array weights; + for (const auto& branch : branches) { + auto conv2d = branch[0]; + weights.push_back(conv2d->args[1]); + auto channels = GetConv2DSuperChannelsDim(conv2d); + num_filters += channels; } - UpdateGroupOutput(combined, branches, i - 1, channel_pos); + auto index = branches[0][0]->attrs.as()->kernel_layout.find('O'); + CHECK_NE(index, std::string::npos); + return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index), + MakeConstScalar(Int(32), num_filters)); } }; diff --git a/src/relay/pass/combine_parallel_dense.cc b/src/relay/pass/combine_parallel_dense.cc new file mode 100644 index 000000000000..7b00fef9bd36 --- /dev/null +++ b/src/relay/pass/combine_parallel_dense.cc @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * + * \file combine_parallel_dense.cc + * \brief Combine parallel dense ops into a single dense. + * + * This pass replaces dense ops that share the same input node, same shape, + * and don't have "units" defined with a single batch matrix multiplication. + * The inputs of the new batch_matmul is the stack of the original inputs. + * Elemwise and broadcast ops following dense are also combined if possible. + * + * This prevents launching multiple kernels in networks with multiple + * dense branches, such as BERT. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "./expr_subst.h" +#include "./pattern_util.h" +#include "./combine_parallel_op_batch.h" + +namespace tvm { +namespace relay { + +class ParallelDenseCombiner : public ParallelOpBatchCombiner { + public: + explicit ParallelDenseCombiner(uint64_t min_num_branches) + : ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) { + } + + protected: + virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { + AttrsEqual eq; + const auto* attrs_a = a->attrs.as(); + const auto* attrs_b = b->attrs.as(); + CHECK(attrs_a); + CHECK(attrs_b); + const auto* weight_a = a->args[1]->type_as(); + const auto* weight_b = b->args[1]->type_as(); + + return eq(attrs_a->out_dtype, attrs_b->out_dtype) && + eq(weight_a->shape[0], weight_b->shape[0]) && + eq(weight_a->shape[1], weight_b->shape[1]); + } +}; + +/*! \brief Combine parallel dense if number of branches >= min_num_branches */ +Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches) { + return ParallelDenseCombiner(min_num_branches).Combine(expr); +} + +namespace transform { + +Pass CombineParallelDense(uint64_t min_num_branches) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(CombineParallelDense(f, min_num_branches)); + }; + return CreateFunctionPass(pass_func, 4, "CombineParallelDense", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.CombineParallelDense") +.set_body_typed(CombineParallelDense); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/combine_parallel_op.cc b/src/relay/pass/combine_parallel_op.cc new file mode 100644 index 000000000000..35e5bff6d63c --- /dev/null +++ b/src/relay/pass/combine_parallel_op.cc @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * + * \file combine_parallel_op.cc + * \brief Abstract class to combine parallel ops and their successive element-wise ops. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "./expr_subst.h" +#include "./pattern_util.h" +#include "./combine_parallel_op.h" + + +namespace tvm { +namespace relay { + +BranchGroupFinder::BranchGroupFinder(const std::string& op_name, + FIsSupportedOp fis_supported_op, + FAreCompatibleOps fare_compatible_ops) + : op_name_(op_name), + fis_supported_op_(fis_supported_op), + fare_compatible_ops_(fare_compatible_ops) { +} + +std::vector BranchGroupFinder::Find(const Expr& expr) { + const Op& op = Op::Get(op_name_); + + this->VisitExpr(expr); + + std::vector groups; + for (const auto& root : op_roots_) { + const auto& children = children_map_.at(root); + size_t ngroups = groups.size(); + for (const CallNode* child : children) { + if (!child->op.same_as(op)) continue; + + auto&& branch = CreateBranch(child); + // add the branch to a group, or create a new group + auto it = std::find_if(groups.begin() + ngroups, groups.end(), [&](const Group& group) { + CHECK(!group.empty() && !group[0].empty()); + return fare_compatible_ops_(child, group[0][0]); + }); + if (it != groups.end()) { + it->push_back(branch); + } else { + groups.emplace_back(); + // each group has at least one branch + groups.back().push_back(branch); + } + } + } + return groups; +} + +// Create a branch starting from op. +Branch BranchGroupFinder::CreateBranch(const CallNode* op) { + auto fpattern = Op::GetAttr("TOpPattern"); + // each branch has at least one element, the first element is always op + Branch branch{op}; + auto it = children_map_.find(GetRef(branch.back())); + while (it != children_map_.end() && it->second.size() == 1) { + const CallNode* call = it->second[0]; + auto pattern = fpattern[Downcast(call->op)]; + if (pattern <= kBroadcast) { + branch.push_back(call); + it = children_map_.find(GetRef(branch.back())); + } else { + break; + } + } + return branch; +} + +void BranchGroupFinder::VisitExpr_(const CallNode* n) { + const Op& op = Op::Get(op_name_); + ExprVisitor::VisitExpr_(n); + if (n->op.same_as(op) && fis_supported_op_(n)) { + op_roots_.insert(n->args[0]); + children_map_[n->args[0]].push_back(n); + } else { + for (size_t i = 0; i < n->args.size(); i++) { + children_map_[n->args[i]].push_back(n); + } + } +} + +ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches) + : op_name_(op_name), + min_num_branches_(min_num_branches) { +} + +Expr ParallelOpCombiner::Combine(const Expr& expr) { + auto groups = BranchGroupFinder(op_name_, + [&](const CallNode* n) { + return IsSupportedOp(n); + }, + [&](const CallNode* a, const CallNode* b) { + return CanOpsBeCombined(a, b); + }).Find(expr); + for (const Group& group : groups) { + if (group.size() < min_num_branches_) { + continue; + } + CombineBranches(group); + } + return ExprSubst(expr, std::move(subst_map_)); +} + +void ParallelOpCombiner::CombineBranches(const Group& branches) { + Call combined = MakeCombinedOp(branches); + auto it = std::min_element(branches.begin(), branches.end(), + [](const Branch& branch_a, + const Branch& branch_b) { + return branch_a.size() < branch_b.size(); + }); + size_t depth = it->size(); + size_t i; + // starting from 1 to skip the op + for (i = 1; i < depth; i++) { + size_t parent_index; + for (parent_index = 0; parent_index < branches[0][i]->args.size(); parent_index++) { + if (branches[0][i]->args[parent_index].get() == branches[0][i - 1]) break; + } + CHECK_NE(parent_index, branches[0][i]->args.size()); + if (!CheckLevel(branches, i, parent_index)) break; + combined = MakeCombinedCallFromFollowingOps(combined, branches, i, parent_index); + } + UpdateGroupOutput(combined, branches, i - 1, &subst_map_); +} + +bool ParallelOpCombiner::CheckLevel(const Group& branches, size_t depth, size_t parent_index) { + const CallNode* call = branches[0][depth]; + AttrsEqual attrs_equal; + // check if all branches in current depth can be combined + for (auto it = branches.begin() + 1; it != branches.end(); it++) { + const Branch& branch = *it; + if (!branch[depth]->op.same_as(call->op) || + !attrs_equal(branch[depth]->attrs, call->attrs) || + branch[depth]->args.size() != call->args.size()) { + return false; + } + + if (branch[depth]->args[parent_index].get() != branch[depth - 1]) + return false; + + // Check args + for (size_t i = 0; i < call->args.size(); i++) { + if (i == parent_index) continue; + + if (!IsArgCompatible(call, branch[depth], i) || + !attrs_equal(call->attrs, branch[depth]->attrs)) { + return false; + } + } + } + return true; + } + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/combine_parallel_op.h b/src/relay/pass/combine_parallel_op.h new file mode 100644 index 000000000000..756dba98a707 --- /dev/null +++ b/src/relay/pass/combine_parallel_op.h @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * + * \file combine_parallel_op.h + * \brief Abstract class to combine parallel ops and their successive element-wise ops. + */ +#ifndef TVM_RELAY_PASS_COMBINE_PARALLEL_OP_H_ +#define TVM_RELAY_PASS_COMBINE_PARALLEL_OP_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "./expr_subst.h" +#include "./pattern_util.h" + + +namespace tvm { +namespace relay { + +using Branch = std::vector; +using Group = std::vector; +using FIsSupportedOp = std::function; +using FAreCompatibleOps = std::function; +using ExprSubstMap = std::unordered_map; + +/* + * Class to find parallel branches starting with op that are + * grouped if they are able to be combined. They are eligible to + * be combined if they have the same input data. + * Op can be followed by zero or more elemwise or broadcast ops, + * which are included in the group. + * Intermediate nodes have exactly one successor. It is possible that branches meet at a point, + * which should be handled in ParallelOpCombiner. + * + * data + * / \ + * op op + * | | + * elem-wise elem-wise + * | | + */ +class BranchGroupFinder : private ExprVisitor { + public: + /* + * \brief Constructor + * \param op_name name of op to start each group + * \param fis_supported_op function that returns true if op + * is supported for combining + * \param fare_compatible_ops function that returns true if + * two ops are compatible for combining + */ + BranchGroupFinder(const std::string& op_name, + FIsSupportedOp fis_supported_op, + FAreCompatibleOps fare_compatible_ops); + + /* + * \brief Finds all groups that can be combined. + * \param expr Relay expression that represents function + * to look at for groups to be combined + * \return Vector of groups which can be combined. + */ + std::vector Find(const Expr& expr); + + private: + /* \brief name of op to find parallel branches for */ + std::string op_name_; + + /* \brief function to return true if op is eligible to be combined, + * false otherwise + */ + FIsSupportedOp fis_supported_op_; + + /* \brief function to return true if two parallel ops are eligible + * to be combined, false otherwise + */ + FAreCompatibleOps fare_compatible_ops_; + + /* \brief ops that are on the first (logically, leftmost) branch + * of parallel ops and are eligible to be combined + */ + std::unordered_set op_roots_; + + /* \brief map of Expr to CallNodes that follow it */ + std::unordered_map, NodeHash, NodeEqual> children_map_; + + /* + * \brief Creates new branch from op and its children that have + * elementwise or broadcast patterns + * \return New branch + */ + Branch CreateBranch(const CallNode* op); + + /* + * \brief Expression visitor function + */ + void VisitExpr_(const CallNode* n) final; +}; + +/* + * Abstract class to find and combine parallel ops and the elementwise ops that follow. + */ +class ParallelOpCombiner { + public: + /* + * \brief Constructor. + * \param op_name name of op to combine + * \param min_num_branches min number of parallel branches beginning with op + * to start combining + */ + explicit ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches); + + /* + * \brief Combines ops and following elementwise or broadcast ops + * \param expr function to modify + * \return new function with combined ops + */ + Expr Combine(const Expr& expr); + + protected: + /* + * \brief Checks if node is supported to be combined + * \param n node in question + * \return True if the op represented by n is supported to be the root of a branch + * to be combined. False otherwise. + */ + virtual bool IsSupportedOp(const CallNode* n) = 0; + + /* + * \brief Checks if two ops can be combined + * \param a node a + * \param b node b + * \return True if a and b can be combined. False otherwise. + */ + virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b) = 0; + + /* + * \brief Makes combined op from parallel ops in branches. This usually involves + * concatenating or stacking inputs, then creating a new call. + * \param branches branches that are to be combined + * \return new call with branches combined. + */ + virtual Call MakeCombinedOp(const Group& branches) = 0; + + /* + * \brief Checks if argument of op following combined ops are able to be combined + * \param a node a + * \param b node b + * \param index index of argument in question + * \return True if argument of a and b and index can be combined + */ + virtual bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) = 0; + + /* + * \brief Create combined call from ops that follow the initial combined op at the depth-th level. + * This usually involves concatenating or stacking inputs, then creating a new call. + * Only called if IsArgCompatbile returns true for each arg. + * \param data combined op + * \param branches branches of parallel ops to be combined + * \param depth depth at which to combine ops + * \param parent_index index of arg that corresponds to original input that was shared among + * all combined ops + * \return new combined call + */ + virtual Call MakeCombinedCallFromFollowingOps(const Expr& data, + const Group& branches, + size_t depth, + size_t parent_index) = 0; + + /* + * \brief Updates map of expr to substitute with combined expr. This usually involves + * slicing or splitting data. + * \param data combined op + * \param branches branches of parallel ops to be combined + * \param depth depth at which to substitute + * \param subst_map map of Expr to replace with Expr to replace it with + */ + virtual void UpdateGroupOutput(const Expr& data, + const Group& branches, + size_t depth, + ExprSubstMap* subst_map) = 0; + + private: + /* \brief name of op to be combined */ + std::string op_name_; + + /* \brief minimum number of parallel branches to combine */ + uint64_t min_num_branches_; + + /* \brief map of Expr to Expr to substitute it with after running pass */ + ExprSubstMap subst_map_; + + /* + * \brief Combine parallel branches and updates subst_map_ with Exprs + * to be substituted + * \param branches branches to be combined + */ + void CombineBranches(const Group& branches); + + /* + * \brief Combine parallel branches and updates subst_map_ with Exprs + * to be substituted + * \param branches parallel branches to potentially be combined + * \param depth depth at which to look at op + * \param parent_index index of arg that corresponds to original input that was shared among + * all combined ops + * \return true if parallel ops at depth can be combined, false otherwise + */ + bool CheckLevel(const Group& branches, size_t depth, size_t parent_index); +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_COMBINE_PARALLEL_OP_H_ diff --git a/src/relay/pass/combine_parallel_op_batch.cc b/src/relay/pass/combine_parallel_op_batch.cc new file mode 100644 index 000000000000..235b230dfb31 --- /dev/null +++ b/src/relay/pass/combine_parallel_op_batch.cc @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * + * \file combine_parallel_op_batch.cc + * \brief Combine parallel ops into a single batch op. + * + * This pass replaces ops that share the same input node and same shape + * with a single op that takes in batched input. The inputs of the new + * batched op are the stack of the original inputs. Elementwise and + * broadcast ops following the original op are also stacked + * and fused if possible. For example: + * + * data + * / \ + * add (2,2) add (2,2) + * | | + * elemwise (2,2) elemwise (2,2) + * | | + * + * Would become: + * + * data + * | + * add+elemwise (2,2,2) + * / \ + * + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "./expr_subst.h" +#include "./pattern_util.h" +#include "./combine_parallel_op.h" +#include "./combine_parallel_op_batch.h" + +namespace tvm { +namespace relay { + +ParallelOpBatchCombiner::ParallelOpBatchCombiner(const std::string& op_name, + const std::string& batch_op_name, + uint64_t min_num_branches) + : ParallelOpCombiner(op_name, min_num_branches), + batch_op_name_(batch_op_name) { +} + +bool ParallelOpBatchCombiner::IsSupportedOp(const CallNode* n) { + return true; +} + +bool ParallelOpBatchCombiner::CanOpsBeCombined(const CallNode* a, const CallNode* b) { + if (a->args.size() != b->args.size()) { + return false; + } + + AttrsEqual eq; + for (size_t i = 0; i < a->args.size(); i++) { + auto ta = a->args[i]->type_as(); + auto tb = b->args[i]->type_as(); + if (ta->shape.size() != tb->shape.size() || !eq(ta->dtype, tb->dtype)) { + return false; + } + + for (size_t j = 0; j < ta->shape.size(); j++) { + if (!eq(ta->shape[j], tb->shape[j])) { + return false; + } + } + } + + return true; +} + +Call ParallelOpBatchCombiner::MakeCombinedOp(const Group& branches) { + const Op& batch_op = Op::Get(batch_op_name_); + + Array new_args; + size_t num_args = branches[0][0]->args.size(); + for (size_t i = 0; i < num_args; i++) { + Array arg_from_all_branches; + for (const auto& branch : branches) { + arg_from_all_branches.push_back(branch[0]->args[i]); + } + + new_args.push_back(MakeStack(TupleNode::make(arg_from_all_branches), 0)); + } + + return CallNode::make(batch_op, new_args, Attrs(), {}); +} + +bool ParallelOpBatchCombiner::IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { + AttrsEqual eq; + auto ta = a->args[index]->type_as(); + auto tb = b->args[index]->type_as(); + + if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) + return false; + + for (size_t i = 0; i < ta->shape.size(); i++) { + if (!eq(ta->shape[i], tb->shape[i])) + return false; + } + return true; +} + +Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data, + const Group& branches, + size_t depth, + size_t parent_index) { + Array new_args; + const CallNode* call = branches[0][depth]; + + for (size_t i = 0; i < call->args.size(); i++) { + if (i == parent_index) { + new_args.push_back(data); + continue; + } + + Array tuple; + for (const auto& branch : branches) { + // if the shape of the arg is of shape (j,), + // expand it to (1,j) so it can be properly broadcasted. + Expr arg = branch[depth]->args[i]; + const TensorTypeNode* arg_tensor = arg->type_as(); + if (arg_tensor->shape.size() == 1) { + Expr expanded_arg = MakeExpandDims(arg, 0, 1); + tuple.push_back(expanded_arg); + } else { + tuple.push_back(arg); + } + } + + auto stack = MakeStack(TupleNode::make(tuple), 0); + new_args.push_back(std::move(stack)); + } + + return CallNode::make(call->op, new_args, call->attrs, {}); +} + +void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data, + const Group& branches, + size_t depth, + ExprSubstMap* subst_map) { + int index = 0; + auto split = MakeSplit(data, Integer(branches.size()), 0); + for (const auto& branch : branches) { + auto split_data = TupleGetItemNode::make(split, index++); + auto squeezed_data = MakeSqueeze(split_data, {0}); + subst_map->insert({GetRef(branch[depth]), squeezed_data}); + } +} + +/*! \brief Combine parallel op into batched op if number of branches >= min_num_branches */ +Expr CombineParallelOpBatch(const Expr& expr, + const std::string& op_name, + const std::string& batch_op_name, + uint64_t min_num_branches) { + return ParallelOpBatchCombiner(op_name, batch_op_name, min_num_branches).Combine(expr); +} + +namespace transform { + +Pass CombineParallelOpBatch(const std::string& op_name, + const std::string& batch_op_name, + uint64_t min_num_branches) { + runtime::TypedPackedFunc pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast(CombineParallelOpBatch(f, + op_name, + batch_op_name, + min_num_branches)); + }; + return CreateFunctionPass(pass_func, 4, "CombineParallelOpBatch", + {ir::StringImm::make("InferType")}); +} + +TVM_REGISTER_API("relay._transform.CombineParallelOpBatch") +.set_body_typed(CombineParallelOpBatch); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/combine_parallel_op_batch.h b/src/relay/pass/combine_parallel_op_batch.h new file mode 100644 index 000000000000..84ef8d353985 --- /dev/null +++ b/src/relay/pass/combine_parallel_op_batch.h @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file combine_parallel_op_batch.cc + * \brief Combine parallel ops into a single batch op. + */ +#ifndef TVM_RELAY_PASS_COMBINE_PARALLEL_OP_BATCH_H_ +#define TVM_RELAY_PASS_COMBINE_PARALLEL_OP_BATCH_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "./expr_subst.h" +#include "./pattern_util.h" +#include "./combine_parallel_op.h" + +namespace tvm { +namespace relay { + +/* + * Class to find and combine parallel ops and following element-wise + * and broadcast ops into a single batch op. Ops can be combined + * if they have the same input data. Batch op is formed by + * stacking inputs. Final results are retrieved by splitting output. + * For example: + * + * data + * / \ + * dense (2,2) dense (2,2) + * | | + * elemwise/bcast (2,2) elemwise/bcast (2,2) + * + * Would become: + * + * data + * | + * batch_matmul+elemwise/bcast (2,2,2) + */ +class ParallelOpBatchCombiner : public ParallelOpCombiner { + public: + /* + * \brief Constructor. + * \param op_name name of op to combine + * \param batch_op_name name of op that combined branches will be joined into + * \param min_num_branches min number of parallel branches beginning with op + * to start combining + */ + ParallelOpBatchCombiner(const std::string& op_name, + const std::string& batch_op_name, + uint64_t min_num_branches); + + protected: + /* + * \brief Checks if node is supported to be combined + * \param n node in question + * \return True by default + */ + virtual bool IsSupportedOp(const CallNode* n); + + /* + * \brief Checks if two ops can be combined + * \param a node a + * \param b node b + * \return True if shapes and dtypes of all args of a and b are the same + */ + virtual bool CanOpsBeCombined(const CallNode* a, const CallNode* b); + + /* + * \brief Makes combined op from parallel ops in branches. This usually involves + * concatenating or stacking inputs, then creating a new call. + * \param branches branches that are to be combined + * \return new call with branches combined as batch op by stacking args + */ + Call MakeCombinedOp(const Group& branches) final; + + /* + * \brief Checks if argument of op following combined ops are able to be combined + * \param a node a + * \param b node b + * \param index index of argument in question + * \return True if shapes and dtypes of args[index] a and b are the same + */ + bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) final; + + /* + * \brief Create combined call from ops that follow the initial combined op at the depth-th level. + * This usually involves concatenating or stacking inputs, then creating a new call. + * Only called if IsArgCompatbile returns true for each arg. + * \param data combined op + * \param branches branches of parallel ops to be combined + * \param depth depth at which to combine ops + * \param parent_index index of arg that corresponds to original input that was shared among + * all combined ops + * \return new combined call as batch op by stacking args + */ + Call MakeCombinedCallFromFollowingOps(const Expr& data, + const Group& branches, + size_t depth, + size_t parent_index) final; + + /* + * \brief Updates map of expr to substitute with combined expr. This usually involves + * slicing or splitting data. + * \param data combined op + * \param branches branches of parallel ops to be combined + * \param depth depth at which to substitute + * \param subst_map map of Expr to replace with Expr to replace it with + */ + void UpdateGroupOutput(const Expr& data, + const Group& branches, + size_t depth, + ExprSubstMap* subst_map) final; + + private: + /* \brief name of op to replace combined ops with. for example, + * for combining parallel dense, this will will be set to + * nn.batch_matmul + */ + std::string batch_op_name_; +}; + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_PASS_COMBINE_PARALLEL_OP_BATCH_H_ diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 18e5df3e04df..d4f7ebce46d8 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -419,6 +419,14 @@ Expr MakeConcatenate(Expr data, int axis); Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides); +Expr MakeStack(Expr data, int axis); + +Expr MakeSplit(Expr data, NodeRef indices_or_sections, int axis); + +Expr MakeSqueeze(Expr data, Array axis); + +Expr MakeExpandDims(Expr data, int axis, int num_newaxis); + Expr StopFusion(Expr data); Expr CastHint(Expr data, DataType dtype); diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py new file mode 100644 index 000000000000..070ab8658b88 --- /dev/null +++ b/tests/python/relay/test_pass_combine_parallel_dense.py @@ -0,0 +1,194 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from tvm import relay +from tvm.relay import transform + + +def run_combine_parallel(expr, min_num_branches=3): + mod = relay.Module.from_expr(expr) + mod = transform.CombineParallelDense(min_num_branches)(mod) + return mod["main"] + +def run_opt_pass(expr, opt_pass): + assert isinstance(opt_pass, transform.Pass) + mod = relay.Module.from_expr(expr) + mod = opt_pass(mod) + return mod["main"] + + +def test_combine_parallel_dense(): + """Simple testcase. One dense cannot be combined due to shape mismatch""" + def before(x, w1, w2, w3, w4): + args = [x, w1, w2, w3, w4] + y1 = relay.nn.dense(x, w1) + y2 = relay.nn.dense(x, w2) + + # y3 cannot be combined + y3 = relay.nn.dense(x, w3) + + y4 = relay.nn.dense(x, w4) + y = relay.Tuple((y1, y2, y3, y4)) + return relay.Function(args, y) + + def expected(x, w1, w2, w3, w4): + # use a fixed order of args so alpha equal check can pass + args = [x, w1, w2, w3, w4] + x_stacked = relay.stack((x, x, x), axis=0) + w = relay.stack((w1, w2, w4), axis=0) + y = relay.nn.batch_matmul(x_stacked, w) + (y1, y2, y4) = relay.split(y, 3) + y1 = relay.squeeze(y1, [0]) + y2 = relay.squeeze(y2, [0]) + y4 = relay.squeeze(y4, [0]) + + # y3 cannot be combined + y3 = relay.nn.dense(x, w3) + + y = relay.Tuple((y1, y2, y3, y4)) + return relay.Function(args, y) + + def check(i, j, k): + x = relay.var("x", shape=(i, k)) + w1 = relay.var("w1", shape=(j, k)) + w2 = relay.var("w2", shape=(j, k)) + w3 = relay.var("w3", shape=(j + 1, k)) + w4 = relay.var("w4", shape=(j, k)) + + y_before = before(x, w1, w2, w3, w4) + y = run_opt_pass(y_before, + transform.CombineParallelDense(min_num_branches=2)) + y_expected = expected(x, w1, w2, w3, w4) + y_expected = run_opt_pass(y_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y, y_expected) + + check(3, 5, 4) + check(100, 200, 300) + + +def test_combine_parallel_dense_biasadd(): + """Testcase of combining dense + 1d biasadd""" + def before(x, w1, w2, b1, b2): + args = [x, w1, w2, b1, b2] + y1 = relay.nn.dense(x, w1) + y2 = relay.nn.dense(x, w2) + y1 = relay.add(y1, b1) + y2 = relay.add(y2, b2) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def expected(x, w1, w2, b1, b2, is_2d_bias): + args = [x, w1, w2, b1, b2] + x_stacked = relay.stack((x, x), axis=0) + w = relay.stack((w1, w2), axis=0) + y = relay.nn.batch_matmul(x_stacked, w) + + if not is_2d_bias: + b1 = relay.expand_dims(b1, 0) + b2 = relay.expand_dims(b2, 0) + + b = relay.stack((b1, b2), axis=0) + y = relay.add(y, b) + (y1, y2) = relay.split(y, 2) + y1 = relay.squeeze(y1, [0]) + y2 = relay.squeeze(y2, [0]) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def check(i, j, k, is_2d_bias): + x = relay.var("x", shape=(i, k)) + w1 = relay.var("w1", shape=(j, k)) + w2 = relay.var("w2", shape=(j, k)) + + if is_2d_bias: + b1 = relay.var("b1", shape=(i, j)) + b2 = relay.var("b2", shape=(i, j)) + else: + b1 = relay.var("b1", shape=(j,)) + b2 = relay.var("b2", shape=(j,)) + + y_before = before(x, w1, w2, b1, b2) + y = run_opt_pass(y_before, + transform.CombineParallelDense(min_num_branches=2)) + y_expected = expected(x, w1, w2, b1, b2, is_2d_bias) + y_expected = run_opt_pass(y_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y, y_expected) + + check(3, 5, 4, False) + check(100, 200, 300, False) + check(3, 5, 4, True) + check(100, 200, 300, True) + +def test_combine_parallel_dense_biasadd_scale_reshape(): + """Testcase of combining dense + 1d biasadd + multiply with non-fused reshape""" + def before(x, w1, w2, b1, b2, scale1, scale2, newshape): + args = [x, w1, w2, b1, b2, scale1, scale2] + y1 = relay.nn.dense(x, w1) + y2 = relay.nn.dense(x, w2) + y1 = relay.add(y1, b1) + y2 = relay.add(y2, b2) + y1 = relay.multiply(y1, scale1) + y2 = relay.multiply(y2, scale2) + y1 = relay.reshape(y1, newshape=newshape) + y2 = relay.reshape(y2, newshape=newshape) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def expected(x, w1, w2, b1, b2, scale1, scale2, newshape): + args = [x, w1, w2, b1, b2, scale1, scale2] + x_stacked = relay.stack((x, x), axis=0) + w = relay.stack((w1, w2), axis=0) + y = relay.nn.batch_matmul(x_stacked, w) + b1 = relay.expand_dims(b1, 0) + b2 = relay.expand_dims(b2, 0) + b = relay.stack((b1, b2), axis=0) + y = relay.add(y, b) + scale1 = relay.expand_dims(scale1, 0) + scale2 = relay.expand_dims(scale2, 0) + scale = relay.stack((scale1, scale2), axis=0) + y = relay.multiply(y, scale) + (y1, y2) = relay.split(y, 2) + y1 = relay.squeeze(y1, [0]) + y2 = relay.squeeze(y2, [0]) + y1 = relay.reshape(y1, newshape=newshape) + y2 = relay.reshape(y2, newshape=newshape) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def check(i, j, k, scale1, scale2, newshape): + x = relay.var("x", shape=(i, k)) + w1 = relay.var("w1", shape=(j, k)) + w2 = relay.var("w2", shape=(j, k)) + b1 = relay.var("b1", shape=(j,)) + b2 = relay.var("b2", shape=(j,)) + scale1 = relay.var("scale1", shape=(1,)) + scale2 = relay.var("scale2", shape=(1,)) + + y_before = before(x, w1, w2, b1, b2, scale1, scale2, newshape) + y = run_opt_pass(y_before, + transform.CombineParallelDense(min_num_branches=2)) + y_expected = expected(x, w1, w2, b1, b2, scale1, scale2, newshape) + y_expected = run_opt_pass(y_expected, transform.InferType()) + assert relay.analysis.alpha_equal(y, y_expected) + + check(3, 5, 4, 0.5, 0.25, (1, 1, 15)) + check(100, 200, 300, 0.5, 0.25, (1, 1, 200)) + + +if __name__ == "__main__": + test_combine_parallel_dense() + test_combine_parallel_dense_biasadd() + test_combine_parallel_dense_biasadd_scale_reshape()