Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR][M2a] Decompose-Reduction #9041

Merged
merged 16 commits into from
Oct 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,22 @@ class ScheduleNode : public runtime::Object {
*/
virtual void ReverseComputeInline(const BlockRV& block) = 0;
/******** Schedule: Reduction ********/
/*!
* \brief Decompose a reduction block into two separate blocks.
* a) The init block, which is translated from the init statement of the reduction block;
* b) The update block, which is the original block without init statement.
*
* The init block is inserted right before the given loop.
*
* The schedule primitive requires:
* 1) The input block is a reduction block.
* 2) The input loop is the ancestor of the block.
* 3) The input loop is not lower than all the loops related to reduce block var.
* \param block_rv The reduction block to be decomposed
* \param loop_rv The loop above which the init block is inserted before.
* \return The init block
*/
virtual BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) = 0;
spectrometerHBH marked this conversation as resolved.
Show resolved Hide resolved
/*!
* \brief Factorize an associative reduction block by the specified loop.
* \details An associative reduction cannot be parallelized directly,
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/tir/schedule/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ class ScheduleStateNode : public Object {
/******** Property of blocks ********/
/*! \brief Returns the BlockInfo correpsonding to the block sref */
TVM_DLL BlockInfo GetBlockInfo(const StmtSRef& block_sref) const;
/*!
* \brief Recalculate the BlockInfo recursively under stmt.
* If stmt is a Block itself, we will not reset its affine binding flag unless it doesn't
* have block vars, since the affine flag depends on the outer scope of stmt.
*/
TVM_DLL void UpdateScopeBlockInfo(const Stmt& stmt);
/*!
* \brief Get the BlockScope correpsonding to the sref of scope root block
* \param scope_root The block sref to be retrieved
Expand Down
76 changes: 76 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,82 @@ def after_inline(a: T.handle, c: T.handle) -> None:

########## Schedule: Reduction ##########

def decompose_reduction(self, block: BlockRV, loop: LoopRV) -> BlockRV:
"""Decompose a reduction block into two separate blocks.

a) The init block, which is translated from the init statement of the reduction block;

b) The update block, which is the original block without init statement.

The init block is inserted right before the given loop.

The schedule primitive requires:

1) The input block is a reduction block.

2) The input loop is the ancestor of the block.

3) The input loop is not lower than all the loops related to reduce block var.

Parameters
----------
block : BlockRV
The reduction block to be decomposed
loop : LoopRV
The loop above which the init block is inserted before.

Returns
-------
init_block : BlockRV
The init block

spectrometerHBH marked this conversation as resolved.
Show resolved Hide resolved
Examples
--------
Before decompose-reduction, in TensorIR, the IR is:

.. code-block:: python

@tvm.script.tir
def before_decompose(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])
for i, j, k in tir.grid(128, 128, 128):
with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
with tir.init():
C[vi, vj] = 0.0
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

Create the schedule and do decompose-reduction with specified loop:

.. code-block:: python

sch = tir.Schedule(before_decompose)
C = sch.get_block("C")
i, j, k = sch.get_loops(C)
sch.decompose_reduction(C, i)
print(tvm.script.asscript(sch.mod["main"]))

After applying decompose-reduction, the IR becomes:

.. code-block:: python

@tvm.script.tir
def after_decompose(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [128, 128])
B = tir.match_buffer(b, [128, 128])
C = tir.match_buffer(c, [128, 128])
for i in tir.serial(128):
for j in tir.serial(128):
with tir.block([128, 128]) as [vi, vj]:
C[vi, vj] = 0.0
for i, j, k in tir.grid(128, 128, 128):
with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]:
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

"""
return _ffi_api.ScheduleDecomposeReduction(self, block, loop) # type: ignore # pylint: disable=no-member

def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV:
"""Factorize an associative reduction block by the specified loop.

Expand Down
9 changes: 9 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,15 @@ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_inde

/******** Schedule: Reduction ********/

BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::DecomposeReduction(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv));
TVM_TIR_SCHEDULE_END("decompose-reduction", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
}

BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class ConcreteScheduleNode : public ScheduleNode {
void ReverseComputeInline(const BlockRV& block) override;
/******** Schedule: Reduction ********/
BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override;
BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override;
/******** Schedule: Block annotation ********/
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) override;
Expand Down
17 changes: 17 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,23 @@ TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref);
*/
TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref);
/******** Schedule: Reduction ********/
/*!
* \brief Decompose a reduction block into two separate blocks.
* a) The init block, which is translated from the init statement of the reduction block;
* b) The update block, which is the original block without init statement.
*
* The init block is inserted right before the given loop.
*
* The schedule primitive requires:
* 1) The input block is a reduction block.
* 2) The input loop is the ancestor of the block.
* 3) The input loop is not lower than all the loops related to reduce block var.
* \param block_rv The reduction block to be decomposed
* \param loop_rv The loop above which the init block is inserted before.
* \return The init block
*/
TVM_DLL StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& loop_sref);
/*!
* \brief Factor a reduction block by the specified loop
* \details See python/tvm/tir/schedule/schedule.py
Expand Down
Loading