Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Pass] Support combine multiple dense op just into dense #6062

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,11 @@ TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);
* `min_num_branch`.
*
* \param min_num_branches The minimun number of branches.
* \param to_batch Combine matmuls to batch matmul.
*
* \return The pass.
*/
TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3);
TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3, bool to_batch = true);
wrongtest-intellif marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Combine parallel batch_matmul ops into a single batch_matmul
Expand Down
15 changes: 13 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def CombineParallelConv2D(min_num_branches=3):
return _ffi_api.CombineParallelConv2D(min_num_branches)


def CombineParallelDense(min_num_branches=3):
def CombineParallelDense(min_num_branches=3, to_batch=True):
"""Combine multiple dense operators into one. For example:

.. code-block
Expand All @@ -295,18 +295,29 @@ def CombineParallelDense(min_num_branches=3):
|
batch_matmul+elemwise/bcast (2,2,2)

or (if to_batch=False)

.. code-block

data
|
dense+elemwise/bcast (2,2+2)

Parameters
----------
min_num_branches : int
The minimum number of required parallel branches for performing this
optimization.

to_batch : bool
Whether convert multiple dense into batch_matmul.

Returns
-------
ret: tvm.transform.Pass
The registered pass that combines parallel dense operators.
"""
return _ffi_api.CombineParallelDense(min_num_branches)
return _ffi_api.CombineParallelDense(min_num_branches, to_batch)

