Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR] Primitive "SetScope" #9738

Merged
merged 7 commits into from
Dec 15, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,14 @@ class ScheduleNode : public runtime::Object {
*/
virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) = 0;
/*!
* \brief Set the storage scope of a buffer, where the buffer is specified by the a block and a
* write-index
* \param block_rv The producer block of the buffer
* \param buffer_index The index of the buffer in block's write region
* \param storage_scope The storage scope to be set
*/
virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0;
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/******** Schedule: Misc ********/
Expand Down
71 changes: 71 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,6 +1660,77 @@ def after_storage_align(a: T.handle, c: T.handle) -> None:
self, block, buffer_index, axis, factor, offset
)

@type_checked
def set_scope(self, block: BlockRV, buffer_index: int, storage_scope: str) -> None:
"""Set the storage scope of a buffer, where the buffer is
specified by the a block and a write-index

Parameters
----------
block : BlockRV
The producer block of the buffer
buffer_index : int
The index of the buffer in block's write region
storage_scope : str
The storage scope to be set

Examples
--------

Before set_scope, in TensorIR, the IR is:

.. code-block:: python

@T.prim_func
def before_set_scope(
A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]
) -> None:
B = T.alloc_buffer((128, 128), dtype="float32")

for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do set_scope:

.. code-block:: python

sch = tir.Schedule(before_set_scope)
sch.set_scope(sch.get_block("B"), buffer_index=0, storage_scope="shared")
print(sch.mod["main"].script())

After applying set_scope, the IR becomes:

.. code-block:: python

@T.prim_func
def after_set_scope(
A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]
) -> None:
B_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared")

for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B_shared[vi, vj] = A[vi, vj] * T.float32(2)
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B_shared[vi, vj] + T.float32(1)

Note
----
Set_scope requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
"""
_ffi_api.ScheduleSetScope( # type: ignore # pylint: disable=no-member
self, block, buffer_index, storage_scope
)

########## Schedule: Blockize & Tensorize ##########

########## Schedule: Annotation ##########
Expand Down
2 changes: 2 additions & 0 deletions src/tir/ir/functor_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/runtime/container/array.h>

/*!
* \file tir/ir/functor_common.h
* \brief Common utils for implementing functors
Expand Down
8 changes: 8 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,14 @@ bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner,

/******** Misc ********/

/*!
* \brief Check whether the input storage scope string is valid. Throw an error if not.
* \param self The schedule state
* \param storage_scope The storage scope string to be checked
* \throw ScheduleError If the input storage scope is not valid
*/
void CheckStorageScope(const ScheduleState& self, String storage_scope);
MasterJH5574 marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Checks if a block could be successfully computed inline into its consumer
* \param self The schedule state
Expand Down
31 changes: 31 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1343,5 +1343,36 @@ StmtSRef GetSRefTreeRoot(const StmtSRef& sref) {
return GetRef<StmtSRef>(p);
}

/******** Storage Scope ********/

void CheckStorageScope(const ScheduleState& self, String storage_scope) {
class InvalidStorageScopeError : public ScheduleError {
public:
explicit InvalidStorageScopeError(IRModule mod, String storage_scope)
: mod_(std::move(mod)), storage_scope_(std::move(storage_scope)) {}

String FastErrorString() const final {
return "ScheduleError: The input storage scope is invalid";
}

String DetailRenderTemplate() const final {
return "The input storage scope \"" + storage_scope_ + "\" is invalid.";
}

Array<ObjectRef> LocationsOfInterest() const final { return {}; }
IRModule mod() const final { return mod_; }

private:
IRModule mod_;
String storage_scope_;
};

try {
runtime::StorageScope::Create(std::string(storage_scope));
} catch (...) {
throw InvalidStorageScopeError(self->mod, std::move(storage_scope));
}
}

} // namespace tir
} // namespace tvm
8 changes: 8 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,14 @@ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_inde
this->state_->DebugVerify();
}

void ConcreteScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index,
const String& storage_scope) {
TVM_TIR_SCHEDULE_BEGIN();
tir::SetScope(state_, this->GetSRef(block_rv), buffer_index, storage_scope);
TVM_TIR_SCHEDULE_END("set-scope", this->error_render_level_);
this->state_->DebugVerify();
}

/******** Schedule: Reduction ********/

BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) {
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class ConcreteScheduleNode : public ScheduleNode {
/******** Schedule: Block annotation ********/
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) override;
void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override;
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/******** Schedule: Misc ********/
Expand Down
11 changes: 11 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ using StorageAlignAnnotation = Array<StorageAlignTuple>;
* more friendly memory access pattern. For example, we can set alignment to be factor=2,
* offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared
* memory.
* \param self The state of the schedule
* \param block_sref The producer block of the buffer
* \param buffer_index The index of the buffer in block's write region
* \param axis The dimension to be specified for alignment
Expand All @@ -337,6 +338,16 @@ using StorageAlignAnnotation = Array<StorageAlignTuple>;
*/
TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
int axis, int factor, int offset);
/*!
* \brief Set the storage scope of a buffer, where the buffer is specified by the a block and a
* write-index
* \param self The state of the schedule
* \param block_sref The sref of the producer block of the buffer
* \param buffer_index The index of the buffer in block's write region
* \param storage_scope The storage scope to be set
*/
TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
const String& storage_scope);

/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
Expand Down
Loading