Skip to content

Commit

Permalink
[TIR] Implement API for padded layout transformations (#12720)
Browse files Browse the repository at this point in the history
Implementation of API in `tvm.tir.schedule` for layout transformations
with padding, as part of #12261,
item "Insert pad value into generated TIR, using `tir::if_then_else`,
`builtin::assume`, and `builtin::undef`".

Following the RFC discussion in
apache/tvm-rfcs#77 (comment) and
apache/tvm-rfcs#77 (comment),
this commit preferentially rewrites the loops that surround a padded
transformation where possible, in order to express padding in terms of
`tir::if_then_else`.
  • Loading branch information
Lunderberg authored Sep 19, 2022
1 parent 60cf692 commit 2af9b90
Show file tree
Hide file tree
Showing 17 changed files with 1,408 additions and 67 deletions.
17 changes: 16 additions & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -601,9 +601,24 @@ class ScheduleNode : public runtime::Object {
* \param buffer_index The index of the buffer in block's read or write region.
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \param index_map The transformation to apply.
*
* \param pad_value The value to write into padding introduced by
* the transformation. If the schedule contains a producer block
* for the specified buffer, the pad value will be written as
* part of the producer block if possible, or after the producer
* block otherwise. Otherwise, if the buffer is an input, will
* insert an annotation block to state that the padding contains
* the known value.
*
* Note: If applied to an input buffer, the calling scope is
* responsible for ensuring that the pad_value is present.
* Algebraic symplifications, branch elimination, and other
* optimizations may assume that this precondition is met, and
* may result in incorrect results being returned.
*/
virtual void TransformLayout(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type, const IndexMap& index_map) = 0;
BufferIndexType buffer_index_type, const IndexMap& index_map,
const Optional<IndexMap>& pad_value = NullOpt) = 0;

/*!
* \brief Apply a transformation represented by IndexMap to block
Expand Down
46 changes: 30 additions & 16 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,9 @@ def from_func(
The function to map from source indices to target indices.
The function should accept `tir.Var` parameters and return
a list. Each element of the returned list should be a
`tir.PrimExpr`.
a either a `tir.PrimExpr`, or a list of `tir.PrimExpr`.
Returning a `tir.PrimExpr` is equivalent to returning a
list of length 1 containing that `tir.PrimExpr`.
ndim: Optional[int]
Expand Down Expand Up @@ -356,9 +357,12 @@ def from_func_with_separators(
mapping_function : Callable
The function to map from source indices to target indices.
The function should accept tir.Var parameters and return a
list. Each element of the returned list should be either a
`tir.PrimExpr` or the object `IndexMap.AXIS_SEPARATOR`.
The function should accept tir.Var parameters and return
either a `tir.PrimExpr` or a list. Each element of the
returned list should be either a `tir.PrimExpr` or the
object `IndexMap.AXIS_SEPARATOR`. Returning a
`tir.PrimExpr` is equivalent to returning a list of length
1 containing that `tir.PrimExpr`.
ndim: Optional[int]
Expand Down Expand Up @@ -423,17 +427,27 @@ def from_func_with_separators(

final_indices = []
axis_separators = []
for val in mapping:
if isinstance(val, tvm.ir.PrimExpr):
final_indices.append(val)
elif val is IndexMap.AXIS_SEPARATOR:
axis_separators.append(len(final_indices))
else:
raise TypeError(
"Expected mapping function to return list of "
"either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR. "
f"Instead received {val} of type {type(val)}."
)

try:
iter(mapping)
is_iterable = True
except TypeError:
is_iterable = False

if is_iterable:
for val in mapping:
if isinstance(val, tvm.ir.PrimExpr):
final_indices.append(val)
elif val is IndexMap.AXIS_SEPARATOR:
axis_separators.append(len(final_indices))
else:
raise TypeError(
"Expected mapping function to return list of "
"either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR. "
f"Instead received {val} of type {type(val)}."
)
else:
final_indices.append(mapping)

return IndexMap(initial_indices, final_indices, inverse_index_map), axis_separators

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/schedule/_type_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _dispatcher(type_: Any) -> Tuple[str, List[type]]:
return "atomic", [type_]


def callable_str(subtypes):
def callable_str(*subtypes):
if subtypes:
*arg_types, return_type = subtypes
arg_str = ", ".join(_type2str(arg_type) for arg_type in arg_types)
Expand Down
42 changes: 41 additions & 1 deletion python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2443,6 +2443,7 @@ def transform_layout(
block: Union[BlockRV, str],
buffer: Union[Tuple[str, int], str, Buffer],
index_map: Union[IndexMap, Callable],
pad_value: Optional[Union[int, float, IndexMap, Callable]] = None,
) -> None:
"""Apply a transformation represented by IndexMap to buffer
Expand Down Expand Up @@ -2479,6 +2480,36 @@ def transform_layout(
primitive will be called in addition to the
TransformLayout primitive.
pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]]
The value to be used for any padding introduced by the
transformation. If the schedule contains a producer block
for the specified buffer, the pad value will be written as
part of the producer block if possible, or after the producer
block otherwise. Otherwise, if the buffer is an input, will
insert an annotation block to state that the padding contains
the known value.
The pad value may not contain instances of BufferLoad,
except where it loads a value from the buffer being
transformed (e.g. to create a circular buffer with
padding that consists of repeated elements).
Note: If applied to an input buffer, the calling scope is
responsible for ensuring that the pad_value is present.
Algebraic symplifications, branch elimination, and other
optimizations may assume that this precondition is met, and
may result in incorrect results being returned.
If None, the transformation may not introduce padding.
If an int, float or PrimExpr, the transformation is the
specific value to be present in the padding.
If an IndexMap or Callable, the transformation is the
value to be present in the padding in terms of the
transformed index.
Examples
--------
Before transform_layout, in TensorIR, the IR is:
Expand Down Expand Up @@ -2536,9 +2567,18 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) ->
else:
axis_separators = []

if pad_value is None:
pass
elif callable(pad_value):
pad_value = IndexMap.from_func(pad_value, ndim=len(index_map.final_indices))
elif not isinstance(pad_value, IndexMap):
pad_value = IndexMap.from_func(
lambda *indices: pad_value, ndim=len(index_map.final_indices)
)

buffer_index_type_enum = 0 if buffer_index_type == "read" else 1
_ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member
self, block, buffer_index, buffer_index_type_enum, index_map
self, block, buffer_index, buffer_index_type_enum, index_map, pad_value
)
if axis_separators:
_ffi_api.ScheduleSetAxisSeparator( # type: ignore # pylint: disable=no-member
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def shared_16x32_to_ldmatrix_32x16_layout(i, j):


def shared_32x16_to_ldmatrix_32x16_layout(i, j):
thread_id = (i % 4) + 4 * (j % 8)
thread_id = (i % 16) // 4 + 4 * (j % 8)
return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4


Expand Down
3 changes: 2 additions & 1 deletion src/meta_schedule/postproc/rewrite_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ bool RewriteLayout(const Schedule& sch) {
// Apply schedule
BlockRV block_rv = sch->GetBlock(block->name_hint, func_name);
BlockRV cached_block_rv = sch->CacheRead(block_rv, buffer_index, "global");
sch->TransformLayout(block_rv, buffer_index, BufferIndexType::kRead, index_map.value());
sch->TransformLayout(block_rv, buffer_index, BufferIndexType::kRead, index_map.value(),
NullOpt);
sch->Annotate(cached_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion(
state->sch->state(), GetRef<tir::Block>(block), buffer_index, index_type);
auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region);
state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map);
state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, NullOpt);
};

for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) {
Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/index_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initia
// Unpack the map to an array, maintaining the same parameter order.
Array<PrimExpr> inverse_exprs;
for (const auto& index : (*this)->initial_indices) {
inverse_exprs.push_back(inverse_exprs_map.at(index));
inverse_exprs.push_back(analyzer.Simplify(inverse_exprs_map.at(index)));
}

PrimExpr padding_predicate = padded_iter_map->padding_predicate;
Expand Down
6 changes: 4 additions & 2 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -761,9 +761,11 @@ void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann
/******** Schedule: Layout transformation ********/
void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
const IndexMap& index_map) {
const IndexMap& index_map,
const Optional<IndexMap>& pad_value) {
TVM_TIR_SCHEDULE_BEGIN();
tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, index_map);
tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, index_map,
pad_value);
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_);
}
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class ConcreteScheduleNode : public ScheduleNode {
void Unannotate(const BlockRV& block_rv, const String& ann_key) override;
/******** Schedule: Layout transformation ********/
void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type,
const IndexMap& index_map) override;
const IndexMap& index_map, const Optional<IndexMap>& pad_value) override;
void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override;
void SetAxisSeparator(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
Expand Down
4 changes: 3 additions & 1 deletion src/tir/schedule/instruction_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,9 @@ TVM_ALWAYS_INLINE Array<ObjectRef> UnpackedInstTraits<TTraits>::_ConvertOutputs(
/********** PythonAPICall **********/

inline void PythonAPICall::AsPythonString(const ObjectRef& obj, std::ostream& os) {
if (const auto* str = obj.as<runtime::StringObj>()) {
if (!obj.defined()) {
os << "None";
} else if (const auto* str = obj.as<runtime::StringObj>()) {
os << str->data;
} else if (const auto* int_imm = obj.as<IntImmNode>()) {
os << int_imm->value;
Expand Down
4 changes: 3 additions & 1 deletion src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,11 @@ TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String&
* \param buffer_index The index of the buffer in block's read or write region.
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \param index_map The transformation to apply.
* \param pad_value The value to write into padding introduced by the transformation.
*/
TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
BufferIndexType buffer_index_type, const IndexMap& index_map);
BufferIndexType buffer_index_type, const IndexMap& index_map,
const Optional<IndexMap>& pad_value);

/*!
* \brief Apply a transformation represented by IndexMap to block
Expand Down
Loading

0 comments on commit 2af9b90

Please sign in to comment.