diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 9f48d9ab9b1f..c4aa1c953ab6 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -365,6 +365,22 @@ class ScheduleNode : public runtime::Object { */ virtual void ReverseComputeInline(const BlockRV& block) = 0; /******** Schedule: Reduction ********/ + /*! + * \brief Decompose a reduction block into two separate blocks. + * a) The init block, which is translated from the init statement of the reduction block; + * b) The update block, which is the original block without init statement. + * + * The init block is inserted right before the given loop. + * + * The schedule primitive requires: + * 1) The input block is a reduction block. + * 2) The input loop is the ancestor of the block. + * 3) The input loop is not lower than all the loops related to reduce block var. + * \param block_rv The reduction block to be decomposed + * \param loop_rv The loop above which the init block is inserted before. + * \return The init block + */ + virtual BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) = 0; /*! * \brief Factorize an associative reduction block by the specified loop. * \details An associative reduction cannot be parallelized directly, diff --git a/include/tvm/tir/schedule/state.h b/include/tvm/tir/schedule/state.h index 7cd1b00c15ef..201d78fe631c 100644 --- a/include/tvm/tir/schedule/state.h +++ b/include/tvm/tir/schedule/state.h @@ -142,6 +142,12 @@ class ScheduleStateNode : public Object { /******** Property of blocks ********/ /*! \brief Returns the BlockInfo correpsonding to the block sref */ TVM_DLL BlockInfo GetBlockInfo(const StmtSRef& block_sref) const; + /*! + * \brief Recalculate the BlockInfo recursively under stmt. + * If stmt is a Block itself, we will not reset its affine binding flag unless it doesn't + * have block vars, since the affine flag depends on the outer scope of stmt. + */ + TVM_DLL void UpdateScopeBlockInfo(const Stmt& stmt); /*! * \brief Get the BlockScope correpsonding to the sref of scope root block * \param scope_root The block sref to be retrieved diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 6e27015648f0..09a52d2e7037 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -1224,6 +1224,82 @@ def after_inline(a: T.handle, c: T.handle) -> None: ########## Schedule: Reduction ########## + def decompose_reduction(self, block: BlockRV, loop: LoopRV) -> BlockRV: + """Decompose a reduction block into two separate blocks. + + a) The init block, which is translated from the init statement of the reduction block; + + b) The update block, which is the original block without init statement. + + The init block is inserted right before the given loop. + + The schedule primitive requires: + + 1) The input block is a reduction block. + + 2) The input loop is the ancestor of the block. + + 3) The input loop is not lower than all the loops related to reduce block var. + + Parameters + ---------- + block : BlockRV + The reduction block to be decomposed + loop : LoopRV + The loop above which the init block is inserted before. + + Returns + ------- + init_block : BlockRV + The init block + + Examples + -------- + Before decompose-reduction, in TensorIR, the IR is: + + .. code-block:: python + + @tvm.script.tir + def before_decompose(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + for i, j, k in tir.grid(128, 128, 128): + with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + Create the schedule and do decompose-reduction with specified loop: + + .. code-block:: python + + sch = tir.Schedule(before_decompose) + C = sch.get_block("C") + i, j, k = sch.get_loops(C) + sch.decompose_reduction(C, i) + print(tvm.script.asscript(sch.mod["main"])) + + After applying decompose-reduction, the IR becomes: + + .. code-block:: python + + @tvm.script.tir + def after_decompose(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + for i in tir.serial(128): + for j in tir.serial(128): + with tir.block([128, 128]) as [vi, vj]: + C[vi, vj] = 0.0 + for i, j, k in tir.grid(128, 128, 128): + with tir.block([128, 128, tir.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + """ + return _ffi_api.ScheduleDecomposeReduction(self, block, loop) # type: ignore # pylint: disable=no-member + def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV: """Factorize an associative reduction block by the specified loop. diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 93eba520f9d1..42839075af30 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -501,6 +501,15 @@ void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_inde /******** Schedule: Reduction ********/ +BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) { + StmtSRef result{nullptr}; + TVM_TIR_SCHEDULE_BEGIN(); + result = tir::DecomposeReduction(state_, this->GetSRef(block_rv), this->GetSRef(loop_rv)); + TVM_TIR_SCHEDULE_END("decompose-reduction", this->error_render_level_); + this->state_->DebugVerify(); + return CreateRV(result); +} + BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { StmtSRef result{nullptr}; TVM_TIR_SCHEDULE_BEGIN(); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index c9a9402832f2..1f9aeecfc776 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -115,6 +115,7 @@ class ConcreteScheduleNode : public ScheduleNode { void ReverseComputeInline(const BlockRV& block) override; /******** Schedule: Reduction ********/ BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override; + BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) override; /******** Schedule: Block annotation ********/ void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, int offset) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 8d8acd2693f4..057e845dbd48 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -233,6 +233,23 @@ TVM_DLL void ComputeInline(ScheduleState self, const StmtSRef& block_sref); */ TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref); /******** Schedule: Reduction ********/ +/*! + * \brief Decompose a reduction block into two separate blocks. + * a) The init block, which is translated from the init statement of the reduction block; + * b) The update block, which is the original block without init statement. + * + * The init block is inserted right before the given loop. + * + * The schedule primitive requires: + * 1) The input block is a reduction block. + * 2) The input loop is the ancestor of the block. + * 3) The input loop is not lower than all the loops related to reduce block var. + * \param block_rv The reduction block to be decomposed + * \param loop_rv The loop above which the init block is inserted before. + * \return The init block + */ +TVM_DLL StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, + const StmtSRef& loop_sref); /*! * \brief Factor a reduction block by the specified loop * \details See python/tvm/tir/schedule/schedule.py diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 677b64311855..0653f6e18e7d 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -21,6 +21,282 @@ namespace tvm { namespace tir { +/*! + * \brief A helper class to create a new scope that contains decomposed init body + * and replaced old reduction block. + */ +class DecomposeReductionBlockReplacer : public StmtMutator { + public: + /*! + * \brief The open interface to users to call the helper class + * \param old_scope_root The original block scope before decomposition + * \param target_loop The loop we insert the decomposed init body before + * \param decompose_body The decomposed init body + * \param old_reduction_block The reduction block we want to decompose + * \return The new block scope and the updated reduction block + */ + static std::pair Replace(Block old_scope_root, For target_loop, + Stmt decomposed_body, Block old_reduction_block) { + DecomposeReductionBlockReplacer replacer(std::move(target_loop), std::move(decomposed_body), + std::move(old_reduction_block)); + return std::make_pair(Downcast(replacer(std::move(old_scope_root))), + replacer.new_reduction_block_); + } + + private: + explicit DecomposeReductionBlockReplacer(For target_loop, Stmt decomposed_body, + Block old_reduction_block) + : target_loop_(std::move(target_loop)), + decomposed_body_(std::move(decomposed_body)), + old_reduction_block_(std::move(old_reduction_block)) {} + + Stmt VisitStmt_(const ForNode* loop) final { + Stmt mutated_stmt = StmtMutator::VisitStmt_(loop); + if (loop == target_loop_.get()) { + return SeqStmt({decomposed_body_, mutated_stmt}); + } else { + return mutated_stmt; + } + } + + Stmt VisitStmt_(const BlockNode* block) final { + if (block == old_reduction_block_.get()) { + ObjectPtr p_new_block = CopyOnWrite(block); + p_new_block->name_hint = p_new_block->name_hint + "_update"; + p_new_block->init = NullOpt; + new_reduction_block_ = Block(p_new_block); + return new_reduction_block_; + } else { + return StmtMutator::VisitStmt_(block); + } + } + + Stmt VisitStmt_(const SeqStmtNode* seq) final { + Array new_stmts; + new_stmts.reserve(seq->seq.size()); + for (const Stmt& old_stmt : seq->seq) { + new_stmts.push_back(VisitStmt(old_stmt)); + } + return SeqStmt::Flatten(new_stmts); + } + + private: + For target_loop_; + Stmt decomposed_body_; + Block old_reduction_block_; + Block new_reduction_block_; +}; + +class LoopPositionError : public ScheduleError { + public: + explicit LoopPositionError(IRModule mod, For loop, Block block) + : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: decompose_reduction expect the loop to be an ancestor of block"; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "ScheduleError: The input loop {0} of decompose_reduction is required to be be an " + "ancestor of block {1}."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_, block_}; } + + IRModule mod_; + For loop_; + Block block_; +}; + +class LoopHeightError : public ScheduleError { + public: + static void CheckLoopHigherThanReduceLoops(const IRModule& mod, const BlockNode* block, + const BlockRealizeNode* realize, + const Array& loops, + const StmtSRef& loop_sref) { + for (int i = 0, n = block->iter_vars.size(); i < n; ++i) { + // For each block var of type kCommReduce, check its binding + const IterVar& iter_var = block->iter_vars[i]; + const PrimExpr& binding = realize->iter_values[i]; + if (iter_var->iter_type != IterVarType::kCommReduce) { + continue; + } + for (const StmtSRef& higher_loop : loops) { + // Only check loops not lower than the target loop + if (higher_loop.same_as(loop_sref)) { + break; + } + // loop_var of a higher loop shouldn't contain loop var + const Var& loop_var = higher_loop->StmtAs()->loop_var; + if (UsesVar(binding, [v = loop_var.get()](const VarNode* var) { return var == v; })) { + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + throw LoopHeightError(mod, GetRef(loop), GetRef(block)); + } + } + } + } + + explicit LoopHeightError(IRModule mod, For loop, Block block) + : mod_(std::move(mod)), loop_(std::move(loop)), block_(std::move(block)) {} + + String FastErrorString() const final { + return "ScheduleError: decompose_reduction expect the loop to be higher than all the loops " + "related to reduce block var"; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "ScheduleError: decompose_reduction expect the loop {0} to be higher than all the loops " + "related to reduce block var of block {1}"; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {loop_, block_}; } + + IRModule mod_; + For loop_; + Block block_; +}; + +PrimExpr RemakePredicate(PrimExpr pred, const std::unordered_set& discarded_loops) { + if (is_one(pred)) return Bool(true); + PrimExpr new_pred = Bool(true); + auto f = [&](const VarNode* var) { return discarded_loops.count(var); }; + arith::PVar lhs, rhs, rest; + for (;;) { + if ((rest && (lhs < rhs)).Match(pred)) { + if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval()); + pred = rest.Eval(); + } else if ((lhs < rhs).Match(pred)) { + if (!UsesVar(lhs.Eval(), f)) new_pred = new_pred && (lhs.Eval() < rhs.Eval()); + break; + } else { + ICHECK(false) << "Unexpected predicate for reduction block"; + } + } + return new_pred; +} + +StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, + const StmtSRef& loop_sref) { + /*! + * Check + * - block is a reduction block + * - loop is not lower than all the loops related to reduce block var + * Mutate + * - generate loops related to data par block vars + * - generate corresponding init block and update block + */ + // Condition Checks and Information Collection + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + // Get the outer loops from high to low + Array loops = GetLoops(block_sref); + const BlockRealizeNode* realize = GetBlockRealize(self, block_sref).get(); + // Cond 0. Check loop_sref is an ancestor of block_sref + if (std::find(loops.begin(), loops.end(), loop_sref) == loops.end()) { + throw LoopPositionError(self->mod, GetRef(loop), GetRef(block)); + } + // Cond 1. Check block is reduction + StmtSRef scope_root_sref = GetScopeRoot(self, block_sref, + /*require_stage_pipeline=*/false, + /*require_subtree_compact_dataflow=*/false); + CheckReductionBlock(self, block_sref, scope_root_sref); + // Cond 2. Check 'loop' is higher than all the loops related to block var of type reduction + LoopHeightError::CheckLoopHigherThanReduceLoops(self->mod, block, realize, loops, loop_sref); + // IR Manipulation + ObjectPtr init_block = make_object(); + ObjectPtr init_realize = make_object(); + init_block->name_hint = block->name_hint + "_init"; + init_realize->iter_values = {}; + init_realize->block = Block(init_block); + // Step 1. Create new block vars and their bindings + // Maps an old block var to the new corresponding block var + std::unordered_map block_var_map; + block_var_map.reserve(block->iter_vars.size()); + for (int i = 0, n = block->iter_vars.size(); i < n; ++i) { + const IterVar& iter_var = block->iter_vars[i]; + const PrimExpr& binding = realize->iter_values[i]; + // Only process data parallel block vars + if (iter_var->iter_type != IterVarType::kDataPar) { + continue; + } + // Create a new block var + IterVar new_iter_var(/*dom=*/iter_var->dom, + /*var=*/iter_var->var.copy_with_suffix(""), + /*iter_type=*/iter_var->iter_type, + /*thread_tag=*/iter_var->thread_tag); + // Add a block var and its binding + init_block->iter_vars.push_back(new_iter_var); + init_realize->iter_values.push_back(binding); + // Add a mapping from old block vars to new block vars + block_var_map[iter_var->var] = new_iter_var->var; + } + // Step 2. After copying block vars, substitute them in init block + init_block->body = Substitute(block->init.value(), block_var_map); + for (const BufferRegion& write : block->writes) { + init_block->writes.push_back( + BufferRegion(write->buffer, Substitute(write->region, block_var_map))); + } + // Step 3. Scan loops not higher than the specified loop above the reduction block. + // If the loop is used in the init block binding, then it is chosen. + // Otherwise, it is discarded. + std::unordered_set discarded_loops; + std::vector chosen_loops; + for (int i = static_cast(loops.size()) - 1; i >= 0; --i) { + const VarNode* loop_var = loops[i]->StmtAs()->loop_var.get(); + bool discarded = true; + for (const PrimExpr& expr : init_realize->iter_values) { + if (!UsesVar(expr, [v = loop_var](const VarNode* var) { return var == v; })) { + continue; + } + // The loop is related to init block bindings; + chosen_loops.push_back(i); + discarded = false; + break; + } + if (discarded) discarded_loops.insert(loop_var); + // Only scan loops not higher than the given loop + if (loops[i].same_as(loop_sref)) { + break; + } + } + // Step 4. After scanning loops, make a new predicate in the init block realize + // We discard predicate that is related to discarded loops + init_realize->predicate = RemakePredicate(realize->predicate, discarded_loops); + // Step 5. Create new loops above init block + std::unordered_map loop_var_map; + Stmt body = BlockRealize(init_realize); + for (int i : chosen_loops) { + const ForNode* old_loop = TVM_SREF_TO_FOR(old_loop, loops[i]); + // Create a new equivalent to the chosen loop + Var old_loop_var = old_loop->loop_var; + Var new_loop_var = old_loop_var.copy_with_suffix("_init"); + loop_var_map[old_loop_var] = new_loop_var; + body = For(/*loop_var=*/new_loop_var, + /*min=*/old_loop->min, + /*extent=*/old_loop->extent, + /*kind=*/ForKind::kSerial, + /*body=*/body); + } + body = Substitute(body, loop_var_map); + // Step 6. Mutate IR + const BlockNode* old_scope_root = TVM_SREF_TO_BLOCK(old_scope_root, scope_root_sref); + Block new_scope_root{nullptr}; + Block new_reduction_block{nullptr}; + std::tie(new_scope_root, new_reduction_block) = DecomposeReductionBlockReplacer::Replace( + GetRef(old_scope_root), GetRef(loop), body, GetRef(block)); + self->Replace(scope_root_sref, new_scope_root, + {{GetRef(old_scope_root), new_scope_root}, + {GetRef(block), new_reduction_block}}); + self->UpdateScopeBlockInfo(new_scope_root); + return self->stmt2ref.at(init_block.get()); +} + /******** Commutative Reducer ********/ /*! @@ -958,6 +1234,31 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax /******** InstructionKind Registration ********/ +struct DecomposeReductionTraits : public UnpackedInstTraits { + static constexpr const char* kName = "DecomposeReduction"; + static constexpr bool kIsPure = false; + + private: + static constexpr size_t kNumInputs = 2; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static BlockRV UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv) { + return sch->DecomposeReduction(block_rv, loop_rv); + } + + static String UnpackedAsPython(Array outputs, String block_rv, String loop_rv) { + PythonAPICall py("decompose_reduction"); + py.Input("block", block_rv); + py.Input("loop", loop_rv); + py.SingleOutput(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + struct RFactorTraits : public UnpackedInstTraits { static constexpr const char* kName = "RFactor"; static constexpr bool kIsPure = false; @@ -984,6 +1285,7 @@ struct RFactorTraits : public UnpackedInstTraits { }; TVM_REGISTER_INST_KIND_TRAITS(RFactorTraits); +TVM_REGISTER_INST_KIND_TRAITS(DecomposeReductionTraits); /******** FFI ********/ diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 4262a099b59d..84a37c392e81 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -155,6 +155,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleComputeInline") TVM_REGISTER_GLOBAL("tir.schedule.ScheduleReverseComputeInline") .set_body_method(&ScheduleNode::ReverseComputeInline); /******** (FFI) Reduction ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleDecomposeReduction") + .set_body_method(&ScheduleNode::DecomposeReduction); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRFactor") .set_body_method(&ScheduleNode::RFactor); /******** (FFI) Block annotation ********/ diff --git a/src/tir/schedule/state.cc b/src/tir/schedule/state.cc index 4604add3bdb4..faeb0b9907d7 100644 --- a/src/tir/schedule/state.cc +++ b/src/tir/schedule/state.cc @@ -169,34 +169,16 @@ void UpdateSRef(ScheduleStateNode* self, StmtSRefNode* sref, const StmtNode* new } /**************** Creation ****************/ - -/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */ -class StateCreator : private StmtVisitor { +/*! \brief A helper class to update BlockInfo for a ScheduleStateNode */ +class BlockInfoCollector : private StmtVisitor { public: - /*! - * \brief The entry function - * \param self The schedule state to be completed - */ - static ObjectPtr Create(IRModule mod, int debug_mask) { - ObjectPtr n = make_object(); - ScheduleStateNode* self = n.get(); - // Set `n->mod` - n->mod = std::move(mod); - // Set `n->debug_mask` - n->debug_mask = debug_mask; - // Set `n->stmt2ref` and `n->block_info` - StateCreator creator(self); - for (const auto& kv : n->mod->functions) { - const BaseFunc& base_func = kv.second; - if (const auto* func = base_func.as()) { - creator.VisitStmt(func->body); - } - } - return n; + static void Collect(ScheduleStateNode* self, const Stmt& stmt) { + BlockInfoCollector collector(self); + collector.VisitStmt(stmt); } private: - explicit StateCreator(ScheduleStateNode* self) + explicit BlockInfoCollector(ScheduleStateNode* self) : self_(self), srefs_{}, block2realize_{}, block_frames_{} { block_frames_.emplace({}); } @@ -206,25 +188,11 @@ class StateCreator : private StmtVisitor { * \param stmt A for-loop statement or a block statement * \return A sref to the stmt */ - StmtSRef PushSRef(const StmtNode* stmt) { - if (srefs_.empty()) { - srefs_.push_back( - StmtSRef(stmt, - /*parent=*/nullptr, - /*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex - } else { - StmtSRefNode* parent = srefs_.back().get(); - srefs_.push_back( - StmtSRef(stmt, parent, - /*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex - } - return srefs_.back(); - } + void PushSRef(const StmtNode* stmt) { srefs_.push_back(self_->stmt2ref.at(stmt)); } - /*! \brief Pop the top of the scope and record it in stmt2ref map */ - StmtSRef PopAndRecordSRef() { - StmtSRef sref = std::move(srefs_.back()); - self_->stmt2ref[sref->stmt] = sref; + /*! \brief Pop the top of the scope */ + StmtSRef PopSRef() { + StmtSRef sref = srefs_.back(); srefs_.pop_back(); return sref; } @@ -238,7 +206,10 @@ class StateCreator : private StmtVisitor { .first->second; // Set `affine_binding` if (is_root_block) { - info.affine_binding = true; + // If the block doesn't have outer loops and BlockRealize, + // then we set the affine binding flag as true only if the block has no block vars + const BlockNode* block = TVM_SREF_TO_BLOCK(block, scope_root); + if (block->iter_vars.empty()) info.affine_binding = true; } else { info.affine_binding = IsAffineBinding(/*realize=*/block2realize_.at(scope_root->stmt), @@ -385,7 +356,7 @@ class StateCreator : private StmtVisitor { analyzer_.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); PushSRef(loop); VisitStmt(loop->body); - PopAndRecordSRef(); + PopSRef(); } void VisitStmt_(const BlockRealizeNode* realize) final { @@ -395,7 +366,7 @@ class StateCreator : private StmtVisitor { // Recursive visit PushSRef(block); VisitStmt(block->body); // `block->init` is not visited - StmtSRef sref = PopAndRecordSRef(); + StmtSRef sref = PopSRef(); // Create BlockInfo for the block MakeBlockInfo(sref); // Update parent scope @@ -409,7 +380,7 @@ class StateCreator : private StmtVisitor { SetSeqIndexInChildren(self_, seq_stmt); } - /*! \brief The result ScheduleStateNode */ + /*! \brief The ScheduleStateNode we are operating on */ ScheduleStateNode* self_; /*! \brief The stack frame used to indicate the current scope */ std::vector srefs_; @@ -421,6 +392,86 @@ class StateCreator : private StmtVisitor { arith::Analyzer analyzer_; }; +/*! \brief A helper class to create a new ScheduleStateNode from an IRModule */ +class StateCreator : private StmtVisitor { + public: + /*! + * \brief The entry function + * \param self The schedule state to be completed + */ + static ObjectPtr Create(IRModule mod, int debug_mask) { + ObjectPtr n = make_object(); + ScheduleStateNode* self = n.get(); + // Set `n->mod` + n->mod = std::move(mod); + // Set `n->debug_mask` + n->debug_mask = debug_mask; + // Set `n->stmt2ref` and `n->block_info` + StateCreator creator(self); + for (const auto& kv : n->mod->functions) { + const BaseFunc& base_func = kv.second; + if (const auto* func = base_func.as()) { + creator.VisitStmt(func->body); + BlockInfoCollector::Collect(self, func->body); + } + } + return n; + } + + private: + explicit StateCreator(ScheduleStateNode* self) : self_(self) {} + + /*! + * \brief Add a new statement to the stack, which becomes the current scope + * \param stmt A for-loop statement or a block statement + * \return A sref to the stmt + */ + void PushSRef(const StmtNode* stmt) { + if (srefs_.empty()) { + srefs_.push_back( + StmtSRef(stmt, + /*parent=*/nullptr, + /*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex + } else { + StmtSRefNode* parent = srefs_.back().get(); + srefs_.push_back( + StmtSRef(stmt, parent, + /*seq_index=*/-1)); // `seq_index` will be set properly in SetSeqIndex + } + } + + /*! \brief Pop the top of the scope and record it in stmt2ref map */ + void PopAndRecordSRef() { + StmtSRef sref = std::move(srefs_.back()); + self_->stmt2ref[sref->stmt] = sref; + srefs_.pop_back(); + } + + void VisitStmt_(const ForNode* loop) final { + PushSRef(loop); + VisitStmt(loop->body); + PopAndRecordSRef(); + } + + void VisitStmt_(const BlockRealizeNode* realize) final { + const BlockNode* block = realize->block.get(); + PushSRef(block); + VisitStmt(block->body); // `block->init` is not visited + PopAndRecordSRef(); + } + + void VisitStmt_(const SeqStmtNode* seq_stmt) final { + // Set `seq_index` information for SeqStmtNode + StmtVisitor::VisitStmt_(seq_stmt); + SetSeqIndexInChildren(self_, seq_stmt); + } + + /*! \brief The result ScheduleStateNode */ + ScheduleStateNode* self_; + /*! \brief The stack frame used to indicate the current scope */ + std::vector srefs_; +}; + /**************** Constructor ****************/ ScheduleState::ScheduleState(IRModule mod, int debug_mask) { @@ -1034,6 +1085,10 @@ BlockInfo ScheduleStateNode::GetBlockInfo(const StmtSRef& block_sref) const { return it->second; } +void ScheduleStateNode::UpdateScopeBlockInfo(const Stmt& stmt) { + BlockInfoCollector::Collect(this, stmt); +} + TVM_DLL Array GetCachedFlags(const ScheduleState& self, const StmtSRef& block_sref) { const BlockInfo& info = self->GetBlockInfo(block_sref); return {Bool(info.affine_binding), // diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 6f679598c9d1..cc48f2b9e7ce 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -236,6 +236,16 @@ void TracedScheduleNode::ReverseComputeInline(const BlockRV& block_rv) { /******** Schedule: Reduction ********/ +BlockRV TracedScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) { + BlockRV result = ConcreteScheduleNode::DecomposeReduction(block_rv, loop_rv); + static const InstructionKind& kind = InstructionKind::Get("DecomposeReduction"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, + /*inputs=*/{block_rv, loop_rv}, + /*attrs=*/{}, + /*outputs=*/{result})); + return result; +} + BlockRV TracedScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) { BlockRV result = ConcreteScheduleNode::RFactor(loop_rv, factor_axis); static const InstructionKind& kind = InstructionKind::Get("RFactor"); diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index fb89783b6036..fae5ca8608dd 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -82,6 +82,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { void ComputeInline(const BlockRV& block_rv) final; void ReverseComputeInline(const BlockRV& block_rv) final; /******** Schedule: Reduction ********/ + BlockRV DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) final; BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) final; /******** Schedule: Block annotation ********/ void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor, diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index d79338ace726..8460b5cf3e66 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -28,607 +28,174 @@ @T.prim_func -def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128, 128]) - C = T.match_buffer(c, [128, 128]) - - for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): - with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: - T.bind(vi, i0) - T.bind(vj, i1) - T.bind(vk, (((i2_outer * 32) + (i2_inner_outer * 4)) + i2_inner_inner)) - T.reads([C[vi, vj], A[vi, vk], B[vj, vk]]) - T.writes([C[vi, vj]]) +def rowsum_blockized(a: T.handle, b: T.handle) -> None: + B = T.match_buffer(b, [32, 4]) + A = T.match_buffer(a, [32, 4, 128]) + for i0, i2_0 in T.grid(32, 16): + with T.block([32, T.reduce_axis(0, 16)], "blockized_B") as [io, ko]: + T.bind(io, i0) + T.bind(ko, i2_0) with T.init(): - C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) + for i1 in T.serial(0, 4): + with T.block([4], "B_init") as [ii_init]: + T.bind(ii_init, i1) + B[io, ii_init] = 0.0 + for i1_1, i2_1 in T.grid(4, 8): + with T.block([4, T.reduce_axis(0, 128)], "B") as [ii, k]: + T.bind(ii, i1_1) + T.bind(k, ko * 8 + i2_1) + B[io, ii] = B[io, ii] + A[io, ii, k] @T.prim_func -def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None: +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - C_rf = T.alloc_buffer([4, 128, 128]) - - for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): - with T.block([4, 128, 128, T.reduce_axis(0, 4), T.reduce_axis(0, 8)], "update_rf") as [ - vi2_inner_inner, - vi, - vj, - vi2_outer, - vi2_inner_outer, - ]: - T.bind(vi2_inner_inner, i2_inner_inner) - T.bind(vi, i0) - T.bind(vj, i1) - T.bind(vi2_outer, i2_outer) - T.bind(vi2_inner_outer, i2_inner_outer) - with T.init(): - C_rf[vi2_inner_inner, vi, vj] = 0.0 - C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + ( - A[vi, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] - * B[vj, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] - ) - - for i0_1, i1_1, i2_inner_inner_1 in T.grid(128, 128, 4): - with T.block([T.reduce_axis(0, 4), 128, 128], "update") as [ - vi2_inner_inner_1, - vi_1, - vj_1, - ]: - T.bind(vi2_inner_inner_1, i2_inner_inner_1) - T.bind(vi_1, i0_1) - T.bind(vj_1, i1_1) - with T.init(): - C[vi_1, vj_1] = 0.0 - C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1] - -@T.prim_func -def matmul_not_stage_pipeline(a: T.handle, b: T.handle, d: T.handle) -> None: - A = T.match_buffer(a, [256, 256]) - B = T.match_buffer(b, [256, 256]) - D = T.match_buffer(d, [256, 256]) - C = T.alloc_buffer([256, 256]) - - with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: with T.init(): C[vi, vj] = 0.0 - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] - - with T.block([256, 256], "D") as [vi, vj]: - D[vi, vj] = C[vi, vj] + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func -def matmul_not_same_buffer_access(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128, 128)) - C = T.match_buffer(c, (128, 128)) - - with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]: - with T.init(): - C[vi, vj] = 0.0 - C[vj, vi] = C[vj, vi] + A[vi, vk] * B[vk, vj] - - -@T.prim_func -def matmul_loop_multiple_children(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: +def matmul_decompose0(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) B = T.match_buffer(b, [128, 128]) C = T.match_buffer(c, [128, 128]) - D = T.match_buffer(d, [128, 128]) - - for k, i, j in T.grid(128, 128, 128): - with T.block([T.reduce_axis(0, 128), 128, 128], "C") as [ck, ci, cj]: - T.bind(ck, k) - T.bind(ci, i) - T.bind(cj, j) - with T.init(): - C[ci, cj] = 0.0 - C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj] - with T.block([T.reduce_axis(0, 128), 128, 128], "D") as [dk, di, dj]: - T.bind(dk, k) - T.bind(di, i) - T.bind(dj, j) - with T.init(): - D[di, dj] = 0.0 - D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj] + with T.block([128, 128], "init") as [vi, vj]: + C[vi, vj] = 0.0 -@T.prim_func -def square_sum(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [16, 256, 256]) - C = T.match_buffer(c, [16]) - - with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]: - with T.init(): - C[b] = 0.0 - C[b] = C[b] + A[b, i, j] * A[b, i, j] + with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] @T.prim_func -def square_sum_rfactor(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [16, 256, 256]) - C = T.match_buffer(c, [16]) - C_rf = T.alloc_buffer([16, 256]) - - for i0, i1, i2 in T.grid(16, 256, 256): - with T.block([256, 16, T.reduce_axis(0, 256)], "C_rf") as [vi2, b, i]: - T.bind(vi2, i2) - T.bind(b, i0) - T.bind(i, i1) - with T.init(): - C_rf[b, vi2] = 0.0 - C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) +def matmul_decompose1(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [32, 4, 128], elem_offset=0, align=128, offset_factor=1) + B = T.match_buffer(b, [32, 4], elem_offset=0, align=128, offset_factor=1) - for i0_1, i2_1 in T.grid(16, 256): - with T.block([T.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]: - T.bind(vi2_1, i2_1) - T.bind(b_1, i0_1) - with T.init(): - C[b_1] = 0.0 - C[b_1] = C[b_1] + C_rf[b_1, vi2_1] - - -@T.prim_func -def transformed_square_sum_square_root(a: T.handle, d: T.handle) -> None: - A = T.match_buffer(a, [16, 256, 256]) - D = T.match_buffer(d, [16]) - C = T.alloc_buffer([16]) - - for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): - with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]: - T.bind(b, i0) - T.bind(i, T.floordiv(i1_i2_fused_outer, 256)) - T.bind(j, T.floormod(i1_i2_fused_outer, 256)) - T.reads([C[b], A[b, i, j]]) - T.writes([C[b]]) - with T.init(): - C[b] = 0.0 - C[b] = C[b] + (A[b, i, j] * A[b, i, j]) - for i0_1 in T.serial(0, 16): - with T.block([16], "D") as [b_1]: - T.bind(b_1, i0_1) - T.reads([C[b_1]]) - T.writes([D[b_1]]) - D[b_1] = T.sqrt(C[b_1], dtype="float32") + for i0 in T.serial(0, 32): + with T.block([32], "blockized_B_init") as [io]: + for i1 in T.serial(0, 4): + with T.block([4], "B_init") as [ii]: + B[io, ii] = T.float32(0) + for i0, i2_o in T.grid(32, 16): + with T.block([32, T.reduce_axis(0, 16)], "blockized_B_update") as [io, ko]: + for i1, i2_i in T.grid(4, 8): + with T.block([4, T.reduce_axis(0, 128)], "B") as [ii, k]: + T.bind(ii, i1) + T.bind(k, ((ko * 8) + i2_i)) + B[io, ii] = B[io, ii] + A[io, ii, k] @T.prim_func -def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None: - A = T.match_buffer(a, [16, 256, 256]) - D = T.match_buffer(d, [16]) - C = T.alloc_buffer([16]) - C_rf = T.alloc_buffer([1, 16]) - - for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): - with T.block([1, 16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C_rf") as [ - vi1_i2_fused_inner, - b, - i, - j, - ]: - T.bind(vi1_i2_fused_inner, i1_i2_fused_inner) - T.bind(b, i0) - T.bind(i, T.floordiv(i1_i2_fused_outer, 256)) - T.bind(j, T.floormod(i1_i2_fused_outer, 256)) - with T.init(): - C_rf[vi1_i2_fused_inner, b] = 0.0 - C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) - - for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1): - with T.block([T.reduce_axis(0, 1), 16], "C") as [vi1_i2_fused_inner_1, b_1]: - T.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1) - T.bind(b_1, i0_1) - with T.init(): - C[b_1] = 0.0 - C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] +def matmul_decompose2(a: T.handle, b: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) - for i0_2 in T.serial(0, 16): - with T.block([16], "D") as [b_2]: - T.bind(b_2, i0_2) - D[b_2] = T.sqrt(C[b_2], dtype="float32") + for i0, i1 in T.grid(128, 128): + with T.block([128, 128], "update_init") as [vi_init, vj_init]: + C[vi_init, vj_init] = T.float32(0) + for i2 in T.serial(0, 128): + with T.block([128, 128, T.reduce_axis(0, 128)], "update_update") as [vi, vj, vk]: + C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) @T.prim_func -def element_wise(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128, 128)) - - with T.block([128, 128], "B") as [vi, vj]: - B[vi, vj] = A[vi, vj] * 2.0 - - -@T.prim_func -def rowsum(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128,)) - - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] + A[vi, vk] - - -@T.prim_func -def rowsum_not_quasi_affine(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128,)) - - for i, k in T.grid(128, 16): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, i) - T.bind(vk, T.floordiv(k * k, 2)) - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] + A[vi, vk] - - -@T.prim_func -def rowsum_not_dominant(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128, 128)) - - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi, vk] = 0.0 - B[vi, vk] = B[vi, vk] + A[vi, vk] - - -@T.prim_func -def rowsum_not_serial(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128,)) - - for i in T.serial(0, 128): - for k in T.parallel(0, 128): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, i) - T.bind(vk, k) - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] + A[vi, vk] - - -@T.prim_func -def rowsum_wrong_reduce_pattern1(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128,)) - - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi] = 1.0 - B[vi] = B[vi] + A[vi, vk] - - -@T.prim_func -def rowsum_wrong_reduce_pattern2(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128,)) - - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] - A[vi, vk] - - -@T.prim_func -def rowsum_transformed(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, (128, 128)) - B = T.match_buffer(b, (128,)) - - for io, ii_ko_fused, ki in T.grid(32, 128, 4): - with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: - T.bind(vi, io * 4 + T.floordiv(ii_ko_fused, 32)) - T.bind(vk, T.floormod(ii_ko_fused, 32) * 4 + ki) - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] + A[vi, vk] - - -@T.prim_func -def rowsum_zero_dim(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, [128]) - B = T.match_buffer(b, []) - - with T.block([T.reduce_axis(0, 128)], "B") as [k]: - with T.init(): - B[()] = 0.0 - B[()] = B[()] + A[k] - - -@T.prim_func -def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None: - A = T.match_buffer(a, [128]) - B = T.match_buffer(b, []) - B_rf = T.alloc_buffer([128]) - - with T.block([128], "B_rf") as [vi0]: - with T.init(): - B_rf[vi0] = 0.0 - B_rf[vi0] = B_rf[vi0] + A[vi0] - - with T.block([T.reduce_axis(0, 128)], "B") as [vi0_1]: - with T.init(): - B[()] = 0.0 - B[()] = B[()] + B_rf[vi0_1] - - -@T.prim_func -def multiple_reduction_blocks(a: T.handle, f: T.handle) -> None: - A = T.match_buffer(a, (16, 16, 16)) - C = T.alloc_buffer((16, 16)) - D = T.alloc_buffer((16, 16)) - E = T.alloc_buffer((16, 16)) - F = T.match_buffer(f, (16, 16)) - - for i in T.serial(0, 16): - for j1 in T.serial(0, 16): - for k1o, k1i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "C") as [ci, cj, ck]: - T.bind(ci, i) - T.bind(cj, j1) - T.bind(ck, k1o * 4 + k1i) - with T.init(): - C[ci, cj] = 0.0 - C[ci, cj] = C[ci, cj] + A[ci, cj, ck] - for k2o, k2i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]: - T.bind(di, i) - T.bind(dj, j1) - T.bind(dk, k2o * 4 + k2i) - with T.init(): - D[di, dj] = 0.0 - D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj] - for j2 in T.serial(0, 16): - for k3o, k3i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]: - T.bind(ei, i) - T.bind(ej, j2) - T.bind(ek, k3o * 4 + k3i) - with T.init(): - E[ei, ej] = 0.0 - E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej] - for k4o, k4i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]: - T.bind(fi, i) - T.bind(fj, j2) - T.bind(fk, k4o * 4 + k4i) - with T.init(): - F[fi, fj] = 0.0 - F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj] - +def matmul_decompose_fail3(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) -@T.prim_func -def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None: - A = T.match_buffer(a, [16, 16, 16]) - C = T.alloc_buffer([16, 16]) - D = T.alloc_buffer([16, 16]) - E = T.alloc_buffer([16, 16]) - F = T.match_buffer(f, [16, 16]) - C_rf = T.alloc_buffer([16, 16, 4]) - - for i, j1, k1o, k1i in T.grid(16, 16, 4, 4): - with T.block([4, 16, 16, T.reduce_axis(0, 4)], "C_rf") as [vk1o, ci, cj, vk1i]: - T.bind(vk1o, k1o) - T.bind(ci, i) - T.bind(cj, j1) - T.bind(vk1i, k1i) + for i, k, j in T.grid(128, 128, 128): + with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: with T.init(): - C_rf[ci, cj, vk1o] = 0.0 - C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o * 4) + vk1i)] - for i_1 in T.serial(0, 16): - for j1_1 in T.serial(0, 16): - for k1o_1 in T.serial(0, 4): - with T.block([T.reduce_axis(0, 4), 16, 16], "C") as [vk1o_1, ci_1, cj_1]: - T.bind(vk1o_1, k1o_1) - T.bind(ci_1, i_1) - T.bind(cj_1, j1_1) - with T.init(): - C[ci_1, cj_1] = 0.0 - C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1] - for k2o, k2i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]: - T.bind(di, i_1) - T.bind(dj, j1_1) - T.bind(dk, (k2o * 4) + k2i) - with T.init(): - D[di, dj] = 0.0 - D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj] - for j2 in T.serial(0, 16): - for k3o, k3i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]: - T.bind(ei, i_1) - T.bind(ej, j2) - T.bind(ek, (k3o * 4) + k3i) - with T.init(): - E[ei, ej] = 0.0 - E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej] - for k4o, k4i in T.grid(4, 4): - with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]: - T.bind(fi, i_1) - T.bind(fj, j2) - T.bind(fk, (k4o * 4) + k4i) - with T.init(): - F[fi, fj] = 0.0 - F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj] + C[vi, vj] = 0.0 + T.bind(vi, i) + T.bind(vj, j) + T.bind(vk, k) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +@T.prim_func +def matmul_decompose4(a: T.handle, b: T.handle, c: T.handle) -> None: + C = T.match_buffer(c, [128, 128], elem_offset=0, align=128, offset_factor=1) + B = T.match_buffer(b, [128, 128], elem_offset=0, align=128, offset_factor=1) + A = T.match_buffer(a, [128, 128], elem_offset=0, align=128, offset_factor=1) + # body + with T.block([], "root"): + T.reads([]) + T.writes([]) + for i0_0 in T.serial(0, 16): + for i0_1_init, i1_init in T.grid(8, 128): + with T.block([128, 128], "update_init") as [vi_init, vj_init]: + T.bind(vi_init, ((i0_0 * 8) + i0_1_init)) + T.bind(vj_init, i1_init) + C[vi_init, vj_init] = T.float32(0) + for i0_1, i1, i2_0, i2_1 in T.grid(8, 128, 19, 7): + with T.block([128, 128, T.reduce_axis(0, 128)], "update_update") as [ + vi, + vj, + vk, + ]: + T.where((((i2_0 * 7) + i2_1) < 128)) + T.bind(vi, ((i0_0 * 8) + i0_1)) + T.bind(vj, i1) + T.bind(vk, ((i2_0 * 7) + i2_1)) + C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg -def test_reduction_rfactor_matmul(): - s = tir.Schedule(transformed_matmul, debug_mask="all") - update = s.get_block("update") - _, _, _, _, kii = s.get_loops(update) - rf_block = s.rfactor(kii, 0) - tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) - assert s.get(update).same_as(s.get(s.get_block("update"))) - verify_trace_roundtrip(s, mod=transformed_matmul) - - -def test_reduction_rfactor_square_sum(): - s = tir.Schedule(square_sum, debug_mask="all") - C = s.get_block("C") - _, _, j = s.get_loops(C) - rf_block = s.rfactor(j, 1) - tvm.ir.assert_structural_equal(s.mod["main"], square_sum_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) - assert s.get(C).same_as(s.get(s.get_block("C"))) - verify_trace_roundtrip(s, mod=square_sum) - - -def test_reduction_rfactor_square_sum_square_root(): - s = tir.Schedule(transformed_square_sum_square_root, debug_mask="all") - C = s.get_block("C") - _, _, f_i = s.get_loops(C) - rf_block = s.rfactor(f_i, 0) - tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) - assert s.get(C).same_as(s.get(s.get_block("C"))) - verify_trace_roundtrip(s, mod=transformed_square_sum_square_root) - - -def test_reduction_rfactor_loop_multiple_children(): - s = tir.Schedule(matmul_loop_multiple_children, debug_mask="all") - k, _, _ = s.get_loops(s.get_block("C")) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(k, 0) - - -def test_reduction_rfactor_not_stage_pipeline(): - s = tir.Schedule(matmul_not_stage_pipeline, debug_mask="all") - _, _, k = s.get_loops(s.get_block("C")) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(k, 0) - - -def test_reduction_rfactor_not_reduction_block1(): - s = tir.Schedule(element_wise, debug_mask="all") - i, _ = s.get_loops(s.get_block("B")) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(i, 0) - - -def test_reduction_rfactor_not_reduction_block2(): - s = tir.Schedule(rowsum_not_quasi_affine, debug_mask="all") - _, k = s.get_loops(s.get_block("B")) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(k, 0) - - -def test_reduction_rfactor_not_reduction_block3(): - s = tir.Schedule(rowsum_not_dominant, debug_mask="all") - _, k = s.get_loops(s.get_block("B")) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(k, 0) - - -def test_reduction_rfactor_not_serial_loop(): - s = tir.Schedule(rowsum_not_serial, debug_mask="all") - _, k = s.get_loops(s.get_block("B")) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(k, 0) - - -def test_reduction_rfactor_not_same_buffer_access(): - s = tir.Schedule(matmul_not_same_buffer_access, debug_mask="all") - _, _, k = s.get_loops(s.get_block("C")) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(k, 0) - - -def test_reduction_rfactor_factor_axis_range_fail(): - s = tir.Schedule(transformed_matmul, debug_mask="all") - _, _, _, _, kii = s.get_loops(s.get_block("update")) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(kii, 3) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(kii, -4) - - -def test_reduction_rfactor_factor_axis_range(): - s = tir.Schedule(transformed_matmul, debug_mask="all") - update = s.get_block("update") - _, _, _, _, kii = s.get_loops(update) - rf_block = s.rfactor(kii, -3) - tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) - assert s.get(update).same_as(s.get(s.get_block("update"))) - verify_trace_roundtrip(s, mod=transformed_matmul) - - -def test_reduction_rfactor_wrong_reduce_pattern1(): - s = tir.Schedule(rowsum_wrong_reduce_pattern1, debug_mask="all") - _, k = s.get_loops(s.get_block("B")) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(k, 0) +def test_reduction_decompose0(): + s = tir.Schedule(matmul, debug_mask="all") + C = s.get_block("update") + i, j, k = s.get_loops(C) + s.decompose_reduction(C, i) + tvm.ir.assert_structural_equal(matmul_decompose0, s.mod["main"]) + verify_trace_roundtrip(s, mod=matmul) -def test_reduction_rfactor_wrong_reduce_pattern2(): - s = tir.Schedule(rowsum_wrong_reduce_pattern2, debug_mask="all") - _, k = s.get_loops(s.get_block("B")) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(k, 0) +def test_reduction_decompose1(): + s = tir.Schedule(rowsum_blockized, debug_mask="all") + blockized_B = s.get_block("blockized_B") + io, ko = s.get_loops(blockized_B) + s.decompose_reduction(blockized_B, io) + tvm.ir.assert_structural_equal(matmul_decompose1, s.mod["main"]) + verify_trace_roundtrip(s, mod=rowsum_blockized) -def test_reduction_rfactor_wrong_loops1(): - s = tir.Schedule(rowsum, debug_mask="all") - i, _ = s.get_loops(s.get_block("B")) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(i, 0) +def test_reduction_decompose2(): + s = tir.Schedule(matmul, debug_mask="all") + C = s.get_block("update") + i, j, k = s.get_loops(C) + s.decompose_reduction(C, k) + tvm.ir.assert_structural_equal(matmul_decompose2, s.mod["main"]) + verify_trace_roundtrip(s, mod=matmul) -def test_reduction_rfactor_wrong_loops2(): - s = tir.Schedule(rowsum_transformed, debug_mask="all") - _, _, k_i = s.get_loops(s.get_block("B")) +def test_reduction_decompose3(): + s = tir.Schedule(matmul_decompose_fail3, debug_mask="all") + C = s.get_block("update") + i, j, k = s.get_loops(C) with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(k_i, 0) + s.decompose_reduction(C, k) -def test_reduction_rfactor_zero_dim(): - s = tir.Schedule(rowsum_zero_dim, debug_mask="all") - B = s.get_block("B") - (k,) = s.get_loops(B) - rf_block = s.rfactor(k, 0) - tvm.ir.assert_structural_equal(s.mod["main"], rowsum_zero_dim_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("B_rf"))) - assert s.get(B).same_as(s.get(s.get_block("B"))) - verify_trace_roundtrip(s, mod=rowsum_zero_dim) - - -def test_reduction_rfactor_outermost_loop_multiple_children_fail(): # pylint: disable=invalid-name - s = tir.Schedule(multiple_reduction_blocks, debug_mask="all") - _, _, k2o, k2i = s.get_loops(s.get_block("D")) - _, _, k3o, k3i = s.get_loops(s.get_block("E")) - _, _, k4o, k4i = s.get_loops(s.get_block("F")) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(k2o, 0) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(k2i, 0) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(k3o, 0) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(k3i, 0) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(k4o, 0) - with pytest.raises(tvm.tir.ScheduleError): - s.rfactor(k4i, 0) - - -def test_reduction_rfactor_outermost_loop_multiple_children(): # pylint: disable=invalid-name - s = tir.Schedule(multiple_reduction_blocks, debug_mask="all") - C = s.get_block("C") - _, _, k1o, _ = s.get_loops(C) - rf_block = s.rfactor(k1o, 2) - tvm.ir.assert_structural_equal(s.mod["main"], multiple_reduction_blocks_rfactor) - assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) - assert s.get(C).same_as(s.get(s.get_block("C"))) - verify_trace_roundtrip(s, mod=multiple_reduction_blocks) +def test_reduction_decompose4(): + s = tir.Schedule(matmul, debug_mask="all") + C = s.get_block("update") + i, j, k = s.get_loops(C) + io, ii = s.split(i, factors=[16, 8]) + ko, ki = s.split(k, factors=[19, 7]) + s.decompose_reduction(C, ii) + tvm.ir.assert_structural_equal(matmul_decompose4, s.mod["main"]) + verify_trace_roundtrip(s, mod=matmul) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py new file mode 100644 index 000000000000..d79338ace726 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_rfactor.py @@ -0,0 +1,635 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import sys + +import pytest +import tvm +import tvm.testing +from tvm import tir +from tvm.script import tir as T +from tvm.tir.schedule.testing import verify_trace_roundtrip + +# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg + + +@T.prim_func +def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): + with T.block([128, 128, T.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + T.bind(vi, i0) + T.bind(vj, i1) + T.bind(vk, (((i2_outer * 32) + (i2_inner_outer * 4)) + i2_inner_inner)) + T.reads([C[vi, vj], A[vi, vk], B[vj, vk]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) + + +@T.prim_func +def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + C_rf = T.alloc_buffer([4, 128, 128]) + + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): + with T.block([4, 128, 128, T.reduce_axis(0, 4), T.reduce_axis(0, 8)], "update_rf") as [ + vi2_inner_inner, + vi, + vj, + vi2_outer, + vi2_inner_outer, + ]: + T.bind(vi2_inner_inner, i2_inner_inner) + T.bind(vi, i0) + T.bind(vj, i1) + T.bind(vi2_outer, i2_outer) + T.bind(vi2_inner_outer, i2_inner_outer) + with T.init(): + C_rf[vi2_inner_inner, vi, vj] = 0.0 + C_rf[vi2_inner_inner, vi, vj] = C_rf[vi2_inner_inner, vi, vj] + ( + A[vi, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] + * B[vj, (((vi2_outer * 32) + (vi2_inner_outer * 4)) + vi2_inner_inner)] + ) + + for i0_1, i1_1, i2_inner_inner_1 in T.grid(128, 128, 4): + with T.block([T.reduce_axis(0, 4), 128, 128], "update") as [ + vi2_inner_inner_1, + vi_1, + vj_1, + ]: + T.bind(vi2_inner_inner_1, i2_inner_inner_1) + T.bind(vi_1, i0_1) + T.bind(vj_1, i1_1) + with T.init(): + C[vi_1, vj_1] = 0.0 + C[vi_1, vj_1] = C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1] + + +@T.prim_func +def matmul_not_stage_pipeline(a: T.handle, b: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [256, 256]) + B = T.match_buffer(b, [256, 256]) + D = T.match_buffer(d, [256, 256]) + C = T.alloc_buffer([256, 256]) + + with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + with T.block([256, 256], "D") as [vi, vj]: + D[vi, vj] = C[vi, vj] + + +@T.prim_func +def matmul_not_same_buffer_access(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + + with T.block([128, 128, T.reduce_axis(0, 128)], "C") as [vi, vj, vk]: + with T.init(): + C[vi, vj] = 0.0 + C[vj, vi] = C[vj, vi] + A[vi, vk] * B[vk, vj] + + +@T.prim_func +def matmul_loop_multiple_children(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + D = T.match_buffer(d, [128, 128]) + + for k, i, j in T.grid(128, 128, 128): + with T.block([T.reduce_axis(0, 128), 128, 128], "C") as [ck, ci, cj]: + T.bind(ck, k) + T.bind(ci, i) + T.bind(cj, j) + with T.init(): + C[ci, cj] = 0.0 + C[ci, cj] = C[ci, cj] + A[ci, ck] * B[ck, cj] + with T.block([T.reduce_axis(0, 128), 128, 128], "D") as [dk, di, dj]: + T.bind(dk, k) + T.bind(di, i) + T.bind(dj, j) + with T.init(): + D[di, dj] = 0.0 + D[di, dj] = D[di, dj] + B[di, dk] * A[dk, dj] + + +@T.prim_func +def square_sum(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 256, 256]) + C = T.match_buffer(c, [16]) + + with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]: + with T.init(): + C[b] = 0.0 + C[b] = C[b] + A[b, i, j] * A[b, i, j] + + +@T.prim_func +def square_sum_rfactor(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 256, 256]) + C = T.match_buffer(c, [16]) + C_rf = T.alloc_buffer([16, 256]) + + for i0, i1, i2 in T.grid(16, 256, 256): + with T.block([256, 16, T.reduce_axis(0, 256)], "C_rf") as [vi2, b, i]: + T.bind(vi2, i2) + T.bind(b, i0) + T.bind(i, i1) + with T.init(): + C_rf[b, vi2] = 0.0 + C_rf[b, vi2] = C_rf[b, vi2] + (A[b, i, vi2] * A[b, i, vi2]) + + for i0_1, i2_1 in T.grid(16, 256): + with T.block([T.reduce_axis(0, 256), 16], "C") as [vi2_1, b_1]: + T.bind(vi2_1, i2_1) + T.bind(b_1, i0_1) + with T.init(): + C[b_1] = 0.0 + C[b_1] = C[b_1] + C_rf[b_1, vi2_1] + + +@T.prim_func +def transformed_square_sum_square_root(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [16, 256, 256]) + D = T.match_buffer(d, [16]) + C = T.alloc_buffer([16]) + + for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): + with T.block([16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C") as [b, i, j]: + T.bind(b, i0) + T.bind(i, T.floordiv(i1_i2_fused_outer, 256)) + T.bind(j, T.floormod(i1_i2_fused_outer, 256)) + T.reads([C[b], A[b, i, j]]) + T.writes([C[b]]) + with T.init(): + C[b] = 0.0 + C[b] = C[b] + (A[b, i, j] * A[b, i, j]) + for i0_1 in T.serial(0, 16): + with T.block([16], "D") as [b_1]: + T.bind(b_1, i0_1) + T.reads([C[b_1]]) + T.writes([D[b_1]]) + D[b_1] = T.sqrt(C[b_1], dtype="float32") + + +@T.prim_func +def square_sum_square_root_rfactor(a: T.handle, d: T.handle) -> None: + A = T.match_buffer(a, [16, 256, 256]) + D = T.match_buffer(d, [16]) + C = T.alloc_buffer([16]) + C_rf = T.alloc_buffer([1, 16]) + + for i0, i1_i2_fused_outer, i1_i2_fused_inner in T.grid(16, 65536, 1): + with T.block([1, 16, T.reduce_axis(0, 256), T.reduce_axis(0, 256)], "C_rf") as [ + vi1_i2_fused_inner, + b, + i, + j, + ]: + T.bind(vi1_i2_fused_inner, i1_i2_fused_inner) + T.bind(b, i0) + T.bind(i, T.floordiv(i1_i2_fused_outer, 256)) + T.bind(j, T.floormod(i1_i2_fused_outer, 256)) + with T.init(): + C_rf[vi1_i2_fused_inner, b] = 0.0 + C_rf[vi1_i2_fused_inner, b] = C_rf[vi1_i2_fused_inner, b] + (A[b, i, j] * A[b, i, j]) + + for i0_1, i1_i2_fused_inner_1 in T.grid(16, 1): + with T.block([T.reduce_axis(0, 1), 16], "C") as [vi1_i2_fused_inner_1, b_1]: + T.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1) + T.bind(b_1, i0_1) + with T.init(): + C[b_1] = 0.0 + C[b_1] = C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1] + + for i0_2 in T.serial(0, 16): + with T.block([16], "D") as [b_2]: + T.bind(b_2, i0_2) + D[b_2] = T.sqrt(C[b_2], dtype="float32") + + +@T.prim_func +def element_wise(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + + with T.block([128, 128], "B") as [vi, vj]: + B[vi, vj] = A[vi, vj] * 2.0 + + +@T.prim_func +def rowsum(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) + + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def rowsum_not_quasi_affine(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) + + for i, k in T.grid(128, 16): + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + T.bind(vi, i) + T.bind(vk, T.floordiv(k * k, 2)) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def rowsum_not_dominant(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + with T.init(): + B[vi, vk] = 0.0 + B[vi, vk] = B[vi, vk] + A[vi, vk] + + +@T.prim_func +def rowsum_not_serial(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) + + for i in T.serial(0, 128): + for k in T.parallel(0, 128): + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + T.bind(vi, i) + T.bind(vk, k) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def rowsum_wrong_reduce_pattern1(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) + + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + with T.init(): + B[vi] = 1.0 + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def rowsum_wrong_reduce_pattern2(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) + + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] - A[vi, vk] + + +@T.prim_func +def rowsum_transformed(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) + + for io, ii_ko_fused, ki in T.grid(32, 128, 4): + with T.block([128, T.reduce_axis(0, 128)], "B") as [vi, vk]: + T.bind(vi, io * 4 + T.floordiv(ii_ko_fused, 32)) + T.bind(vk, T.floormod(ii_ko_fused, 32) * 4 + ki) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + +@T.prim_func +def rowsum_zero_dim(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128]) + B = T.match_buffer(b, []) + + with T.block([T.reduce_axis(0, 128)], "B") as [k]: + with T.init(): + B[()] = 0.0 + B[()] = B[()] + A[k] + + +@T.prim_func +def rowsum_zero_dim_rfactor(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, [128]) + B = T.match_buffer(b, []) + B_rf = T.alloc_buffer([128]) + + with T.block([128], "B_rf") as [vi0]: + with T.init(): + B_rf[vi0] = 0.0 + B_rf[vi0] = B_rf[vi0] + A[vi0] + + with T.block([T.reduce_axis(0, 128)], "B") as [vi0_1]: + with T.init(): + B[()] = 0.0 + B[()] = B[()] + B_rf[vi0_1] + + +@T.prim_func +def multiple_reduction_blocks(a: T.handle, f: T.handle) -> None: + A = T.match_buffer(a, (16, 16, 16)) + C = T.alloc_buffer((16, 16)) + D = T.alloc_buffer((16, 16)) + E = T.alloc_buffer((16, 16)) + F = T.match_buffer(f, (16, 16)) + + for i in T.serial(0, 16): + for j1 in T.serial(0, 16): + for k1o, k1i in T.grid(4, 4): + with T.block([16, 16, T.reduce_axis(0, 16)], "C") as [ci, cj, ck]: + T.bind(ci, i) + T.bind(cj, j1) + T.bind(ck, k1o * 4 + k1i) + with T.init(): + C[ci, cj] = 0.0 + C[ci, cj] = C[ci, cj] + A[ci, cj, ck] + for k2o, k2i in T.grid(4, 4): + with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]: + T.bind(di, i) + T.bind(dj, j1) + T.bind(dk, k2o * 4 + k2i) + with T.init(): + D[di, dj] = 0.0 + D[di, dj] = D[di, dj] + A[di, dj, dk] + C[di, dj] + for j2 in T.serial(0, 16): + for k3o, k3i in T.grid(4, 4): + with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]: + T.bind(ei, i) + T.bind(ej, j2) + T.bind(ek, k3o * 4 + k3i) + with T.init(): + E[ei, ej] = 0.0 + E[ei, ej] = E[ei, ej] + A[ei, ej, ek] + D[ei, ej] + for k4o, k4i in T.grid(4, 4): + with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]: + T.bind(fi, i) + T.bind(fj, j2) + T.bind(fk, k4o * 4 + k4i) + with T.init(): + F[fi, fj] = 0.0 + F[fi, fj] = F[fi, fj] + A[fi, fj, fk] + E[fi, fj] + + +@T.prim_func +def multiple_reduction_blocks_rfactor(a: T.handle, f: T.handle) -> None: + A = T.match_buffer(a, [16, 16, 16]) + C = T.alloc_buffer([16, 16]) + D = T.alloc_buffer([16, 16]) + E = T.alloc_buffer([16, 16]) + F = T.match_buffer(f, [16, 16]) + C_rf = T.alloc_buffer([16, 16, 4]) + + for i, j1, k1o, k1i in T.grid(16, 16, 4, 4): + with T.block([4, 16, 16, T.reduce_axis(0, 4)], "C_rf") as [vk1o, ci, cj, vk1i]: + T.bind(vk1o, k1o) + T.bind(ci, i) + T.bind(cj, j1) + T.bind(vk1i, k1i) + with T.init(): + C_rf[ci, cj, vk1o] = 0.0 + C_rf[ci, cj, vk1o] = C_rf[ci, cj, vk1o] + A[ci, cj, ((vk1o * 4) + vk1i)] + for i_1 in T.serial(0, 16): + for j1_1 in T.serial(0, 16): + for k1o_1 in T.serial(0, 4): + with T.block([T.reduce_axis(0, 4), 16, 16], "C") as [vk1o_1, ci_1, cj_1]: + T.bind(vk1o_1, k1o_1) + T.bind(ci_1, i_1) + T.bind(cj_1, j1_1) + with T.init(): + C[ci_1, cj_1] = 0.0 + C[ci_1, cj_1] = C[ci_1, cj_1] + C_rf[ci_1, cj_1, vk1o_1] + for k2o, k2i in T.grid(4, 4): + with T.block([16, 16, T.reduce_axis(0, 16)], "D") as [di, dj, dk]: + T.bind(di, i_1) + T.bind(dj, j1_1) + T.bind(dk, (k2o * 4) + k2i) + with T.init(): + D[di, dj] = 0.0 + D[di, dj] = (D[di, dj] + A[di, dj, dk]) + C[di, dj] + for j2 in T.serial(0, 16): + for k3o, k3i in T.grid(4, 4): + with T.block([16, 16, T.reduce_axis(0, 16)], "E") as [ei, ej, ek]: + T.bind(ei, i_1) + T.bind(ej, j2) + T.bind(ek, (k3o * 4) + k3i) + with T.init(): + E[ei, ej] = 0.0 + E[ei, ej] = (E[ei, ej] + A[ei, ej, ek]) + D[ei, ej] + for k4o, k4i in T.grid(4, 4): + with T.block([16, 16, T.reduce_axis(0, 16)], "F") as [fi, fj, fk]: + T.bind(fi, i_1) + T.bind(fj, j2) + T.bind(fk, (k4o * 4) + k4i) + with T.init(): + F[fi, fj] = 0.0 + F[fi, fj] = (F[fi, fj] + A[fi, fj, fk]) + E[fi, fj] + + +# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg + + +def test_reduction_rfactor_matmul(): + s = tir.Schedule(transformed_matmul, debug_mask="all") + update = s.get_block("update") + _, _, _, _, kii = s.get_loops(update) + rf_block = s.rfactor(kii, 0) + tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) + assert s.get(update).same_as(s.get(s.get_block("update"))) + verify_trace_roundtrip(s, mod=transformed_matmul) + + +def test_reduction_rfactor_square_sum(): + s = tir.Schedule(square_sum, debug_mask="all") + C = s.get_block("C") + _, _, j = s.get_loops(C) + rf_block = s.rfactor(j, 1) + tvm.ir.assert_structural_equal(s.mod["main"], square_sum_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + assert s.get(C).same_as(s.get(s.get_block("C"))) + verify_trace_roundtrip(s, mod=square_sum) + + +def test_reduction_rfactor_square_sum_square_root(): + s = tir.Schedule(transformed_square_sum_square_root, debug_mask="all") + C = s.get_block("C") + _, _, f_i = s.get_loops(C) + rf_block = s.rfactor(f_i, 0) + tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + assert s.get(C).same_as(s.get(s.get_block("C"))) + verify_trace_roundtrip(s, mod=transformed_square_sum_square_root) + + +def test_reduction_rfactor_loop_multiple_children(): + s = tir.Schedule(matmul_loop_multiple_children, debug_mask="all") + k, _, _ = s.get_loops(s.get_block("C")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_stage_pipeline(): + s = tir.Schedule(matmul_not_stage_pipeline, debug_mask="all") + _, _, k = s.get_loops(s.get_block("C")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_reduction_block1(): + s = tir.Schedule(element_wise, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(i, 0) + + +def test_reduction_rfactor_not_reduction_block2(): + s = tir.Schedule(rowsum_not_quasi_affine, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_reduction_block3(): + s = tir.Schedule(rowsum_not_dominant, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_serial_loop(): + s = tir.Schedule(rowsum_not_serial, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_not_same_buffer_access(): + s = tir.Schedule(matmul_not_same_buffer_access, debug_mask="all") + _, _, k = s.get_loops(s.get_block("C")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_factor_axis_range_fail(): + s = tir.Schedule(transformed_matmul, debug_mask="all") + _, _, _, _, kii = s.get_loops(s.get_block("update")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(kii, 3) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(kii, -4) + + +def test_reduction_rfactor_factor_axis_range(): + s = tir.Schedule(transformed_matmul, debug_mask="all") + update = s.get_block("update") + _, _, _, _, kii = s.get_loops(update) + rf_block = s.rfactor(kii, -3) + tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) + assert s.get(update).same_as(s.get(s.get_block("update"))) + verify_trace_roundtrip(s, mod=transformed_matmul) + + +def test_reduction_rfactor_wrong_reduce_pattern1(): + s = tir.Schedule(rowsum_wrong_reduce_pattern1, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_wrong_reduce_pattern2(): + s = tir.Schedule(rowsum_wrong_reduce_pattern2, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + +def test_reduction_rfactor_wrong_loops1(): + s = tir.Schedule(rowsum, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(i, 0) + + +def test_reduction_rfactor_wrong_loops2(): + s = tir.Schedule(rowsum_transformed, debug_mask="all") + _, _, k_i = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k_i, 0) + + +def test_reduction_rfactor_zero_dim(): + s = tir.Schedule(rowsum_zero_dim, debug_mask="all") + B = s.get_block("B") + (k,) = s.get_loops(B) + rf_block = s.rfactor(k, 0) + tvm.ir.assert_structural_equal(s.mod["main"], rowsum_zero_dim_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("B_rf"))) + assert s.get(B).same_as(s.get(s.get_block("B"))) + verify_trace_roundtrip(s, mod=rowsum_zero_dim) + + +def test_reduction_rfactor_outermost_loop_multiple_children_fail(): # pylint: disable=invalid-name + s = tir.Schedule(multiple_reduction_blocks, debug_mask="all") + _, _, k2o, k2i = s.get_loops(s.get_block("D")) + _, _, k3o, k3i = s.get_loops(s.get_block("E")) + _, _, k4o, k4i = s.get_loops(s.get_block("F")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k2o, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k2i, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k3o, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k3i, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k4o, 0) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k4i, 0) + + +def test_reduction_rfactor_outermost_loop_multiple_children(): # pylint: disable=invalid-name + s = tir.Schedule(multiple_reduction_blocks, debug_mask="all") + C = s.get_block("C") + _, _, k1o, _ = s.get_loops(C) + rf_block = s.rfactor(k1o, 2) + tvm.ir.assert_structural_equal(s.mod["main"], multiple_reduction_blocks_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + assert s.get(C).same_as(s.get(s.get_block("C"))) + verify_trace_roundtrip(s, mod=multiple_reduction_blocks) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))