def CombineParallelBatchMatmul(min_num_branches=3):
"""Combine multiple batch matmul operators into one. For example:
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ class RelayBuildModule : public runtime::ModuleNode {
});
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::CombineParallelDense(3, true));
wrongtest-intellif marked this conversation as resolved.
Show resolved Hide resolved
pass_seqs.push_back(transform::CombineParallelBatchMatmul(3));
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::InlinePrimitives());

pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::CombineParallelDense(3, true));
wrongtest-intellif marked this conversation as resolved.
Show resolved Hide resolved
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast());
Expand Down
166 changes: 160 additions & 6 deletions src/relay/transforms/combine_parallel_dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,12 @@
namespace tvm {
namespace relay {

class ParallelDenseCombiner : public ParallelOpBatchCombiner {
/*
* Class that find and combine parallel dense ops into batch_matmul.
*/
class ParallelDenseBatchCombiner : public ParallelOpBatchCombiner {
public:
explicit ParallelDenseCombiner(uint64_t min_num_branches)
explicit ParallelDenseBatchCombiner(uint64_t min_num_branches)
: ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {}

protected:
Expand All @@ -68,17 +71,168 @@ class ParallelDenseCombiner : public ParallelOpBatchCombiner {
}
};

/*
* Class that find and combine parallel dense ops into one dense op
* whose num of output units equals to sum of each sub-ops.
*/
class ParallelDenseFlatCombiner : public ParallelOpCombiner {
public:
explicit ParallelDenseFlatCombiner(uint64_t min_num_branches)
: ParallelOpCombiner("nn.dense", min_num_branches) {}

protected:
bool IsSupportedOp(const CallNode* n) { return true; }

bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
StructuralEqual eq;
const auto* attrs_a = a->attrs.as<DenseAttrs>();
const auto* attrs_b = b->attrs.as<DenseAttrs>();
const auto* weight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* weight_b = b->args[1]->type_as<TensorTypeNode>();
CHECK(attrs_a != nullptr && attrs_b != nullptr && weight_a != nullptr && weight_b != nullptr);
// output dims (weight->shape[0]) can be different
return eq(attrs_a->out_dtype, attrs_b->out_dtype) && eq(weight_a->shape[1], weight_b->shape[1]);
}

Call MakeCombinedOp(const Group& branches) {
const Op& dense_op = Op::Get("nn.dense");
Expr input = branches[0][0]->args[0];
Expr new_weight;
IndexExpr new_output_dims;
// concat all weights into one
std::tie(new_weight, new_output_dims) = TransformWeight(branches);
const auto* origin_attrs = branches[0][0]->attrs.as<DenseAttrs>();
CHECK(origin_attrs);
const auto dense_attrs = make_object<DenseAttrs>();
dense_attrs->units = new_output_dims;
dense_attrs->out_dtype = origin_attrs->out_dtype;
return Call(dense_op, {input, new_weight}, Attrs{dense_attrs}, {});
}

bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
StructuralEqual eq;
auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>();
auto toutput_a = a->type_as<TensorTypeNode>();
auto toutput_b = b->type_as<TensorTypeNode>();
CHECK(ta != nullptr && tb != nullptr && toutput_a != nullptr && toutput_b != nullptr);

if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) {
return false;
}
if (toutput_a->shape.size() < ta->shape.size() || toutput_b->shape.size() < tb->shape.size()) {
return false; // not broadcast/elemwise
}
if (ta->shape.size() > 0) {
for (size_t i = 0; i < ta->shape.size() - 1; i++) {
// shape dims must match except last dim
if (!eq(ta->shape[i], tb->shape[i])) return false;
}
}
return true;
}

Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth,
size_t parent_index) {
Array<Expr> new_args;
const CallNode* call = branches[0][depth];
for (size_t i = 0; i < call->args.size(); i++) {
if (i == parent_index) {
new_args.push_back(data);
continue;
}
size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size();
size_t concat_axis = arg_ndim == 0 ? 0 : arg_ndim - 1;
Array<Expr> tuple;
for (const auto& branch : branches) {
auto parent = branch[depth]->args[parent_index];
auto& parent_shape = parent->type_as<TensorTypeNode>()->shape;
auto out_dim = tir::as_const_int(parent_shape[parent_shape.size() - 1]);
CHECK(out_dim != nullptr);

auto arg = branch[depth]->args[i];
auto& arg_shape = arg->type_as<TensorTypeNode>()->shape;
bool repeat_last_dim = false;
if (arg_ndim == 0) {
repeat_last_dim = true;
arg = MakeExpandDims(arg, -1, 1);
} else {
auto arg_last_dim = tir::as_const_int(arg_shape[arg_shape.size() - 1]);
CHECK(arg_last_dim != nullptr);
if (*out_dim > 1 && *arg_last_dim == 1) {
repeat_last_dim = true;
}
}
if (repeat_last_dim) {
// ensure broadcast is valid after concat args
arg = MakeRepeat(arg, *out_dim, concat_axis);
}
tuple.push_back(arg);
}
auto concat = MakeConcatenate(Tuple(tuple), concat_axis);
new_args.push_back(std::move(concat));
}
return Call(call->op, new_args, call->attrs, {});
}

void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
ExprSubstMap* subst_map) {
int index = 0;
for (const auto& branch : branches) {
const CallNode* call = branch[depth];
auto& out_shape = call->type_as<TensorTypeNode>()->shape;
auto out_dims = tir::as_const_int(out_shape[out_shape.size() - 1]);
CHECK(out_dims != nullptr);
std::vector<int64_t> begin;
std::vector<int64_t> end;
std::vector<int64_t> strides;
for (size_t k = 0; k < out_shape.size() - 1; ++k) {
begin.push_back(0);
end.push_back(-1);
strides.push_back(1);
}
begin.push_back(index);
end.push_back(*out_dims);
strides.push_back(1);
index += *out_dims;
std::vector<int64_t> ndarray_shape = {static_cast<int64_t>(begin.size())};
Constant begin_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, begin);
Constant end_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, end);
Constant strides_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, strides);
auto slice = MakeStridedSlice(data, begin_const, end_const, strides_const, "size");
subst_map->insert({GetRef<Expr>(branch[depth]), slice});
}
}

private:
std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
int64_t out_dims = 0;
Array<Expr> weights;
for (const auto& branch : branches) {
auto weight = branch[0]->args[1];
weights.push_back(weight);
out_dims += *tir::as_const_int(weight->type_as<TensorTypeNode>()->shape[0]);
}
return std::make_tuple(MakeConcatenate(Tuple(weights), 0),
tir::make_const(DataType::Int(32), out_dims));
}
};

/*! \brief Combine parallel dense if number of branches >= min_num_branches */
Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches) {
return ParallelDenseCombiner(min_num_branches).Combine(expr);
Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches, bool to_batch) {
if (to_batch) {
return ParallelDenseBatchCombiner(min_num_branches).Combine(expr);
} else {
return ParallelDenseFlatCombiner(min_num_branches).Combine(expr);
wrongtest-intellif marked this conversation as resolved.
Show resolved Hide resolved
}
}

namespace transform {

Pass CombineParallelDense(uint64_t min_num_branches) {
Pass CombineParallelDense(uint64_t min_num_branches, bool to_batch) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelDense(f, min_num_branches));
return Downcast<Function>(CombineParallelDense(f, min_num_branches, to_batch));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"});
}
Expand Down
Loading