Skip to content

Commit

Permalink
feat: Support combine multiple matmuls to flat matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif committed Jul 15, 2020
1 parent 9fcde21 commit 2170bed
Show file tree
Hide file tree
Showing 6 changed files with 364 additions and 13 deletions.
3 changes: 2 additions & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,11 @@ TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);
* `min_num_branch`.
*
* \param min_num_branches The minimun number of branches.
* \param to_batch Combine matmuls to batch matmul.
*
* \return The pass.
*/
TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3);
TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3, bool to_batch = true);

/*!
* \brief Combine parallel batch_matmul ops into a single batch_matmul
Expand Down
15 changes: 13 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def CombineParallelConv2D(min_num_branches=3):
return _ffi_api.CombineParallelConv2D(min_num_branches)


def CombineParallelDense(min_num_branches=3):
def CombineParallelDense(min_num_branches=3, to_batch=True):
"""Combine multiple dense operators into one. For example:
.. code-block
Expand All @@ -295,18 +295,29 @@ def CombineParallelDense(min_num_branches=3):
|
batch_matmul+elemwise/bcast (2,2,2)
or (if to_batch=False)
.. code-block
data
|
dense+elemwise/bcast (2,2+2)
Parameters
----------
min_num_branches : int
The minimum number of required parallel branches for performing this
optimization.
to_batch : bool
Whether convert multiple dense into batch_matmul.
Returns
-------
ret: tvm.transform.Pass
The registered pass that combines parallel dense operators.
"""
return _ffi_api.CombineParallelDense(min_num_branches)
return _ffi_api.CombineParallelDense(min_num_branches, to_batch)

def CombineParallelBatchMatmul(min_num_branches=3):
"""Combine multiple batch matmul operators into one. For example:
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ class RelayBuildModule : public runtime::ModuleNode {
});
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::CombineParallelDense(3, true));
pass_seqs.push_back(transform::CombineParallelBatchMatmul(3));
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::InlinePrimitives());

pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::CombineParallelDense(3, true));
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast());
Expand Down
166 changes: 160 additions & 6 deletions src/relay/transforms/combine_parallel_dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,12 @@
namespace tvm {
namespace relay {

class ParallelDenseCombiner : public ParallelOpBatchCombiner {
/*
* Class that find and combine parallel dense ops into batch_matmul.
*/
class ParallelDenseBatchCombiner : public ParallelOpBatchCombiner {
public:
explicit ParallelDenseCombiner(uint64_t min_num_branches)
explicit ParallelDenseBatchCombiner(uint64_t min_num_branches)
: ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {}

protected:
Expand All @@ -68,17 +71,168 @@ class ParallelDenseCombiner : public ParallelOpBatchCombiner {
}
};

/*
* Class that find and combine parallel dense ops into one dense op
* whose num of output units equals to sum of each sub-ops.
*/
class ParallelDenseFlatCombiner : public ParallelOpCombiner {
public:
explicit ParallelDenseFlatCombiner(uint64_t min_num_branches)
: ParallelOpCombiner("nn.dense", min_num_branches) {}

protected:
bool IsSupportedOp(const CallNode* n) { return true; }

bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
StructuralEqual eq;
const auto* attrs_a = a->attrs.as<DenseAttrs>();
const auto* attrs_b = b->attrs.as<DenseAttrs>();
const auto* weight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* weight_b = b->args[1]->type_as<TensorTypeNode>();
CHECK(attrs_a != nullptr && attrs_b != nullptr && weight_a != nullptr && weight_b != nullptr);
// output dims (weight->shape[0]) can be different
return eq(attrs_a->out_dtype, attrs_b->out_dtype) && eq(weight_a->shape[1], weight_b->shape[1]);
}

Call MakeCombinedOp(const Group& branches) {
const Op& dense_op = Op::Get("nn.dense");
Expr input = branches[0][0]->args[0];
Expr new_weight;
IndexExpr new_output_dims;
// concat all weights into one
std::tie(new_weight, new_output_dims) = TransformWeight(branches);
const auto* origin_attrs = branches[0][0]->attrs.as<DenseAttrs>();
CHECK(origin_attrs);
const auto dense_attrs = make_object<DenseAttrs>();
dense_attrs->units = new_output_dims;
dense_attrs->out_dtype = origin_attrs->out_dtype;
return Call(dense_op, {input, new_weight}, Attrs{dense_attrs}, {});
}

bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
StructuralEqual eq;
auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>();
auto toutput_a = a->type_as<TensorTypeNode>();
auto toutput_b = b->type_as<TensorTypeNode>();
CHECK(ta != nullptr && tb != nullptr && toutput_a != nullptr && toutput_b != nullptr);

if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) {
return false;
}
if (toutput_a->shape.size() < ta->shape.size() || toutput_b->shape.size() < tb->shape.size()) {
return false; // not broadcast/elemwise
}
if (ta->shape.size() > 0) {
for (size_t i = 0; i < ta->shape.size() - 1; i++) {
// shape dims must match except last dim
if (!eq(ta->shape[i], tb->shape[i])) return false;
}
}
return true;
}

Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth,
size_t parent_index) {
Array<Expr> new_args;
const CallNode* call = branches[0][depth];
for (size_t i = 0; i < call->args.size(); i++) {
if (i == parent_index) {
new_args.push_back(data);
continue;
}
size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size();
size_t concat_axis = arg_ndim == 0 ? 0 : arg_ndim - 1;
Array<Expr> tuple;
for (const auto& branch : branches) {
auto parent = branch[depth]->args[parent_index];
auto& parent_shape = parent->type_as<TensorTypeNode>()->shape;
auto out_dim = tir::as_const_int(parent_shape[parent_shape.size() - 1]);
CHECK(out_dim != nullptr);

auto arg = branch[depth]->args[i];
auto& arg_shape = arg->type_as<TensorTypeNode>()->shape;
bool repeat_last_dim = false;
if (arg_ndim == 0) {
repeat_last_dim = true;
arg = MakeExpandDims(arg, -1, 1);
} else {
auto arg_last_dim = tir::as_const_int(arg_shape[arg_shape.size() - 1]);
CHECK(arg_last_dim != nullptr);
if (*out_dim > 1 && *arg_last_dim == 1) {
repeat_last_dim = true;
}
}
if (repeat_last_dim) {
// ensure broadcast is valid after concat args
arg = MakeRepeat(arg, *out_dim, concat_axis);
}
tuple.push_back(arg);
}
auto concat = MakeConcatenate(Tuple(tuple), concat_axis);
new_args.push_back(std::move(concat));
}
return Call(call->op, new_args, call->attrs, {});
}

void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
ExprSubstMap* subst_map) {
int index = 0;
for (const auto& branch : branches) {
const CallNode* call = branch[depth];
auto& out_shape = call->type_as<TensorTypeNode>()->shape;
auto out_dims = tir::as_const_int(out_shape[out_shape.size() - 1]);
CHECK(out_dims != nullptr);
std::vector<int64_t> begin;
std::vector<int64_t> end;
std::vector<int64_t> strides;
for (size_t k = 0; k < out_shape.size() - 1; ++k) {
begin.push_back(0);
end.push_back(-1);
strides.push_back(1);
}
begin.push_back(index);
end.push_back(*out_dims);
strides.push_back(1);
index += *out_dims;
std::vector<int64_t> ndarray_shape = {static_cast<int64_t>(begin.size())};
Constant begin_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, begin);
Constant end_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, end);
Constant strides_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, strides);
auto slice = MakeStridedSlice(data, begin_const, end_const, strides_const, "size");
subst_map->insert({GetRef<Expr>(branch[depth]), slice});
}
}

private:
std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
int64_t out_dims = 0;
Array<Expr> weights;
for (const auto& branch : branches) {
auto weight = branch[0]->args[1];
weights.push_back(weight);
out_dims += *tir::as_const_int(weight->type_as<TensorTypeNode>()->shape[0]);
}
return std::make_tuple(MakeConcatenate(Tuple(weights), 0),
tir::make_const(DataType::Int(32), out_dims));
}
};

/*! \brief Combine parallel dense if number of branches >= min_num_branches */
Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches) {
return ParallelDenseCombiner(min_num_branches).Combine(expr);
Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches, bool to_batch) {
if (to_batch) {
return ParallelDenseBatchCombiner(min_num_branches).Combine(expr);
} else {
return ParallelDenseFlatCombiner(min_num_branches).Combine(expr);
}
}

namespace transform {

Pass CombineParallelDense(uint64_t min_num_branches) {
Pass CombineParallelDense(uint64_t min_num_branches, bool to_batch) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelDense(f, min_num_branches));
return Downcast<Function>(CombineParallelDense(f, min_num_branches, to_batch));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"});
}
Expand Down
Loading

0 comments on commit 2170bed

Please sign in to comment.