Skip to content

Commit

Permalink
[TIR] Add schedule primitive ReIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed May 31, 2022
1 parent 2252f95 commit 6d48a37
Show file tree
Hide file tree
Showing 12 changed files with 833 additions and 0 deletions.
13 changes: 13 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,19 @@ class ScheduleNode : public runtime::Object {
*/
virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) = 0;
/*!
* \brief Create a block that read/write a buffer region into a read/write cache with reindexing.
* The layout of the cache will be the same as by the iterators of the block that reads/writes the
* buffer. It requires:
* 1) There is only one block who reads/writes the target buffer
* 2) There is only one buffer load/store of this buffer in the block
* \param block_rv The block operates on the target buffer.
* \param buffer_index The index of the buffer in block's read or write region.
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \return The reindex stage block.
*/
virtual BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) = 0;
/******** Schedule: Compute location ********/
/*!
* \brief Move a producer block under the specific loop, and regenerate the
Expand Down
73 changes: 73 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,79 @@ def after_cache_write(a: T.handle, b: T.handle) -> None:
self, block, write_buffer_index, storage_scope
)

@type_checked
def reindex(self, block: BlockRV, buffer_index: int, buffer_index_type: str) -> BlockRV:
"""Create a block that read/write a buffer region into a read/write cache with reindexing.
The layout of the cache will be the same as by the iterators of the block that reads/writes
the buffer. It requires:
1) There is only one block who reads/writes the target buffer
2) There is only one buffer load/store of this buffer in the block
Parameters
----------
block: BlockRV
The block that accesses the target buffer
buffer_index: int
The index of the buffer in block's read or write region
buffer_index_type : str
Type of the buffer index, "read" or "write"
Returns
-------
reindex_block : BlockRV
The block of the reindex stage
Examples
--------
Before transform_layout, in TensorIR, the IR is:
.. code-block:: python
@T.prim_func
def before_reindex(
A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(128, 128), "float32"]
) -> None:
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vj, vi] * 2.0
Create the schedule and do transform_layout:
.. code-block:: python
sch = tir.Schedule(before_reindex)
block = sch.get_block("B")
sch.reindex(block, 0, "read)
After applying reindex, the IR becomes:
.. code-block:: python
@T.prim_func
def after_reindex(
A: T.Buffer[(128, 128), "float32"],
B: T.Buffer[(128, 128), "float32"]
) -> None:
A_reindex = T.alloc_buffer((128, 128), "float32")
for i, j in T.grid(128, 128):
with T.block("A_reindex"):
vi, vj = T.axis.remap("SS", [i, j])
A_reindex[vi, vj] = A[vj, vi]
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A_reindex[vi, vj] * 2.0
"""
assert buffer_index_type in ["read", "write"], "Invalid buffer_index_type"
buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
return _ffi_api.ScheduleReIndex( # type: ignore # pylint: disable=no-member
self, block, buffer_index, buffer_index_type_enum
)

########## Schedule: Compute location ##########

@type_checked
Expand Down
10 changes: 10 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,16 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff
return CreateRV<BlockRV>(result);
}

BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::ReIndex(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type);
TVM_TIR_SCHEDULE_END("reindex", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
}

/******** Schedule: Compute location ********/

void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class ConcreteScheduleNode : public ScheduleNode {
const String& storage_scope) override;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) override;
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) override;
/******** Schedule: Compute location ********/
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
Expand Down
15 changes: 15 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,21 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r
*/
TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
const String& storage_scope);
/*!
*!
* \brief Create a block that read/write a buffer region into a read/write cache with reindexing.
* The layout of the cache will be the same as by the iterators of the block that reads/writes the
* buffer. It requires:
* 1) There is only one block who reads/writes the target buffer
* 2) There is only one buffer load/store of this buffer in the block
* \param self The state of the schedule
* \param block_rv The block operates on the target buffer.
* \param buffer_index The index of the buffer in block's read or write region.
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \return The reindex stage block.
*/
TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
BufferIndexType buffer_index_type);
/******** Schedule: Compute location ********/
/*!
* \brief Move a producer block under the specific loop, and regenerate the
Expand Down
Loading

0 comments on commit 6d48a37

Please sign in to comment.