From 2af9b90ec191424724842795c552d4c15682eb8c Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 19 Sep 2022 08:20:33 -0500 Subject: [PATCH] [TIR] Implement API for padded layout transformations (#12720) Implementation of API in `tvm.tir.schedule` for layout transformations with padding, as part of https://github.com/apache/tvm/issues/12261, item "Insert pad value into generated TIR, using `tir::if_then_else`, `builtin::assume`, and `builtin::undef`". Following the RFC discussion in https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1170294348 and https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1171290053, this commit preferentially rewrites the loops that surround a padded transformation where possible, in order to express padding in terms of `tir::if_then_else`. --- include/tvm/tir/schedule/schedule.h | 17 +- python/tvm/tir/function.py | 46 +- python/tvm/tir/schedule/_type_checker.py | 2 +- python/tvm/tir/schedule/schedule.py | 42 +- python/tvm/tir/tensor_intrin/cuda.py | 2 +- src/meta_schedule/postproc/rewrite_layout.cc | 3 +- .../multi_level_tiling_tensor_core.cc | 2 +- src/tir/ir/index_map.cc | 2 +- src/tir/schedule/concrete_schedule.cc | 6 +- src/tir/schedule/concrete_schedule.h | 2 +- src/tir/schedule/instruction_traits.h | 4 +- src/tir/schedule/primitive.h | 4 +- .../primitive/layout_transformation.cc | 910 +++++++++++++++++- src/tir/schedule/schedule.cc | 6 +- src/tir/schedule/traced_schedule.cc | 15 +- src/tir/schedule/traced_schedule.h | 2 +- .../test_tir_schedule_transform_layout.py | 410 ++++++++ 17 files changed, 1408 insertions(+), 67 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 8e5cd34d2e0b..049f063240df 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -601,9 +601,24 @@ class ScheduleNode : public runtime::Object { * \param buffer_index The index of the buffer in block's read or write region. * \param buffer_index_type The type of the buffer index, kRead or kWrite. * \param index_map The transformation to apply. + * + * \param pad_value The value to write into padding introduced by + * the transformation. If the schedule contains a producer block + * for the specified buffer, the pad value will be written as + * part of the producer block if possible, or after the producer + * block otherwise. Otherwise, if the buffer is an input, will + * insert an annotation block to state that the padding contains + * the known value. + * + * Note: If applied to an input buffer, the calling scope is + * responsible for ensuring that the pad_value is present. + * Algebraic symplifications, branch elimination, and other + * optimizations may assume that this precondition is met, and + * may result in incorrect results being returned. */ virtual void TransformLayout(const BlockRV& block_rv, int buffer_index, - BufferIndexType buffer_index_type, const IndexMap& index_map) = 0; + BufferIndexType buffer_index_type, const IndexMap& index_map, + const Optional& pad_value = NullOpt) = 0; /*! * \brief Apply a transformation represented by IndexMap to block diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index e525fc2cc31a..df39f8aebf71 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -308,8 +308,9 @@ def from_func( The function to map from source indices to target indices. The function should accept `tir.Var` parameters and return - a list. Each element of the returned list should be a - `tir.PrimExpr`. + a either a `tir.PrimExpr`, or a list of `tir.PrimExpr`. + Returning a `tir.PrimExpr` is equivalent to returning a + list of length 1 containing that `tir.PrimExpr`. ndim: Optional[int] @@ -356,9 +357,12 @@ def from_func_with_separators( mapping_function : Callable The function to map from source indices to target indices. - The function should accept tir.Var parameters and return a - list. Each element of the returned list should be either a - `tir.PrimExpr` or the object `IndexMap.AXIS_SEPARATOR`. + The function should accept tir.Var parameters and return + either a `tir.PrimExpr` or a list. Each element of the + returned list should be either a `tir.PrimExpr` or the + object `IndexMap.AXIS_SEPARATOR`. Returning a + `tir.PrimExpr` is equivalent to returning a list of length + 1 containing that `tir.PrimExpr`. ndim: Optional[int] @@ -423,17 +427,27 @@ def from_func_with_separators( final_indices = [] axis_separators = [] - for val in mapping: - if isinstance(val, tvm.ir.PrimExpr): - final_indices.append(val) - elif val is IndexMap.AXIS_SEPARATOR: - axis_separators.append(len(final_indices)) - else: - raise TypeError( - "Expected mapping function to return list of " - "either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR. " - f"Instead received {val} of type {type(val)}." - ) + + try: + iter(mapping) + is_iterable = True + except TypeError: + is_iterable = False + + if is_iterable: + for val in mapping: + if isinstance(val, tvm.ir.PrimExpr): + final_indices.append(val) + elif val is IndexMap.AXIS_SEPARATOR: + axis_separators.append(len(final_indices)) + else: + raise TypeError( + "Expected mapping function to return list of " + "either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR. " + f"Instead received {val} of type {type(val)}." + ) + else: + final_indices.append(mapping) return IndexMap(initial_indices, final_indices, inverse_index_map), axis_separators diff --git a/python/tvm/tir/schedule/_type_checker.py b/python/tvm/tir/schedule/_type_checker.py index 0b48dfc2b0e6..0c66f7ef6cdf 100644 --- a/python/tvm/tir/schedule/_type_checker.py +++ b/python/tvm/tir/schedule/_type_checker.py @@ -164,7 +164,7 @@ def _dispatcher(type_: Any) -> Tuple[str, List[type]]: return "atomic", [type_] -def callable_str(subtypes): +def callable_str(*subtypes): if subtypes: *arg_types, return_type = subtypes arg_str = ", ".join(_type2str(arg_type) for arg_type in arg_types) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index fdc871703275..b8f696b7a134 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -2443,6 +2443,7 @@ def transform_layout( block: Union[BlockRV, str], buffer: Union[Tuple[str, int], str, Buffer], index_map: Union[IndexMap, Callable], + pad_value: Optional[Union[int, float, IndexMap, Callable]] = None, ) -> None: """Apply a transformation represented by IndexMap to buffer @@ -2479,6 +2480,36 @@ def transform_layout( primitive will be called in addition to the TransformLayout primitive. + pad_value: Optional[Union[int, float, PrimExpr, IndexMap, Callable]] + + The value to be used for any padding introduced by the + transformation. If the schedule contains a producer block + for the specified buffer, the pad value will be written as + part of the producer block if possible, or after the producer + block otherwise. Otherwise, if the buffer is an input, will + insert an annotation block to state that the padding contains + the known value. + + The pad value may not contain instances of BufferLoad, + except where it loads a value from the buffer being + transformed (e.g. to create a circular buffer with + padding that consists of repeated elements). + + Note: If applied to an input buffer, the calling scope is + responsible for ensuring that the pad_value is present. + Algebraic symplifications, branch elimination, and other + optimizations may assume that this precondition is met, and + may result in incorrect results being returned. + + If None, the transformation may not introduce padding. + + If an int, float or PrimExpr, the transformation is the + specific value to be present in the padding. + + If an IndexMap or Callable, the transformation is the + value to be present in the padding in terms of the + transformed index. + Examples -------- Before transform_layout, in TensorIR, the IR is: @@ -2536,9 +2567,18 @@ def two_elementwise_transformed_intermediate_buffer(a: T.handle, c: T.handle) -> else: axis_separators = [] + if pad_value is None: + pass + elif callable(pad_value): + pad_value = IndexMap.from_func(pad_value, ndim=len(index_map.final_indices)) + elif not isinstance(pad_value, IndexMap): + pad_value = IndexMap.from_func( + lambda *indices: pad_value, ndim=len(index_map.final_indices) + ) + buffer_index_type_enum = 0 if buffer_index_type == "read" else 1 _ffi_api.ScheduleTransformLayout( # type: ignore # pylint: disable=no-member - self, block, buffer_index, buffer_index_type_enum, index_map + self, block, buffer_index, buffer_index_type_enum, index_map, pad_value ) if axis_separators: _ffi_api.ScheduleSetAxisSeparator( # type: ignore # pylint: disable=no-member diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 64d7c24840ae..a309b091285b 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -36,7 +36,7 @@ def shared_16x32_to_ldmatrix_32x16_layout(i, j): def shared_32x16_to_ldmatrix_32x16_layout(i, j): - thread_id = (i % 4) + 4 * (j % 8) + thread_id = (i % 16) // 4 + 4 * (j % 8) return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4 diff --git a/src/meta_schedule/postproc/rewrite_layout.cc b/src/meta_schedule/postproc/rewrite_layout.cc index 6ff9958c791f..998b22b57463 100644 --- a/src/meta_schedule/postproc/rewrite_layout.cc +++ b/src/meta_schedule/postproc/rewrite_layout.cc @@ -148,7 +148,8 @@ bool RewriteLayout(const Schedule& sch) { // Apply schedule BlockRV block_rv = sch->GetBlock(block->name_hint, func_name); BlockRV cached_block_rv = sch->CacheRead(block_rv, buffer_index, "global"); - sch->TransformLayout(block_rv, buffer_index, BufferIndexType::kRead, index_map.value()); + sch->TransformLayout(block_rv, buffer_index, BufferIndexType::kRead, index_map.value(), + NullOpt); sch->Annotate(cached_block_rv, attr::meta_schedule_layout_rewrite_preproc, const_true()); } } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 8fcb8fe503b7..6759b59a3245 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -499,7 +499,7 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( const tir::BufferRegion& reindexed_buffer_region = tir::GetNthAccessBufferRegion( state->sch->state(), GetRef(block), buffer_index, index_type); auto sub_index_map = f_get_sub_index_map(lhs_buffer, reindexed_buffer_region->region); - state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map); + state->sch->TransformLayout(state->block_rv, buffer_index, index_type, sub_index_map, NullOpt); }; for (int i = 0, n = block_before_reindex->reads.size(); i < n; ++i) { diff --git a/src/tir/ir/index_map.cc b/src/tir/ir/index_map.cc index cceff72ec82f..64c5d5d5ddde 100644 --- a/src/tir/ir/index_map.cc +++ b/src/tir/ir/index_map.cc @@ -93,7 +93,7 @@ std::pair IndexMap::NonSurjectiveInverse(Array initia // Unpack the map to an array, maintaining the same parameter order. Array inverse_exprs; for (const auto& index : (*this)->initial_indices) { - inverse_exprs.push_back(inverse_exprs_map.at(index)); + inverse_exprs.push_back(analyzer.Simplify(inverse_exprs_map.at(index))); } PrimExpr padding_predicate = padded_iter_map->padding_predicate; diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 9d7dc6b95f50..4558ad04baed 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -761,9 +761,11 @@ void ConcreteScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann /******** Schedule: Layout transformation ********/ void ConcreteScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const IndexMap& index_map) { + const IndexMap& index_map, + const Optional& pad_value) { TVM_TIR_SCHEDULE_BEGIN(); - tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, index_map); + tir::TransformLayout(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type, index_map, + pad_value); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("transform_layout", this->error_render_level_); } diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 1aa9dafcc93e..59a9e3752859 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -144,7 +144,7 @@ class ConcreteScheduleNode : public ScheduleNode { void Unannotate(const BlockRV& block_rv, const String& ann_key) override; /******** Schedule: Layout transformation ********/ void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const IndexMap& index_map) override; + const IndexMap& index_map, const Optional& pad_value) override; void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override; void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, diff --git a/src/tir/schedule/instruction_traits.h b/src/tir/schedule/instruction_traits.h index 56c69224fe17..122c5ff0d9fe 100644 --- a/src/tir/schedule/instruction_traits.h +++ b/src/tir/schedule/instruction_traits.h @@ -430,7 +430,9 @@ TVM_ALWAYS_INLINE Array UnpackedInstTraits::_ConvertOutputs( /********** PythonAPICall **********/ inline void PythonAPICall::AsPythonString(const ObjectRef& obj, std::ostream& os) { - if (const auto* str = obj.as()) { + if (!obj.defined()) { + os << "None"; + } else if (const auto* str = obj.as()) { os << str->data; } else if (const auto* int_imm = obj.as()) { os << int_imm->value; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 97233fe4bc6f..21388ff132ae 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -474,9 +474,11 @@ TVM_DLL void Unannotate(ScheduleState self, const StmtSRef& sref, const String& * \param buffer_index The index of the buffer in block's read or write region. * \param buffer_index_type The type of the buffer index, kRead or kWrite. * \param index_map The transformation to apply. + * \param pad_value The value to write into padding introduced by the transformation. */ TVM_DLL void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - BufferIndexType buffer_index_type, const IndexMap& index_map); + BufferIndexType buffer_index_type, const IndexMap& index_map, + const Optional& pad_value); /*! * \brief Apply a transformation represented by IndexMap to block diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 32ed279f028f..025723e1793d 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -16,12 +16,647 @@ * specific language governing permissions and limitations * under the License. */ + +#include +#include + #include "../../../arith/ir_mutator_with_analyzer.h" #include "../utils.h" namespace tvm { namespace tir { +/*! \brief Planning stage prior to rewriting in TransformLayoutRewriter + * + * There are four ways that transformation may be handled. Each + * updates the buffer shape and the indices used to acces the buffer + * in BufferStore/BufferLoad nodes, but differ in how they handle the + * `pad_value`. In order of preference, the different strategies are + * as follows: + * + * 1. NoPaddingRequired. The transformation does not introduce + * padding, so only local changes to update the indices of + * BufferLoad/BufferStore nodes are required. No blocks are added, + * removed, or replaced. + * + * 2. ProloguePlan. The transformation introduces padding, but the + * analyzed block has no write stages for the transformed buffer. + * This buffer is an input and the caller is responsible for ensuring + * that the padding contains the specified `pad_value`. The generated + * prologue contains `builtin::assume()` calls that will expose this + * known value during scheduling/simplification, but will be removed + * during lowering. + * + * 3. ReplacementPlan. The transformation introduces padding, has at + * least one write stage for the transformed buffer, and at least one + * of those write stages writes to all pre-transformation indices + * following a row-major traversal. These write stage is rewritten to + * be row-major traversals of the post-transformation indices, with a + * `tir::if_then_else` call to write either the specified `pad_value` + * into padding or the computed value into non-padding. + * + * 4. EpiloguePlan. The transformation introduces padding, has at + * least one write stage for the transformed buffer, but no write + * stage can be rewritten to use `tir::if_then_else`. The + * transformation still requires the `pad_value` to be written into + * the padding, so a new block is inserted after the last write stage + * to explicitly fill the padding. + * + */ +class TransformLayoutPlanner : private StmtExprVisitor { + public: + // Statement to be inserted prior to the analyzed block + struct ProloguePlan { + Stmt prologue; + }; + + // Loops within the analyzed block that should be replaced + struct ReplacementPlan { + Map replacements; + Map block_sref_reuse; + }; + + // The block to be inserted, along with the location at which it + // should be inserted. The location will be either a For or a + // Block, and will be after all writes the transformed buffer. + struct EpiloguePlan { + Stmt insert_after; + Stmt new_block; + }; + + struct NoPaddingRequired {}; + + using TransformPlan = + std::variant; + + static TransformPlan Plan(Block block, Buffer old_buffer, Buffer new_buffer, IndexMap index_map, + IndexMap inverse, PrimExpr padding_predicate, + Optional pad_value) { + ICHECK(!pad_value.defined() || pad_value.value()->final_indices.size() == 1) + << "Internal error: Should be caught by ScheduleError checks prior to this point"; + TransformLayoutPlanner visitor(old_buffer); + visitor(block); + return visitor.Finalize(new_buffer, index_map, inverse, padding_predicate, pad_value); + } + + private: + explicit TransformLayoutPlanner(Buffer old_buffer) : old_buffer_(old_buffer) {} + + void VisitStmt_(const ForNode* op) override { + BindLoopVar context(this, GetRef(op)); + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const LetStmtNode* op) override { + BindVariableDefinition context(this, op->var, op->value); + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const BlockRealizeNode* op) override { + BindBlockRealize context(this, GetRef(op)); + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const BufferStoreNode* op) override { + if (!op->buffer.same_as(old_buffer_)) { + return; + } + + std::optional> loop_dependency_range = std::nullopt; + for (const auto& index : op->indices) { + if (auto index_depth = LoopDependencyRange(index); index_depth.has_value()) { + if (loop_dependency_range) { + loop_dependency_range = { + std::min(loop_dependency_range.value().first, index_depth.value().first), + std::max(loop_dependency_range.value().second, index_depth.value().second)}; + } else { + loop_dependency_range = index_depth; + } + } + } + + WriteInfo write_info; + write_info.store = GetRef(op); + if (loop_dependency_range) { + size_t i = loop_dependency_range.value().first; + size_t j = loop_dependency_range.value().second; + ICHECK_LT(i, active_loops_.size()); + ICHECK_LT(j, active_loops_.size()); + + write_info.dependent_loopnest = {active_loops_.begin() + i, active_loops_.begin() + j + 1}; + } + write_info.innermost_block_realize = innermost_block_realize_; + + write_info.contains_row_major_traversal = [&]() -> bool { + const auto& loopnest = write_info.dependent_loopnest; + if (loopnest.empty()) { + return false; + } + + if (loopnest.size() != old_buffer_->shape.size() || loopnest.size() != op->indices.size()) { + return false; + } + + for (size_t i = 0; i < loopnest.size(); i++) { + const For& loop = loopnest[i]; + const PrimExpr& buffer_dim = old_buffer_->shape[i]; + PrimExpr index = Substitute(op->indices[i], active_var_bindings_); + bool is_loop_over_axis = index.same_as(loop->loop_var) && is_const_int(loop->min, 0) && + ExprDeepEqual()(loop->extent, buffer_dim) && + loop->kind == ForKind::kSerial; + if (!is_loop_over_axis) { + return false; + } + } + + return true; + }(); + + write_info_.push_back(write_info); + + // Don't need to continue recursing, as the entire goal was to + // find the BufferStore. + } + + std::optional> LoopDependencyRange(const PrimExpr& expr) const { + std::optional> prev = std::nullopt; + for (const auto& var : UndefinedVars(expr)) { + auto it = loop_depth_lookup_.find(var.get()); + if (it != loop_depth_lookup_.end()) { + if (prev.has_value()) { + prev = {std::min(prev.value().first, it->second.first), + std::max(prev.value().second, it->second.second)}; + } else { + prev = it->second; + } + } + } + + return prev; + } + + class BufferStoreReplacer : public StmtExprMutator { + public: + BufferStoreReplacer(std::function(const BufferStoreNode*)> replace_store, + std::function(const BlockRealizeNode*, const BlockRealize&)> + replace_block_realize) + : replace_store_(replace_store), replace_block_realize_(replace_block_realize) {} + + Stmt VisitStmt_(const BufferStoreNode* op) final { + if (auto replacement = replace_store_(op)) { + auto store = Downcast(replacement.value()); + return StmtExprMutator::VisitStmt_(store.get()); + } else { + return StmtExprMutator::VisitStmt_(op); + } + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + auto realize = Downcast(StmtExprMutator::VisitStmt_(op)); + if (auto replacement = replace_block_realize_(op, realize)) { + return replacement.value(); + } else { + return std::move(realize); + } + } + + private: + std::function(const BufferStoreNode*)> replace_store_; + std::function(const BlockRealizeNode*, const BlockRealize&)> + replace_block_realize_; + }; + + TransformPlan Finalize(Buffer new_buffer, IndexMap index_map, IndexMap inverse, + PrimExpr padding_predicate, Optional pad_value) const { + if (auto prologue_plan = + FinalizeProloguePlan(new_buffer, index_map, inverse, padding_predicate, pad_value); + prologue_plan.has_value()) { + return prologue_plan.value(); + } else if (auto replacement_plan = FinalizeReplacementPlan(new_buffer, index_map, inverse, + padding_predicate, pad_value); + replacement_plan.has_value()) { + return replacement_plan.value(); + } else if (auto epilogue_plan = FinalizeEpiloguePlan(new_buffer, index_map, inverse, + padding_predicate, pad_value); + epilogue_plan.has_value()) { + return epilogue_plan.value(); + } else { + return NoPaddingRequired(); + } + } + + std::optional FinalizeProloguePlan(Buffer new_buffer, IndexMap index_map, + IndexMap inverse, PrimExpr padding_predicate, + Optional pad_value) const { + if (write_info_.size() || is_zero(padding_predicate) || !pad_value.defined()) { + return std::nullopt; + } + + Array iter_vars; + Array iter_values; + Array indices; + Map loop_indices_to_block_indices; + ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); + for (size_t i = 0; i < inverse->initial_indices.size(); i++) { + const auto& loop_var = inverse->initial_indices[i]; + const auto& dim = new_buffer->shape[i]; + Var block_var("v_" + loop_var->name_hint, loop_var->dtype); + IterVar iter_var(Range(0, dim), block_var, kDataPar); + loop_indices_to_block_indices.Set(loop_var, block_var); + indices.push_back(iter_var->var); + iter_vars.push_back(iter_var); + iter_values.push_back(loop_var); + } + padding_predicate = Substitute(std::move(padding_predicate), loop_indices_to_block_indices); + + PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices)[0]; + PrimExpr expr = (!padding_predicate) || (BufferLoad(new_buffer, indices) == pad_value_at_index); + Stmt stmt = Evaluate(Call(DataType::Bool(), builtin::assume(), {expr})); + + std::stringstream block_name; + block_name << "buffer_" << new_buffer->name << "_assumptions"; + auto read_region = BufferRegion::FromPoint(new_buffer, indices); + stmt = BlockRealize(iter_values, Bool(true), + Block(iter_vars, {read_region}, {}, block_name.str(), stmt)); + + for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) { + size_t i = (inverse->initial_indices.size() - 1) - rev_i; + Var loop_var = inverse->initial_indices[i]; + PrimExpr extent = new_buffer->shape[i]; + stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt); + } + return ProloguePlan{stmt}; + } + + std::optional FinalizeReplacementPlan(Buffer new_buffer, IndexMap index_map, + IndexMap inverse, + PrimExpr padding_predicate, + Optional pad_value) const { + if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) { + return std::nullopt; + } + + auto generate_if_then_else_block = [&](const WriteInfo& info) -> Optional { + if (!info.contains_row_major_traversal || !pad_value.defined() || + is_zero(padding_predicate)) { + return NullOpt; + } + + Array old_indices = info.store->indices; + PrimExpr if_then_else_condition = padding_predicate; + Array new_indices; + for (const auto& var : inverse->initial_indices) { + new_indices.push_back(var); + } + + auto replace_block_realize = + [&]() -> std::function(const BlockRealizeNode*, const BlockRealize&)> { + auto no_change = [](const BlockRealizeNode*, const BlockRealize&) -> Optional { + return NullOpt; + }; + if (!info.innermost_block_realize) { + return no_change; + } + if (old_indices.empty()) { + return no_change; + } + + BlockRealize block_realize = info.innermost_block_realize.value(); + const auto& block = block_realize->block; + + // Find the block iterators that are used to access the buffer. Must be in the same order + // as they appear in the indices. + if (block->iter_vars.size() < old_indices.size()) { + return no_change; + } + const auto& iter_vars = block->iter_vars; + size_t block_index_start = 0; + for (; block_index_start < iter_vars.size() - old_indices.size(); block_index_start++) { + if (old_indices[0].same_as(iter_vars[block_index_start]->var)) { + break; + } + } + if (block_index_start > iter_vars.size() - old_indices.size()) { + return no_change; + } + + for (size_t i = 0; i < old_indices.size(); i++) { + if (!old_indices[i].same_as(iter_vars[block_index_start + i]->var) || + iter_vars[block_index_start + i]->iter_type != kDataPar) { + return no_change; + } + } + + // If we got to this point, all indices used to access the + // buffer are virtual indices defined in the innermost block. + // Therefore, generate new virtual indices for iterating over + // the post-transform buffer. + Array new_iter_values; // For BlockRealize + Array new_iter_vars; // For Block + Array new_access_indices; // For BufferStore + Map loop_var_to_virtual_var; // For updating if_then_else_condition + + for (size_t i = 0; i < block_index_start; i++) { + new_iter_vars.push_back(iter_vars[i]); + new_iter_values.push_back(block_realize->iter_values[i]); + } + + ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); + for (size_t i = 0; i < inverse->initial_indices.size(); i++) { + Var var = inverse->initial_indices[i]; + PrimExpr dim = new_buffer->shape[i]; + std::stringstream ss; + ss << "v_" << var->name_hint; + Var virtual_var(ss.str(), var.dtype()); + new_iter_values.push_back(var); + new_iter_vars.push_back(IterVar(Range::FromMinExtent(0, dim), virtual_var, kDataPar)); + new_access_indices.push_back(virtual_var); + loop_var_to_virtual_var.Set(var, virtual_var); + } + + for (size_t i = block_index_start + old_indices.size(); i < iter_vars.size(); i++) { + new_iter_vars.push_back(iter_vars[i]); + new_iter_values.push_back(block_realize->iter_values[i]); + } + + Map old_virtual_var_to_new_virtual_var; + ICHECK_EQ(inverse->final_indices.size(), old_indices.size()); + for (size_t i = 0; i < old_indices.size(); i++) { + Var var = Downcast(old_indices[i]); + PrimExpr expr = Substitute(inverse->final_indices[i], loop_var_to_virtual_var); + old_virtual_var_to_new_virtual_var.Set(var, expr); + } + + if_then_else_condition = Substitute(if_then_else_condition, loop_var_to_virtual_var); + new_indices = new_access_indices; + + return [target_realize = info.innermost_block_realize, new_iter_vars, new_iter_values, + old_virtual_var_to_new_virtual_var](const BlockRealizeNode* op, + const BlockRealize& visited) -> Optional { + if (op == target_realize.get()) { + Block block = visited->block; + block = + Downcast(Substitute(std::move(block), old_virtual_var_to_new_virtual_var)); + block.CopyOnWrite()->iter_vars = new_iter_vars; + + BlockRealize realize = visited; + { + auto write_ptr = realize.CopyOnWrite(); + write_ptr->block = block; + write_ptr->iter_values = new_iter_values; + } + return realize; + } else { + return NullOpt; + } + }; + }(); + + bool all_stores_replaced = true; + auto replace_store = [&](const BufferStoreNode* op) -> Optional { + if (!op->buffer.same_as(info.store->buffer)) { + all_stores_replaced = false; + return NullOpt; + } + ICHECK_EQ(old_indices.size(), op->indices.size()); + ExprDeepEqual expr_equal; + for (size_t i = 0; i < old_indices.size(); i++) { + if (!expr_equal(old_indices[i], op->indices[i])) { + all_stores_replaced = false; + return NullOpt; + } + } + + PrimExpr pad_value_at_index = pad_value.value()->MapIndices(new_indices)[0]; + return BufferStore(new_buffer, + if_then_else(if_then_else_condition, pad_value_at_index, op->value), + new_indices); + }; + + BufferStoreReplacer replacer(replace_store, replace_block_realize); + Stmt stmt = replacer(info.dependent_loopnest.back()->body); + if (!all_stores_replaced) { + return NullOpt; + } + + std::unordered_map var_remap; + ICHECK_EQ(info.dependent_loopnest.size(), inverse->final_indices.size()); + for (size_t i = 0; i < info.dependent_loopnest.size(); i++) { + Var var = info.dependent_loopnest[i]->loop_var; + PrimExpr expr = inverse->final_indices[i]; + var_remap[var.get()] = expr; + } + stmt = Substitute(std::move(stmt), var_remap); + + ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); + for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) { + size_t i = (inverse->initial_indices.size() - 1) - rev_i; + Var loop_var = inverse->initial_indices[i]; + PrimExpr extent = new_buffer->shape[i]; + stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt); + } + + return stmt; + }; + + Map loop_replacements; + + for (const auto& info : write_info_) { + if (info.dependent_loopnest.size()) { + if (auto opt_stmt = generate_if_then_else_block(info)) { + loop_replacements.Set(info.dependent_loopnest[0], opt_stmt.value()); + } + } + } + + if (loop_replacements.size()) { + return ReplacementPlan{std::move(loop_replacements)}; + } else { + return std::nullopt; + } + } + + std::optional FinalizeEpiloguePlan(Buffer new_buffer, IndexMap index_map, + IndexMap inverse, PrimExpr padding_predicate, + Optional pad_value) const { + if (write_info_.empty() || is_zero(padding_predicate) || !pad_value.defined()) { + return std::nullopt; + } + + Array iter_vars; + Array iter_values; + Array indices; + ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); + for (size_t i = 0; i < inverse->initial_indices.size(); i++) { + const auto& loop_var = inverse->initial_indices[i]; + const auto& dim = new_buffer->shape[i]; + Var block_var("v_" + loop_var->name_hint, loop_var->dtype); + IterVar iter_var(Range(0, dim), block_var, kDataPar); + indices.push_back(iter_var->var); + iter_vars.push_back(iter_var); + iter_values.push_back(loop_var); + } + + PrimExpr pad_value_at_index = pad_value.value()->MapIndices(indices)[0]; + Stmt stmt = BufferStore(new_buffer, pad_value_at_index, indices); + + std::stringstream block_name; + block_name << "buffer_" << new_buffer->name << "_padding"; + auto write_region = BufferRegion::FromPoint(new_buffer, indices); + stmt = BlockRealize(iter_values, padding_predicate, + Block(iter_vars, {}, {write_region}, block_name.str(), stmt)); + + ICHECK_EQ(inverse->initial_indices.size(), new_buffer->shape.size()); + for (size_t rev_i = 0; rev_i < inverse->initial_indices.size(); rev_i++) { + size_t i = (inverse->initial_indices.size() - 1) - rev_i; + Var loop_var = inverse->initial_indices[i]; + PrimExpr extent = new_buffer->shape[i]; + stmt = For(loop_var, 0, extent, ForKind::kSerial, stmt); + } + + const auto& info = write_info_.back(); + Stmt insert_after = [&]() -> Stmt { + if (info.dependent_loopnest.size()) { + return info.dependent_loopnest.front(); + } else if (info.innermost_block_realize) { + return info.innermost_block_realize.value(); + } else { + LOG(FATAL) << "Write occured outside of any block/loop"; + return Stmt(); + } + }(); + return EpiloguePlan{insert_after, stmt}; + } + + struct BindLoopVar { + BindLoopVar(TransformLayoutPlanner* self, For for_node) + : self_(self), var_(for_node->loop_var) { + size_t loop_depth = self_->active_loops_.size(); + self_->loop_depth_lookup_[var_.get()] = {loop_depth, loop_depth}; + self_->active_loops_.push_back(std::move(for_node)); + } + ~BindLoopVar() { + self_->active_loops_.pop_back(); + self_->loop_depth_lookup_.erase(var_.get()); + } + BindLoopVar(const BindLoopVar&) = delete; + BindLoopVar& operator=(const BindLoopVar&) = delete; + BindLoopVar(BindLoopVar&&) = delete; + BindLoopVar& operator=(BindLoopVar&&) = delete; + + TransformLayoutPlanner* self_{nullptr}; + Var var_; + }; + + struct BindVariableDefinition { + BindVariableDefinition() {} + BindVariableDefinition(TransformLayoutPlanner* self, Var var, PrimExpr value) + : self_(self), var_(var) { + if (auto loop_depth = self->LoopDependencyRange(value); loop_depth.has_value()) { + self_->loop_depth_lookup_[var_.get()] = loop_depth.value(); + self_->active_var_bindings_[var_.get()] = Substitute(value, self_->active_var_bindings_); + } + } + ~BindVariableDefinition() { + if (self_) { + self_->loop_depth_lookup_.erase(var_.get()); + self_->active_var_bindings_.erase(var_.get()); + } + } + BindVariableDefinition(const BindVariableDefinition&) = delete; + BindVariableDefinition& operator=(const BindVariableDefinition&) = delete; + BindVariableDefinition(BindVariableDefinition&& other) : BindVariableDefinition() { + swap(other); + } + BindVariableDefinition& operator=(BindVariableDefinition&& other) { + swap(other); + return *this; + } + void swap(BindVariableDefinition& other) { + std::swap(self_, other.self_); + std::swap(var_, other.var_); + } + + TransformLayoutPlanner* self_{nullptr}; + Var var_; + }; + + struct BindBlockRealize { + BindBlockRealize(TransformLayoutPlanner* self, BlockRealize block_realize) : self_(self) { + ICHECK_EQ(block_realize->iter_values.size(), block_realize->block->iter_vars.size()); + for (size_t i = 0; i < block_realize->iter_values.size(); i++) { + bound_vars_.emplace_back(self, block_realize->block->iter_vars[i]->var, + block_realize->iter_values[i]); + } + cache_ = std::move(block_realize); + std::swap(self_->innermost_block_realize_, cache_); + } + ~BindBlockRealize() { std::swap(self_->innermost_block_realize_, cache_); } + BindBlockRealize(const BindBlockRealize&) = delete; + BindBlockRealize& operator=(const BindBlockRealize&) = delete; + BindBlockRealize(BindBlockRealize&&) = delete; + BindBlockRealize& operator=(BindBlockRealize&&) = delete; + + TransformLayoutPlanner* self_{nullptr}; + Optional cache_; + std::vector bound_vars_; + }; + + struct WriteInfo { + // The BufferStore object + BufferStore store; + + // The block realize that contains the store, if any. + Optional innermost_block_realize; + + // The nested loops whose values contribute to the indices used in + // the store. Not all loop variables in the loopnest need to + // contribute, but the first and last must. + std::vector dependent_loopnest; + + // Whether the padding could be represented as a tir::if_then_else + // node. This requires that the surrounding loop iterators + // iterate over all pre-transformation buffer axes, that there are + // no data dependencies between loop iterations, and that + bool contains_row_major_traversal{false}; + }; + + /*! \brief Collected information about each BufferStore */ + std::vector write_info_; + + /*! \brief The loop iterators surrounding the current node + * + * The outermost loop iterator is `active_loops_.front()`, and the + * innermost loop iterator is `active_loops_.back()`. + * + * Used to fill the `WriteInfo::dependent_loopnest` field. + */ + std::vector active_loops_; + + /*! \brief Lookup for the outer/inner loops + * + * Used to fill the `WriteInfo::dependent_loopnest` field. + */ + std::unordered_map> loop_depth_lookup_; + + /*! \brief The variable mappings that are currently in-scope + * + * Used to determine whether the indices of a BufferStore are a + * row-major traversal, even if they are rebound in let/block + * mappings. + */ + std::unordered_map active_var_bindings_; + + /*! \brief The innermost BlockRealize surrounding the current node + * + * Used to fill the `WriteInfo::innermost_block_realize` field.. + */ + Optional innermost_block_realize_{NullOpt}; + + /*! \brief The buffer to be replaced */ + Buffer old_buffer_; +}; + class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { public: /*! @@ -33,23 +668,33 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { * \return The new AST rooting at the original parent scope and the map from the old block to the * new block */ - static std::pair> Rewrite(const Stmt& scope_stmt, - const Buffer& old_buffer, - const Buffer& new_buffer, - const IndexMap& index_map) { + static std::pair> Rewrite( + const Block& scope_stmt, const Buffer& old_buffer, const Buffer& new_buffer, + const IndexMap& index_map, const IndexMap& inverse, const PrimExpr& padding_predicate, + const Optional& pad_value) { + auto plan = TransformLayoutPlanner::Plan(scope_stmt, old_buffer, new_buffer, index_map, inverse, + padding_predicate, pad_value); + arith::Analyzer analyzer; - TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, &analyzer); - Stmt result = rewriter(scope_stmt); + TransformLayoutRewriter rewriter(old_buffer, new_buffer, index_map, plan, &analyzer); + Block result = Downcast(rewriter(scope_stmt)); + if (auto plan_ptr = std::get_if(&plan)) { + auto write_ptr = result.CopyOnWrite(); + write_ptr->body = SeqStmt({plan_ptr->prologue, write_ptr->body}); + } return {result, rewriter.block_sref_reuse_}; } private: TransformLayoutRewriter(const Buffer& old_buffer, const Buffer& new_buffer, - const IndexMap& index_map, arith::Analyzer* analyzer) + const IndexMap& index_map, + const TransformLayoutPlanner::TransformPlan& plan, + arith::Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer), old_buffer_(old_buffer), new_buffer_(new_buffer), index_map_(index_map), + plan_(plan), buffer_data_to_buffer_{{new_buffer->data, new_buffer}} {} void RewriteBufferAccess(Buffer* buffer, Array* indices) { @@ -61,6 +706,31 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { using Parent::VisitExpr_; using Parent::VisitStmt_; + Stmt VisitStmt(const Stmt& stmt) final { + Stmt output = Parent::VisitStmt(stmt); + if (auto plan_ptr = std::get_if(&plan_)) { + if (plan_ptr->insert_after.same_as(stmt)) { + return SeqStmt({output, plan_ptr->new_block}); + } + } + return output; + } + + Stmt VisitStmt_(const ForNode* op) final { + // Some replacements may include the original string, such as + // replacing `loop` with `{loop, post_proc}`. In this case, avoid + // infinite recursion. + + For node = GetRef(op); + if (auto plan_ptr = std::get_if(&plan_)) { + auto it = plan_ptr->replacements.find(node); + if (it != plan_ptr->replacements.end()) { + return VisitStmt((*it).second); + } + } + return Parent::VisitStmt_(op); + } + PrimExpr VisitExpr_(const BufferLoadNode* op) final { BufferLoad buffer_load = Downcast(Parent::VisitExpr_(op)); if (buffer_load->buffer.same_as(old_buffer_)) { @@ -97,6 +767,13 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { auto* n = block.CopyOnWrite(); RewriteAccessRegion(&n->reads, infered_access_regions[0]); RewriteAccessRegion(&n->writes, infered_access_regions[1]); + n->alloc_buffers.MutateByApply([this](const Buffer& buffer) { + if (buffer.same_as(old_buffer_)) { + return new_buffer_; + } else { + return buffer; + } + }); block_sref_reuse_.Set(GetRef(op), block); return std::move(block); } @@ -104,6 +781,7 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer { const Buffer& old_buffer_; const Buffer& new_buffer_; const IndexMap& index_map_; + const TransformLayoutPlanner::TransformPlan& plan_; Map buffer_data_to_buffer_; Map block_sref_reuse_; }; @@ -132,8 +810,158 @@ class BufferIsSubregionError : public ScheduleError { Buffer buffer_; }; +class TransformationPaddingIndexMapError : public ScheduleError { + public: + TransformationPaddingIndexMapError(IRModule mod, IndexMap pad_value) + : mod_(mod), pad_value_(pad_value) {} + + String FastErrorString() const final { + std::ostringstream ss; + ss << "ScheduleError: The IndexMap specifying pad_value has " + << pad_value_->final_indices.size() << " outputs, should only have one output"; + return ss.str(); + } + + String DetailRenderTemplate() const final { + std::ostringstream ss; + ss << "ScheduleError: Pad value is specified as " << pad_value_ << " which has " + << pad_value_->final_indices.size() << " outputs, but should only have one output"; + return ss.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + private: + IRModule mod_; + IndexMap pad_value_; +}; + +class TransformationPaddingTypeError : public ScheduleError { + public: + TransformationPaddingTypeError(IRModule mod, Buffer buffer, IndexMap pad_value) + : mod_(mod), buffer_(buffer), pad_value_(pad_value) { + ICHECK_EQ(pad_value_->final_indices.size(), 1); + pad_value_dtype_ = pad_value_->final_indices[0].dtype(); + } + + String FastErrorString() const final { + std::ostringstream ss; + ss << "ScheduleError: Type mismatch " << buffer_->dtype << " vs " << pad_value_dtype_; + return ss.str(); + } + + String DetailRenderTemplate() const final { + std::ostringstream ss; + ss << "ScheduleError: Buffer " << buffer_->name << " has elements of type " << buffer_->dtype + << ", but the transformation fills padding with " << pad_value_ << ", which is of type " + << pad_value_dtype_; + return ss.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + private: + IRModule mod_; + Buffer buffer_; + IndexMap pad_value_; + DataType pad_value_dtype_; +}; + +class TransformationPaddingExpressionError : public ScheduleError { + public: + static void Check(IRModule mod, Buffer buffer, IndexMap pad_value) { + Visitor visitor(buffer); + ICHECK_EQ(pad_value->final_indices.size(), 1) + << "Internal error: Should be caught by ScheduleError checks prior to this point"; + visitor(pad_value->final_indices[0]); + if (visitor.illegal_load) { + throw TransformationPaddingExpressionError(mod, buffer, pad_value, + visitor.illegal_load.value()); + } + } + + private: + struct Visitor : ExprVisitor { + explicit Visitor(const Buffer& buffer) : buffer_(buffer) {} + + void VisitExpr_(const BufferLoadNode* op) final { + if (!op->buffer.same_as(buffer_)) { + illegal_load = GetRef(op); + } + ExprVisitor::VisitExpr_(op); + } + + const Buffer& buffer_; + Optional illegal_load; + }; + + TransformationPaddingExpressionError(IRModule mod, Buffer buffer, IndexMap pad_value, + BufferLoad illegal_load) + : mod_(mod), buffer_(buffer), pad_value_(pad_value), illegal_load_(illegal_load) {} + + String FastErrorString() const final { + std::ostringstream ss; + ss << "ScheduleError: Pad value may not contain load load from " << illegal_load_->buffer->name; + return ss.str(); + } + + String DetailRenderTemplate() const final { + std::ostringstream ss; + ss << "ScheduleError: Pad value may only contain BufferLoad from the transformed buffer " + << buffer_->name << ", but pad_value " << pad_value_ << " contains expression " + << illegal_load_; + return ss.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; + Buffer buffer_; + IndexMap pad_value_; + BufferLoad illegal_load_; +}; + +class TransformationIntroducesPaddingError : public ScheduleError { + public: + TransformationIntroducesPaddingError(IRModule mod, Buffer buffer, IndexMap index_map, + PrimExpr padding_predicate) + : mod_(std::move(mod)), + buffer_(std::move(buffer)), + index_map_(std::move(index_map)), + padding_predicate_(std::move(padding_predicate)) {} + + String FastErrorString() const final { + std::ostringstream ss; + ss << "ScheduleError: Transformation would introduce padding at " << padding_predicate_ << "."; + return ss.str(); + } + + String DetailRenderTemplate() const final { + auto new_shape = index_map_->MapShape(buffer_->shape); + std::ostringstream os; + os << "The transformation " << index_map_ << " applied on buffer " << buffer_->name + << " of shape " << buffer_->shape << " would result in shape " << new_shape + << ". However, this would introduce padding wherever " << padding_predicate_ << " is true."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + private: + IRModule mod_; + Buffer buffer_; + IndexMap index_map_; + PrimExpr padding_predicate_; +}; + void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - BufferIndexType buffer_index_type, const IndexMap& index_map) { + BufferIndexType buffer_index_type, const IndexMap& index_map, + const Optional& pad_value) { + // Step 1: Input handling and error checking const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer old_buffer = GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, buffer_index_type); @@ -141,33 +969,48 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ if (defining_site_sref.defined() && !is_alloc) { throw BufferIsSubregionError(self->mod, old_buffer); } + if (pad_value) { + if (pad_value.value()->final_indices.size() != 1) { + throw TransformationPaddingIndexMapError(self->mod, pad_value.value()); + } + if (pad_value.value()->final_indices[0]->dtype != old_buffer->dtype) { + throw TransformationPaddingTypeError(self->mod, old_buffer, pad_value.value()); + } + + TransformationPaddingExpressionError::Check(self->mod, old_buffer, pad_value.value()); + } StmtSRef scope_sref = defining_site_sref.defined() ? defining_site_sref.value() : GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_sref); - // Step 1: Infer the shape of the new buffer - ObjectPtr new_buffer_node = make_object(*(old_buffer.get())); - new_buffer_node->shape = index_map->MapShape(old_buffer->shape); - Buffer new_buffer{new_buffer_node}; + auto [inverse, padding_predicate] = [&]() { + Array region; + for (const auto& dim : old_buffer->shape) { + region.push_back(Range::FromMinExtent(0, dim)); + } + return index_map.NonSurjectiveInverse(region); + }(); + + bool has_padding = !is_zero(padding_predicate); + if (has_padding && !pad_value.defined()) { + throw TransformationIntroducesPaddingError(self->mod, old_buffer, index_map, padding_predicate); + } - // Step 2: Rewrite access indices and regions of the buffer - auto [new_stmt, block_sref_reuse] = TransformLayoutRewriter::Rewrite( - GetRef(scope_block), old_buffer, new_buffer, index_map); + // Step 2: Infer the shape of the new buffer + Buffer new_buffer = old_buffer; + new_buffer.CopyOnWrite()->shape = index_map->MapShape(old_buffer->shape); + + // Step 3: Rewrite BufferLoad/BufferStore access indices, block read/write regions, and block + // alloc_buffers. + auto [new_stmt, block_sref_reuse] = + TransformLayoutRewriter::Rewrite(GetRef(scope_block), old_buffer, new_buffer, + index_map, inverse, padding_predicate, pad_value); Block new_scope_block = Downcast(new_stmt); - // Step 3: Rewrite alloc_buffer of the block or buffer_map of the PrimFunc. - if (defining_site_sref.defined()) { - auto* n = new_scope_block.CopyOnWrite(); - n->alloc_buffers.MutateByApply([&old_buffer, &new_buffer](const Buffer& buffer) { - if (buffer.same_as(old_buffer)) { - return new_buffer; - } - return buffer; - }); - block_sref_reuse.Set(GetRef(scope_block), new_scope_block); - } else { + // Step 4: Rewrite buffer_map of the PrimFunc if necessary. + if (!defining_site_sref.defined()) { GlobalVar g_var; GetRootPrimFunc(self->mod, scope_block, &g_var); IRModuleNode* new_mod = self->mod.CopyOnWrite(); @@ -502,17 +1345,20 @@ struct TransformLayoutTraits : public UnpackedInstTraits private: static constexpr size_t kNumInputs = 1; - static constexpr size_t kNumAttrs = 3; + static constexpr size_t kNumAttrs = 4; static constexpr size_t kNumDecisions = 0; static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index, - Integer buffer_index_type, IndexMap index_map) { + Integer buffer_index_type, IndexMap index_map, + Optional pad_value) { return sch->TransformLayout(block_rv, buffer_index.IntValue(), - static_cast(buffer_index_type->value), index_map); + static_cast(buffer_index_type->value), index_map, + pad_value); } static String UnpackedAsPython(Array outputs, String block_rv, Integer buffer_index, - Integer buffer_index_type, IndexMap index_map) { + Integer buffer_index_type, IndexMap index_map, + Optional pad_value) { PythonAPICall py("transform_layout"); py.Input("block", block_rv); @@ -522,6 +1368,8 @@ struct TransformLayoutTraits : public UnpackedInstTraits py.Input("buffer", os.str()); py.Input("index_map", index_map->ToPythonString()); + py.Input("pad_value", pad_value ? pad_value.value()->ToPythonString() : "None"); + return py.Str(); } @@ -532,6 +1380,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits attrs_record.push_back(attrs[0]); attrs_record.push_back(attrs[1]); attrs_record.push_back(String(::tvm::SaveJSON(attrs[2]))); + attrs_record.push_back(attrs[3]); return std::move(attrs_record); } @@ -541,6 +1390,7 @@ struct TransformLayoutTraits : public UnpackedInstTraits attrs.push_back(attrs_record[0]); attrs.push_back(attrs_record[1]); attrs.push_back(::tvm::LoadJSON(Downcast(attrs_record[2]))); + attrs.push_back(attrs_record[3]); return attrs; } diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index d72f67fb7c2d..2f27dbb9fbf1 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -248,9 +248,11 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnannotate") /******** (FFI) Layout transformation ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformLayout") .set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index, - int buffer_index_type, const IndexMap& index_map) { + int buffer_index_type, const IndexMap& index_map, + const Optional& pad_value) { return self->TransformLayout(block_rv, buffer_index, - static_cast(buffer_index_type), index_map); + static_cast(buffer_index_type), index_map, + pad_value); }); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleTransformBlockLayout") .set_body_method(&ScheduleNode::TransformBlockLayout); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index a31950d33115..9ff793dc39dd 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -487,14 +487,17 @@ void TracedScheduleNode::Unannotate(const BlockRV& block_rv, const String& ann_k void TracedScheduleNode::TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const IndexMap& index_map) { - ConcreteScheduleNode::TransformLayout(block_rv, buffer_index, buffer_index_type, index_map); + const IndexMap& index_map, + const Optional& pad_value) { + ConcreteScheduleNode::TransformLayout(block_rv, buffer_index, buffer_index_type, index_map, + pad_value); static const InstructionKind& kind = InstructionKind::Get("TransformLayout"); trace_->Append( - /*inst=*/Instruction(/*kind=*/kind, - /*inputs=*/{block_rv}, - /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type), index_map}, - /*outputs=*/{})); + /*inst=*/Instruction( + /*kind=*/kind, + /*inputs=*/{block_rv}, + /*attrs=*/{Integer(buffer_index), Integer(buffer_index_type), index_map, pad_value}, + /*outputs=*/{})); } void TracedScheduleNode::TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index ad44cc6ae552..0e83b35f44e9 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -103,7 +103,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { void Unannotate(const BlockRV& block_rv, const String& ann_key) override; /******** Schedule: Layout transformation ********/ void TransformLayout(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, - const IndexMap& index_map) override; + const IndexMap& index_map, const Optional& pad_value) override; void TransformBlockLayout(const BlockRV& block_rv, const IndexMap& index_map) override; void SetAxisSeparator(const BlockRV& block_rv, int buffer_index, BufferIndexType buffer_index_type, diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 0332df7fd312..8ed350cc4c46 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -329,5 +329,415 @@ def test_transform_block_layout_fail_mixed_iter_type(use_block_name): ) +class BasePaddingCompare(tvm.testing.CompareBeforeAfter): + pad_value = tvm.testing.parameter(None) + + transformed_buffer = tvm.testing.parameter("A") + + @pytest.fixture + def transform(self, pad_value, transformed_buffer): + def transform(mod): + sch = tir.Schedule(mod) + sch.transform_layout( + "block", transformed_buffer, lambda i: [i // 4, i % 4], pad_value=pad_value + ) + return sch.mod + + return transform + + +class TestNoPadding(BasePaddingCompare): + """Transformations without padding do not depend on pad_value.""" + + pad_value = tvm.testing.parameter(None, 42) + + def before(): + A = T.alloc_buffer(16, "int32") + for i in T.serial(16): + with T.block("block"): + vi = T.axis.remap("S", [i]) + A[vi] = 0 + + def expected(): + A = T.alloc_buffer([4, 4], "int32") + for i in T.serial(16): + with T.block("block"): + vi = T.axis.remap("S", [i]) + A[vi // 4, vi % 4] = 0 + + +class TestNoPaddingMultipleUsage(BasePaddingCompare): + """Transformations without padding do not depend on pad_value. + + Like TestNoPadding, but the buffer A shows up in multiple + locations. To remain internally consistent, all instances of the + buffer should be rewritten. + """ + + pad_value = tvm.testing.parameter(None, 42) + + def before(): + A = T.alloc_buffer(16, "int32") + for i in T.serial(16): + with T.block("block"): + vi = T.axis.remap("S", [i]) + A[vi] = 0 + + B = T.alloc_buffer(16, "int32") + for i in T.serial(16): + with T.block("other"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + def expected(): + A = T.alloc_buffer([4, 4], "int32") + for i in T.serial(16): + with T.block("block"): + vi = T.axis.remap("S", [i]) + A[vi // 4, vi % 4] = 0 + + B = T.alloc_buffer(16, "int32") + for i in T.serial(16): + with T.block("other"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi // 4, vi % 4] + + +class TestNoPaddingOpaqueBlock(BasePaddingCompare): + """Transformations without padding do not depend on pad_value. + + Like TestNoPadding, but buffer access is done in an opaque block. + """ + + pad_value = tvm.testing.parameter(None, 42) + + def before(): + A = T.alloc_buffer(16, "int32") + for i in T.serial(16): + with T.block("block"): + A[i] = 0 + + def expected(): + A = T.alloc_buffer([4, 4], "int32") + for i in T.serial(16): + with T.block("block"): + A[i // 4, i % 4] = 0 + + +class TestErrorIfPaddingForbidden(BasePaddingCompare): + """Unless padding is explicitly enabled, should raise error""" + + def before(): + A = T.alloc_buffer(14, "int32") + for i in T.serial(14): + with T.block("block"): + vi = T.axis.remap("S", [i]) + A[vi] = 0 + + expected = tvm.tir.schedule.schedule.ScheduleError + + +class TestErrorOnWrongPaddingType(BasePaddingCompare): + """The padding must have the same dtype as the buffer""" + + pad_value = tvm.testing.parameter(0.5) + + def before(): + A = T.alloc_buffer(14, "int32") + for i in T.serial(14): + with T.block("block"): + vi = T.axis.remap("S", [i]) + A[vi] = 0 + + expected = tvm.tir.schedule.schedule.ScheduleError + + +class TestPaddedTransformIfThenElse(BasePaddingCompare): + """Use if_then_else to represent padding, if possible. + + For a block that is a producer of the pre-transformation buffer, + which visits all indices according to a row-major traversal, and + which has no effect other than producing the transformed buffer, + transform the loop iterators to be a row-major traversal of the + post-transformation buffer, with padding represented by + `T.if_then_else`. + """ + + pad_value = tvm.testing.parameter(0) + transformed_buffer = tvm.testing.parameter("B") + + def before(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer(14, "int32") + for i in T.serial(14): + with T.block("block"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + def expected(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer([4, 4], "int32") + for i, j in T.grid(4, 4): + with T.block("block"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.if_then_else(vi == 3 and 2 <= vj, 0, A[vi * 4 + vj], dtype="int32") + + +class TestPaddedTransformWithoutLoop(BasePaddingCompare): + """Handle padded writes without a loop + + The statement being replaced may be something other than a + for-loop, such as if a loop has already been unrolled. + """ + + pad_value = tvm.testing.parameter(0) + + def before(A: T.Buffer[14, "int32"]): + with T.block("root"): + T.reads() + T.writes() + with T.block("block"): + A[0] = 0 + + def expected(A: T.Buffer[(4, 4), "int32"]): + with T.block("block"): + A[0, 0] = 0 + + for i, j in T.grid(4, 4): + with T.block("buffer_A_padding"): + vi, vj = T.axis.remap("SS", [i, j]) + T.where(i == 3 and 2 <= j) + A[vi, vj] = 0 + + +class TestPaddedTransformIfThenElseReduction(BasePaddingCompare): + """Like TestPaddedTransformIfThenElse, but with a reduction axis""" + + pad_value = tvm.testing.parameter(0) + transformed_buffer = tvm.testing.parameter("B") + + def before(A: T.Buffer[(14, 32), "int32"]): + B = T.alloc_buffer(14, "int32") + for i, k in T.grid(14, 32): + with T.block("block"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + B[vi] = 0 + B[vi] = B[vi] + A[vi, vk] + + def expected(A: T.Buffer[(14, 32), "int32"]): + B = T.alloc_buffer([4, 4], "int32") + for i, j, k in T.grid(4, 4, 32): + with T.block("block"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + B[vi, vj] = T.if_then_else(vi == 3 and 2 <= vj, 0, 0, dtype="int32") + B[vi, vj] = T.if_then_else( + vi == 3 and 2 <= vj, 0, B[vi, vj] + A[vi * 4 + vj, vk], dtype="int32" + ) + + +class TestPaddedTransformIfThenElseReductionOpaque(BasePaddingCompare): + """Like TestPaddedTransformIfThenElseReduction, but with opaque blocks""" + + pad_value = tvm.testing.parameter(0) + transformed_buffer = tvm.testing.parameter("B") + + def before(A: T.Buffer[(14, 32), "int32"]): + B = T.alloc_buffer(14, "int32") + for i in T.serial(14): + B[i] = 0 + for k in T.serial(32): + with T.block("block"): + B[i] = B[i] + A[i, k] + + def expected(A: T.Buffer[(14, 32), "int32"]): + B = T.alloc_buffer([4, 4], "int32") + for i, j in T.grid(4, 4): + B[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, 0, dtype="int32") + for k in T.serial(32): + with T.block("block"): + B[i, j] = T.if_then_else( + i == 3 and 2 <= j, 0, B[i, j] + A[i * 4 + j, k], dtype="int32" + ) + + +class TestPaddedTransformPostProcIfRequiredDueToSideEffects(BasePaddingCompare): + """Set the transformation padding in a post-processing block. + + Like TestPaddedTransformIfThenElse, but the block that produces B + also has the effect of setting `C`. + """ + + pad_value = tvm.testing.parameter(0) + transformed_buffer = tvm.testing.parameter("B") + + def before(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer(14, "int32") + C = T.alloc_buffer(14, "int32") + for i in T.serial(14): + with T.block("block"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + C[vi] = 0 + + def expected(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer([4, 4], "int32") + C = T.alloc_buffer(14, "int32") + for i in T.serial(14): + with T.block("block"): + vi = T.axis.remap("S", [i]) + B[vi // 4, vi % 4] = A[vi] + C[vi] = 0 + + for i, j in T.grid(4, 4): + with T.block("block_pad_B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.where(i == 3 and 2 <= j) + B[vi, vj] = 0 + + +class TestPaddedTransformOfInputCreatesAssumption(BasePaddingCompare): + """Transformation of an input buffer places T.assume locally""" + + pad_value = tvm.testing.parameter(42) + + def before(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]): + for i in T.serial(14): + with T.block("block"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + def expected(A: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]): + for i, j in T.grid(4, 4): + with T.block("buffer_A_assumption"): + vi, vj = T.axis.remap("SS", [i, j]) + T.assume(not (vi == 3 and 2 <= vj) or A[vi, vj] == 42) + + for i in T.serial(14): + with T.block("block"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi // 4, vi % 4] + + +class TestPaddedTransformNonConstantValue(tvm.testing.CompareBeforeAfter): + """Allow an expression to specify the pad value. + + Like TestPaddedTransformIfThenElse, but the pad value depends on + the indices. + """ + + @pytest.fixture + def transform(self): + def transform(mod): + sch = tir.Schedule(mod) + sch.transform_layout( + "block", + "B", + lambda i: [i // 4, i % 4], + pad_value=lambda i, j: i + j, + ) + return sch.mod + + return transform + + def before(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer(14, "int32") + for i in T.serial(14): + with T.block("block"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + def expected(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer([4, 4], "int32") + for i, j in T.grid(4, 4): + with T.block("block"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = T.if_then_else( + vi == 3 and 2 <= vj, vi + vj, A[vi * 4 + vj], dtype="int32" + ) + + +@pytest.mark.xfail(reason="Not yet implemented") +class TestPaddedTransformRepeatedBufferElement(tvm.testing.CompareBeforeAfter): + """Allow an expression to specify the pad value. + + Like TestPaddedTransformOfInputCreatesAssumption, but the pad + value depends on another portion of the buffer. In this case, the + padding at the end of A contains repeated elements from the + beginning of A. + """ + + @pytest.fixture + def transform(self): + def transform(mod): + sch = tir.Schedule(mod) + + A = sch.get(sch.get_block("block")).reads[0].buffer + sch.transform_layout( + "block", + "A", + lambda i: [i // 4, i % 4], + pad_value=lambda i, j: A[(4 * i + j) % 14], + ) + return sch.mod + + return transform + + def before(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer(14, "int32") + for i in T.serial(14): + with T.block("block"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + def expected(A: T.Buffer[(4, 4), "int32"]): + for i, j in T.grid(4, 4): + with T.block("buffer_A_assumption"): + vi, vj = T.axis.remap("SS", [i, j]) + T.assume( + not (vi == 3 and 2 <= vj) + or A[vi, vj] == A[((4 * vi + j) % 14) // 4, ((4 * vi + j) % 14) % 4] + ) + + B = T.alloc_buffer(14, "int32") + for i in T.grid(14): + with T.block("block"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi // 4, vi % 4] + + +class TestPadValueMayNotReferenceOtherBuffer(tvm.testing.CompareBeforeAfter): + """Allow an expression to specify the pad value. + + Like TestPaddedTransformRepeatedBufferElement, but the pad value depends on + a different buffer, which is not allowed. + """ + + @pytest.fixture + def transform(self): + def transform(mod): + sch = tir.Schedule(mod) + + A = sch.get(sch.get_block("block")).reads[0].buffer + other = tir.decl_buffer(1, A.dtype, name="other") + sch.transform_layout( + "block", + "A", + lambda i: [i // 4, i % 4], + pad_value=lambda i, j: other[0], + ) + return sch.mod + + return transform + + def before(A: T.Buffer[14, "int32"]): + B = T.alloc_buffer(14, "int32") + for i in T.serial(14): + with T.block("block"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + expected = tvm.tir.schedule.schedule.ScheduleError + + if __name__ == "__main__": tvm.testing.main()