Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR][M2a] CacheRead/Write #8863

Merged
merged 12 commits into from
Aug 31, 2021
22 changes: 22 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,28 @@ class ScheduleNode : public runtime::Object {
*/
virtual void Unroll(const LoopRV& loop_rv) = 0;
/******** Schedule: Insert cache stages ********/
/*!
* \brief Create a block that reads a buffer region into a read cache. It requires:
* 1) There is at most one block who writes the buffer in the scope.
* 2) The scope block have stage-pipeline property.
* \param block_rv The consumer block of the target buffer.
* \param read_buffer_index The index of the buffer in block's read region.
* \param storage_scope The target storage scope.
* \return The cache stage block.
*/
virtual BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) = 0;
/*!
* \brief Create a block that writes a buffer region into a write cache. It requires:
* 1) There is only one block who writes the target buffer.
* 2) The scope block have stage-pipeline property.
* \param block_rv The producer of the buffer
* \param write_buffer_index The index of the buffer in block's write region
* \param storage_scope The target storage scope
* \return The cache stage block.
*/
virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) = 0;
/******** Schedule: Compute location ********/
/*!
* \brief Inline a block into its consumer(s). It requires:
Expand Down
5 changes: 5 additions & 0 deletions include/tvm/tir/schedule/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,11 @@ class ScheduleStateNode : public Object {
*/
TVM_DLL void Replace(const tir::StmtSRef& src_sref, const Stmt& tgt_stmt,
const Map<Block, Block>& block_sref_reuse);
/*!
* \brief Recalculate the `affine_binding` flag of the scope block info.
* \param scope_sref The sref to the interested scope block.
*/
TVM_DLL void UpdateAffineFlag(const StmtSRef& scope_sref);
Comment on lines +131 to +135
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TQ and I intentionally removed this method, because in most of the schedule primitives, it is known almost known whether a block binding is affine or not. In our particular case, trivial bindings are always affine AFAICT

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, it is not always affine. Please see the following case:

for io in 4:
    for ii in 5:
        with tir.block([17]) as vi:
            tir.bind(vi, io * 4 + ii)
            B_shared[vi] = B[vi]
    for ii in 4:
        with tir.block([16]) as vi:
            tir.bind(vi, io * 4 + ii)
            A[vi] = tir.max(B_shared[vi], B_shared[vi + 1])

There are two options now:

  1. Keep this method. I think it's OK because it must be called explicitly in schedule primitives.
  2. Inline the method into primitive (or move it into a private help function in cc file). Therefore, it may cause duplicate code, since other primitive may also need it (as compute_at)

