diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index cf14febb02c1f..d322710ec95a3 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -238,10 +238,12 @@ TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3); * `min_num_branch`. * * \param min_num_branches The minimun number of branches. + * \param to_batch_matmul Whether to combine parallel dense ops to batch matmul. + * If set false, combine dense ops to single dense op. * * \return The pass. */ -TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3); +TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3, bool to_batch_matmul = true); /*! * \brief Combine parallel batch_matmul ops into a single batch_matmul diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 3abc3822f0ef6..cc92141b73db4 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -277,7 +277,7 @@ def CombineParallelConv2D(min_num_branches=3): return _ffi_api.CombineParallelConv2D(min_num_branches) -def CombineParallelDense(min_num_branches=3): +def CombineParallelDense(min_num_branches=3, to_batch=True): """Combine multiple dense operators into one. For example: .. code-block @@ -295,18 +295,30 @@ def CombineParallelDense(min_num_branches=3): | batch_matmul+elemwise/bcast (2,2,2) + or (if to_batch=False) + + .. code-block + + data + | + dense+elemwise/bcast (2,2+2) + Parameters ---------- min_num_branches : int The minimum number of required parallel branches for performing this optimization. + to_batch_matmul : bool + If True, combine parallel dense ops into batch_matmul op. + If False, combine parallel dense ops into dense op. + Returns ------- ret: tvm.transform.Pass The registered pass that combines parallel dense operators. """ - return _ffi_api.CombineParallelDense(min_num_branches) + return _ffi_api.CombineParallelDense(min_num_branches, to_batch) def CombineParallelBatchMatmul(min_num_branches=3): """Combine multiple batch matmul operators into one. For example: @@ -342,6 +354,21 @@ def CombineParallelBatchMatmul(min_num_branches=3): return _ffi_api.CombineParallelBatchMatmul(min_num_branches) +def BatchingOps(): + """Batching parallel operators into one for Conv2D, Dense and BatchMatmul. + + Returns + ------- + ret: tvm.transform.Pass + The sequential pass which apply batching for different operator types. + """ + return tvm.transform.Sequential([ + CombineParallelConv2D(), + CombineParallelDense(), + CombineParallelBatchMatmul() + ]) + + def AlterOpLayout(): """Alternate the layouts of operators or replace primitive operators with other expressions. diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index 8613dbe1466e8..aec4315cb083f 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -48,9 +48,12 @@ namespace tvm { namespace relay { -class ParallelDenseCombiner : public ParallelOpBatchCombiner { +/* + * Class that find and combine parallel dense ops into batch_matmul. + */ +class ParallelDenseToBatchCombiner : public ParallelOpBatchCombiner { public: - explicit ParallelDenseCombiner(uint64_t min_num_branches) + explicit ParallelDenseToBatchCombiner(uint64_t min_num_branches) : ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {} protected: @@ -68,17 +71,168 @@ class ParallelDenseCombiner : public ParallelOpBatchCombiner { } }; +/* + * Class that find and combine parallel dense ops into one dense op + * whose num of output units equals to sum of each sub-ops. + */ +class ParallelDenseToDenseCombiner : public ParallelOpCombiner { + public: + explicit ParallelDenseToDenseCombiner(uint64_t min_num_branches) + : ParallelOpCombiner("nn.dense", min_num_branches) {} + + protected: + bool IsSupportedOp(const CallNode* n) { return true; } + + bool CanOpsBeCombined(const CallNode* a, const CallNode* b) { + StructuralEqual eq; + const auto* attrs_a = a->attrs.as(); + const auto* attrs_b = b->attrs.as(); + const auto* weight_a = a->args[1]->type_as(); + const auto* weight_b = b->args[1]->type_as(); + CHECK(attrs_a != nullptr && attrs_b != nullptr && weight_a != nullptr && weight_b != nullptr); + // output dims (weight->shape[0]) can be different + return eq(attrs_a->out_dtype, attrs_b->out_dtype) && eq(weight_a->shape[1], weight_b->shape[1]); + } + + Call MakeCombinedOp(const Group& branches) { + const Op& dense_op = Op::Get("nn.dense"); + Expr input = branches[0][0]->args[0]; + Expr new_weight; + IndexExpr new_output_dims; + // concat all weights into one + std::tie(new_weight, new_output_dims) = TransformWeight(branches); + const auto* origin_attrs = branches[0][0]->attrs.as(); + CHECK(origin_attrs); + const auto dense_attrs = make_object(); + dense_attrs->units = new_output_dims; + dense_attrs->out_dtype = origin_attrs->out_dtype; + return Call(dense_op, {input, new_weight}, Attrs{dense_attrs}, {}); + } + + bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) { + StructuralEqual eq; + auto ta = a->args[index]->type_as(); + auto tb = b->args[index]->type_as(); + auto toutput_a = a->type_as(); + auto toutput_b = b->type_as(); + CHECK(ta != nullptr && tb != nullptr && toutput_a != nullptr && toutput_b != nullptr); + + if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) { + return false; + } + if (toutput_a->shape.size() < ta->shape.size() || toutput_b->shape.size() < tb->shape.size()) { + return false; // not broadcast/elemwise + } + if (ta->shape.size() > 0) { + for (size_t i = 0; i < ta->shape.size() - 1; i++) { + // shape dims must match except last dim + if (!eq(ta->shape[i], tb->shape[i])) return false; + } + } + return true; + } + + Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth, + size_t parent_index) { + Array new_args; + const CallNode* call = branches[0][depth]; + for (size_t i = 0; i < call->args.size(); i++) { + if (i == parent_index) { + new_args.push_back(data); + continue; + } + size_t arg_ndim = call->args[i]->type_as()->shape.size(); + size_t concat_axis = arg_ndim == 0 ? 0 : arg_ndim - 1; + Array tuple; + for (const auto& branch : branches) { + auto parent = branch[depth]->args[parent_index]; + auto& parent_shape = parent->type_as()->shape; + auto out_dim = tir::as_const_int(parent_shape[parent_shape.size() - 1]); + CHECK(out_dim != nullptr); + + auto arg = branch[depth]->args[i]; + auto& arg_shape = arg->type_as()->shape; + bool repeat_last_dim = false; + if (arg_ndim == 0) { + repeat_last_dim = true; + arg = MakeExpandDims(arg, -1, 1); + } else { + auto arg_last_dim = tir::as_const_int(arg_shape[arg_shape.size() - 1]); + CHECK(arg_last_dim != nullptr); + if (*out_dim > 1 && *arg_last_dim == 1) { + repeat_last_dim = true; + } + } + if (repeat_last_dim) { + // ensure broadcast is valid after concat args + arg = MakeRepeat(arg, *out_dim, concat_axis); + } + tuple.push_back(arg); + } + auto concat = MakeConcatenate(Tuple(tuple), concat_axis); + new_args.push_back(std::move(concat)); + } + return Call(call->op, new_args, call->attrs, {}); + } + + void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth, + ExprSubstMap* subst_map) { + int index = 0; + for (const auto& branch : branches) { + const CallNode* call = branch[depth]; + auto& out_shape = call->type_as()->shape; + auto out_dims = tir::as_const_int(out_shape[out_shape.size() - 1]); + CHECK(out_dims != nullptr); + std::vector begin; + std::vector end; + std::vector strides; + for (size_t k = 0; k < out_shape.size() - 1; ++k) { + begin.push_back(0); + end.push_back(-1); + strides.push_back(1); + } + begin.push_back(index); + end.push_back(*out_dims); + strides.push_back(1); + index += *out_dims; + std::vector ndarray_shape = {static_cast(begin.size())}; + Constant begin_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, begin); + Constant end_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, end); + Constant strides_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, strides); + auto slice = MakeStridedSlice(data, begin_const, end_const, strides_const, "size"); + subst_map->insert({GetRef(branch[depth]), slice}); + } + } + + private: + std::tuple TransformWeight(const Group& branches) { + int64_t out_dims = 0; + Array weights; + for (const auto& branch : branches) { + auto weight = branch[0]->args[1]; + weights.push_back(weight); + out_dims += *tir::as_const_int(weight->type_as()->shape[0]); + } + return std::make_tuple(MakeConcatenate(Tuple(weights), 0), + tir::make_const(DataType::Int(32), out_dims)); + } +}; + /*! \brief Combine parallel dense if number of branches >= min_num_branches */ -Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches) { - return ParallelDenseCombiner(min_num_branches).Combine(expr); +Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches, bool to_batch) { + if (to_batch) { + return ParallelDenseToBatchCombiner(min_num_branches).Combine(expr); + } else { + return ParallelDenseToDenseCombiner(min_num_branches).Combine(expr); + } } namespace transform { -Pass CombineParallelDense(uint64_t min_num_branches) { +Pass CombineParallelDense(uint64_t min_num_branches, bool to_batch_matmul) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(CombineParallelDense(f, min_num_branches)); + return Downcast(CombineParallelDense(f, min_num_branches, to_batch_matmul)); }; return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"}); } diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py index 12beafb2c578d..a8d7f11c471a3 100644 --- a/tests/python/relay/test_pass_combine_parallel_dense.py +++ b/tests/python/relay/test_pass_combine_parallel_dense.py @@ -20,9 +20,9 @@ from tvm.relay import transform -def run_combine_parallel(expr, min_num_branches=3): +def run_combine_parallel(expr, min_num_branches=3, to_batch=True): mod = tvm.IRModule.from_expr(expr) - mod = transform.CombineParallelDense(min_num_branches)(mod) + mod = transform.CombineParallelDense(min_num_branches, to_batch)(mod) return mod["main"] def run_opt_pass(expr, opt_pass): @@ -190,7 +190,192 @@ def check(i, j, k, scale1, scale2, newshape): check(100, 200, 300, 0.5, 0.25, (1, 1, 200)) +def test_combine_parallel_dense_flat(): + """Simple testcase. All matmul of different output dim can be combined""" + def before(x, w1, w2, w3): + args = [x, w1, w2, w3] + y1 = relay.nn.dense(x, w1) + y2 = relay.nn.dense(x, w2) + y3 = relay.nn.dense(x, w3) + y = relay.Tuple((y1, y2, y3)) + return relay.Function(args, y) + + def expected(x, w1, w2, w3, j): + args = [x, w1, w2, w3] + w_stacked = relay.concatenate((w1, w2, w3), axis=0) + y = relay.nn.dense(x, w_stacked, units=6 * j) + strides = relay.const([1, 1], 'int64') + y1 = relay.strided_slice(y, + begin=relay.const([0, 0], "int64"), + end=relay.const([-1, j], "int64"), + strides=strides, slice_mode="size") + y2 = relay.strided_slice(y, + begin=relay.const([0, j], "int64"), + end=relay.const([-1, 2 * j], "int64"), + strides=strides, slice_mode="size") + y3 = relay.strided_slice(y, + begin=relay.const([0, 3 * j], "int64"), + end=relay.const([-1, 3 * j], "int64"), + strides=strides, slice_mode="size") + y = relay.Tuple((y1, y2, y3)) + return relay.Function(args, y) + + def check(i, j, k): + x = relay.var("x", shape=(i, k)) + w1 = relay.var("w1", shape=(j, k)) + w2 = relay.var("w2", shape=(2 * j, k)) + w3 = relay.var("w3", shape=(3 * j, k)) + + y_before = before(x, w1, w2, w3) + combine_pass = transform.CombineParallelDense(min_num_branches=3, + to_batch=False) + y = run_opt_pass(y_before, combine_pass) + y_expected = expected(x, w1, w2, w3, j) + y_expected = run_opt_pass(y_expected, transform.InferType()) + tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True) + + check(3, 5, 4) + check(100, 200, 300) + + +def test_combine_parallel_dense_flat_biasadd(): + """Testcase of combining dense + 1d biasadd with different out dims""" + def before(x, w1, w2, b1, b2): + args = [x, w1, w2, b1, b2] + y1 = relay.nn.dense(x, w1) + y2 = relay.nn.dense(x, w2) + y1 = relay.add(y1, b1) + y2 = relay.add(y2, b2) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def expected(x, w1, w2, b1, b2, j, bias_shape1, bias_shape2): + args = [x, w1, w2, b1, b2] + w_stacked = relay.concatenate((w1, w2), axis=0) + y = relay.nn.dense(x, w_stacked, units=3 * j) + n_out_dims = max(len(bias_shape1), 2) + if len(bias_shape1) == 0: + b1 = relay.repeat(relay.expand_dims(b1, -1), j, 0) + elif bias_shape1[-1] == 1: + b1 = relay.repeat(b1, j, len(bias_shape1) - 1) + if len(bias_shape2) == 0: + b2 = relay.repeat(relay.expand_dims(b2, -1), 2 * j, 0) + elif bias_shape2[-1] == 1: + b2 = relay.repeat(b2, 2 * j, len(bias_shape2) - 1) + b = relay.concatenate((b1, b2), axis=max(0, len(bias_shape1) - 1)) + y = relay.add(y, b) + begin = [0 for _ in range(n_out_dims - 1)] + end = [-1 for _ in range(n_out_dims - 1)] + strides = [1 for _ in range(n_out_dims)] + y1 = relay.strided_slice(y, + begin=relay.const(begin + [0], "int64"), + end=relay.const(end + [j], "int64"), + strides=relay.const(strides, "int64"), + slice_mode="size") + y2 = relay.strided_slice(y, + begin=relay.const(begin + [j], "int64"), + end=relay.const(end + [2 * j], "int64"), + strides=relay.const(strides, "int64"), + slice_mode="size") + return relay.Function(args, relay.Tuple((y1, y2))) + + def check(i, j, k, bias_shape1, bias_shape2): + x = relay.var("x", shape=(i, k)) + w1 = relay.var("w1", shape=(j, k)) + w2 = relay.var("w2", shape=(2 * j, k)) + b1 = relay.var("b1", shape=bias_shape1) + b2 = relay.var("b2", shape=bias_shape2) + + y_before = before(x, w1, w2, b1, b2) + combine_pass = transform.CombineParallelDense(min_num_branches=2, + to_batch=False) + y = run_opt_pass(y_before, combine_pass) + y_expected = expected(x, w1, w2, b1, b2, j, bias_shape1, bias_shape2) + y_expected = run_opt_pass(y_expected, transform.InferType()) + print(y.astext(False)) + print(y_expected.astext(False)) + tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True) + + check(3, 5, 4, (), ()) + check(3, 5, 4, (1,), (1,)) + check(3, 5, 4, (5,), (1,)) + check(3, 5, 4, (1,), (10,)) + check(3, 5, 4, (3, 1), (3, 1)) + check(3, 5, 4, (3, 5), (3, 10)) + check(3, 5, 4, (3, 1), (3, 10)) + check(3, 5, 4, (3, 5), (3, 1)) + check(3, 5, 4, (9, 3, 5), (9, 3, 10)) + check(3, 5, 4, (9, 3, 5), (9, 3, 1)) + check(3, 5, 4, (9, 3, 1), (9, 3, 10)) + +def test_combine_parallel_dense_flat_biasadd_scale_reshape(): + """Testcase of combining dense with different out dims + following bias add, scale, reshape ops + """ + def before(x, w1, w2, b1, b2, scale1, scale2, newshape1, newshape2): + args = [x, w1, w2, b1, b2, scale1, scale2] + y1 = relay.nn.dense(x, w1) + y2 = relay.nn.dense(x, w2) + y1 = relay.add(y1, b1) + y2 = relay.add(y2, b2) + y1 = relay.multiply(y1, scale1) + y2 = relay.multiply(y2, scale2) + y1 = relay.reshape(y1, newshape=newshape1) + y2 = relay.reshape(y2, newshape=newshape2) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def expected(x, w1, w2, b1, b2, scale1, scale2, newshape1, newshape2, j): + args = [x, w1, w2, b1, b2, scale1, scale2] + w_stacked = relay.concatenate((w1, w2), axis=0) + y = relay.nn.dense(x, w_stacked, units=3*j) + b = relay.concatenate((b1, b2), axis=0) + y = relay.add(y, b) + scale1 = relay.repeat(scale1, j, 0) + scale2 = relay.repeat(scale2, 2 * j, 0) + scale = relay.concatenate((scale1, scale2), axis=0) + y = relay.multiply(y, scale) + strides = relay.const([1, 1], 'int64') + y1 = relay.strided_slice(y, + begin=relay.const([0, 0], "int64"), + end=relay.const([-1, j], "int64"), + strides=strides, slice_mode="size") + y2 = relay.strided_slice(y, + begin=relay.const([0, j], "int64"), + end=relay.const([-1, 2 * j], "int64"), + strides=strides, slice_mode="size") + y1 = relay.reshape(y1, newshape=newshape1) + y2 = relay.reshape(y2, newshape=newshape2) + y = relay.Tuple((y1, y2)) + return relay.Function(args, y) + + def check(i, j, k, scale1, scale2, newshape1, newshape2): + x = relay.var("x", shape=(i, k)) + w1 = relay.var("w1", shape=(j, k)) + w2 = relay.var("w2", shape=(2 * j, k)) + b1 = relay.var("b1", shape=(j,)) + b2 = relay.var("b2", shape=(2 * j,)) + scale1 = relay.var("scale1", shape=(1,)) + scale2 = relay.var("scale2", shape=(1,)) + + y_before = before(x, w1, w2, b1, b2, scale1, scale2, + newshape1, newshape2) + combine_pass = transform.CombineParallelDense(min_num_branches=2, + to_batch=False) + y = run_opt_pass(y_before, combine_pass) + y_expected = expected(x, w1, w2, b1, b2, scale1, scale2, + newshape1, newshape2, j) + y_expected = run_opt_pass(y_expected, transform.InferType()) + tvm.ir.assert_structural_equal(y, y_expected, map_free_vars=True) + + check(3, 5, 4, 0.5, 0.25, (1, 1, 15), (1, 1, 30)) + check(100, 200, 300, 0.5, 0.25, (1, 1, 200), (1, 1, 400)) + + if __name__ == "__main__": test_combine_parallel_dense() test_combine_parallel_dense_biasadd() test_combine_parallel_dense_biasadd_scale_reshape() + test_combine_parallel_dense_flat() + test_combine_parallel_dense_flat_biasadd() + test_combine_parallel_dense_flat_biasadd_scale_reshape()