Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Pass] Support combine multiple dense op just into dense #6062

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,12 @@ TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);
* `min_num_branch`.
*
* \param min_num_branches The minimun number of branches.
* \param to_batch_matmul Whether to combine parallel dense ops to batch matmul.
* If set false, combine dense ops to single dense op.
*
* \return The pass.
*/
TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3);
TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3, bool to_batch_matmul = true);

/*!
* \brief Combine parallel batch_matmul ops into a single batch_matmul
Expand Down
31 changes: 29 additions & 2 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def CombineParallelConv2D(min_num_branches=3):
return _ffi_api.CombineParallelConv2D(min_num_branches)


def CombineParallelDense(min_num_branches=3):
def CombineParallelDense(min_num_branches=3, to_batch=True):
"""Combine multiple dense operators into one. For example:

.. code-block
Expand All @@ -295,18 +295,30 @@ def CombineParallelDense(min_num_branches=3):
|
batch_matmul+elemwise/bcast (2,2,2)

or (if to_batch=False)

.. code-block

data
|
dense+elemwise/bcast (2,2+2)

Parameters
----------
min_num_branches : int
The minimum number of required parallel branches for performing this
optimization.

to_batch_matmul : bool
If True, combine parallel dense ops into batch_matmul op.
If False, combine parallel dense ops into dense op.

Returns
-------
ret: tvm.transform.Pass
The registered pass that combines parallel dense operators.
"""
return _ffi_api.CombineParallelDense(min_num_branches)
return _ffi_api.CombineParallelDense(min_num_branches, to_batch)

def CombineParallelBatchMatmul(min_num_branches=3):
"""Combine multiple batch matmul operators into one. For example:
Expand Down Expand Up @@ -342,6 +354,21 @@ def CombineParallelBatchMatmul(min_num_branches=3):
return _ffi_api.CombineParallelBatchMatmul(min_num_branches)


def BatchingOps():
"""Batching parallel operators into one for Conv2D, Dense and BatchMatmul.

Returns
-------
ret: tvm.transform.Pass
The sequential pass which apply batching for different operator types.
"""
return tvm.transform.Sequential([
CombineParallelConv2D(),
CombineParallelDense(),
CombineParallelBatchMatmul()
])


def AlterOpLayout():
"""Alternate the layouts of operators or replace primitive operators with
other expressions.
Expand Down
166 changes: 160 additions & 6 deletions src/relay/transforms/combine_parallel_dense.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,12 @@
namespace tvm {
namespace relay {

class ParallelDenseCombiner : public ParallelOpBatchCombiner {
/*
* Class that find and combine parallel dense ops into batch_matmul.
*/
class ParallelDenseToBatchCombiner : public ParallelOpBatchCombiner {
public:
explicit ParallelDenseCombiner(uint64_t min_num_branches)
explicit ParallelDenseToBatchCombiner(uint64_t min_num_branches)
: ParallelOpBatchCombiner("nn.dense", "nn.batch_matmul", min_num_branches) {}

protected:
Expand All @@ -68,17 +71,168 @@ class ParallelDenseCombiner : public ParallelOpBatchCombiner {
}
};

/*
* Class that find and combine parallel dense ops into one dense op
* whose num of output units equals to sum of each sub-ops.
*/
class ParallelDenseToDenseCombiner : public ParallelOpCombiner {
public:
explicit ParallelDenseToDenseCombiner(uint64_t min_num_branches)
: ParallelOpCombiner("nn.dense", min_num_branches) {}

protected:
bool IsSupportedOp(const CallNode* n) { return true; }

bool CanOpsBeCombined(const CallNode* a, const CallNode* b) {
StructuralEqual eq;
const auto* attrs_a = a->attrs.as<DenseAttrs>();
const auto* attrs_b = b->attrs.as<DenseAttrs>();
const auto* weight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* weight_b = b->args[1]->type_as<TensorTypeNode>();
CHECK(attrs_a != nullptr && attrs_b != nullptr && weight_a != nullptr && weight_b != nullptr);
// output dims (weight->shape[0]) can be different
return eq(attrs_a->out_dtype, attrs_b->out_dtype) && eq(weight_a->shape[1], weight_b->shape[1]);
}

Call MakeCombinedOp(const Group& branches) {
const Op& dense_op = Op::Get("nn.dense");
Expr input = branches[0][0]->args[0];
Expr new_weight;
IndexExpr new_output_dims;
// concat all weights into one
std::tie(new_weight, new_output_dims) = TransformWeight(branches);
const auto* origin_attrs = branches[0][0]->attrs.as<DenseAttrs>();
CHECK(origin_attrs);
const auto dense_attrs = make_object<DenseAttrs>();
dense_attrs->units = new_output_dims;
dense_attrs->out_dtype = origin_attrs->out_dtype;
return Call(dense_op, {input, new_weight}, Attrs{dense_attrs}, {});
}

bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
StructuralEqual eq;
auto ta = a->args[index]->type_as<TensorTypeNode>();
auto tb = b->args[index]->type_as<TensorTypeNode>();
auto toutput_a = a->type_as<TensorTypeNode>();
auto toutput_b = b->type_as<TensorTypeNode>();
CHECK(ta != nullptr && tb != nullptr && toutput_a != nullptr && toutput_b != nullptr);

if (!eq(ta->dtype, tb->dtype) || ta->shape.size() != tb->shape.size()) {
return false;
}
if (toutput_a->shape.size() < ta->shape.size() || toutput_b->shape.size() < tb->shape.size()) {
return false; // not broadcast/elemwise
}
if (ta->shape.size() > 0) {
for (size_t i = 0; i < ta->shape.size() - 1; i++) {
// shape dims must match except last dim
if (!eq(ta->shape[i], tb->shape[i])) return false;
}
}
return true;
}