/*!
* \brief Trigger the verification according to the `debug_mask` bitmask.
* 1) If the bitmask `kVerifySRefTree` is on, verify the correctness of the sref tree.
Expand Down
135 changes: 135 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,141 @@ def after_unroll(a: ty.handle, b: ty.handle) -> None:

########## Schedule: Insert cache stages ##########

def cache_read(self, block: BlockRV, read_buffer_index: int, storage_scope: str) -> BlockRV:
"""Create a block that reads a buffer region into a read cache. It requires:

1) There is at most one block who write the buffer in the scope.

2) The scope block have stage-pipeline property.

Parameters
----------
block : BlockRV
The consumer block of the target buffer.

read_buffer_index: int
The index of the buffer in block's read region.

storage_scope: str
The target storage scope.

Returns
-------
cached_block : BlockRV
The block of the cache stage

Examples
--------
Before cache_read, in TensorIR, the IR is:

.. code-block:: python

@tvm.script.tir
def before_cache_read(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and cache_read:

.. code-block:: python

sch = tir.Schedule(before_cache_read)
block_b = sch.get_block("B")
sch.cache_read(block_b, 0, "local")
print(tvm.script.asscript(sch.mod["main"]))

After applying cache_read, the IR becomes:

.. code-block:: python

@tvm.script.tir
def after_cache_read(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
A_local = tir.alloc_buffer((128, 128), scope="local")
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "A_local") as [vi, vj]:
A_local[vi, vj] = A[vi, vj]
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A_local[vi, vj] * 2.0

"""
return _ffi_api.ScheduleCacheRead( # type: ignore # pylint: disable=no-member
self, block, read_buffer_index, storage_scope
)

def cache_write(self, block: BlockRV, write_buffer_index: int, storage_scope: str) -> BlockRV:
"""Create a block that reads a buffer region into a write cache. It requires:

1) There is only one block who write the buffer in the scope.

2) The scope block have stage-pipeline property.

Parameters
----------
block : BlockRV
The producer block of the target buffer.

write_buffer_index: int
The index of the buffer in block's write region.

storage_scope: str
The target storage scope.


Returns
-------
cached_block : BlockRV
The block of the cache stage

Examples
--------
Before cache_write, in TensorIR, the IR is:

.. code-block:: python

@tvm.script.tir
def before_cache_write(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and cache_write:

.. code-block:: python

sch = tir.Schedule(before_cache_write)
block_b = sch.get_block("B")
sch.cache_write(block_b, 0, "local")
print(tvm.script.asscript(sch.mod["main"]))

After applying cache_write, the IR becomes:

.. code-block:: python

@tvm.script.tir
def after_cache_write(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
B_local = tir.alloc_buffer((128, 128), scope="local")
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "A_local") as [vi, vj]:
B_local[vi, vj] = A[vi, vj] * 2.0
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = B_local[vi, vj]

"""
return _ffi_api.ScheduleCacheWrite( # type: ignore # pylint: disable=no-member
self, block, write_buffer_index, storage_scope
)

########## Schedule: Compute location ##########

def compute_inline(self, block: BlockRV) -> None:
Expand Down
21 changes: 14 additions & 7 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ void VerifyCachedFlags(const ScheduleState& self);
const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_block,
GlobalVar* result_g_var);

/*!
* \brief Get the root node of the sref tree, which is the root block of the PrimFunc.
* \param sref The given sref.
* \return The root node of the sref tree which contains the given node.
*/
StmtSRef GetSRefTreeRoot(const StmtSRef& sref);

/******** Scope ********/
/*!
* \brief Checks if scope the specified sref is in is a stage-pipeline and return it
Expand Down Expand Up @@ -228,15 +235,15 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr
/******** Block-buffer relation ********/

/*!
* \brief Get the BlockRealize of the single child block of the block or loop specified by
* `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks
* \param self The schedule state
* \param block The queried block
* \param n The index of the queried buffer
* \return The buffer of the n-th write region of the block.
* \brief Get the n-th read or write buffer of the given block.
* \param self The schedule state.
* \param block The queried block.
* \param n The index of the queried buffer.
* \param is_write A boolean flag to indicate querying write buffer or read buffer.
* \return The buffer of the n-th read/write region of the block.
* \throw ScheduleError If the buffer index is out of bound.
*/
Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n);
Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write);

/******** Commutative Reducer ********/

Expand Down
50 changes: 36 additions & 14 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -588,25 +588,37 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr

/******** Block-buffer relation ********/

Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) {
class WriteBufferIndexOutOfRangeError : public ScheduleError {
Buffer GetNthAccessBuffer(const ScheduleState& self, const Block& block, int n, bool is_write) {
class BufferIndexOutOfRangeError : public ScheduleError {
public:
explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index)
: mod_(std::move(mod)), block_(std::move(block)), buffer_index_(buffer_index) {}
explicit BufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index, bool is_write)
: mod_(std::move(mod)),
block_(std::move(block)),
buffer_index_(buffer_index),
is_write_(is_write) {}

String FastErrorString() const final {
return "ScheduleError: The input `buffer_index` is out of range. It is required to be in "
"range [0, num_write_regions) where `num_write_regions` is the number of buffer "
"regions written by the block.";
if (is_write_) {
return "ScheduleError: The input `buffer_index` is out of range. It is required to be in "
"range "
"[0, num_write_regions) where `num_write_regions` is the number of buffer regions "
"written by the block.";
} else {
return "ScheduleError: The input `buffer_index` is out of range. It is required to be in "
"range "
"[0, num_read_regions) where `num_read_regions` is the number of buffer regions "
"read by the block.";
}
}

String DetailRenderTemplate() const final {
std::ostringstream os;
size_t num_writes = block_->writes.size();
os << "The block {0} has " << num_writes
<< " write regions, so `buffer_index` is required to be in [0, " << num_writes
size_t num = is_write_ ? block_->writes.size() : block_->reads.size();
std::string access_type = is_write_ ? "write" : "read";
os << "The block {0} has " << num << " " << access_type
<< " regions, so `buffer_index` is required to be in [0, " << num
<< "). However, the input `buffer_index` is " << buffer_index_
<< ", which is out of the expected range";
<< ", which is out of the expected range.";
return os.str();
}

Expand All @@ -617,12 +629,15 @@ Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) {
IRModule mod_;
Block block_;
int buffer_index_;
bool is_write_;
};

if (n < 0 || static_cast<size_t>(n) >= block->writes.size()) {
throw WriteBufferIndexOutOfRangeError(self->mod, block, n);
const Array<BufferRegion>& access_region = is_write ? block->writes : block->reads;

if (n < 0 || static_cast<int>(access_region.size()) <= n) {
throw BufferIndexOutOfRangeError(self->mod, block, n, is_write);
}
return block->writes[n]->buffer;
return access_region[n]->buffer;
}

/******** Pattern Matcher ********/
Expand Down Expand Up @@ -941,5 +956,12 @@ bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner,
return false;
}

