From 375c8b15c926f5104203bacd0af5c013af380ed3 Mon Sep 17 00:00:00 2001 From: Hongyi Jin <3231950289@qq.com> Date: Wed, 14 Jul 2021 22:11:47 +0800 Subject: [PATCH 01/16] Fuse&split (#408) * first commit * fix cpplint * fix * remove redundant blank * address comments * lint * address comments * address comments * address comments * change fuse * change split * polish * lint * fix rebase * fix bug and add tests * clang format * address comments * format * address comments * address comments * add symbolic test * lint * address comment * check stage pipeline * fix mypy * check stage_pipeline * Revert "check stage_pipeline" This reverts commit a5a7f4fe * add stage_pipeline_assert Co-authored-by: jinhongyi <323195289@qq.com> --- include/tvm/arith/iter_affine_map.h | 12 + include/tvm/tir/schedule/schedule.h | 31 ++ python/tvm/tir/schedule/schedule.py | 131 ++++- src/arith/iter_affine_map.cc | 16 + src/arith/rewrite_simplify.cc | 4 +- src/tir/schedule/analysis.h | 6 + src/tir/schedule/analysis/analysis.cc | 30 ++ src/tir/schedule/concrete_schedule.cc | 28 + src/tir/schedule/concrete_schedule.h | 49 +- src/tir/schedule/primitive.h | 22 +- src/tir/schedule/primitive/fuse_split.cc | 483 ++++++++++++++++++ src/tir/schedule/schedule.cc | 2 + .../unittest/test_tir_schedule_split_fuse.py | 469 +++++++++++++++++ 13 files changed, 1268 insertions(+), 15 deletions(-) mode change 100644 => 100755 python/tvm/tir/schedule/schedule.py create mode 100644 src/tir/schedule/primitive/fuse_split.cc create mode 100644 tests/python/unittest/test_tir_schedule_split_fuse.py diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index d671339fb66b..6c72cbeafdd4 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -282,6 +282,18 @@ class IterSumExpr : public IterMapExpr { Array DetectIterMap(const Array& indices, const Map& input_iters, const PrimExpr& predicate, bool require_bijective, arith::Analyzer* analyzer); +/*! + * \brief Use IterVarMap detector to rewrite and simplify the indices + * + * \param indices The indices to detect pattern for. + * \param input_iters Map from variable to iterator's range. + * \param input_pred The predicate constraints on the input iterators + * \param require_bijective A boolean flag that indicates whether the mapping should be bijective. + * + * \return The indices after rewrite + */ +Array IterMapSimplify(const Array& indices, const Map& input_iters, + const PrimExpr& input_pred, bool require_bijective); /*! * \brief Apply the inverse of the affine transformation to the outputs. diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 9a09d0ad211f..38a15a814370 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -151,6 +151,18 @@ class ScheduleNode : public runtime::Object { * \return The corresponding loop sref */ virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0; + /*! + * \brief Get the block srefs corresponding to an array of BlockRVs + * \param block_rvs The BlockRVs to be looked up + * \return The corresponding block srefs + */ + virtual Array GetSRefs(const Array& block_rvs) const = 0; + /*! + * \brief Get the loop srefs corresponding to an array of LoopRVs + * \param loop_rvs The LoopRVs to be looked up + * \return The corresponding loop srefs + */ + virtual Array GetSRefs(const Array& loop_rvs) const = 0; /*! * \brief Get the block/loop sref corresponding to the specific statement * \param stmt The statement to be looked up @@ -196,6 +208,25 @@ class ScheduleNode : public runtime::Object { */ virtual Array GetLoops(const BlockRV& block_rv) = 0; /******** Schedule: loops manipulation ********/ + /*! + * \brief Fuse consecutive loops into one. It requires: + * 1) The loops can't have annotations or thread bindings. + * 2) The (i+1)-th loop must be the only child of the i-th loop. + * 3) All loops must start with 0. + * \param loop_rvs The loops to be fused + * \return The fused loop + */ + virtual LoopRV Fuse(const Array& loop_rvs) = 0; + /*! + * \brief Split a specified loop into two or more with the specific factor.It requires: + * 1) The loop can't have annotation or thread binding. + * 2) The loop must start with 0. + * \param loop_rv The loop to be split + * \param factors The tiling factors, and at most one of which is -1, which means that + * factor is inferred. + * \return The loops after splitting + */ + virtual Array Split(const LoopRV& loop_rv, const Array& factors) = 0; /******** Schedule: compute location ********/ /*! * \brief Inline a block into its consumer(s). It requires: diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py old mode 100644 new mode 100755 index 2091f4d80ab3..67350bd109d0 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=unused-import """The TensorIR schedule class""" -from typing import List, Optional, Union +from typing import List, Optional, Union, Tuple from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error @@ -43,7 +43,7 @@ class BlockRV(Object): """A random variable that refers to a block""" -ExprRV = PrimExpr # A random variable that evaluates to an integer +ExprRV = Union[PrimExpr] # A random variable that evaluates to an integer RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # type: ignore # pylint: disable=invalid-name @@ -257,6 +257,133 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]: return _ffi_api_schedule.ScheduleGetLoops(self, block) # type: ignore # pylint: disable=no-member ########## Schedule: loops manipulation ########## + def fuse(self, *loops: List[LoopRV]) -> LoopRV: + """Fuse a list of consecutive loops into one. It requires: + 1) The loops can't have annotations or thread bindings. + 2) The (i+1)-th loop must be the only child of the i-th loop. + 3) All loops must start with 0. + + Parameters + ---------- + *loops : List[LoopRV] + The loops to be fused + + Returns + ---------- + fused_loop : LoopRV + The new loop after fusion + + Examples + -------- + + Before fuse, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_fuse(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do fuse: + + .. code-block:: python + + sch = tir.Schedule(before_fuse, debug_mode=True) + i, j = sch.get_loops(sch.get_block("B")) + sch.fuse(i, j) + print(tvm.script.asscript(sch.mod["main"])) + + After applying fuse, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_fuse(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, [128, 128]) + for i0_i1_fused in tir.serial(0, 16384): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, tir.floordiv(i0_i1_fused, 128)) + tir.bind(vj, tir.floormod(i0_i1_fused, 128)) + tir.reads([A[vi, vj]]) + tir.writes([B[vi, vj]]) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + return _ffi_api_schedule.ScheduleFuse(self, loops) # type: ignore # pylint: disable=no-member + + def split( + self, + loop: LoopRV, + factors: List[Optional[ExprRV]], + ) -> List[LoopRV]: + """Split a loop into a list of consecutive loops. It requires: + 1) The loop can't have annotation or thread binding. + 2) The loop must start with 0. + Predicates may be added to ensure the total loop numbers keeps unchanged. + In `factors`, at most one of the factors can be None or -1, + which will be automatically inferred. + Parameters + ---------- + loop : LoopRV + The loop to be split + + factors: List[Optional[ExprRV]] + The splitting factors + + Returns + ---------- + split_loops : List[LoopRV] + The new loops after split + + Examples + -------- + + Before split, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_split(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, (128, 128)) + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + Create the schedule and do fuse: + + .. code-block:: python + + sch = tir.Schedule(before_split, debug_mode=True) + i, j = sch.get_loops(sch.get_block("B")) + sch.split(i, factors=[2, 64]) + print(tvm.script.asscript(sch.mod["main"])) + + After applying split, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_split(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128)) + B = tir.match_buffer(b, [128, 128]) + for i0_outer, i0_inner, i1 in tir.grid(2, 64, 128): + with tir.block([128, 128], "B") as [vi, vj]: + tir.bind(vi, ((i0_outer*64) + i0_inner)) + tir.bind(vj, i1) + tir.reads([A[vi, vj]]) + tir.writes([B[vi, vj]]) + B[vi, vj] = A[vi, vj] * 2.0 + + """ + for i, factor in enumerate(factors): + if factor is None: + factors[i] = -1 + return _ffi_api_schedule.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member + ########## Schedule: compute location ########## def compute_inline(self, block: BlockRV) -> None: """Inline a block into its consumer(s). It requires: diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index cd482279efe0..c7e4d7b4335b 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1085,6 +1085,22 @@ TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const Iter return NormalizeIterMapToExpr(expr); }); +Array IterMapSimplify(const Array& indices, const Map& input_iters, + const PrimExpr& input_pred, bool require_bijective) { + Analyzer analyzer; + Array rewrite = + DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer); + if (rewrite.empty()) { + return indices; + } else { + Array res; + res.reserve(rewrite.size()); + IterMapToExprNormalizer converter(&analyzer); + for (const auto& expr : rewrite) res.push_back(converter.Convert(expr)); + return res; + } +} + /*! * \brief Divider to divide the bindings into two sets of bindings(outer and inner) * such that binding_i = Y_i * E(Xi) + Xi, where E(X) is the extent of X. diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index a58e4433dadd..49ecb85b89b3 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -799,6 +799,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floordiv(x * c1, x * c2), floordiv(c1, c2), c2.Eval()->value > 0); + TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0)); @@ -881,7 +883,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - + TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0); TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x)); TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y)); diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index dd7fee37e2d1..0d713707a5ac 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -142,6 +142,12 @@ Array GetLoops(const StmtSRef& block_sref); * \return A list of leaf blocks */ Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref); +/*! + * \brief Get the direct child Schedulable Stmt (Block and For) + * \param stmt the parent stmt. + * \return the list of child stmts + */ +Array GetChildren(const Stmt& stmt); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index d58dece3c644..7584d36a65f6 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -298,5 +298,35 @@ Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent throw; } +Array GetChildren(const Stmt& stmt) { + /*! \note Nested SeqStmt is not allowed in schedule. */ + Stmt body; + if (const auto* block = stmt.as()) { + body = block->body; + } else if (const auto* loop = stmt.as()) { + body = loop->body; + } else { + LOG(FATAL) << "The Stmt can only be a Block or a For"; + } + if (const auto* seq = body.as()) { + Array ret; + for (const Stmt& child : seq->seq) { + ICHECK(!child->IsInstance()) << "Nested SeqStmt is not allowed in schedule."; + if (child->IsInstance()) { + ret.push_back(child.as()->block); + } else { + ret.push_back(child); + } + } + return ret; + } else { + if (body->IsInstance()) { + return Array{body.as()->block}; + } else { + return Array{body}; + } + } +} + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 0563d39427b1..a180bd76134b 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -258,6 +258,34 @@ Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { } /******** Schedule: loops manipulation ********/ + +LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs) { + TVM_TIR_SCHEDULE_BEGIN(); + CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)"; + Array loop_srefs = this->GetSRefs(loop_rvs); + StmtSRef fused_sref = tir::Fuse(state_, loop_srefs); + this->state_->DebugVerify(); + return CreateRV(fused_sref); + TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_); + throw; +} + +Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, const Array& factor_rvs) { + TVM_TIR_SCHEDULE_BEGIN(); + // Prepare for the splitting + StmtSRef loop_sref = this->GetSRef(loop_rv); + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + Array factors; + factors.reserve(factor_rvs.size()); + for (const ExprRV& factor_rv : factor_rvs) { + factors.push_back(this->Get(factor_rv)); + } + Array results = tir::Split(state_, loop_sref, factors); + return CreateRV(results); + TVM_TIR_SCHEDULE_END("split", this->error_render_level_); + throw; +} + /******** Schedule: compute location ********/ void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 8945fb9ee0dc..250246a01e17 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -68,6 +68,8 @@ class ConcreteScheduleNode : public ScheduleNode { inline PrimExpr Get(const ExprRV& expr_rv) const final; inline StmtSRef GetSRef(const BlockRV& block_rv) const final; inline StmtSRef GetSRef(const LoopRV& loop_rv) const final; + inline Array GetSRefs(const Array& rvs) const final; + inline Array GetSRefs(const Array& rvs) const final; void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); } void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); } void RemoveRV(const ExprRV& expr_rv) final { RemoveFromSymbolTable(expr_rv); } @@ -78,6 +80,8 @@ class ConcreteScheduleNode : public ScheduleNode { BlockRV GetBlock(const String& name, const String& func_name = "main") override; Array GetLoops(const BlockRV& block_rv) override; /******** Schedule: loops manipulation ********/ + LoopRV Fuse(const Array& loop_rvs) override; + Array Split(const LoopRV& loop_rv, const Array& factors) override; /******** Schedule: compute location ********/ void ComputeInline(const BlockRV& block) override; void ReverseComputeInline(const BlockRV& block) override; @@ -143,17 +147,22 @@ inline For ConcreteScheduleNode::Get(const LoopRV& loop_rv) const { } inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { - auto it = this->symbol_table_.find(expr_rv); - if (it == this->symbol_table_.end()) { - LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << expr_rv; - } - const ObjectRef& obj = (*it).second; - const auto* expr_node = obj.as(); - if (expr_node == nullptr) { - LOG(FATAL) << "ValueError: ExprRV's corresponding type is invalid: " - << (obj.defined() ? obj->GetTypeKey() : "None"); - } - return GetRef(expr_node); + PrimExpr transformed = Substitute(expr_rv, [this](const Var& var) -> Optional { + auto it = this->symbol_table_.find(var); + if (it == this->symbol_table_.end()) { + LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << var; + } + const ObjectRef& obj = (*it).second; + const auto* int_imm = obj.as(); + if (int_imm == nullptr) { + LOG(FATAL) << "ValueError: ExprRV's corresponding type is invalid: " + << (obj.defined() ? obj->GetTypeKey() : "None"); + } + return Integer(int_imm->value); + }); + PrimExpr simplified = this->analyzer_->Simplify(transformed); + CHECK(is_const_int(transformed)) << "ValueError: The ExprRV does not have a specific value"; + return simplified; } inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { @@ -198,6 +207,24 @@ inline StmtSRef ConcreteScheduleNode::GetSRef(const LoopRV& loop_rv) const { return GetRef(sref); } +template +inline Array GetSRefsHelper(const ConcreteScheduleNode* sch, const Array& rvs) { + Array result; + result.reserve(rvs.size()); + for (const T& rv : rvs) { + result.push_back(sch->GetSRef(rv)); + } + return result; +} + +inline Array ConcreteScheduleNode::GetSRefs(const Array& rvs) const { + return GetSRefsHelper(this, rvs); +} + +inline Array ConcreteScheduleNode::GetSRefs(const Array& rvs) const { + return GetSRefsHelper(this, rvs); +} + /******** Adding/Removing elements in the symbol table ********/ template diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index ab8299e38169..4f3691098942 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -25,7 +25,27 @@ namespace tvm { namespace tir { /******** Schedule: loops manipulation ********/ - +/*! + * Split a loop into several consecutive loops. It requires: + * 1) The loop can't have annotation or thread binding. + * 2) The loop must start with 0. + * \param self The state of the schedule + * \param loop_sref The sref to the loop being split + * \param factors The splitting factors + * \return An array of srefs to the loops after splitting + */ +TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, + const Array& factors); +/*! + * \brief Fuse consecutive loops. It requires: + * 1) The loops can't have annotations or thread bindings. + * 2) The inner loop must be the only child of the outer loop. + * 3) All loops must start with 0. + * \param self The state of the schedule + * \param loop_srefs An array of srefs to the loops to be fused + * \return The sref to the fused loop + */ +TVM_DLL StmtSRef Fuse(ScheduleState self, const Array& loop_srefs); /******** Schedule: compute location ********/ /*! * \brief Inline a block into its consumer(s). It requires: diff --git a/src/tir/schedule/primitive/fuse_split.cc b/src/tir/schedule/primitive/fuse_split.cc new file mode 100644 index 000000000000..02a8774f9120 --- /dev/null +++ b/src/tir/schedule/primitive/fuse_split.cc @@ -0,0 +1,483 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" +namespace tvm { +namespace tir { + +/*! \brief Append a new predicate to the each children of type BlockRealize (not recursively) */ +class PredicateUpdater : public StmtMutator { + public: + /*! + * \brief Constructor + * \param predicate The predicate to be apppend to BlockRealizeNode + */ + explicit PredicateUpdater(const PrimExpr& predicate, arith::Analyzer* ana) + : predicate_(predicate) { + if (!ana->CanProve(predicate)) { + add_predicate_ = true; + } + } + + private: + // For each direct child of type BlockRealizeNode, append the predicate + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + // We do not recursively do this + if (add_predicate_) { + ObjectPtr n = CopyOnWrite(realize); + n->predicate = n->predicate && predicate_; + return BlockRealize(n); + } else { + return GetRef(realize); + } + } + + /*! \brief The predicate to be added */ + const PrimExpr& predicate_; + /*! \brief whether to add predicate */ + bool add_predicate_; +}; +/*! \brief Substitute vars and collect the reuse mapping of opaque blocks */ +class IRSubstituteAndCollectOpaqueBlock : public StmtExprMutator { + public: + explicit IRSubstituteAndCollectOpaqueBlock(std::function(const Var&)> vmap, + Map* opaque_blocks) + : vmap_(vmap), opaque_blocks_(opaque_blocks) {} + + private: + PrimExpr VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + Optional ret = vmap_(var); + if (ret.defined()) { + return ret.value(); + } else { + return std::move(var); + } + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + Stmt res = StmtMutator::VisitStmt_(op); + if (op->block->iter_vars.empty()) { + const BlockRealizeNode* realize = res.as(); + opaque_blocks_->Set(op->block, realize->block); + } + return res; + } + + /*! \brief The substitute function */ + std::function(const Var&)> vmap_; + /*! \brief The reuse mapping */ + Map* opaque_blocks_; +}; + +Stmt SubstituteAndCollectOpaqueBlock(Stmt stmt, Map* opaque_blocks, + std::function(const Var&)> vmap) { + return IRSubstituteAndCollectOpaqueBlock(vmap, opaque_blocks)(std::move(stmt)); +} + +/*! \brief Simplify the binding of block realize and update the opaque block reuse mapping*/ +class BlockRealizeRewriter : public StmtExprMutator { + public: + explicit BlockRealizeRewriter( + const std::unordered_map& loop_map, + Map* opaque_blocks) + : opaque_blocks_(opaque_blocks) { + loop_map_.insert(loop_map.begin(), loop_map.end()); + } + + private: + Stmt VisitStmt_(const ForNode* op) final { + loop_map_[op->loop_var] = Range::FromMinExtent(op->min, op->extent); + Stmt res = StmtMutator::VisitStmt_(op); + loop_map_.erase(op->loop_var); + return res; + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + // skip opaque block and update mapping + if (op->iter_values.empty()) { + Stmt res = StmtMutator::VisitStmt_(op); + const BlockRealizeNode* realize = res.as(); + for (const std::pair& entry : *opaque_blocks_) { + if (entry.second.same_as(op->block)) { + opaque_blocks_->Set(entry.first, realize->block); + break; + } + } + return res; + } + auto v = arith::IterMapSimplify(op->iter_values, loop_map_, op->predicate, false); + if (v.same_as(op->iter_values)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->iter_values = std::move(v); + return Stmt(n); + } + } + /*! \brief The range of loops */ + std::unordered_map loop_map_; + /*! \brief The reuse mapping */ + Map* opaque_blocks_; +}; + +Stmt SimplifyBindings(const Stmt& stmt, const Array& loops, + Map* opaque_blocks) { + std::unordered_map loop_map; + for (const StmtSRef& sref : loops) { + const auto* loop = sref->StmtAs(); + loop_map[loop->loop_var] = Range::FromMinExtent(loop->min, loop->extent); + } + BlockRealizeRewriter rewriter(loop_map, opaque_blocks); + return rewriter(stmt); +} + +class NotLoopError : public ScheduleError { + public: + explicit NotLoopError(IRModule mod, String type) : mod_(mod), type_(type) {} + + String FastErrorString() const final { + return "ScheduleError: this primitive only operates on a " + "loop"; + } + + String DetailRenderTemplate() const final { + return "this primitive only operates on a loop, but the StmtSref passed in points to" + "type: {0} "; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {type_}; } + + IRModule mod_; + String type_; +}; + +class HasAnnotationError : public ScheduleError { + public: + explicit HasAnnotationError(IRModule mod, For loop) : mod_(mod), loop_(loop) {} + + String FastErrorString() const final { + return "ScheduleError: The primitive can't be applied because the loop has annotation"; + } + + String DetailRenderTemplate() const final { + return "The primitive can't be applied because the loop {0} has annotation"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class HasThreadBindingError : public ScheduleError { + public: + explicit HasThreadBindingError(IRModule mod, For loop) : mod_(mod), loop_(loop) {} + + String FastErrorString() const final { + return "ScheduleError: The primitive can't be applied because the loop has thread binding"; + } + + String DetailRenderTemplate() const final { + return "The primitive can't be applied because the loop {0} has thread binding"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class OuterNotInnerParent : public ScheduleError { + public: + explicit OuterNotInnerParent(IRModule mod, For outer, For inner) + : mod_(mod), outer_(outer), inner_(inner) {} + + String FastErrorString() const final { + return "ScheduleError: the outer loop is not the parent of the inner loop"; + } + + String DetailRenderTemplate() const final { + return "The loops can't be fused because the outer loop {0} is not the parent of the inner " + "loop {1}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {outer_, inner_}; } + + IRModule mod_; + For outer_; + For inner_; +}; + +class NotOnlyChildError : public ScheduleError { + public: + explicit NotOnlyChildError(IRModule mod, For outer, For inner) + : mod_(mod), outer_(outer), inner_(inner) {} + + String FastErrorString() const final { + return "ScheduleError: the inner loop is not the only child of outer loop"; + } + + String DetailRenderTemplate() const final { + return "The loops can't be fused because the inner loop {1} is not the only child of outer " + "loop {0}."; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {outer_, inner_}; } + + IRModule mod_; + For outer_; + For inner_; +}; + +class LoopNotStartWithZeroError : public ScheduleError { + public: + explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), loop_(loop) {} + + String FastErrorString() const final { + return "ScheduleError: the primitive only supports loop starting with 0"; + } + + String DetailRenderTemplate() const final { + return "The loop {0} does not start with 0, which is not supported"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +class NotSingleInferFactorError : public ScheduleError { + public: + explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} + + String FastErrorString() const final { + return "ScheduleError: only one factor can be specified as -1 or none"; + } + + String DetailRenderTemplate() const final { + return "Only one factor can be specified as -1 or none"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; +}; + +class WrongFactorProductError : public ScheduleError { + public: + explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(loop) {} + + String FastErrorString() const final { + return "ScheduleError: The product of factors is not larger than or equal to the extent of " + "loop"; + } + + String DetailRenderTemplate() const final { + return "The product of factors is not larger than or equal to the extent of loop {0}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; +}; + +Array Split(ScheduleState self, const StmtSRef& loop_sref, + const Array& factors) { + // Invariance + // - The total repeat number has not changed for each direct child block with updating predicate. + // - The execution order has not changed. (The block executes with the same args and the same + // order with before. + // Step 1. Check correctness + GetScopeRootAndCheckStagePipeline(self, loop_sref); + const auto* loop = loop_sref->StmtAs(); + if (loop == nullptr) { + throw NotLoopError(self->mod, loop_sref->stmt->GetTypeKey()); + } + if (!loop->annotations.empty()) { + throw HasAnnotationError(self->mod, GetRef(loop)); + } + if (loop->thread_binding.defined()) { + throw HasThreadBindingError(self->mod, GetRef(loop)); + } + // Currently, loops starting with 0 is not supported + arith::Analyzer analyzer; + if (!analyzer.CanProve(loop->min == 0)) { + throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); + } + PrimExpr tot_length = 1; + int infer_index = -1; + for (size_t i = 0; i < factors.size(); i++) { + if (!analyzer.CanProve(factors[i] == -1)) { + tot_length *= factors[i]; + } else { + if (infer_index != -1) { + throw NotSingleInferFactorError(self->mod); + } else { + infer_index = i; + } + } + } + // Step 2. infer factors if needed + Array inferred_factors(factors); + if (infer_index != -1) { + inferred_factors.Set(infer_index, + analyzer.Simplify(floordiv(loop->extent + tot_length - 1, tot_length))); + } else { + if (!analyzer.CanProve(tot_length >= loop->extent)) { + throw WrongFactorProductError(self->mod, GetRef(loop)); + } + } + // Step 3. Replace all occurrence of the original loop var with new variables + std::vector new_loop_vars; + new_loop_vars.reserve(inferred_factors.size()); + for (size_t i = 0; i < inferred_factors.size(); i++) { + new_loop_vars.push_back(loop->loop_var.copy_with_suffix("_" + std::to_string(i))); + } + PrimExpr substitute_value = 0; + for (size_t i = 0; i < inferred_factors.size(); i++) { + substitute_value *= inferred_factors[i]; + substitute_value += new_loop_vars[i]; + } + Map opaque_block_reuse; + auto substitute_function = [&](const Var& v) -> Optional { + if (v.same_as(loop->loop_var)) { + return substitute_value; + } else { + return NullOpt; + } + }; + Stmt new_loop_body = + SubstituteAndCollectOpaqueBlock(loop->body, &opaque_block_reuse, substitute_function); + for (size_t i = 0; i < inferred_factors.size(); i++) { + analyzer.Bind(new_loop_vars[i], Range::FromMinExtent(0, inferred_factors[i])); + } + // Step 4. Update predicate to guard the loop + PrimExpr predicate = substitute_value < loop->extent; + new_loop_body = PredicateUpdater(predicate, &analyzer)(new_loop_body); + // Step 5. Generate tnested loops to replace the original loop and simplify the binding + Stmt outer_stmt = new_loop_body; + for (int i = inferred_factors.size() - 1; i >= 0; i--) { + outer_stmt = For(new_loop_vars[i], 0, inferred_factors[i], loop->kind, outer_stmt); + } + + outer_stmt = + Downcast(SimplifyBindings(outer_stmt, GetLoops(loop_sref), &opaque_block_reuse)); + self->Replace(loop_sref, outer_stmt, opaque_block_reuse); + Array result_srefs; + result_srefs.reserve(inferred_factors.size()); + for (size_t i = 0; i < inferred_factors.size(); i++) { + result_srefs.push_back(self->stmt2ref.at(outer_stmt.get())); + const ForNode* outer_loop = outer_stmt.as(); + ICHECK(outer_loop); + outer_stmt = outer_loop->body; + } + return result_srefs; +} + +StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { + // Invariance + // - The total repeat number has not changed for each direct child block. + // - The execution order has not changed. (The block executes with the same + // args and the same order with before.) + std::vector loops; + loops.reserve(loop_srefs.size()); + StmtSRef outer_sref{nullptr}; + const ForNode* outer_loop = nullptr; + arith::Analyzer analyzer; + // Step 1. check correctness + GetScopeRootAndCheckStagePipeline(self, loop_srefs[0]); + for (const StmtSRef& sref : loop_srefs) { + const auto* loop = sref->StmtAs(); + if (loop == nullptr) { + throw NotLoopError(self->mod, sref->stmt->GetTypeKey()); + } + if (!loop->annotations.empty()) { + throw HasAnnotationError(self->mod, GetRef(loop)); + } + if (loop->thread_binding.defined()) { + throw HasThreadBindingError(self->mod, GetRef(loop)); + } + if (outer_sref.defined()) { + if (sref->parent != outer_sref.get()) { + throw OuterNotInnerParent(self->mod, GetRef(outer_loop), GetRef(loop)); + } + Array outer_children = GetChildren(GetRef(outer_loop)); + if (outer_children.size() != 1 || outer_children[0].get() != loop) { + throw NotOnlyChildError(self->mod, GetRef(outer_loop), GetRef(loop)); + } + } + outer_sref = sref; + outer_loop = loop; + if (!analyzer.CanProve(loop->min == 0)) { + throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); + } + loops.push_back(loop); + } + // Step 2. Create fused loop var and replace the original loop vars + std::string suffix; + for (size_t i = 1; i < loops.size(); i++) { + suffix += "_" + loops[i]->loop_var->name_hint; + } + suffix += "_fused"; + Var fused_var = loops[0]->loop_var.copy_with_suffix(suffix); + Array substitute_value; + substitute_value.resize(loops.size()); + PrimExpr tot = fused_var; + for (int i = loops.size() - 1; i >= 0; i--) { + substitute_value.Set(i, floormod(tot, loops[i]->extent)); + tot = floordiv(tot, loops[i]->extent); + } + Stmt loop_body = loops.back()->body; + Map opaque_block_reuse; + auto substitute_function = [&](const Var& v) -> Optional { + for (size_t i = 0; i < loops.size(); i++) { + if (v.same_as(loops[i]->loop_var)) { + return substitute_value[i]; + } + } + return NullOpt; + }; + Stmt new_loop_body = + SubstituteAndCollectOpaqueBlock(loop_body, &opaque_block_reuse, substitute_function); + // Step 3. Generate a loop to replace the original loops + PrimExpr fused_min = 0; + PrimExpr fused_extent = 1; + for (size_t i = 0; i < loops.size(); i++) { + fused_extent *= loops[i]->extent; + } + fused_extent = analyzer.Simplify(fused_extent); + For fused_loop = For(fused_var, fused_min, fused_extent, loops[0]->kind, new_loop_body); + fused_loop = + Downcast(SimplifyBindings(fused_loop, GetLoops(loop_srefs[0]), &opaque_block_reuse)); + self->Replace(loop_srefs[0], fused_loop, opaque_block_reuse); + return self->stmt2ref.at(fused_loop.get()); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 115f7936f64e..77d17c9dc6e9 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -123,6 +123,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetLoops") .set_body_method(&ScheduleNode::GetLoops); /******** (FFI) loops manipulation ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSplit").set_body_method(&ScheduleNode::Split); /******** (FFI) compute location ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") .set_body_method(&ScheduleNode::ComputeInline); diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py new file mode 100644 index 000000000000..56f1a4a3fff7 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -0,0 +1,469 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import pytest +import tvm +from tvm import tir +from tvm.script import ty + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def elementwise(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_symbolic(a: ty.handle, b: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (128, 128, n)) + B = tir.match_buffer(b, (128, 128, n)) + for i, j, k in tir.grid(128, 128, n): + with tir.block([128, 128, n], "B") as [vi, vj, vk]: + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_symbolic_fused(a: ty.handle, b: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (128, 128, n)) + B = tir.match_buffer(b, (128, 128, n)) + for i_j_k_fused in tir.serial(0, (n * 16384)): + with tir.block([128, 128, n], "B") as [vi, vj, vk]: + tir.bind(vi, tir.floordiv(i_j_k_fused, (n * 128))) + tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, n), 128)) + tir.bind(vk, tir.floormod(i_j_k_fused, n)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_symbolic_split(a: ty.handle, b: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (128, 128, n)) + B = tir.match_buffer(b, (128, 128, n)) + for i, j, k0, k1 in tir.grid(128, 128, 10, tir.floordiv((n + 9), 10)): + with tir.block([128, 128, n], "B") as [vi, vj, vk]: + tir.where((((k0 * tir.floordiv((n + 9), 10)) + k1) < n)) + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, ((k0 * tir.floordiv((n + 9), 10)) + k1)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_seq(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + C = tir.alloc_buffer((128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(0, 128): + with tir.block([128, 128, 128], "C") as [vi, vj, vk]: + C[vi, vj, vk] = A[vi, vj, vk] * 2.0 + for k in tir.serial(0, 128): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + B[vi, vj, vk] = C[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_anno(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(0, 128, annotations={"useless_annotation": True}): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_thread_binding(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.thread_binding(0, 128, thread="threadIdx.x"): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_starting_point(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j in tir.grid(128, 128): + for k in tir.serial(10, 128): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_with_opaque_block(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for i, j, k in tir.grid(128, 128, 128): + with tir.block([], "opaque"): + tir.reads([A[i, j, k]]) + tir.writes([B[i, j, k]]) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_fused(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + for fused in tir.serial(0, 2097152): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, tir.floordiv(fused, 16384)) + tir.bind(vj, tir.floormod(tir.floordiv(fused, 128), 128)) + tir.bind(vk, tir.floormod(fused, 128)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_case0(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128, 128]) + B = tir.match_buffer(b, [128, 128, 128]) + for i1, i2, i3, j1, j2, k1, k2 in tir.grid(2, 1, 64, 4, 32, 16, 8): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, ((i1 * 64) + i3)) + tir.bind(vj, ((j1 * 32) + j2)) + tir.bind(vk, ((k1 * 8) + k2)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_case1(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128, 128]) + B = tir.match_buffer(b, [128, 128, 128]) + for i1, i2, i3, j1, j2, j3, k1, k2, k3 in tir.grid(2, 1, 64, 2, 1, 64, 2, 1, 64): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i1 * 64 + i3) + tir.bind(vj, j1 * 64 + j3) + tir.bind(vk, k1 * 64 + k3) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_with_predicate(a: ty.handle, b: ty.handle) -> None: + B = tir.match_buffer(b, [128, 128, 128]) + A = tir.match_buffer(a, [128, 128, 128]) + for i0, i1, i2, j0, j1, k0, k1 in tir.grid(1000, 2, 3, 1, 129, 3, 43): + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.where( + ( + ((((((i0 * 2) + i1) * 3) + i2) < 128) and (((j0 * 129) + j1) < 128)) + and (((k0 * 43) + k1) < 128) + ) + ) + tir.bind(vi, (((i0 * 6) + (i1 * 3)) + i2)) + tir.bind(vj, j1) + tir.bind(vk, ((k0 * 43) + k1)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_fuse_with_opaque_block(a: ty.handle, b: ty.handle) -> None: + B = tir.match_buffer(b, [128, 128, 128]) + A = tir.match_buffer(a, [128, 128, 128]) + for i_j_k_fused in tir.serial(0, 2097152): + with tir.block([], "opaque"): + tir.reads( + [ + A[ + tir.floormod(tir.floordiv(tir.floordiv(i_j_k_fused, 128), 128), 128), + tir.floormod(tir.floordiv(i_j_k_fused, 128), 128), + tir.floormod(i_j_k_fused, 128), + ] + ] + ) + tir.writes( + [ + B[ + tir.floormod(tir.floordiv(tir.floordiv(i_j_k_fused, 128), 128), 128), + tir.floormod(tir.floordiv(i_j_k_fused, 128), 128), + tir.floormod(i_j_k_fused, 128), + ] + ] + ) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, tir.floordiv(i_j_k_fused, 16384)) + tir.bind(vj, tir.floormod(tir.floordiv(i_j_k_fused, 128), 128)) + tir.bind(vk, tir.floormod(i_j_k_fused, 128)) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def elementwise_split_with_opaque_block(a: ty.handle, b: ty.handle) -> None: + B = tir.match_buffer(b, [128, 128, 128]) + A = tir.match_buffer(a, [128, 128, 128]) + + for i0, i1, j, k in tir.grid(8, 16, 128, 128): + with tir.block([], "opaque"): + tir.reads([A[i0 * 16 + i1, j, k]]) + tir.writes([B[i0 * 16 + i1, j, k]]) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + tir.bind(vi, i0 * 16 + i1) + tir.bind(vj, j) + tir.bind(vk, k) + tir.reads([A[vi, vj, vk]]) + tir.writes([B[vi, vj, vk]]) + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +@tvm.script.tir +def opaque_access(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16], "float32") + B = tir.match_buffer(b, [16, 16], "float32") + with tir.block([16, 16], "A") as [vi, vj]: + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, vi * 16 + vj, 1) + with tir.block([16, 16], "B") as [vi, vj]: + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate(tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, vi * 16 + vj, dtype="handle")) + + +@tvm.script.tir +def opaque_access_fused(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16]) + B = tir.match_buffer(b, [16, 16]) + for i_j_fused in tir.serial(0, 256): + with tir.block([16, 16], "A") as [vi, vj]: + tir.bind(vi, tir.floordiv(i_j_fused, 16)) + tir.bind(vj, tir.floormod(i_j_fused, 16)) + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, ((vi * 16) + vj), 1, 1) + for i_j_fused in tir.serial(0, 256): + with tir.block([16, 16], "B") as [vi, vj]: + tir.bind(vi, tir.floordiv(i_j_fused, 16)) + tir.bind(vj, tir.floormod(i_j_fused, 16)) + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate( + tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle") + ) + + +@tvm.script.tir +def opaque_access_split(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + B = tir.match_buffer(b, (16, 16)) + for i, j0, j1 in tir.grid(16, 4, 4): + with tir.block([16, 16], "A") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, ((j0 * 4) + j1)) + tir.reads([]) + tir.writes([A[0:16, 0:16]]) + tir.store(A.data, ((vi * 16) + vj), 1, 1) + for i, j0, j1 in tir.grid(16, 4, 4): + with tir.block([16, 16], "B") as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, ((j0 * 4) + j1)) + tir.reads([]) + tir.writes([B[0:16, 0:16]]) + tir.evaluate( + tir.tvm_fill_fragment(B.data, 16, 16, 16, 0, ((vi * 16) + vj), dtype="handle") + ) + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_fuse(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.fuse(i, j, k) + assert sch.state._get_cached_flags(sch.get_sref(block_b)).stage_pipeline + tvm.ir.assert_structural_equal(elementwise_fused, sch.mod["main"]) + + +def test_split(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[2, 1, 64]) + sch.split(j, factors=[4, 32]) + sch.split(k, factors=[16, 8]) + assert sch.state._get_cached_flags(sch.get_sref(block_b)).stage_pipeline + tvm.ir.assert_structural_equal(elementwise_split_case0, sch.mod["main"]) + + +def test_split_with_inferred_factor(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[None, 1, 64]) + sch.split(j, factors=[2, None, 64]) + sch.split(k, factors=[2, 1, -1]) + tvm.ir.assert_structural_equal(elementwise_split_case1, sch.mod["main"]) + + +def test_split_with_predicate(): + sch = tir.Schedule(elementwise, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(i, factors=[1000, 2, 3]) + sch.split(j, factors=[None, 129]) + sch.split(k, factors=[3, None]) + assert sch.state._get_cached_flags(sch.get_sref(block_b)).stage_pipeline + tvm.ir.assert_structural_equal(elementwise_split_with_predicate, sch.mod["main"]) + + +def test_fuse_fail_not_only_child(): + sch = tir.Schedule(elementwise_with_seq, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + + +def test_fuse_split_fail_with_annotation(): + sch = tir.Schedule(elementwise_with_anno, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + with pytest.raises(tvm.tir.ScheduleError): + sch.split(k, factors=[None, 10]) + + +def test_fuse_split_fail_not_start_with_zero(): + sch = tir.Schedule(elementwise_with_anno, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + with pytest.raises(tvm.tir.ScheduleError): + sch.split(k, factors=[None, 10]) + + +def test_fuse_with_opaque_block(): + sch = tir.Schedule(elementwise_with_opaque_block, debug_mode=True) + block_opaque = sch.get_block("opaque") + i, j, k = sch.get_loops(block_opaque) + sch.fuse(i, j, k) + tvm.ir.assert_structural_equal(elementwise_fuse_with_opaque_block, sch.mod["main"]) + + +def test_fuse_with_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mode=True) + block_a = sch.get_block("A") + i, j = sch.get_loops(block_a) + sch.fuse(i, j) + block_b = sch.get_block("B") + i, j = sch.get_loops(block_b) + sch.fuse(i, j) + tvm.ir.assert_structural_equal(opaque_access_fused, sch.mod["main"]) + + +def test_split_with_opaque_block(): + sch = tir.Schedule(elementwise_with_opaque_block, debug_mode=True) + block_opaque = sch.get_block("opaque") + i, j, k = sch.get_loops(block_opaque) + sch.split(i, factors=[None, 16]) + tvm.ir.assert_structural_equal(elementwise_split_with_opaque_block, sch.mod["main"]) + + +def test_split_with_opaque_access(): + sch = tir.Schedule(opaque_access, debug_mode=True) + block_a = sch.get_block("A") + i, j = sch.get_loops(block_a) + sch.split(j, factors=[None, 4]) + block_b = sch.get_block("B") + i, j = sch.get_loops(block_b) + sch.split(j, factors=[None, 4]) + tvm.ir.assert_structural_equal(opaque_access_split, sch.mod["main"]) + + +def test_fuse_split_fail_with_thread_binding(): + sch = tir.Schedule(elementwise_with_thread_binding, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + with pytest.raises(tvm.tir.ScheduleError): + sch.fuse(j, k) + with pytest.raises(tvm.tir.ScheduleError): + sch.split(k, factors=[None, 10]) + + +def test_fuse_symbolic(): + sch = tir.Schedule(elementwise_symbolic, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.fuse(i, j, k) + tvm.ir.assert_structural_equal(elementwise_symbolic_fused, sch.mod["main"]) + + +def test_split_symbolic(): + sch = tir.Schedule(elementwise_symbolic, debug_mode=True) + block_b = sch.get_block("B") + i, j, k = sch.get_loops(block_b) + sch.split(k, factors=[10, None]) + tvm.ir.assert_structural_equal(elementwise_symbolic_split, sch.mod["main"]) + + +if __name__ == "__main__": + test_fuse() + test_fuse_with_opaque_block() + test_fuse_with_opaque_access() + test_fuse_symbolic() + test_split() + test_split_with_inferred_factor() + test_split_with_opaque_block() + test_split_with_opaque_access() + test_split_with_predicate() + test_split_symbolic() + test_fuse_fail_not_only_child() + test_fuse_split_fail_with_annotation() + test_fuse_split_fail_not_start_with_zero() + test_fuse_split_fail_with_thread_binding() From 36527c1f06a05a6595fb9b20c65c706dd67bf08b Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Thu, 15 Jul 2021 12:15:35 +0800 Subject: [PATCH 02/16] address comments --- python/tvm/tir/schedule/schedule.py | 10 +++++++--- src/tir/schedule/primitive/fuse_split.cc | 14 +++++--------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 67350bd109d0..d7ed32a8e9c8 100755 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -43,7 +43,10 @@ class BlockRV(Object): """A random variable that refers to a block""" -ExprRV = Union[PrimExpr] # A random variable that evaluates to an integer +# It is a workaround for mypy: https://github.com/python/mypy/issues/7866#issuecomment-549454370 +# This feature is not supported until python 3.10: +# https://docs.python.org/3.10/whatsnew/3.10.html#pep-613-typealias +ExprRV = Union[PrimExpr] # A random variable that evaluates to an integer RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # type: ignore # pylint: disable=invalid-name @@ -318,7 +321,7 @@ def after_fuse(a: ty.handle, b: ty.handle) -> None: def split( self, loop: LoopRV, - factors: List[Optional[ExprRV]], + factors: List[Union[ExprRV, None]], ) -> List[LoopRV]: """Split a loop into a list of consecutive loops. It requires: 1) The loop can't have annotation or thread binding. @@ -331,7 +334,7 @@ def split( loop : LoopRV The loop to be split - factors: List[Optional[ExprRV]] + factors: List[Union[ExprRV, None]] The splitting factors Returns @@ -379,6 +382,7 @@ def after_split(a: ty.handle, b: ty.handle) -> None: B[vi, vj] = A[vi, vj] * 2.0 """ + # it will be checked later in C++ implementation that there is at most one None or -1 in `factors` for i, factor in enumerate(factors): if factor is None: factors[i] = -1 diff --git a/src/tir/schedule/primitive/fuse_split.cc b/src/tir/schedule/primitive/fuse_split.cc index 02a8774f9120..c9b543f55f13 100644 --- a/src/tir/schedule/primitive/fuse_split.cc +++ b/src/tir/schedule/primitive/fuse_split.cc @@ -29,9 +29,7 @@ class PredicateUpdater : public StmtMutator { */ explicit PredicateUpdater(const PrimExpr& predicate, arith::Analyzer* ana) : predicate_(predicate) { - if (!ana->CanProve(predicate)) { - add_predicate_ = true; - } + add_predicate_ = !ana->CanProve(predicate)); } private: @@ -52,6 +50,7 @@ class PredicateUpdater : public StmtMutator { /*! \brief whether to add predicate */ bool add_predicate_; }; + /*! \brief Substitute vars and collect the reuse mapping of opaque blocks */ class IRSubstituteAndCollectOpaqueBlock : public StmtExprMutator { public: @@ -152,8 +151,7 @@ class NotLoopError : public ScheduleError { explicit NotLoopError(IRModule mod, String type) : mod_(mod), type_(type) {} String FastErrorString() const final { - return "ScheduleError: this primitive only operates on a " - "loop"; + return "ScheduleError: this primitive only operates on a loop"; } String DetailRenderTemplate() const final { @@ -348,10 +346,8 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, if (infer_index != -1) { inferred_factors.Set(infer_index, analyzer.Simplify(floordiv(loop->extent + tot_length - 1, tot_length))); - } else { - if (!analyzer.CanProve(tot_length >= loop->extent)) { - throw WrongFactorProductError(self->mod, GetRef(loop)); - } + } else if (!analyzer.CanProve(tot_length >= loop->extent)) { + throw WrongFactorProductError(self->mod, GetRef(loop)); } // Step 3. Replace all occurrence of the original loop var with new variables std::vector new_loop_vars; From 0bd76fb72a6d3d93cc54c9bb40dafa5eae525b15 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sat, 17 Jul 2021 15:14:03 +0800 Subject: [PATCH 03/16] address comments --- include/tvm/tir/schedule/schedule.h | 22 +- python/tvm/tir/schedule/schedule.py | 43 +-- src/arith/iter_affine_map.cc | 11 +- src/arith/rewrite_simplify.cc | 2 + src/tir/schedule/analysis.h | 6 - src/tir/schedule/analysis/analysis.cc | 30 -- src/tir/schedule/concrete_schedule.cc | 26 +- src/tir/schedule/concrete_schedule.h | 8 +- src/tir/schedule/primitive.h | 4 +- .../{fuse_split.cc => loop_transformation.cc} | 269 +++++++----------- .../unittest/test_tir_schedule_split_fuse.py | 15 +- 11 files changed, 165 insertions(+), 271 deletions(-) rename src/tir/schedule/primitive/{fuse_split.cc => loop_transformation.cc} (56%) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 38a15a814370..868454b18b74 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -151,18 +151,6 @@ class ScheduleNode : public runtime::Object { * \return The corresponding loop sref */ virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0; - /*! - * \brief Get the block srefs corresponding to an array of BlockRVs - * \param block_rvs The BlockRVs to be looked up - * \return The corresponding block srefs - */ - virtual Array GetSRefs(const Array& block_rvs) const = 0; - /*! - * \brief Get the loop srefs corresponding to an array of LoopRVs - * \param loop_rvs The LoopRVs to be looked up - * \return The corresponding loop srefs - */ - virtual Array GetSRefs(const Array& loop_rvs) const = 0; /*! * \brief Get the block/loop sref corresponding to the specific statement * \param stmt The statement to be looked up @@ -209,24 +197,24 @@ class ScheduleNode : public runtime::Object { virtual Array GetLoops(const BlockRV& block_rv) = 0; /******** Schedule: loops manipulation ********/ /*! - * \brief Fuse consecutive loops into one. It requires: + * \brief Fuse a list of consecutive loops into one. It requires: * 1) The loops can't have annotations or thread bindings. * 2) The (i+1)-th loop must be the only child of the i-th loop. * 3) All loops must start with 0. * \param loop_rvs The loops to be fused - * \return The fused loop + * \return The new loop after fusion */ virtual LoopRV Fuse(const Array& loop_rvs) = 0; /*! - * \brief Split a specified loop into two or more with the specific factor.It requires: + * \brief Split a loop into a list of consecutive loops. It requires: * 1) The loop can't have annotation or thread binding. * 2) The loop must start with 0. * \param loop_rv The loop to be split * \param factors The tiling factors, and at most one of which is -1, which means that * factor is inferred. - * \return The loops after splitting + * \return The new loops after split */ - virtual Array Split(const LoopRV& loop_rv, const Array& factors) = 0; + virtual Array Split(const LoopRV& loop_rv, const Array>& factors) = 0; /******** Schedule: compute location ********/ /*! * \brief Inline a block into its consumer(s). It requires: diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index d7ed32a8e9c8..8bfc9646cfa8 100755 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -279,7 +279,7 @@ def fuse(self, *loops: List[LoopRV]) -> LoopRV: Examples -------- - Before fuse, in TensorIR, the IR is: + Before applying fuse, in TensorIR, the IR is: .. code-block:: python @@ -287,14 +287,15 @@ def fuse(self, *loops: List[LoopRV]) -> LoopRV: def before_fuse(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.match_buffer(b, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do fuse: .. code-block:: python - sch = tir.Schedule(before_fuse, debug_mode=True) + sch = tir.Schedule(before_fuse) i, j = sch.get_loops(sch.get_block("B")) sch.fuse(i, j) print(tvm.script.asscript(sch.mod["main"])) @@ -306,13 +307,12 @@ def before_fuse(a: ty.handle, b: ty.handle) -> None: @tvm.script.tir def after_fuse(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, [128, 128]) - for i0_i1_fused in tir.serial(0, 16384): + B = tir.match_buffer(b, (128, 128)) + # the 2 loops are fused into 1 + for i_j_fused in tir.serial(0, 16384): with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, tir.floordiv(i0_i1_fused, 128)) - tir.bind(vj, tir.floormod(i0_i1_fused, 128)) - tir.reads([A[vi, vj]]) - tir.writes([B[vi, vj]]) + tir.bind(vi, tir.floordiv(i_j_fused, 128)) + tir.bind(vj, tir.floormod(i_j_fused, 128)) B[vi, vj] = A[vi, vj] * 2.0 """ @@ -329,6 +329,7 @@ def split( Predicates may be added to ensure the total loop numbers keeps unchanged. In `factors`, at most one of the factors can be None or -1, which will be automatically inferred. + Parameters ---------- loop : LoopRV @@ -336,6 +337,10 @@ def split( factors: List[Union[ExprRV, None]] The splitting factors + Potential inputs are: + - None or -1 + - ExprRV + - Nonnegative constant integers Returns ---------- @@ -353,14 +358,15 @@ def split( def before_split(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) B = tir.match_buffer(b, (128, 128)) - with tir.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 + for i, j in tir.grid(128, 128): + with tir.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 Create the schedule and do fuse: .. code-block:: python - sch = tir.Schedule(before_split, debug_mode=True) + sch = tir.Schedule(before_split) i, j = sch.get_loops(sch.get_block("B")) sch.split(i, factors=[2, 64]) print(tvm.script.asscript(sch.mod["main"])) @@ -372,13 +378,12 @@ def before_split(a: ty.handle, b: ty.handle) -> None: @tvm.script.tir def after_split(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128)) - B = tir.match_buffer(b, [128, 128]) - for i0_outer, i0_inner, i1 in tir.grid(2, 64, 128): + B = tir.match_buffer(b, (128, 128)) + # the original loop is split into 2 loops + for i0, i1, j in tir.grid(2, 64, 128): with tir.block([128, 128], "B") as [vi, vj]: - tir.bind(vi, ((i0_outer*64) + i0_inner)) - tir.bind(vj, i1) - tir.reads([A[vi, vj]]) - tir.writes([B[vi, vj]]) + tir.bind(vi, ((i0*64) + i1)) + tir.bind(vj, j) B[vi, vj] = A[vi, vj] * 2.0 """ diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index c7e4d7b4335b..ac78c55ed610 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1092,13 +1092,12 @@ Array IterMapSimplify(const Array& indices, const Map res; - res.reserve(rewrite.size()); - IterMapToExprNormalizer converter(&analyzer); - for (const auto& expr : rewrite) res.push_back(converter.Convert(expr)); - return res; } + Array res; + res.reserve(rewrite.size()); + IterMapToExprNormalizer converter(&analyzer); + for (const auto& expr : rewrite) res.push_back(converter.Convert(expr)); + return res; } /*! diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 49ecb85b89b3..52a9fe916e87 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -883,7 +883,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); + TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0); + TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x)); TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y)); diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 0d713707a5ac..dd7fee37e2d1 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -142,12 +142,6 @@ Array GetLoops(const StmtSRef& block_sref); * \return A list of leaf blocks */ Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent_sref); -/*! - * \brief Get the direct child Schedulable Stmt (Block and For) - * \param stmt the parent stmt. - * \return the list of child stmts - */ -Array GetChildren(const Stmt& stmt); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 7584d36a65f6..d58dece3c644 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -298,35 +298,5 @@ Array GetChildBlocks(const ScheduleState& self, const StmtSRef& parent throw; } -Array GetChildren(const Stmt& stmt) { - /*! \note Nested SeqStmt is not allowed in schedule. */ - Stmt body; - if (const auto* block = stmt.as()) { - body = block->body; - } else if (const auto* loop = stmt.as()) { - body = loop->body; - } else { - LOG(FATAL) << "The Stmt can only be a Block or a For"; - } - if (const auto* seq = body.as()) { - Array ret; - for (const Stmt& child : seq->seq) { - ICHECK(!child->IsInstance()) << "Nested SeqStmt is not allowed in schedule."; - if (child->IsInstance()) { - ret.push_back(child.as()->block); - } else { - ret.push_back(child); - } - } - return ret; - } else { - if (body->IsInstance()) { - return Array{body.as()->block}; - } else { - return Array{body}; - } - } -} - } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index a180bd76134b..69304d05ea3c 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -260,30 +260,32 @@ Array ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) { /******** Schedule: loops manipulation ********/ LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs) { - TVM_TIR_SCHEDULE_BEGIN(); CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)"; Array loop_srefs = this->GetSRefs(loop_rvs); - StmtSRef fused_sref = tir::Fuse(state_, loop_srefs); - this->state_->DebugVerify(); - return CreateRV(fused_sref); + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::Fuse(state_, loop_srefs); TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_); - throw; + this->state_->DebugVerify(); + return CreateRV(result); } -Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, const Array& factor_rvs) { - TVM_TIR_SCHEDULE_BEGIN(); +Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, + const Array>& factor_rvs) { // Prepare for the splitting StmtSRef loop_sref = this->GetSRef(loop_rv); const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); Array factors; factors.reserve(factor_rvs.size()); - for (const ExprRV& factor_rv : factor_rvs) { - factors.push_back(this->Get(factor_rv)); + for (const Optional& factor_rv : factor_rvs) { + factors.push_back(this->Get(factor_rv.value_or(Integer(-1)))); } - Array results = tir::Split(state_, loop_sref, factors); - return CreateRV(results); + Array results; + TVM_TIR_SCHEDULE_BEGIN(); + results = tir::Split(state_, loop_sref, factors); TVM_TIR_SCHEDULE_END("split", this->error_render_level_); - throw; + this->state_->DebugVerify(); + return CreateRV(results); } /******** Schedule: compute location ********/ diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 250246a01e17..1dd06dae0140 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -68,8 +68,8 @@ class ConcreteScheduleNode : public ScheduleNode { inline PrimExpr Get(const ExprRV& expr_rv) const final; inline StmtSRef GetSRef(const BlockRV& block_rv) const final; inline StmtSRef GetSRef(const LoopRV& loop_rv) const final; - inline Array GetSRefs(const Array& rvs) const final; - inline Array GetSRefs(const Array& rvs) const final; + inline Array GetSRefs(const Array& rvs) const; + inline Array GetSRefs(const Array& rvs) const; void RemoveRV(const BlockRV& block_rv) final { RemoveFromSymbolTable(block_rv); } void RemoveRV(const LoopRV& loop_rv) final { RemoveFromSymbolTable(loop_rv); } void RemoveRV(const ExprRV& expr_rv) final { RemoveFromSymbolTable(expr_rv); } @@ -81,7 +81,7 @@ class ConcreteScheduleNode : public ScheduleNode { Array GetLoops(const BlockRV& block_rv) override; /******** Schedule: loops manipulation ********/ LoopRV Fuse(const Array& loop_rvs) override; - Array Split(const LoopRV& loop_rv, const Array& factors) override; + Array Split(const LoopRV& loop_rv, const Array>& factors) override; /******** Schedule: compute location ********/ void ComputeInline(const BlockRV& block) override; void ReverseComputeInline(const BlockRV& block) override; @@ -153,7 +153,7 @@ inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { LOG(FATAL) << "IndexError: Cannot find corresponding ExprRV: " << var; } const ObjectRef& obj = (*it).second; - const auto* int_imm = obj.as(); + const auto* int_imm = TVM_TYPE_AS(int_imm, obj, IntImmNode); if (int_imm == nullptr) { LOG(FATAL) << "ValueError: ExprRV's corresponding type is invalid: " << (obj.defined() ? obj->GetTypeKey() : "None"); diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 4f3691098942..088c4df58859 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -26,7 +26,7 @@ namespace tir { /******** Schedule: loops manipulation ********/ /*! - * Split a loop into several consecutive loops. It requires: + * Split a loop into a list of consecutive loops. It requires: * 1) The loop can't have annotation or thread binding. * 2) The loop must start with 0. * \param self The state of the schedule @@ -37,7 +37,7 @@ namespace tir { TVM_DLL Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array& factors); /*! - * \brief Fuse consecutive loops. It requires: + * \brief Fuse a list of consecutive loops into one. It requires: * 1) The loops can't have annotations or thread bindings. * 2) The inner loop must be the only child of the outer loop. * 3) All loops must start with 0. diff --git a/src/tir/schedule/primitive/fuse_split.cc b/src/tir/schedule/primitive/loop_transformation.cc similarity index 56% rename from src/tir/schedule/primitive/fuse_split.cc rename to src/tir/schedule/primitive/loop_transformation.cc index c9b543f55f13..eb2a8b185bcf 100644 --- a/src/tir/schedule/primitive/fuse_split.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -17,19 +17,20 @@ * under the License. */ #include "../utils.h" + namespace tvm { namespace tir { -/*! \brief Append a new predicate to the each children of type BlockRealize (not recursively) */ -class PredicateUpdater : public StmtMutator { +/*! \brief Append a new predicate to the each child of type BlockRealize (not recursively) */ +class BlockPredicateAppender : public StmtMutator { public: /*! * \brief Constructor - * \param predicate The predicate to be apppend to BlockRealizeNode + * \param to_append The predicate to be appended to BlockRealizeNode */ - explicit PredicateUpdater(const PrimExpr& predicate, arith::Analyzer* ana) - : predicate_(predicate) { - add_predicate_ = !ana->CanProve(predicate)); + explicit BlockPredicateAppender(const PrimExpr& to_append, arith::Analyzer* analyzer) + : to_append_(to_append) { + add_predicate_ = !analyzer->CanProve(to_append); } private: @@ -38,31 +39,30 @@ class PredicateUpdater : public StmtMutator { // We do not recursively do this if (add_predicate_) { ObjectPtr n = CopyOnWrite(realize); - n->predicate = n->predicate && predicate_; + n->predicate = n->predicate && to_append_; return BlockRealize(n); } else { return GetRef(realize); } } - /*! \brief The predicate to be added */ - const PrimExpr& predicate_; - /*! \brief whether to add predicate */ + /*! \brief The predicate to be appended */ + const PrimExpr& to_append_; + /*! \brief Whether to add predicate */ bool add_predicate_; }; /*! \brief Substitute vars and collect the reuse mapping of opaque blocks */ -class IRSubstituteAndCollectOpaqueBlock : public StmtExprMutator { +class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { public: - explicit IRSubstituteAndCollectOpaqueBlock(std::function(const Var&)> vmap, - Map* opaque_blocks) + explicit SubstituteVarAndCollectOpaqueBlock(std::function(const Var&)> vmap, + Map* opaque_blocks) : vmap_(vmap), opaque_blocks_(opaque_blocks) {} private: PrimExpr VisitExpr_(const VarNode* op) final { Var var = GetRef(op); - Optional ret = vmap_(var); - if (ret.defined()) { + if (Optional ret = vmap_(var)) { return ret.value(); } else { return std::move(var); @@ -72,7 +72,7 @@ class IRSubstituteAndCollectOpaqueBlock : public StmtExprMutator { Stmt VisitStmt_(const BlockRealizeNode* op) final { Stmt res = StmtMutator::VisitStmt_(op); if (op->block->iter_vars.empty()) { - const BlockRealizeNode* realize = res.as(); + const BlockRealizeNode* realize = TVM_TYPE_AS(realize, res, BlockRealizeNode); opaque_blocks_->Set(op->block, realize->block); } return res; @@ -80,30 +80,37 @@ class IRSubstituteAndCollectOpaqueBlock : public StmtExprMutator { /*! \brief The substitute function */ std::function(const Var&)> vmap_; - /*! \brief The reuse mapping */ + /*! \brief The reuse mapping of opaque blocks */ Map* opaque_blocks_; }; Stmt SubstituteAndCollectOpaqueBlock(Stmt stmt, Map* opaque_blocks, std::function(const Var&)> vmap) { - return IRSubstituteAndCollectOpaqueBlock(vmap, opaque_blocks)(std::move(stmt)); + return SubstituteVarAndCollectOpaqueBlock(vmap, opaque_blocks)(std::move(stmt)); } -/*! \brief Simplify the binding of block realize and update the opaque block reuse mapping*/ -class BlockRealizeRewriter : public StmtExprMutator { +/*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */ +class IterMapSimplifyBlockBinding : public StmtExprMutator { public: - explicit BlockRealizeRewriter( - const std::unordered_map& loop_map, - Map* opaque_blocks) - : opaque_blocks_(opaque_blocks) { - loop_map_.insert(loop_map.begin(), loop_map.end()); + explicit IterMapSimplifyBlockBinding(const Map& loop_map, + Map* opaque_blocks) + : opaque_blocks_(opaque_blocks), loop_var2extent_(std::move(loop_map)) {} + + static For SimplifyBindings(const Stmt& stmt, const Array& loop_srefs, + Map* opaque_blocks) { + Map loop_var2extent; + for (const StmtSRef& sref : loop_srefs) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); + loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + } + return Downcast(IterMapSimplifyBlockBinding(loop_var2extent, opaque_blocks)(stmt)); } private: Stmt VisitStmt_(const ForNode* op) final { - loop_map_[op->loop_var] = Range::FromMinExtent(op->min, op->extent); + loop_var2extent_.Set(op->loop_var, Range::FromMinExtent(op->min, op->extent)); Stmt res = StmtMutator::VisitStmt_(op); - loop_map_.erase(op->loop_var); + loop_var2extent_.erase(op->loop_var); return res; } @@ -112,89 +119,46 @@ class BlockRealizeRewriter : public StmtExprMutator { if (op->iter_values.empty()) { Stmt res = StmtMutator::VisitStmt_(op); const BlockRealizeNode* realize = res.as(); + MapNode* mutable_map = opaque_blocks_->CopyOnWrite(); for (const std::pair& entry : *opaque_blocks_) { if (entry.second.same_as(op->block)) { - opaque_blocks_->Set(entry.first, realize->block); + mutable_map->at(entry.first) = realize->block; break; } } return res; } - auto v = arith::IterMapSimplify(op->iter_values, loop_map_, op->predicate, false); + Array v = arith::IterMapSimplify(/*indices=*/op->iter_values, + /*input_iters=*/loop_var2extent_, + /*input_pred=*/op->predicate, + /*require_bijective=*/false); if (v.same_as(op->iter_values)) { return GetRef(op); } else { - auto n = CopyOnWrite(op); + ObjectPtr n = CopyOnWrite(op); n->iter_values = std::move(v); return Stmt(n); } } - /*! \brief The range of loops */ - std::unordered_map loop_map_; + /*! \brief The reuse mapping */ Map* opaque_blocks_; + /*! \brief The range of loops */ + Map loop_var2extent_; }; -Stmt SimplifyBindings(const Stmt& stmt, const Array& loops, - Map* opaque_blocks) { - std::unordered_map loop_map; - for (const StmtSRef& sref : loops) { - const auto* loop = sref->StmtAs(); - loop_map[loop->loop_var] = Range::FromMinExtent(loop->min, loop->extent); - } - BlockRealizeRewriter rewriter(loop_map, opaque_blocks); - return rewriter(stmt); -} - -class NotLoopError : public ScheduleError { - public: - explicit NotLoopError(IRModule mod, String type) : mod_(mod), type_(type) {} - - String FastErrorString() const final { - return "ScheduleError: this primitive only operates on a loop"; - } - - String DetailRenderTemplate() const final { - return "this primitive only operates on a loop, but the StmtSref passed in points to" - "type: {0} "; - } - - IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {type_}; } - - IRModule mod_; - String type_; -}; - -class HasAnnotationError : public ScheduleError { - public: - explicit HasAnnotationError(IRModule mod, For loop) : mod_(mod), loop_(loop) {} - - String FastErrorString() const final { - return "ScheduleError: The primitive can't be applied because the loop has annotation"; - } - - String DetailRenderTemplate() const final { - return "The primitive can't be applied because the loop {0} has annotation"; - } - - IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {loop_}; } - - IRModule mod_; - For loop_; -}; - -class HasThreadBindingError : public ScheduleError { +class HasAnnotationOrThreadBindingError : public ScheduleError { public: - explicit HasThreadBindingError(IRModule mod, For loop) : mod_(mod), loop_(loop) {} + explicit HasAnnotationOrThreadBindingError(IRModule mod, For loop) + : mod_(mod), loop_(std::move(loop)) {} String FastErrorString() const final { - return "ScheduleError: The primitive can't be applied because the loop has thread binding"; + return "ScheduleError: The primitive can't be applied because the loop has annotation or " + "thread binding"; } String DetailRenderTemplate() const final { - return "The primitive can't be applied because the loop {0} has thread binding"; + return "The primitive can't be applied because the loop {0} has annotation or thread binding"; } IRModule mod() const final { return mod_; } @@ -207,10 +171,10 @@ class HasThreadBindingError : public ScheduleError { class OuterNotInnerParent : public ScheduleError { public: explicit OuterNotInnerParent(IRModule mod, For outer, For inner) - : mod_(mod), outer_(outer), inner_(inner) {} + : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} String FastErrorString() const final { - return "ScheduleError: the outer loop is not the parent of the inner loop"; + return "ScheduleError: The outer loop is not the parent of the inner loop"; } String DetailRenderTemplate() const final { @@ -229,10 +193,10 @@ class OuterNotInnerParent : public ScheduleError { class NotOnlyChildError : public ScheduleError { public: explicit NotOnlyChildError(IRModule mod, For outer, For inner) - : mod_(mod), outer_(outer), inner_(inner) {} + : mod_(mod), outer_(std::move(outer)), inner_(std::move(inner)) {} String FastErrorString() const final { - return "ScheduleError: the inner loop is not the only child of outer loop"; + return "ScheduleError: The inner loop is not the only child of outer loop"; } String DetailRenderTemplate() const final { @@ -250,10 +214,10 @@ class NotOnlyChildError : public ScheduleError { class LoopNotStartWithZeroError : public ScheduleError { public: - explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), loop_(loop) {} + explicit LoopNotStartWithZeroError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} String FastErrorString() const final { - return "ScheduleError: the primitive only supports loop starting with 0"; + return "ScheduleError: The primitive only supports loop starting with 0"; } String DetailRenderTemplate() const final { @@ -287,7 +251,7 @@ class NotSingleInferFactorError : public ScheduleError { class WrongFactorProductError : public ScheduleError { public: - explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(loop) {} + explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} String FastErrorString() const final { return "ScheduleError: The product of factors is not larger than or equal to the extent of " @@ -312,33 +276,26 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, // - The execution order has not changed. (The block executes with the same args and the same // order with before. // Step 1. Check correctness - GetScopeRootAndCheckStagePipeline(self, loop_sref); - const auto* loop = loop_sref->StmtAs(); - if (loop == nullptr) { - throw NotLoopError(self->mod, loop_sref->stmt->GetTypeKey()); - } - if (!loop->annotations.empty()) { - throw HasAnnotationError(self->mod, GetRef(loop)); + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + ICHECK(loop) << "the input sref does not point to a loop"; + if (!loop->annotations.empty() || loop->thread_binding.defined()) { + throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); } - if (loop->thread_binding.defined()) { - throw HasThreadBindingError(self->mod, GetRef(loop)); - } - // Currently, loops starting with 0 is not supported + // Currently, loops not starting with 0 are not supported arith::Analyzer analyzer; if (!analyzer.CanProve(loop->min == 0)) { throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); } PrimExpr tot_length = 1; int infer_index = -1; - for (size_t i = 0; i < factors.size(); i++) { + size_t n = factors.size(); + for (size_t i = 0; i < n; i++) { if (!analyzer.CanProve(factors[i] == -1)) { tot_length *= factors[i]; + } else if (infer_index != -1) { + throw NotSingleInferFactorError(self->mod); } else { - if (infer_index != -1) { - throw NotSingleInferFactorError(self->mod); - } else { - infer_index = i; - } + infer_index = i; } } // Step 2. infer factors if needed @@ -349,86 +306,77 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, } else if (!analyzer.CanProve(tot_length >= loop->extent)) { throw WrongFactorProductError(self->mod, GetRef(loop)); } - // Step 3. Replace all occurrence of the original loop var with new variables + // Step 3. Replace all occurrences of the original loop var with new variables std::vector new_loop_vars; - new_loop_vars.reserve(inferred_factors.size()); - for (size_t i = 0; i < inferred_factors.size(); i++) { + new_loop_vars.reserve(n); + for (size_t i = 0; i < n; i++) { new_loop_vars.push_back(loop->loop_var.copy_with_suffix("_" + std::to_string(i))); } PrimExpr substitute_value = 0; - for (size_t i = 0; i < inferred_factors.size(); i++) { + for (size_t i = 0; i < n; i++) { substitute_value *= inferred_factors[i]; substitute_value += new_loop_vars[i]; } Map opaque_block_reuse; - auto substitute_function = [&](const Var& v) -> Optional { + auto f_substitute = [&](const Var& v) -> Optional { if (v.same_as(loop->loop_var)) { return substitute_value; } else { return NullOpt; } }; - Stmt new_loop_body = - SubstituteAndCollectOpaqueBlock(loop->body, &opaque_block_reuse, substitute_function); - for (size_t i = 0; i < inferred_factors.size(); i++) { + Stmt new_stmt = + SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(loop->body)); + for (size_t i = 0; i < n; i++) { analyzer.Bind(new_loop_vars[i], Range::FromMinExtent(0, inferred_factors[i])); } // Step 4. Update predicate to guard the loop - PrimExpr predicate = substitute_value < loop->extent; - new_loop_body = PredicateUpdater(predicate, &analyzer)(new_loop_body); - // Step 5. Generate tnested loops to replace the original loop and simplify the binding - Stmt outer_stmt = new_loop_body; - for (int i = inferred_factors.size() - 1; i >= 0; i--) { - outer_stmt = For(new_loop_vars[i], 0, inferred_factors[i], loop->kind, outer_stmt); + new_stmt = + BlockPredicateAppender(/*predicate=*/substitute_value < loop->extent, &analyzer)(new_stmt); + // Step 5. Generate nested loops to replace the original loop and simplify the binding + for (int i = n - 1; i >= 0; i--) { + new_stmt = For(new_loop_vars[i], 0, inferred_factors[i], loop->kind, new_stmt); } - outer_stmt = - Downcast(SimplifyBindings(outer_stmt, GetLoops(loop_sref), &opaque_block_reuse)); - self->Replace(loop_sref, outer_stmt, opaque_block_reuse); + new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(new_stmt, GetLoops(loop_sref), + &opaque_block_reuse); + self->Replace(loop_sref, new_stmt, opaque_block_reuse); Array result_srefs; - result_srefs.reserve(inferred_factors.size()); - for (size_t i = 0; i < inferred_factors.size(); i++) { - result_srefs.push_back(self->stmt2ref.at(outer_stmt.get())); - const ForNode* outer_loop = outer_stmt.as(); - ICHECK(outer_loop); - outer_stmt = outer_loop->body; + result_srefs.reserve(n); + for (size_t i = 0; i < n; i++) { + result_srefs.push_back(self->stmt2ref.at(new_stmt.get())); + const ForNode* outer_loop = TVM_TYPE_AS(outer_loop, new_stmt, ForNode); + new_stmt = outer_loop->body; } return result_srefs; } StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { - // Invariance - // - The total repeat number has not changed for each direct child block. - // - The execution order has not changed. (The block executes with the same - // args and the same order with before.) + // Invariance + // - The total repeat number has not changed for each direct child block. + // - The execution order has not changed. (The block executes with the same + // args and the same order with before.) std::vector loops; loops.reserve(loop_srefs.size()); - StmtSRef outer_sref{nullptr}; + StmtSRef outer_loop_sref{nullptr}; const ForNode* outer_loop = nullptr; arith::Analyzer analyzer; // Step 1. check correctness - GetScopeRootAndCheckStagePipeline(self, loop_srefs[0]); for (const StmtSRef& sref : loop_srefs) { const auto* loop = sref->StmtAs(); - if (loop == nullptr) { - throw NotLoopError(self->mod, sref->stmt->GetTypeKey()); - } - if (!loop->annotations.empty()) { - throw HasAnnotationError(self->mod, GetRef(loop)); - } - if (loop->thread_binding.defined()) { - throw HasThreadBindingError(self->mod, GetRef(loop)); + ICHECK(loop) << "the input sref does not point to a loop"; + if (!loop->annotations.empty() || loop->thread_binding.defined()) { + throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); } - if (outer_sref.defined()) { - if (sref->parent != outer_sref.get()) { + if (outer_loop_sref.defined()) { + if (sref->parent != outer_loop_sref.get()) { throw OuterNotInnerParent(self->mod, GetRef(outer_loop), GetRef(loop)); } - Array outer_children = GetChildren(GetRef(outer_loop)); - if (outer_children.size() != 1 || outer_children[0].get() != loop) { + if (!outer_loop->body.same_as(GetRef(loop))) { throw NotOnlyChildError(self->mod, GetRef(outer_loop), GetRef(loop)); } } - outer_sref = sref; + outer_loop_sref = sref; outer_loop = loop; if (!analyzer.CanProve(loop->min == 0)) { throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); @@ -451,7 +399,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { } Stmt loop_body = loops.back()->body; Map opaque_block_reuse; - auto substitute_function = [&](const Var& v) -> Optional { + auto f_substitute = [&](const Var& v) -> Optional { for (size_t i = 0; i < loops.size(); i++) { if (v.same_as(loops[i]->loop_var)) { return substitute_value[i]; @@ -459,20 +407,19 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { } return NullOpt; }; - Stmt new_loop_body = - SubstituteAndCollectOpaqueBlock(loop_body, &opaque_block_reuse, substitute_function); + Stmt new_stmt = + SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(loop_body)); // Step 3. Generate a loop to replace the original loops - PrimExpr fused_min = 0; PrimExpr fused_extent = 1; for (size_t i = 0; i < loops.size(); i++) { fused_extent *= loops[i]->extent; } fused_extent = analyzer.Simplify(fused_extent); - For fused_loop = For(fused_var, fused_min, fused_extent, loops[0]->kind, new_loop_body); - fused_loop = - Downcast(SimplifyBindings(fused_loop, GetLoops(loop_srefs[0]), &opaque_block_reuse)); - self->Replace(loop_srefs[0], fused_loop, opaque_block_reuse); - return self->stmt2ref.at(fused_loop.get()); + new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt); + new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(new_stmt, GetLoops(loop_srefs[0]), + &opaque_block_reuse); + self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse); + return self->stmt2ref.at(new_stmt.get()); } } // namespace tir diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 56f1a4a3fff7..47bb81722956 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -453,17 +453,4 @@ def test_split_symbolic(): if __name__ == "__main__": - test_fuse() - test_fuse_with_opaque_block() - test_fuse_with_opaque_access() - test_fuse_symbolic() - test_split() - test_split_with_inferred_factor() - test_split_with_opaque_block() - test_split_with_opaque_access() - test_split_with_predicate() - test_split_symbolic() - test_fuse_fail_not_only_child() - test_fuse_split_fail_with_annotation() - test_fuse_split_fail_not_start_with_zero() - test_fuse_split_fail_with_thread_binding() + pytest.main([__file__]) From 9673c00c7ccd2da771ce13bec27fe07c93a5707d Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sat, 17 Jul 2021 15:16:42 +0800 Subject: [PATCH 04/16] address comments --- src/tir/schedule/primitive/loop_transformation.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index eb2a8b185bcf..59c0c8b72006 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -99,9 +99,10 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { static For SimplifyBindings(const Stmt& stmt, const Array& loop_srefs, Map* opaque_blocks) { Map loop_var2extent; + MapNode* loop_var2extent_mutable = loop_var2extent.CopyOnWrite(); for (const StmtSRef& sref : loop_srefs) { const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); - loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + loop_var2extent_mutable->at(loop->loop_var) = Range::FromMinExtent(loop->min, loop->extent); } return Downcast(IterMapSimplifyBlockBinding(loop_var2extent, opaque_blocks)(stmt)); } From 8b4bee8daade8d052e478bc64d5e3ffded97ab33 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sat, 17 Jul 2021 15:17:48 +0800 Subject: [PATCH 05/16] address comments --- python/tvm/tir/schedule/schedule.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 8bfc9646cfa8..a2ace82d52e8 100755 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -388,9 +388,6 @@ def after_split(a: ty.handle, b: ty.handle) -> None: """ # it will be checked later in C++ implementation that there is at most one None or -1 in `factors` - for i, factor in enumerate(factors): - if factor is None: - factors[i] = -1 return _ffi_api_schedule.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member ########## Schedule: compute location ########## From ce0bd7a38a2bd0e3e476ba6426ed125a3235a827 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sat, 17 Jul 2021 15:26:30 +0800 Subject: [PATCH 06/16] fix --- src/tir/schedule/primitive/loop_transformation.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 59c0c8b72006..eb2a8b185bcf 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -99,10 +99,9 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { static For SimplifyBindings(const Stmt& stmt, const Array& loop_srefs, Map* opaque_blocks) { Map loop_var2extent; - MapNode* loop_var2extent_mutable = loop_var2extent.CopyOnWrite(); for (const StmtSRef& sref : loop_srefs) { const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); - loop_var2extent_mutable->at(loop->loop_var) = Range::FromMinExtent(loop->min, loop->extent); + loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); } return Downcast(IterMapSimplifyBlockBinding(loop_var2extent, opaque_blocks)(stmt)); } From cf9a7294742f4aebf4cd2cf591f792ce6a65d4ad Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sat, 17 Jul 2021 15:27:45 +0800 Subject: [PATCH 07/16] fix --- tests/python/unittest/test_tir_schedule_split_fuse.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 47bb81722956..845b25598c15 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -324,7 +324,6 @@ def test_fuse(): block_b = sch.get_block("B") i, j, k = sch.get_loops(block_b) sch.fuse(i, j, k) - assert sch.state._get_cached_flags(sch.get_sref(block_b)).stage_pipeline tvm.ir.assert_structural_equal(elementwise_fused, sch.mod["main"]) @@ -335,7 +334,6 @@ def test_split(): sch.split(i, factors=[2, 1, 64]) sch.split(j, factors=[4, 32]) sch.split(k, factors=[16, 8]) - assert sch.state._get_cached_flags(sch.get_sref(block_b)).stage_pipeline tvm.ir.assert_structural_equal(elementwise_split_case0, sch.mod["main"]) @@ -356,7 +354,6 @@ def test_split_with_predicate(): sch.split(i, factors=[1000, 2, 3]) sch.split(j, factors=[None, 129]) sch.split(k, factors=[3, None]) - assert sch.state._get_cached_flags(sch.get_sref(block_b)).stage_pipeline tvm.ir.assert_structural_equal(elementwise_split_with_predicate, sch.mod["main"]) From a097b88ff0ad734113519f97a466310506224242 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sat, 17 Jul 2021 15:29:52 +0800 Subject: [PATCH 08/16] fix --- src/arith/rewrite_simplify.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 52a9fe916e87..ff6536ab066b 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -883,9 +883,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2), c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0); - + TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0); - + TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x)); TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y)); From 76e7443ea24d737ba17b0ea7814bbc7d2d114814 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sat, 17 Jul 2021 15:37:33 +0800 Subject: [PATCH 09/16] fix --- python/tvm/tir/schedule/schedule.py | 0 src/tir/schedule/primitive/loop_transformation.cc | 5 ----- 2 files changed, 5 deletions(-) mode change 100755 => 100644 python/tvm/tir/schedule/schedule.py diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py old mode 100755 new mode 100644 diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index eb2a8b185bcf..71538ce09d27 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -84,11 +84,6 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { Map* opaque_blocks_; }; -Stmt SubstituteAndCollectOpaqueBlock(Stmt stmt, Map* opaque_blocks, - std::function(const Var&)> vmap) { - return SubstituteVarAndCollectOpaqueBlock(vmap, opaque_blocks)(std::move(stmt)); -} - /*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */ class IterMapSimplifyBlockBinding : public StmtExprMutator { public: From 87c98e2a329345ea86d9405aee8a898527f97ee4 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Sat, 17 Jul 2021 15:53:17 +0800 Subject: [PATCH 10/16] fix --- python/tvm/tir/schedule/schedule.py | 3 ++- src/tir/schedule/primitive/loop_transformation.cc | 3 +-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index a2ace82d52e8..83576019ff45 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -387,7 +387,8 @@ def after_split(a: ty.handle, b: ty.handle) -> None: B[vi, vj] = A[vi, vj] * 2.0 """ - # it will be checked later in C++ implementation that there is at most one None or -1 in `factors` + # it will be checked later in C++ implementation + # sthat there is at most one None or -1 in `factors` return _ffi_api_schedule.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member ########## Schedule: compute location ########## diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 71538ce09d27..dbf238dff32c 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -358,8 +358,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { arith::Analyzer analyzer; // Step 1. check correctness for (const StmtSRef& sref : loop_srefs) { - const auto* loop = sref->StmtAs(); - ICHECK(loop) << "the input sref does not point to a loop"; + const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); if (!loop->annotations.empty() || loop->thread_binding.defined()) { throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); } From ac7afc708b2c9435996550e7bc773d18b4dad589 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Mon, 19 Jul 2021 20:38:43 +0800 Subject: [PATCH 11/16] address comment --- python/tvm/tir/schedule/schedule.py | 2 +- src/tir/schedule/concrete_schedule.cc | 6 +- src/tir/schedule/concrete_schedule.h | 8 +- .../schedule/primitive/loop_transformation.cc | 88 +++++++++---------- 4 files changed, 50 insertions(+), 54 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 83576019ff45..a9211ef3f777 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -388,7 +388,7 @@ def after_split(a: ty.handle, b: ty.handle) -> None: """ # it will be checked later in C++ implementation - # sthat there is at most one None or -1 in `factors` + # that there is at most one None or -1 in `factors` return _ffi_api_schedule.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member ########## Schedule: compute location ########## diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 69304d05ea3c..d57720ef5ff9 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -278,7 +278,11 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, Array factors; factors.reserve(factor_rvs.size()); for (const Optional& factor_rv : factor_rvs) { - factors.push_back(this->Get(factor_rv.value_or(Integer(-1)))); + if (factor_rv.defined()) { + factors.push_back(Integer(-1)); + } else { + factors.push_back(this->Get(factor_rv.value())); + } } Array results; TVM_TIR_SCHEDULE_BEGIN(); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 1dd06dae0140..fab3e259b752 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -154,15 +154,9 @@ inline PrimExpr ConcreteScheduleNode::Get(const ExprRV& expr_rv) const { } const ObjectRef& obj = (*it).second; const auto* int_imm = TVM_TYPE_AS(int_imm, obj, IntImmNode); - if (int_imm == nullptr) { - LOG(FATAL) << "ValueError: ExprRV's corresponding type is invalid: " - << (obj.defined() ? obj->GetTypeKey() : "None"); - } return Integer(int_imm->value); }); - PrimExpr simplified = this->analyzer_->Simplify(transformed); - CHECK(is_const_int(transformed)) << "ValueError: The ExprRV does not have a specific value"; - return simplified; + return this->analyzer_->Simplify(transformed); } inline StmtSRef ConcreteScheduleNode::GetSRef(const BlockRV& block_rv) const { diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index dbf238dff32c..039ea6ade091 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -70,12 +70,11 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { } Stmt VisitStmt_(const BlockRealizeNode* op) final { - Stmt res = StmtMutator::VisitStmt_(op); - if (op->block->iter_vars.empty()) { - const BlockRealizeNode* realize = TVM_TYPE_AS(realize, res, BlockRealizeNode); + BlockRealize realize = Downcast(StmtMutator::VisitStmt_(op)); + if (realize->block->iter_vars.empty()) { opaque_blocks_->Set(op->block, realize->block); } - return res; + return std::move(realize); } /*! \brief The substitute function */ @@ -87,18 +86,19 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { /*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */ class IterMapSimplifyBlockBinding : public StmtExprMutator { public: - explicit IterMapSimplifyBlockBinding(const Map& loop_map, - Map* opaque_blocks) - : opaque_blocks_(opaque_blocks), loop_var2extent_(std::move(loop_map)) {} + explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, + Map loop_var2extent) + : opaque_blocks_(opaque_blocks), loop_var2extent_(loop_var2extent) {} - static For SimplifyBindings(const Stmt& stmt, const Array& loop_srefs, + static For SimplifyBindings(Stmt stmt, const Array& loop_srefs, Map* opaque_blocks) { Map loop_var2extent; for (const StmtSRef& sref : loop_srefs) { const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); } - return Downcast(IterMapSimplifyBlockBinding(loop_var2extent, opaque_blocks)(stmt)); + return Downcast(IterMapSimplifyBlockBinding(opaque_blocks->CopyOnWrite(), + std::move(loop_var2extent))(std::move(stmt))); } private: @@ -112,16 +112,15 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { Stmt VisitStmt_(const BlockRealizeNode* op) final { // skip opaque block and update mapping if (op->iter_values.empty()) { - Stmt res = StmtMutator::VisitStmt_(op); - const BlockRealizeNode* realize = res.as(); - MapNode* mutable_map = opaque_blocks_->CopyOnWrite(); - for (const std::pair& entry : *opaque_blocks_) { + Block block = op->block; + BlockRealize realize = Downcast(StmtMutator::VisitStmt_(op)); + for (const std::pair& entry : *opaque_blocks_) { if (entry.second.same_as(op->block)) { - mutable_map->at(entry.first) = realize->block; + opaque_blocks_->at(entry.first) = realize->block; break; } } - return res; + return std::move(realize); } Array v = arith::IterMapSimplify(/*indices=*/op->iter_values, /*input_iters=*/loop_var2extent_, @@ -137,7 +136,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { } /*! \brief The reuse mapping */ - Map* opaque_blocks_; + MapNode* opaque_blocks_; /*! \brief The range of loops */ Map loop_var2extent_; }; @@ -272,7 +271,6 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, // order with before. // Step 1. Check correctness const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); - ICHECK(loop) << "the input sref does not point to a loop"; if (!loop->annotations.empty() || loop->thread_binding.defined()) { throw HasAnnotationOrThreadBindingError(self->mod, GetRef(loop)); } @@ -283,8 +281,8 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, } PrimExpr tot_length = 1; int infer_index = -1; - size_t n = factors.size(); - for (size_t i = 0; i < n; i++) { + int n = factors.size(); + for (int i = 0; i < n; i++) { if (!analyzer.CanProve(factors[i] == -1)) { tot_length *= factors[i]; } else if (infer_index != -1) { @@ -302,43 +300,42 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, throw WrongFactorProductError(self->mod, GetRef(loop)); } // Step 3. Replace all occurrences of the original loop var with new variables + PrimExpr substitute_value = 0; std::vector new_loop_vars; new_loop_vars.reserve(n); for (size_t i = 0; i < n; i++) { - new_loop_vars.push_back(loop->loop_var.copy_with_suffix("_" + std::to_string(i))); - } - PrimExpr substitute_value = 0; - for (size_t i = 0; i < n; i++) { - substitute_value *= inferred_factors[i]; - substitute_value += new_loop_vars[i]; + const PrimExpr& factor = inferred_factors[i]; + Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)); + substitute_value = substitute_value * factor + var; + analyzer.Bind(var, Range::FromMinExtent(0, factor)); + new_loop_vars.emplace_back(std::move(var)); } Map opaque_block_reuse; - auto f_substitute = [&](const Var& v) -> Optional { - if (v.same_as(loop->loop_var)) { - return substitute_value; - } else { - return NullOpt; - } - }; - Stmt new_stmt = - SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(loop->body)); - for (size_t i = 0; i < n; i++) { - analyzer.Bind(new_loop_vars[i], Range::FromMinExtent(0, inferred_factors[i])); - } + Stmt new_stmt = loop->body; + new_stmt = SubstituteVarAndCollectOpaqueBlock( + [&](const Var& v) -> Optional { + if (v.same_as(loop->loop_var)) { + return substitute_value; + } else { + return NullOpt; + } + }, + &opaque_block_reuse + )(std::move(new_stmt)); // Step 4. Update predicate to guard the loop new_stmt = BlockPredicateAppender(/*predicate=*/substitute_value < loop->extent, &analyzer)(new_stmt); // Step 5. Generate nested loops to replace the original loop and simplify the binding for (int i = n - 1; i >= 0; i--) { - new_stmt = For(new_loop_vars[i], 0, inferred_factors[i], loop->kind, new_stmt); + new_stmt = For(new_loop_vars[i], 0, inferred_factors[i], ForKind::kSerial, new_stmt); } - new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(new_stmt, GetLoops(loop_sref), + new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops(loop_sref), &opaque_block_reuse); self->Replace(loop_sref, new_stmt, opaque_block_reuse); Array result_srefs; result_srefs.reserve(n); - for (size_t i = 0; i < n; i++) { + for (int i = 0; i < n; i++) { result_srefs.push_back(self->stmt2ref.at(new_stmt.get())); const ForNode* outer_loop = TVM_TYPE_AS(outer_loop, new_stmt, ForNode); new_stmt = outer_loop->body; @@ -387,11 +384,11 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { Array substitute_value; substitute_value.resize(loops.size()); PrimExpr tot = fused_var; - for (int i = loops.size() - 1; i >= 0; i--) { + for (int i = static_cast(loops.size()) - 1; i >= 0; i--) { substitute_value.Set(i, floormod(tot, loops[i]->extent)); tot = floordiv(tot, loops[i]->extent); } - Stmt loop_body = loops.back()->body; + Stmt new_stmt = loops.back()->body; Map opaque_block_reuse; auto f_substitute = [&](const Var& v) -> Optional { for (size_t i = 0; i < loops.size(); i++) { @@ -401,8 +398,8 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { } return NullOpt; }; - Stmt new_stmt = - SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(loop_body)); + new_stmt = + SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(new_stmt)); // Step 3. Generate a loop to replace the original loops PrimExpr fused_extent = 1; for (size_t i = 0; i < loops.size(); i++) { @@ -410,7 +407,8 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { } fused_extent = analyzer.Simplify(fused_extent); new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt); - new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(new_stmt, GetLoops(loop_srefs[0]), + new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops + (loop_srefs[0]), &opaque_block_reuse); self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse); return self->stmt2ref.at(new_stmt.get()); From 5d45b905b622f7380d7d3594d70f1682d3c184aa Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Mon, 19 Jul 2021 20:48:18 +0800 Subject: [PATCH 12/16] fix --- src/tir/schedule/concrete_schedule.cc | 2 +- src/tir/schedule/primitive/loop_transformation.cc | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index d57720ef5ff9..c5a60d7b5e2b 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -278,7 +278,7 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, Array factors; factors.reserve(factor_rvs.size()); for (const Optional& factor_rv : factor_rvs) { - if (factor_rv.defined()) { + if (!factor_rv.defined()) { factors.push_back(Integer(-1)); } else { factors.push_back(this->Get(factor_rv.value())); diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 039ea6ade091..dbb0e6299cc7 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -115,7 +115,7 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { Block block = op->block; BlockRealize realize = Downcast(StmtMutator::VisitStmt_(op)); for (const std::pair& entry : *opaque_blocks_) { - if (entry.second.same_as(op->block)) { + if (entry.second.same_as(block)) { opaque_blocks_->at(entry.first) = realize->block; break; } @@ -303,7 +303,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, PrimExpr substitute_value = 0; std::vector new_loop_vars; new_loop_vars.reserve(n); - for (size_t i = 0; i < n; i++) { + for (int i = 0; i < n; i++) { const PrimExpr& factor = inferred_factors[i]; Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)); substitute_value = substitute_value * factor + var; From c6167cac36be7bf5a3a81907e4160e1e3299bc97 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Mon, 19 Jul 2021 20:52:24 +0800 Subject: [PATCH 13/16] fix --- src/tir/schedule/primitive/loop_transformation.cc | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index dbb0e6299cc7..1c29dee15489 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -86,8 +86,7 @@ class SubstituteVarAndCollectOpaqueBlock : public StmtExprMutator { /*! \brief Simplify the binding of block realize and update the opaque block reuse mapping */ class IterMapSimplifyBlockBinding : public StmtExprMutator { public: - explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, - Map loop_var2extent) + explicit IterMapSimplifyBlockBinding(MapNode* opaque_blocks, Map loop_var2extent) : opaque_blocks_(opaque_blocks), loop_var2extent_(loop_var2extent) {} static For SimplifyBindings(Stmt stmt, const Array& loop_srefs, @@ -320,8 +319,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, return NullOpt; } }, - &opaque_block_reuse - )(std::move(new_stmt)); + &opaque_block_reuse)(std::move(new_stmt)); // Step 4. Update predicate to guard the loop new_stmt = BlockPredicateAppender(/*predicate=*/substitute_value < loop->extent, &analyzer)(new_stmt); @@ -407,9 +405,8 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { } fused_extent = analyzer.Simplify(fused_extent); new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt); - new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops - (loop_srefs[0]), - &opaque_block_reuse); + new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings( + std::move(new_stmt), GetLoops(loop_srefs[0]), &opaque_block_reuse); self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse); return self->stmt2ref.at(new_stmt.get()); } From 09c6a418f32c5710941db38c1f208d6ed87c053f Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Mon, 19 Jul 2021 22:32:19 +0800 Subject: [PATCH 14/16] address comments --- python/tvm/tir/schedule/schedule.py | 6 +- src/tir/schedule/concrete_schedule.cc | 65 +++++++++++++++++-- .../schedule/primitive/loop_transformation.cc | 53 ++++----------- .../unittest/test_tir_schedule_split_fuse.py | 2 +- 4 files changed, 77 insertions(+), 49 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index a9211ef3f777..a71e2e1241be 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -327,7 +327,7 @@ def split( 1) The loop can't have annotation or thread binding. 2) The loop must start with 0. Predicates may be added to ensure the total loop numbers keeps unchanged. - In `factors`, at most one of the factors can be None or -1, + In `factors`, at most one of the factors can be None, which will be automatically inferred. Parameters @@ -338,7 +338,7 @@ def split( factors: List[Union[ExprRV, None]] The splitting factors Potential inputs are: - - None or -1 + - None - ExprRV - Nonnegative constant integers @@ -388,7 +388,7 @@ def after_split(a: ty.handle, b: ty.handle) -> None: """ # it will be checked later in C++ implementation - # that there is at most one None or -1 in `factors` + # that there is at most one None in `factors` return _ffi_api_schedule.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member ########## Schedule: compute location ########## diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index c5a60d7b5e2b..07b49e459483 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -272,20 +272,75 @@ LoopRV ConcreteScheduleNode::Fuse(const Array& loop_rvs) { Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, const Array>& factor_rvs) { + class NotSingleInferFactorError : public ScheduleError { + public: + explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {} + + String FastErrorString() const final { + return "ScheduleError: only one factor can be specified as -1 or none"; + } + + String DetailRenderTemplate() const final { + return "Only one factor can be specified as -1 or none"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; + }; + + class WrongFactorProductError : public ScheduleError { + public: + explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {} + + String FastErrorString() const final { + return "ScheduleError: The product of factors is not larger than or equal to the extent of " + "loop"; + } + + String DetailRenderTemplate() const final { + return "The product of factors is not larger than or equal to the extent of loop {0}"; + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_}; } + + IRModule mod_; + For loop_; + }; // Prepare for the splitting StmtSRef loop_sref = this->GetSRef(loop_rv); const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); Array factors; factors.reserve(factor_rvs.size()); - for (const Optional& factor_rv : factor_rvs) { - if (!factor_rv.defined()) { + int infer_index = -1; + PrimExpr tot_length = 1; + Array results; + TVM_TIR_SCHEDULE_BEGIN(); + // infer factor if needed and check validity of factors + for (size_t i = 0; i < factor_rvs.size(); i++) { + if (!factor_rvs[i].defined()) { factors.push_back(Integer(-1)); + if (infer_index == -1) { + infer_index = i; + } else { + throw NotSingleInferFactorError(state_->mod); + } } else { - factors.push_back(this->Get(factor_rv.value())); + PrimExpr factor = this->Get(factor_rvs[i].value()); + factors.push_back(factor); + tot_length *= factor; } } - Array results; - TVM_TIR_SCHEDULE_BEGIN(); + arith::Analyzer analyzer; + if (infer_index != -1) { + factors.Set(infer_index, + analyzer.Simplify(floordiv(loop->extent + tot_length - 1, tot_length))); + } else if (!analyzer.CanProve(tot_length >= loop->extent)) { + LOG(INFO) << infer_index; + throw WrongFactorProductError(state_->mod, GetRef(loop)); + } results = tir::Split(state_, loop_sref, factors); TVM_TIR_SCHEDULE_END("split", this->error_render_level_); this->state_->DebugVerify(); diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 1c29dee15489..72b0dfe468f4 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -28,28 +28,19 @@ class BlockPredicateAppender : public StmtMutator { * \brief Constructor * \param to_append The predicate to be appended to BlockRealizeNode */ - explicit BlockPredicateAppender(const PrimExpr& to_append, arith::Analyzer* analyzer) - : to_append_(to_append) { - add_predicate_ = !analyzer->CanProve(to_append); - } + explicit BlockPredicateAppender(const PrimExpr& to_append) : to_append_(to_append) {} private: // For each direct child of type BlockRealizeNode, append the predicate Stmt VisitStmt_(const BlockRealizeNode* realize) final { // We do not recursively do this - if (add_predicate_) { - ObjectPtr n = CopyOnWrite(realize); - n->predicate = n->predicate && to_append_; - return BlockRealize(n); - } else { - return GetRef(realize); - } + ObjectPtr n = CopyOnWrite(realize); + n->predicate = n->predicate && to_append_; + return BlockRealize(n); } /*! \brief The predicate to be appended */ const PrimExpr& to_append_; - /*! \brief Whether to add predicate */ - bool add_predicate_; }; /*! \brief Substitute vars and collect the reuse mapping of opaque blocks */ @@ -278,32 +269,13 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, if (!analyzer.CanProve(loop->min == 0)) { throw LoopNotStartWithZeroError(self->mod, GetRef(loop)); } - PrimExpr tot_length = 1; - int infer_index = -1; + // Step 2. Replace all occurrences of the original loop var with new variables int n = factors.size(); - for (int i = 0; i < n; i++) { - if (!analyzer.CanProve(factors[i] == -1)) { - tot_length *= factors[i]; - } else if (infer_index != -1) { - throw NotSingleInferFactorError(self->mod); - } else { - infer_index = i; - } - } - // Step 2. infer factors if needed - Array inferred_factors(factors); - if (infer_index != -1) { - inferred_factors.Set(infer_index, - analyzer.Simplify(floordiv(loop->extent + tot_length - 1, tot_length))); - } else if (!analyzer.CanProve(tot_length >= loop->extent)) { - throw WrongFactorProductError(self->mod, GetRef(loop)); - } - // Step 3. Replace all occurrences of the original loop var with new variables PrimExpr substitute_value = 0; std::vector new_loop_vars; new_loop_vars.reserve(n); for (int i = 0; i < n; i++) { - const PrimExpr& factor = inferred_factors[i]; + const PrimExpr& factor = factors[i]; Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)); substitute_value = substitute_value * factor + var; analyzer.Bind(var, Range::FromMinExtent(0, factor)); @@ -320,14 +292,15 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, } }, &opaque_block_reuse)(std::move(new_stmt)); - // Step 4. Update predicate to guard the loop - new_stmt = - BlockPredicateAppender(/*predicate=*/substitute_value < loop->extent, &analyzer)(new_stmt); - // Step 5. Generate nested loops to replace the original loop and simplify the binding + // Step 3. Update predicate to guard the loop + PrimExpr predicate = substitute_value < loop->extent; + if (!analyzer.CanProve(predicate)) { + new_stmt = BlockPredicateAppender(/*predicate=*/predicate)(std::move(new_stmt)); + } + // Step 4. Generate nested loops to replace the original loop and simplify the binding for (int i = n - 1; i >= 0; i--) { - new_stmt = For(new_loop_vars[i], 0, inferred_factors[i], ForKind::kSerial, new_stmt); + new_stmt = For(new_loop_vars[i], 0, factors[i], ForKind::kSerial, new_stmt); } - new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops(loop_sref), &opaque_block_reuse); self->Replace(loop_sref, new_stmt, opaque_block_reuse); diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 845b25598c15..4c5c49a1a039 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -343,7 +343,7 @@ def test_split_with_inferred_factor(): i, j, k = sch.get_loops(block_b) sch.split(i, factors=[None, 1, 64]) sch.split(j, factors=[2, None, 64]) - sch.split(k, factors=[2, 1, -1]) + sch.split(k, factors=[2, 1, None]) tvm.ir.assert_structural_equal(elementwise_split_case1, sch.mod["main"]) From c603478815a97cd3d4d301d77a837dc09aad8cd5 Mon Sep 17 00:00:00 2001 From: jinhongyi <3231950289@qq.com> Date: Tue, 20 Jul 2021 22:09:57 +0800 Subject: [PATCH 15/16] address comments --- src/tir/schedule/concrete_schedule.cc | 6 ++---- .../schedule/primitive/loop_transformation.cc | 17 +++++++++-------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 07b49e459483..0d5bfce46e37 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -333,12 +333,10 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, tot_length *= factor; } } - arith::Analyzer analyzer; if (infer_index != -1) { factors.Set(infer_index, - analyzer.Simplify(floordiv(loop->extent + tot_length - 1, tot_length))); - } else if (!analyzer.CanProve(tot_length >= loop->extent)) { - LOG(INFO) << infer_index; + this->analyzer_->Simplify(floordiv(loop->extent + tot_length - 1, tot_length))); + } else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) { throw WrongFactorProductError(state_->mod, GetRef(loop)); } results = tir::Split(state_, loop_sref, factors); diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 72b0dfe468f4..2a2d9ed2a888 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -81,14 +81,14 @@ class IterMapSimplifyBlockBinding : public StmtExprMutator { : opaque_blocks_(opaque_blocks), loop_var2extent_(loop_var2extent) {} static For SimplifyBindings(Stmt stmt, const Array& loop_srefs, - Map* opaque_blocks) { + MapNode* opaque_blocks) { Map loop_var2extent; for (const StmtSRef& sref : loop_srefs) { const ForNode* loop = TVM_SREF_TO_FOR(loop, sref); loop_var2extent.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); } - return Downcast(IterMapSimplifyBlockBinding(opaque_blocks->CopyOnWrite(), - std::move(loop_var2extent))(std::move(stmt))); + return Downcast( + IterMapSimplifyBlockBinding(opaque_blocks, std::move(loop_var2extent))(std::move(stmt))); } private: @@ -302,7 +302,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, new_stmt = For(new_loop_vars[i], 0, factors[i], ForKind::kSerial, new_stmt); } new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops(loop_sref), - &opaque_block_reuse); + opaque_block_reuse.CopyOnWrite()); self->Replace(loop_sref, new_stmt, opaque_block_reuse); Array result_srefs; result_srefs.reserve(n); @@ -347,7 +347,8 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { } // Step 2. Create fused loop var and replace the original loop vars std::string suffix; - for (size_t i = 1; i < loops.size(); i++) { + int n = loops.size(); + for (int i = 1; i < n; i++) { suffix += "_" + loops[i]->loop_var->name_hint; } suffix += "_fused"; @@ -362,7 +363,7 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { Stmt new_stmt = loops.back()->body; Map opaque_block_reuse; auto f_substitute = [&](const Var& v) -> Optional { - for (size_t i = 0; i < loops.size(); i++) { + for (int i = 0; i < n; i++) { if (v.same_as(loops[i]->loop_var)) { return substitute_value[i]; } @@ -373,13 +374,13 @@ StmtSRef Fuse(ScheduleState self, const Array& loop_srefs) { SubstituteVarAndCollectOpaqueBlock(f_substitute, &opaque_block_reuse)(std::move(new_stmt)); // Step 3. Generate a loop to replace the original loops PrimExpr fused_extent = 1; - for (size_t i = 0; i < loops.size(); i++) { + for (int i = 0; i < n; i++) { fused_extent *= loops[i]->extent; } fused_extent = analyzer.Simplify(fused_extent); new_stmt = For(fused_var, 0, fused_extent, ForKind::kSerial, new_stmt); new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings( - std::move(new_stmt), GetLoops(loop_srefs[0]), &opaque_block_reuse); + std::move(new_stmt), GetLoops(loop_srefs[0]), opaque_block_reuse.CopyOnWrite()); self->Replace(loop_srefs[0], new_stmt, opaque_block_reuse); return self->stmt2ref.at(new_stmt.get()); } From 5e442f9bef26db18ae3ae12be48e06261d764891 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Wed, 21 Jul 2021 04:05:27 +0000 Subject: [PATCH 16/16] retrigger ci