Skip to content

Commit

Permalink
[MetaSchedule] Support padding for irregular shapes for CUDA tensor core
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Sep 16, 2022
1 parent 8f8b6d8 commit 078d773
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 26 deletions.
7 changes: 4 additions & 3 deletions python/tvm/tir/schedule/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class TensorizeInfo(Object):


def get_tensorize_loop_mapping(
sch: Schedule, block: BlockRV, desc_func: PrimFunc
sch: Schedule, block: BlockRV, desc_func: PrimFunc, allow_padding: bool = False
) -> Optional[TensorizeInfo]:
"""Establish a mapping between loops in a target block and an intrinsic description
Expand All @@ -80,13 +80,14 @@ def get_tensorize_loop_mapping(
The target block to match against
desc_func : PrimFunc
The prim func describing the computation to be tensorized
allow_padding : bool
Whether to allow padding the block iters to match the intrinsic description
Returns
-------
tensorize_info : Optional[TensorizeInfo]
TensorizeInfo structure if a valid mapping is found, None otherwise
"""
return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func) # type: ignore
return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func, allow_padding) # type: ignore


@tvm._ffi.register_object("tir.schedule.AutoTensorizeMappingInfo")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,8 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
state->sch->TransformBlockLayout(state->tensor_core_reindex_B, index_map);
state->sch->TransformBlockLayout(state->block_rv, index_map);

return tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_name);
return tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_name,
/*allow_padding=*/true);
}

inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorization(
Expand Down
8 changes: 7 additions & 1 deletion src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -731,10 +731,15 @@ class TensorizeInfoNode : public Object {
Map<tir::StmtSRef, tir::For> loop_map;
/*! \brief Maps loops in an intrinsic description to its index, outer to inner */
Map<tir::For, Integer> desc_loop_indexer;
/*! \brief Optional padded extents of the block iters when padding is needed to match the
* intrinsic description
*/
Optional<Array<Integer>> block_iter_paddings;

void VisitAttrs(AttrVisitor* v) {
v->Visit("loop_map", &loop_map);
v->Visit("desc_loop_indexer", &desc_loop_indexer);
v->Visit("block_iter_paddings", &block_iter_paddings);
}

static constexpr const char* _type_key = "tir.schedule.TensorizeInfo";
Expand All @@ -751,11 +756,12 @@ class TensorizeInfo : public ObjectRef {
* \param self The schedule state to be tensorized
* \param block_sref The target block to match against
* \param desc_func The prim func describing the computation to be tensorized
* \param allow_padding Whether to allow padding the block iters to match the intrinsic description
* \return TensorizeInfo structure if a valid mapping is found, NullOpt otherwise
*/
Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref,
const tir::PrimFunc& desc_func);
const tir::PrimFunc& desc_func, bool allow_padding);

/*!\brief Necessary information used to perform transformations for tensorization */
class AutoTensorizeMappingInfoNode : public Object {
Expand Down
53 changes: 44 additions & 9 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1699,7 +1699,8 @@ TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer,

Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const tir::StmtSRef& block_sref,
const tir::PrimFunc& desc_func) {
const tir::PrimFunc& desc_func,
bool allow_padding) {
arith::Analyzer analyzer;
const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref);
// Step 1. Analyze desc_func, extract its block, loops and loop vars
Expand Down Expand Up @@ -1732,6 +1733,8 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
const int n_desc_vars = desc_block->iter_values.size();
const int offset = n_block_vars - n_desc_vars;

std::unordered_map<int, int> block_index_to_padding; // padding of each block iter if necessary

if (offset < 0) {
return NullOpt;
}
Expand Down Expand Up @@ -1782,10 +1785,11 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,

// Step 3.2. Find the corresponding iter_value of the target block with a matching iterator type
PrimExpr block_bind;
for (int i = next_block_ind; i >= 0; --i) {
if (iter_types_block[i] == iter_type_desc) {
next_block_ind = i - 1;
block_bind = block->iter_values[i];
int current_block_ind = next_block_ind;
for (; current_block_ind >= 0; --current_block_ind) {
if (iter_types_block[current_block_ind] == iter_type_desc) {
next_block_ind = current_block_ind - 1;
block_bind = block->iter_values[current_block_ind];
break;
}
}
Expand All @@ -1802,15 +1806,30 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,

PrimExpr residual = analyzer.Simplify(block_bind - block_loops[i]->loop_var);
if (UsesVar(residual,
[&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); }))
[&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) {
continue;
}
// padding is allowed only when the block has trivial bindings
if (allow_padding && !is_zero(residual)) {
allow_padding = false;
}

const IntImmNode* int_block_extent = block_loops[i]->extent.as<IntImmNode>();

// Check divisibility
if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) {
if (!int_block_extent) {
return NullOpt;
}
int64_t remainder = int_block_extent->value % int_desc_extent->value;
if (remainder != 0) {
if (allow_padding) {
// If the block loop is not divisible by the desc loop, we pad the block loop to make it
// divisible if padding is allowed.
block_index_to_padding[current_block_ind] = int_desc_extent->value - remainder;
} else {
return NullOpt;
}
}

ret->loop_map.Set(block_loop_sref, GetRef<tir::For>(desc_loop));
break;
Expand All @@ -1820,13 +1839,29 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
for (int i = 0, n = desc_loops.size(); i < n; ++i) {
ret->desc_loop_indexer.Set(GetRef<tir::For>(desc_loops[i]), Integer(i));
}
if (!block_index_to_padding.empty()) {
if (!allow_padding) {
return NullOpt;
}
Array<Integer> paddings;
for (int i = 0, n = block->block->iter_vars.size(); i < n; ++i) {
const IterVar& iter_var = block->block->iter_vars[i];
if (auto it = block_index_to_padding.find(i); it != block_index_to_padding.end()) {
paddings.push_back(IntImm(iter_var->var.dtype(), it->second));
} else {
paddings.push_back(IntImm(iter_var->var.dtype(), 0));
}
}
ret->block_iter_paddings = std::move(paddings);
}

return TensorizeInfo(ret);
}

TVM_REGISTER_GLOBAL("tir.schedule.IsSpatialPrimFunc").set_body_typed(IsSpatialPrimFunc);
TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping")
.set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) {
return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func);
.set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func, bool allow_padding) {
return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func, allow_padding);
});

/******** Auto Tensorization ********/
Expand Down
10 changes: 7 additions & 3 deletions src/tir/schedule/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,15 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_
}

Optional<LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
const String& intrin_name) {
Optional<tir::TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc);
const String& intrin_name, bool allow_padding) {
Optional<tir::TensorizeInfo> opt_tensorize_info =
GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block_rv),
tir::TensorIntrin::Get(intrin_name)->desc, allow_padding);
if (!opt_tensorize_info) return NullOpt;
const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get();
if (info->block_iter_paddings.defined()) {
sch->PadEinsum(block_rv, info->block_iter_paddings.value());
}
// Construct a mapping from tir loops back to LoopRVs
Map<tir::StmtSRef, LoopRV> loop2rv;
{
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_
* block tiled according to the given intrin, NullOpt if a valid loop mapping is not found
*/
Optional<tir::LoopRV> TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
const String& intrin_name);
const String& intrin_name, bool allow_padding = false);

/******** Block mutation ********/

Expand Down
Loading

0 comments on commit 078d773

Please sign in to comment.