/******** SRef Tree Related ********/
StmtSRef GetSRefTreeRoot(const StmtSRef& sref) {
const StmtSRefNode* p = sref.get();
for (; p->parent != nullptr; p = p->parent) {
}
return GetRef<StmtSRef>(p);
}
} // namespace tir
} // namespace tvm
21 changes: 21 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,27 @@ void ConcreteScheduleNode::Unroll(const LoopRV& loop_rv) {
}

/******** Schedule: Insert cache stages ********/

BlockRV ConcreteScheduleNode::CacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::CacheRead(state_, this->GetSRef(block_rv), read_buffer_index, storage_scope);
TVM_TIR_SCHEDULE_END("cache-read", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
}

BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::CacheWrite(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope);
TVM_TIR_SCHEDULE_END("cache-write", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<BlockRV>(result);
}

/******** Schedule: Compute location ********/

void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) {
Expand Down
4 changes: 4 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ class ConcreteScheduleNode : public ScheduleNode {
void Bind(const LoopRV& loop_rv, const String& thread_axis) override;
void Unroll(const LoopRV& loop_rv) override;
/******** Schedule: Insert cache stages ********/
BlockRV CacheRead(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) override;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) override;
/******** Schedule: Compute location ********/
void ComputeInline(const BlockRV& block) override;
void ReverseComputeInline(const BlockRV& block) override;
Expand Down
24 changes: 24 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,30 @@ TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar&
*/
TVM_DLL void Unroll(ScheduleState self, const StmtSRef& loop_sref);
/******** Schedule: Insert cache stages ********/
/*!
* \brief Create a block that reads a buffer region into a read cache. It requires:
* 1) There is at most one block who writes the buffer in the scope.
* 2) The scope block have stage-pipeline property.
* \param self The state of the schedule
* \param block_sref The consumer block of the target buffer.
* \param read_buffer_index The index of the buffer in block's read region.
* \param storage_scope The target storage scope.
* \return The cache stage block.
*/
TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buffer_index,
const String& storage_scope);
/*!
* \brief Create a block that writes a buffer region into a write cache. It requires:
* 1) There is only one block that writes the target buffer.
* 2) The scope block have stage-pipeline property.
* \param self The state of the schedule
* \param block_sref The producer of the buffer
* \param write_buffer_index The index of the buffer in block's write region
* \param storage_scope The target storage scope
* \return The cache stage block.
*/
TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
const String& storage_scope);
/******** Schedule: Compute location ********/
/*!
* \brief Inline a block into its consumer(s). It requires:
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/primitive/block_annotate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
* specific language governing permissions and limitations
* under the License.
*/
#include "../transform.h"
#include "../utils.h"

namespace tvm {
Expand Down Expand Up @@ -237,7 +236,8 @@ class StorageAlignInvalidAnnotationError : public ScheduleError {
void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index, int axis,
int factor, int offset) {
const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_ptr, block_sref);
Buffer buffer = GetNthWriteBuffer(self, GetRef<Block>(block_ptr), buffer_index);
Buffer buffer =
GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, /*is_write=*/true);
StorageAlignInvalidFactorError::Check(self->mod, factor);
axis = StorageAlignAxisOutOfRangeError::CheckAndUpdate(self->mod, buffer, axis);
NonAllocatedBufferError::CheckBufferAllocated(self->mod, block_sref, buffer);
Expand Down
Loading