Call MakeCombinedCallFromFollowingOps(const Expr& data, const Group& branches, size_t depth,
size_t parent_index) {
Array<Expr> new_args;
const CallNode* call = branches[0][depth];
for (size_t i = 0; i < call->args.size(); i++) {
if (i == parent_index) {
new_args.push_back(data);
continue;
}
size_t arg_ndim = call->args[i]->type_as<TensorTypeNode>()->shape.size();
size_t concat_axis = arg_ndim == 0 ? 0 : arg_ndim - 1;
Array<Expr> tuple;
for (const auto& branch : branches) {
auto parent = branch[depth]->args[parent_index];
auto& parent_shape = parent->type_as<TensorTypeNode>()->shape;
auto out_dim = tir::as_const_int(parent_shape[parent_shape.size() - 1]);
CHECK(out_dim != nullptr);

auto arg = branch[depth]->args[i];
auto& arg_shape = arg->type_as<TensorTypeNode>()->shape;
bool repeat_last_dim = false;
if (arg_ndim == 0) {
repeat_last_dim = true;
arg = MakeExpandDims(arg, -1, 1);
} else {
auto arg_last_dim = tir::as_const_int(arg_shape[arg_shape.size() - 1]);
CHECK(arg_last_dim != nullptr);
if (*out_dim > 1 && *arg_last_dim == 1) {
repeat_last_dim = true;
}
}
if (repeat_last_dim) {
// ensure broadcast is valid after concat args
arg = MakeRepeat(arg, *out_dim, concat_axis);
}
tuple.push_back(arg);
}
auto concat = MakeConcatenate(Tuple(tuple), concat_axis);
new_args.push_back(std::move(concat));
}
return Call(call->op, new_args, call->attrs, {});
}

void UpdateGroupOutput(const Expr& data, const Group& branches, size_t depth,
ExprSubstMap* subst_map) {
int index = 0;
for (const auto& branch : branches) {
const CallNode* call = branch[depth];
auto& out_shape = call->type_as<TensorTypeNode>()->shape;
auto out_dims = tir::as_const_int(out_shape[out_shape.size() - 1]);
CHECK(out_dims != nullptr);
std::vector<int64_t> begin;
std::vector<int64_t> end;
std::vector<int64_t> strides;
for (size_t k = 0; k < out_shape.size() - 1; ++k) {
begin.push_back(0);
end.push_back(-1);
strides.push_back(1);
}
begin.push_back(index);
end.push_back(*out_dims);
strides.push_back(1);
index += *out_dims;
std::vector<int64_t> ndarray_shape = {static_cast<int64_t>(begin.size())};
Constant begin_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, begin);
Constant end_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, end);
Constant strides_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, strides);
auto slice = MakeStridedSlice(data, begin_const, end_const, strides_const, "size");
subst_map->insert({GetRef<Expr>(branch[depth]), slice});
}
}

private:
std::tuple<Expr, IndexExpr> TransformWeight(const Group& branches) {
int64_t out_dims = 0;
Array<Expr> weights;
for (const auto& branch : branches) {
auto weight = branch[0]->args[1];
weights.push_back(weight);
out_dims += *tir::as_const_int(weight->type_as<TensorTypeNode>()->shape[0]);
}
return std::make_tuple(MakeConcatenate(Tuple(weights), 0),
tir::make_const(DataType::Int(32), out_dims));
}
};

/*! \brief Combine parallel dense if number of branches >= min_num_branches */
Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches) {
return ParallelDenseCombiner(min_num_branches).Combine(expr);
Expr CombineParallelDense(const Expr& expr, uint64_t min_num_branches, bool to_batch) {
if (to_batch) {
return ParallelDenseToBatchCombiner(min_num_branches).Combine(expr);
} else {
return ParallelDenseToDenseCombiner(min_num_branches).Combine(expr);
}
}

namespace transform {

Pass CombineParallelDense(uint64_t min_num_branches) {
Pass CombineParallelDense(uint64_t min_num_branches, bool to_batch_matmul) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(CombineParallelDense(f, min_num_branches));
return Downcast<Function>(CombineParallelDense(f, min_num_branches, to_batch_matmul));
};
return CreateFunctionPass(pass_func, 4, "CombineParallelDense", {"InferType"});
}
Expand Down
Loading