Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Add new IR pass CombineParallelDense #3862

Merged
merged 21 commits into from
Sep 24, 2019
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api/python/relay/transform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ tvm.relay.transform

.. autofunction:: tvm.relay.transform.CombineParallelConv2D

.. autofunction:: tvm.relay.transform.CombineParallelDense

.. autofunction:: tvm.relay.transform.AlterOpLayout

.. autofunction:: tvm.relay.transform.Legalize
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,17 @@ TVM_DLL Pass EliminateCommonSubexpr(PackedFunc fskip = nullptr);
*/
TVM_DLL Pass CombineParallelConv2D(uint64_t min_num_branches = 3);

/*!
* \brief Combine parallel dense ops into a single batch_matmul if the
* number of branches of this dense operator is not less than
* `min_num_branch`.
*
* \param min_num_branches The minimun number of branches.
*
* \return The pass.
*/
TVM_DLL Pass CombineParallelDense(uint64_t min_num_branches = 3);

/*!
* \brief Backward fold axis scaling into weights of conv/dense operators.
*
Expand Down
28 changes: 28 additions & 0 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def build_config(opt_level=2,
"CanonicalizeCast": 3,
"EliminateCommonSubexpr": 3,
"CombineParallelConv2D": 4,
"CombineParallelDense": 4
}

fallback_device : int, str, or tvm.TVMContext, optional
Expand Down Expand Up @@ -400,6 +401,33 @@ def CombineParallelConv2D(min_num_branches=3):
return _transform.CombineParallelConv2D(min_num_branches)


def CombineParallelDense(min_num_branches=3):
"""Combine multiple dense operators into one. For example:

data
/ \
dense (2,2) dense (2,2)
soiferj marked this conversation as resolved.
Show resolved Hide resolved

Would become:

data
|
batch_matmul (2,2,2)

Parameters
----------
min_num_branches : int
The minimum number of required parallel branches for performing this
soiferj marked this conversation as resolved.
Show resolved Hide resolved
optimization.

Returns
-------
ret: tvm.relay.Pass
The registered pass that combines parallel dense operators.
"""
return _transform.CombineParallelDense(min_num_branches)


def AlterOpLayout():
"""Alternate the layouts of operators or replace primitive operators with
other expressions.
Expand Down
1 change: 1 addition & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ class RelayBuildModule : public runtime::ModuleNode {
});
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast());
Expand Down
Loading