diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index 4ede2dd90da8..9c3029ebf513 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -24,6 +24,20 @@ namespace tvm { namespace tir { +template +bool UsesVar(const T& x, const Var& var) { + return UsesVar(x, [tgt = var.get()](const VarNode* v) { return v == tgt; }); +} + +Range RangeFromExtent(const PrimExpr& extent) { + return Range::FromMinExtent(make_zero(extent->dtype), extent); +} + +template +T DeepCopy(const T& stmt) { + return Downcast(LoadJSON(SaveJSON(stmt))); +} + /*! * \brief ScheduleError that the bindings of the inner block are not divisible by the subspace * represented by the outer loops. @@ -64,16 +78,16 @@ class SubspaceNotDivisibleError : public ScheduleError { * * \param iter_vars The input iterators * \param bindings The values of iter_vars - * \param outer_loops Iterators outside the subspace. - * \param inner_loops Iterators of the subspace * \param predicate The predicate constraint on the input iterators. + * \param outer_iters The iters of the outer space + * \param inner_iters The iters of the inner space * \return The result of the subspace division. */ Array> TrivialSubspaceDivision(const Array& iter_vars, const Array& bindings, + const PrimExpr& predicate, const Array& outer_iters, - const Array& inner_iters, - const PrimExpr& predicate) { + const Array& inner_iters) { if (!is_one(predicate)) return {}; Array> res; std::unordered_set outer_loop_vars; @@ -95,7 +109,7 @@ Array> TrivialSubspaceDivision(const Array& iter auto use_inner_loop_vars = make_uses_var(inner_iters); arith::IterMark unit_iter_mark(arith::IterSumExpr({}, 0), 1); - for (size_t i = 0; i < bindings.size(); ++i) { + for (int i = 0, n = bindings.size(); i < n; ++i) { bool outer = use_outer_loop_vars(bindings[i]); bool inner = use_inner_loop_vars(bindings[i]); arith::IterMark iter_mark; @@ -122,531 +136,462 @@ Array> TrivialSubspaceDivision(const Array& iter } /*! - * \brief Generate the blockized init block. - * \param block The original block with init. - * \param inner_block_realize The block realize of the inner block after blockize. - * \param inner_loops The inner loops after blockize. - * \return The subtree of the init block and its outer loops. + * \brief Subspace division. The space is divided into two subspaces: + * 1. The subspace represented by the outer loops above `loop_sref` (exclusive). + * 2. The subspace represented by the inner loops below `loop_sref` (inclusive). + * \param realize The inner block + * \param block_sref The sref to the inner block + * \param loop_sref The loop that is the root of the second subspace. + * \param loops The loops that represents the second part of the subspace. + * \param analyzer The arithmetic analyzer to use. */ -Stmt GenerateBlockizedInit(const Block& block, const BlockRealize& inner_block_realize, - const std::vector& inner_loops) { - Array init_block_iters; - Array init_bindings; - const Block& inner_block = inner_block_realize->block; - - // Step 1: Collect data-parallel block iters - for (size_t i = 0; i < inner_block->iter_vars.size(); i++) { - const IterVar& iter_var = inner_block->iter_vars[i]; - const PrimExpr& binding = inner_block_realize->iter_values[i]; - if (iter_var->iter_type == IterVarType::kDataPar && - UsesVar(block->init.value(), - [tgt_var = iter_var->var.get()](const VarNode* var) { return var == tgt_var; })) { - init_block_iters.push_back(iter_var); - init_bindings.push_back(binding); +Array> SubspaceDivide(const BlockRealize& realize, + const StmtSRef& block_sref, // + const StmtSRef& loop_sref, // + std::vector* loops, + arith::Analyzer* analyzer) { + Array inner_vars; + Array outer_vars; + Map loop_var_domain; + bool inner = true; + for (StmtSRefNode* sref = block_sref->parent; // + sref && sref->stmt->IsInstance(); // + sref = sref->parent) { + const ForNode* loop = static_cast(sref->stmt); + if (inner) { + loops->push_back(loop); + inner_vars.push_back(loop->loop_var); + } else { + outer_vars.push_back(loop->loop_var); } - } - - // Step 2: Collect loops related to iters of the init block - std::vector init_loops; - for (const ForNode* inner_loop : inner_loops) { - for (const PrimExpr& init_binding : init_bindings) { - if (UsesVar(init_binding, [tgt_var = inner_loop->loop_var.get()](const VarNode* var) { - return var == tgt_var; - })) { - init_loops.push_back(inner_loop); - break; - } + loop_var_domain.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + if (sref == loop_sref.get()) { + inner = false; } } - - // Step 3: Create new block iters for the init block - Map subst_map; - for (size_t i = 0; i < init_block_iters.size(); i++) { - IterVar new_iter_var = init_block_iters[i]; - Var old_var = new_iter_var->var; - Var new_var = old_var.copy_with_suffix("_init"); - new_iter_var.CopyOnWrite()->var = new_var; - subst_map.Set(old_var, new_var); - init_block_iters.Set(i, std::move(new_iter_var)); - } - - // Step 4: Generate loop nests and the init block - Stmt new_init = BlockRealize( - /*iter_values=*/init_bindings, - /*predicate=*/inner_block_realize->predicate, - /*block=*/ - Block{/*iter_vars=*/init_block_iters, - /*reads=*/{}, - /*writes=*/block->writes, - /*name_hint=*/block->name_hint + "_init", - /*body=*/block->init.value(), - /*init=*/NullOpt}); - - // Step 5: Generate the parent loops for the init block - for (const ForNode* init_loop : init_loops) { - ObjectPtr new_loop = make_object(*init_loop); - new_loop->loop_var = init_loop->loop_var.copy_with_suffix(""); - subst_map.Set(init_loop->loop_var, new_loop->loop_var); - new_loop->body = std::move(new_init); - new_init = For(new_loop); + Array> result = + arith::SubspaceDivide(realize->iter_values, loop_var_domain, inner_vars, realize->predicate, + arith::IterMapLevel::Surjective, analyzer); + if (!result.empty()) { + return result; } - - // Step 6: Substitute with new loop variables and block iters to prevent duplication of - // variables in the outer block. - new_init = Substitute(new_init, subst_map); - - return new_init; + return TrivialSubspaceDivision(realize->block->iter_vars, + realize->iter_values, // + realize->predicate, // + outer_vars, inner_vars); } /*! - * \brief A helper to collect the parent loops of the block. The loops are divided into two groups, - * 'outer_loops', and 'inner_loops', by a specified loop as the separator. 'outer_loops' are the - * ancestor loops of the separator loop. 'inner_loops' include the separator loop itself, and its - * successor loops. It is possible that 'outer_loops' is empty. + * \brief Derive the block bindings for both inner and outer block + * \param iter_vars The original block iterators to the inner block + * \param division The subspace division. + * \param outer_iter_vars The outer block iterators. + * \param outer_bindings The outer block bindings. + * \param inner_iter_vars The inner block iterators. + * \param inner_bindings The inner block bindings. + * \return A substitution plan to the iterators in the original inner block. */ -class LoopSubspaceCollector { - public: - /*! - * \brief Collect the parent loops of the block and store the result in the corresponding fields. - * \param block_sref The sref to the target block. - * \param loop_sref The sref to the separator loop. The loop itself is counted as an inner loop. - */ - void Collect(const StmtSRef& block_sref, const StmtSRef& loop_sref) { - bool inner = true; - for (StmtSRefNode* current_sref = block_sref->parent; - current_sref && current_sref->stmt->IsInstance(); - current_sref = current_sref->parent) { - const auto* current_loop = current_sref->StmtAs(); - ICHECK(current_loop); - if (inner) { - inner_loops.push_back(current_loop); - inner_loop_vars.push_back(current_loop->loop_var); - } else { - outer_loops.push_back(current_loop); - outer_loop_vars.push_back(current_loop->loop_var); - } - loop_var_domain.Set(current_loop->loop_var, - Range::FromMinExtent(current_loop->min, current_loop->extent)); - if (current_sref == loop_sref.get()) inner = false; +Map DeriveBlockBinding(const Array& iter_vars, // + const Array>& division, // + Array* outer_iter_vars, // + Array* outer_bindings, // + Array* inner_iter_vars, // + Array* inner_bindings) { + using arith::IterMapExpr; + using arith::IterMapExprNode; + using arith::NormalizeIterMapToExpr; + Map block_var_subst; + ICHECK_EQ(iter_vars.size() + 1, division.size()); + for (int i = 0, n = iter_vars.size(); i < n; ++i) { + const IterVar& iter_var = iter_vars[i]; + arith::IterMark outer_mark = division[i][0]; + arith::IterMark inner_mark = division[i][1]; + IterMapExpr outer_binding = Downcast(outer_mark->source); + IterMapExpr inner_binding = Downcast(inner_mark->source); + // After computing the subspace division, bindings[i] can be written as + // outer_binding * inner_binding->extent + inner_binding + // The outer block will have binding: iter_outer -> outer_binding + // The inner block will have binding: iter_inner -> inner_binding + // The iter in the original block will be substituted with base + iter_inner where + // base == iter_outer * iter_inner_extent + if (is_one(inner_mark->extent)) { // IsOuter + // extract this iter var to outer block directly + outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding)); + outer_iter_vars->push_back(iter_var); + continue; } + // create iter var for the outer block + IterVar outer_iter(/*dom=*/RangeFromExtent(outer_mark->extent), + /*var=*/iter_var->var.copy_with_suffix("_o"), + /*iter_type=*/iter_var->iter_type); + outer_bindings->push_back(NormalizeIterMapToExpr(outer_binding)); + outer_iter_vars->push_back(outer_iter); + // create iter var for the inner block + IterVar inner_iter(/*dom=*/RangeFromExtent(inner_mark->extent), + /*var=*/iter_var->var.copy_with_suffix("_i"), + /*iter_type=*/iter_var->iter_type); + inner_bindings->push_back(NormalizeIterMapToExpr(inner_binding)); + inner_iter_vars->push_back(inner_iter); + // substitution + PrimExpr sub{nullptr}; + if (is_one(outer_mark->extent)) { + sub = inner_iter->var; + } else { + sub = outer_iter * inner_mark->extent + inner_iter->var; + } + block_var_subst.Set(iter_var->var, sub); } - /*! \brief Outer loops which are ancestors of the separator. */ - std::vector outer_loops; - /*! \brief Inner loops which are the separator itself or its successors. */ - std::vector inner_loops; - /*! \brief Loop variables of the outer loops. */ - Array outer_loop_vars; - /*! \brief Loop variables of the inner loops. */ - Array inner_loop_vars; - /*! \brief Domain of the loop variables. */ - Map loop_var_domain; -}; + return block_var_subst; +} /*! - * \brief Check the bindings of the block iters can be divided by a subspace collected by the - * collector. - * \param mod The current IR module. - * \param block_realize The block realize to be checked. - * \param collector The collector which has collected the loops of the block. - * \param analyzer The arithmetic analyzer. - * \return The result of the subspace division. - * \throws ScheduleError If the bindings are not divisible by the subspace. + * \brief Generate the inner block for blockization + * \param is_write_reduction Whether the write regions of the inner block are actually reduction. + * \param iter_vars IterVars used in the inner block. + * \param iter_values IterVar bindings used in the inner block. + * \param predicate The predicate of the inner block. + * \param block The inner block as a template to be created from. This method will modify its + * `iter_vars`, `init` and `reads` fields. + * \return The inner block created. */ -Array> CheckSubspaceDivisible(const IRModule& mod, - const BlockRealize& block_realize, - const LoopSubspaceCollector& collector, - arith::Analyzer* analyzer) { - const Block& block = block_realize->block; - - Array> division = arith::SubspaceDivide( - block_realize->iter_values, collector.loop_var_domain, collector.inner_loop_vars, - block_realize->predicate, arith::IterMapLevel::Surjective, analyzer); - - if (division.empty()) { - // If we can't do perfect subspace division, check if it is a trivial case of subspace division. - // In this case, we can still blockize. - division = TrivialSubspaceDivision(block->iter_vars, block_realize->iter_values, - collector.outer_loop_vars, collector.inner_loop_vars, - block_realize->predicate); - } - if (division.empty()) { - throw SubspaceNotDivisibleError(mod, GetRef(collector.inner_loops.back()), block); +BlockRealize GenerateInner(bool is_write_reduction, + const Array& iter_vars, // + const Array& iter_values, // + const PrimExpr& predicate, // + Block block) { + BlockNode* n = block.CopyOnWrite(); + n->iter_vars = iter_vars; + n->init = NullOpt; + if (is_write_reduction) { + Array reads; + reads.reserve(block->writes.size() + block->reads.size()); + reads.insert(reads.end(), block->writes.begin(), block->writes.end()); + reads.insert(reads.end(), block->reads.begin(), block->reads.end()); + n->reads = std::move(reads); } - return division; + return BlockRealize(/*iter_values=*/iter_values, /*predicate=*/predicate, + /*block=*/block); } /*! - * \brief The binding extractor to compute the bindings of the outer and the inner blocks after - * blockize. + * \brief Generate the init stmt for the outer block + * \param block The original block with init. + * \param inner_realize The block realize of the inner block after blockize. + * \param loops The inner loops after blockize. + * \return The subtree of the init block and its outer loops. */ -class BlockizedBindingExtractor { - public: - /*! - * \brief Extract bindings for blockize. - * \param iter_vars The iter vars of the original inner block. - * \param division The result of the subspace division. - */ - void ExtractBindings(const Array& iter_vars, - const Array>& division, arith::Analyzer* analyzer) { - ICHECK_EQ(iter_vars.size() + 1, division.size()); - for (size_t i = 0; i < iter_vars.size(); ++i) { - const IterVar& iter_var = iter_vars[i]; - arith::IterMark outer_mark = division[i][0]; - arith::IterMark inner_mark = division[i][1]; - const auto* outer_binding = - TVM_TYPE_AS(outer_binding, outer_mark->source, arith::IterMapExprNode); - const auto* inner_binding = - TVM_TYPE_AS(inner_binding, inner_mark->source, arith::IterMapExprNode); - - // After computing the subspace division, bindings[i] can be written as - // outer_binding * inner_binding->extent + inner_binding - // The outer block will have binding: iter_outer -> outer_binding - // The inner block will have binding: iter_inner -> inner_binding - // The iter in the original block will be substituted with base + iter_inner where - // base == iter_outer * iter_inner_extent - - if (is_one(division[i][1]->extent)) { // IsOuter - // extract this iter var to outer block directly - outer_bindings.push_back( - arith::NormalizeIterMapToExpr(GetRef(outer_binding))); - outer_iter_vars.push_back(iter_var); - } else { - // create iter var for the outer block - const IterVar outer_var(/*dom=*/Range::FromMinExtent(0, division[i][0]->extent), - /*var=*/iter_var->var.copy_with_suffix("_o"), - /*iter_type=*/iter_var->iter_type); - outer_bindings.push_back( - arith::NormalizeIterMapToExpr(GetRef(outer_binding))); - outer_iter_vars.push_back(outer_var); - PrimExpr base = is_one(division[i][0]->extent) ? 0 : outer_var * division[i][1]->extent; - // create iter var for the inner block - IterVar new_iter(Range::FromMinExtent(0, division[i][1]->extent), Var(iter_var->var), - iter_var->iter_type, iter_var->thread_tag, iter_var->span); - inner_iter_dom_map.Set(new_iter->var, arith::IntSet::FromRange(new_iter->dom)); - analyzer->Bind(new_iter->var, new_iter->dom); - inner_iter_vars.push_back(new_iter); - inner_bindings.push_back( - arith::NormalizeIterMapToExpr(GetRef(inner_binding))); - inner_iter_subst_map.Set(iter_var->var, base + new_iter->var); +Stmt GenerateOuterInit(const Stmt& block_init, const BlockRealize& inner_realize, + const std::vector& loops, String block_name) { + const Block& inner_block = inner_realize->block; + Map subst_map; + // Step 1: Create new block vars for the block inside the init stmt of outer block + // A iter is used in the block if + // 1) It is data parallel + // 2) It is used in the original init block + Array iter_vars; + Array iter_values; + ICHECK_EQ(inner_block->iter_vars.size(), inner_realize->iter_values.size()); + int n = inner_block->iter_vars.size(); + iter_vars.reserve(n); + iter_values.reserve(n); + for (int i = 0; i < n; ++i) { + const IterVar& old_iter_var = inner_block->iter_vars[i]; + const PrimExpr& iter_value = inner_realize->iter_values[i]; + if (old_iter_var->iter_type == IterVarType::kDataPar && + UsesVar(block_init, old_iter_var->var)) { + ObjectPtr new_iter_var = make_object(*old_iter_var.get()); + new_iter_var->var = new_iter_var->var.copy_with_suffix("_init"); + subst_map.Set(old_iter_var->var, new_iter_var->var); + iter_vars.push_back(IterVar(new_iter_var)); + iter_values.push_back(iter_value); + } + } + // Step 2: Generate the block inside init stmt of outer block + Stmt stmt = BlockRealize( + /*iter_values=*/iter_values, + /*predicate=*/inner_realize->predicate, + /*block=*/ + Block(/*iter_vars=*/iter_vars, + /*reads=*/{}, + /*writes=*/inner_block->writes, + /*name_hint=*/block_name, + /*body=*/block_init, + /*init=*/NullOpt)); + // Step 3. Create the loop nest on top of the block + for (const ForNode* loop : loops) { + bool is_init_loop = false; + for (const PrimExpr& init_binding : iter_values) { + if (UsesVar(init_binding, loop->loop_var)) { + is_init_loop = true; + break; } } + if (is_init_loop) { + ObjectPtr new_loop = make_object(*loop); + new_loop->loop_var = loop->loop_var.copy_with_suffix(""); + new_loop->body = std::move(stmt); + subst_map.Set(loop->loop_var, new_loop->loop_var); + stmt = For(new_loop); + } } - Map inner_iter_subst_map; - /*! \brief Iters of the outer block. */ - Array outer_iter_vars; - /*! \brief Iters of the outer block. */ - Array inner_iter_vars; - /*! \brief Binding values of the outer block. */ - Array outer_bindings; - /*! \brief Binding values of the inner block. */ - Array inner_bindings; - /*! \brief The domain of the inner block iters. */ - Map inner_iter_dom_map; -}; + // Step 4: Substitute the iter vars and loop vars + return Substitute(stmt, subst_map); +} /*! - * \brief Replacer for the inner block after blockize. Inner block iters will be replaced with - * base + inner_iter and the expressions after substituion will be simplified if possible. + * \brief Substitute variables in the stmt, do simplification and track block substitution + * \param stmt The stmt to be substituted. + * \param sub The substitution map. + * \param block_sref_reuse The block substitution happens during the substitution. + * \param analyzer The analyzer for arithmetic simplification. + * \return The substituted stmt. */ -class InnerIterReplacer : public StmtExprMutator { - public: - /*! - * \brief The constructor - * \param subst_map The substitution map of the inner block iters. - * \param analyzer The arithmetic analyzer. - * \param block_sref_reuse The map to save the block reuse information. - */ - InnerIterReplacer(Map subst_map, arith::Analyzer* analyzer, - Map* block_sref_reuse) - : subst_map_(std::move(subst_map)), - analyzer_(analyzer), - block_sref_reuse_(block_sref_reuse) {} - - PrimExpr VisitExpr_(const VarNode* op) final { - auto it = subst_map_.find(GetRef(op)); - if (it != subst_map_.end()) { - return (*it).second; +Stmt Substitute(const Stmt& stmt, const Map& sub, + Map* block_sref_reuse, arith::Analyzer* analyzer) { + struct Replacer : public StmtExprMutator { + explicit Replacer(const Map& sub, Map* block_sref_reuse, + arith::Analyzer* analyzer) + : sub_(sub), block_sref_reuse_(block_sref_reuse), analyzer_(analyzer) {} + + PrimExpr VisitExpr(const PrimExpr& op) final { + PrimExpr result = StmtExprMutator::VisitExpr(op); + if (!result.same_as(op)) { + return analyzer_->Simplify(result); + } + return result; } - return StmtExprMutator::VisitExpr_(op); - } - PrimExpr VisitExpr(const PrimExpr& op) final { - PrimExpr result = StmtExprMutator::VisitExpr(op); - if (!result.same_as(op)) { - return analyzer_->Simplify(result); + PrimExpr VisitExpr_(const VarNode* op) final { + if (Optional e = sub_.Get(GetRef(op))) { + return e.value(); + } + return StmtExprMutator::VisitExpr_(op); } - return result; - } - Stmt VisitStmt_(const BlockNode* op) final { - Stmt result = StmtExprMutator::VisitStmt_(op); - if (!result.same_as(GetRef(op))) { - block_sref_reuse_->Set(GetRef(op), Downcast(result)); + Stmt VisitStmt_(const BlockNode* op) final { + Block src = GetRef(op); + Block tgt = Downcast(StmtExprMutator::VisitStmt_(op)); + if (!src.same_as(tgt)) { + block_sref_reuse_->Set(src, tgt); + } + return tgt; } - return result; - } - private: - Map subst_map_; - arith::Analyzer* analyzer_; - Map* block_sref_reuse_; -}; + const Map& sub_; + Map* block_sref_reuse_; + arith::Analyzer* analyzer_; + }; + return Replacer(sub, block_sref_reuse, analyzer)(stmt); +} /*! - * \brief Compute the access region of the outer block by relaxing the inner loops. - * \param buffer_region The original buffer region. - * \param The range of the inner loops. - * \return The new buffer region. + * \brief Relax the variables for the given regions + * \param regions The regions to be relaxed. + * \param dom_map The variables to be relaxed + * \return The relaxed regions */ -BufferRegion RelaxBlockizedInnerIters(const BufferRegion& buffer_region, - const Map& inner_iter_relaxed_range) { - Array new_region; - new_region.reserve(buffer_region->region.size()); - Array relaxed_int_set = - arith::EvalSet(buffer_region->region, inner_iter_relaxed_range); - ICHECK(buffer_region->region.size() == buffer_region->buffer->shape.size()); - for (size_t i = 0; i < buffer_region->region.size(); i++) { - Range max_range = Range::FromMinExtent(0, buffer_region->buffer->shape[i]); - new_region.push_back(relaxed_int_set[i].CoverRange(max_range)); +Array EvalSetRegions(const Array& regions, + const Map& dom_map) { + Array results; + results.reserve(regions.size()); + for (const BufferRegion& buffer_region : regions) { + const Buffer& buffer = buffer_region->buffer; + Array relaxed = arith::EvalSet(buffer_region->region, dom_map); + ICHECK_EQ(relaxed.size(), buffer->shape.size()); + int ndim = buffer->shape.size(); + Array new_region; + new_region.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + new_region.push_back(relaxed[i].CoverRange(RangeFromExtent(buffer->shape[i]))); + } + results.push_back(BufferRegion(buffer, new_region)); } - return BufferRegion(buffer_region->buffer, std::move(new_region)); + return results; } /*! - * \brief Generate the outer block after blockize. - * \param extractor The binding extractor which has extracted the blockized bindings. - * \param block The original inner block. - * \param inner_block_realize The block realize of the inner block after blockize. - * \param inner_loops The inner loops after blockize. - * \param predicate The outer predicate of the subspace division. - * \return The block realize of the outer block after blockize. + * \brief Create the loop nest on top of the given stmt. + * \param stmt The stmt to be wrapped. + * \param loops The loop nests + * \return The wrapped stmt. */ -BlockRealize GenerateBlockizedOuterBlock(const BlockizedBindingExtractor& extractor, - const Block& block, BlockRealize inner_block_realize, - const std::vector& inner_loops, - PrimExpr predicate) { - // Step 1: Generate the init block if needed - Optional new_init = NullOpt; - if (block->init.defined()) { - new_init = GenerateBlockizedInit(block, inner_block_realize, inner_loops); - } - - // Step 2: Compute the access regions of the outer block by relaxing the inner loops - Array new_reads = block->reads; - Array new_writes = block->writes; - - auto f_mutate = [&](const BufferRegion& buffer_region) { - return RelaxBlockizedInnerIters(buffer_region, extractor.inner_iter_dom_map); - }; - new_reads.MutateByApply(f_mutate); - new_writes.MutateByApply(f_mutate); - - // Step 3: Generate the body of the outer block. The body of the outer block is the inner block - // realize and its surrounding loops. - Stmt outer_block_body = inner_block_realize; - for (const ForNode* loop : inner_loops) { +Stmt MakeLoopNest(Stmt stmt, const std::vector& loops) { + for (const ForNode* loop : loops) { ObjectPtr new_loop = make_object(*loop); - new_loop->body = std::move(outer_block_body); - outer_block_body = For(new_loop); + new_loop->body = std::move(stmt); + stmt = For(new_loop); } - - // Step 4: Generate the outer block and block realize. - return BlockRealize(/*iter_values=*/std::move(extractor.outer_bindings), - /*predicate=*/std::move(predicate), - /*block=*/ - Block(/*iter_vars=*/std::move(extractor.outer_iter_vars), // - /*reads=*/std::move(new_reads), // - /*writes=*/std::move(new_writes), // - /*name_hint=*/block->name_hint + "_o", // - /*body=*/std::move(outer_block_body), // - /*init=*/std::move(new_init))); + return stmt; } -StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { +BlockRealize BlockizeImpl(const ScheduleState& self, const StmtSRef& loop_sref, + Map* block_sref_reuse, arith::Analyzer* analyzer) { const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); - arith::Analyzer analyzer; - - // Step 1: Check the loop has a single child BlockRealize on the sref tree. + // Step 1: Check and get the only block under `loop`. BlockRealize block_realize = CheckGetSingleChildBlockRealizeOnSRefTree(self, loop_sref); Block block = block_realize->block; StmtSRef block_sref = self->stmt2ref.at(block.get()); - - // Step 2: Collect loops inside and outside loop_sref. - LoopSubspaceCollector collector; - collector.Collect(block_sref, loop_sref); - - // Step 3: Calculate subspace division for the inner loops. + // Step 2: Derive subspace division + std::vector loops; Array> division = - CheckSubspaceDivisible(self->mod, block_realize, collector, &analyzer); - - // Step 4: Generate bindings for the outer block and the inner block based on the result of - // the subspace division. - BlockizedBindingExtractor extractor; - extractor.ExtractBindings(block->iter_vars, division, &analyzer); - const PrimExpr& outer_pred = division.back()[0]->extent; - const PrimExpr& inner_pred = division.back()[1]->extent; - - // Step 5: Substitute the iter vars in the original block with the inner iters after the subspace - // division - Map block_sref_reuse; - InnerIterReplacer replacer(std::move(extractor.inner_iter_subst_map), &analyzer, - &block_sref_reuse); - Block new_block = Downcast(replacer(block)); - - // Step 6: Generate the inner block. - bool outer_reduction = false; // whether there are outer reduction iter vars. - for (const IterVar& iter_var : extractor.outer_iter_vars) { - if (iter_var->iter_type == kCommReduce) { - outer_reduction = true; - } + SubspaceDivide(block_realize, block_sref, loop_sref, &loops, analyzer); + if (division.empty()) { + throw SubspaceNotDivisibleError(self->mod, GetRef(loops.back()), block); } - BlockRealizeNode* inner_block_realize = block_realize.CopyOnWrite(); - inner_block_realize->iter_values = extractor.inner_bindings; - inner_block_realize->predicate = inner_pred; - inner_block_realize->block = new_block; - BlockNode* inner_block = inner_block_realize->block.CopyOnWrite(); - inner_block->iter_vars = extractor.inner_iter_vars; - inner_block->init = NullOpt; - /* Add write regions to read regions if - * 1. there are outer reduction iter vars. - * 2. the init block is defined for current block. - */ - if (outer_reduction && block->init.defined()) { - Array new_reads; - for (const BufferRegion& write_access : inner_block->writes) { - new_reads.push_back(write_access); - } - for (const BufferRegion& read_access : inner_block->reads) { - new_reads.push_back(read_access); + PrimExpr outer_predicate = division.back()[0]->extent; + PrimExpr inner_predicate = division.back()[1]->extent; + // Step 3. Derive block bindings for both outer and inner block. + Array outer_iter_vars; + Array inner_iter_vars; + Array outer_bindings; + Array inner_bindings; + Map block_var_subst = // + DeriveBlockBinding(block->iter_vars, division, // + &outer_iter_vars, &outer_bindings, // + &inner_iter_vars, &inner_bindings); + // Step 4: Do var substitution to adjust to the new block bindings + Map inner_iter_dom; + for (const IterVar& iter : inner_iter_vars) { + inner_iter_dom.Set(iter->var, arith::IntSet::FromRange(iter->dom)); + analyzer->Bind(iter->var, iter->dom); + } + Block block_subst = + Downcast(Substitute(block, block_var_subst, block_sref_reuse, analyzer)); + // Step 5: Generate the inner block. The write regions of the inner blocks will be reduction if + // 1. The original block has init stmt. + // 2. There are outer reduction iter vars. + bool has_outer_reduction = false; + if (block_subst->init.defined()) { + for (const IterVar& iter_var : outer_iter_vars) { + if (iter_var->iter_type == kCommReduce) { + has_outer_reduction = true; + break; + } } - inner_block->reads = std::move(new_reads); } - block_sref_reuse.Set(block, inner_block_realize->block); - + BlockRealize inner_realize = GenerateInner(/*is_write_reduction=*/has_outer_reduction, + /*iter_vars=*/inner_iter_vars, + /*iter_values*/ inner_bindings, + /*predicate=*/inner_predicate, + /*block=*/block_subst); + block_sref_reuse->Set(block, inner_realize->block); // Step 6: Generate the outer block. - BlockRealize outer_realize = - GenerateBlockizedOuterBlock(extractor, new_block, GetRef(inner_block_realize), - collector.inner_loops, outer_pred); - // Step 7: Do the actual replacement - self->Replace(loop_sref, outer_realize, block_sref_reuse); - - // Step 8: Update the cached flags - StmtSRef outer_block_sref = self->stmt2ref.at(outer_realize->block.get()); - StmtSRef scope_root = tir::GetScopeRoot(self, outer_block_sref, /*require_stage_pipeline=*/false); + return BlockRealize( + /*iter_values=*/std::move(outer_bindings), + /*predicate=*/std::move(outer_predicate), + /*block=*/ + Block(/*iter_vars=*/std::move(outer_iter_vars), + /*reads=*/EvalSetRegions(block_subst->reads, inner_iter_dom), + /*writes=*/EvalSetRegions(block_subst->writes, inner_iter_dom), + /*name_hint=*/block_subst->name_hint + "_o", + /*body=*/MakeLoopNest(inner_realize, loops), + /*init=*/ + block_subst->init.defined() // + ? GenerateOuterInit(block_subst->init.value(), inner_realize, loops, + block_subst->name_hint + "_init") + : Optional(NullOpt))); +} + +StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref) { + arith::Analyzer analyzer; + Map block_sref_reuse; + BlockRealize blockized = BlockizeImpl(self, loop_sref, &block_sref_reuse, &analyzer); + self->Replace(loop_sref, blockized, block_sref_reuse); + StmtSRef result = self->stmt2ref.at(blockized->block.get()); + StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false); bool scope_block_affine_binding = self->IsAffineBlockBinding(scope_root); self->UpdateScopeBlockInfo(tir::GetBlockRealize(self, scope_root)); self->block_info[scope_root].affine_binding = scope_block_affine_binding; - return outer_block_sref; -} - -/*! - * \brief Update the map from the buffers in the desc to the impl of the tensor - * intrinsic. - * \param intrinsic The tensor intrinsic. - * \param buffer_map The map to be updated. - */ -void RemapTensorIntrinBuffers( - const TensorIntrin& intrinsic, - std::unordered_map* buffer_map) { - ICHECK_EQ(intrinsic->desc->params.size(), intrinsic->impl->params.size()); - for (size_t i = 0; i < intrinsic->desc->params.size(); ++i) { - const Var& lhs_var = intrinsic->desc->params[i]; - const Buffer& lhs_buffer = intrinsic->desc->buffer_map[lhs_var]; - const Var& rhs_var = intrinsic->impl->params[i]; - const Buffer& rhs_buffer = intrinsic->impl->buffer_map[rhs_var]; - (*buffer_map)[rhs_buffer] = lhs_buffer; - } + return result; } -void Tensorize(ScheduleState self, const StmtSRef& block_or_loop_sref, - const TensorIntrin& intrinsic) { - /*! - * Check: - * - Check buffer binding, including type, alignment, shape and etc. - * - Check the sub AST is equal to the desc function. - * - * Mutate: - * - Blockize the sub AST (please refer blockize for details) - * - Bind buffers - * - Mutate the impl of the tensor intrinsic by replacing its buffers with new - * buffers created via match buffer region. - * - Replace the sub tree with the mutated function. - */ - const BlockRealize& desc_block_realize = Downcast(intrinsic->desc->body); - const BlockRealize& impl_block_realize = Downcast(intrinsic->impl->body); - Block impl_block = impl_block_realize->block; - +void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& intrin) { // Step 1: Blockize the subtree rooted at the given loop if needed - StmtSRef block_sref{nullptr}; - if (block_or_loop_sref->StmtAs()) { - block_sref = Blockize(self, block_or_loop_sref); + BlockRealize block_realize{nullptr}; + Optional old_block = NullOpt; + if (sref->stmt->IsInstance()) { + block_realize = GetBlockRealize(self, sref); + old_block = block_realize->block; + } else if (sref->stmt->IsInstance()) { + arith::Analyzer analyzer; + Map block_sref_reuse; + block_realize = BlockizeImpl(self, sref, &block_sref_reuse, &analyzer); } else { - ICHECK(block_or_loop_sref->StmtAs()); - block_sref = block_or_loop_sref; + LOG(FATAL) << "TypeError: Tensorize only support For or Block, but gets: " + << GetRef(sref->stmt); + throw; } - const BlockRealize& block_realize = GetBlockRealize(self, block_sref); - - // Step 2: Compare the block with the desc of the tensor intrinsic, find the correspondence - // between buffers in the block and the desc. + PrimFunc intrin_desc = intrin->desc; + PrimFunc intrin_impl = DeepCopy(intrin->impl); + // Step 2: Structural pattern matching TensorizeComparator comparator(self->mod, /*assert_mode=*/true); - comparator.VisitStmt(block_realize, desc_block_realize); - - // Step 3: Find the correspondence between buffers in the current AST and the impl of - // the tensor intrinsic - // Step 3.1: Map from intrinsic func buffer to desc func buffer - std::unordered_map intrin_buffer_map; - RemapTensorIntrinBuffers(intrinsic, &intrin_buffer_map); - // Step 3.2: Map form intrinsic func buffer to current AST buffer - std::unordered_map buffer_map; - for (const auto& pair : intrin_buffer_map) { - auto it = comparator.rhs_buffer_map_.find(pair.second); - ICHECK(it != comparator.rhs_buffer_map_.end()) << pair.second; - buffer_map[pair.first] = it->second; + comparator.VisitStmt(block_realize, intrin_desc->body); + // Step 3: Prepare necessary mapping + // 1) Buffer mapping from intrin impl buffers to intrin desc buffers. + // 2) Buffer mapping from intrin impl buffers to buffers in the current AST. + // 3) Mapping impl buffers to their accessed regions. + std::unordered_map impl2desc; + ICHECK_EQ(intrin_desc->params.size(), intrin_impl->params.size()); + for (int i = 0, n = intrin_desc->params.size(); i < n; ++i) { + const Buffer& desc = intrin_desc->buffer_map[intrin_desc->params[i]]; + const Buffer& impl = intrin_impl->buffer_map[intrin_impl->params[i]]; + impl2desc[impl] = desc; } - - // Step 4: Create MatchBufferRegion for the params of the impl function of the tensor - // intrin to make them subregions of the buffer in the original IR. - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_region_map; + std::unordered_map impl2cur; + for (const auto& pair : impl2desc) { + const Buffer& impl = pair.first; + const Buffer& desc = pair.second; + ICHECK(comparator.rhs_buffer_map_.count(desc)); + impl2cur[impl] = comparator.rhs_buffer_map_[desc]; + } + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> impl2region; + Block impl_block = Downcast(intrin_impl->body)->block; for (const BufferRegion& read : impl_block->reads) { - buffer_region_map.emplace(read->buffer, read->region); + impl2region.emplace(read->buffer, read->region); } for (const BufferRegion& write : impl_block->writes) { - buffer_region_map.emplace(write->buffer, write->region); + impl2region.emplace(write->buffer, write->region); } + // Step 4: Create MatchBufferRegion for the params of the impl function of the tensor + // intrin to make them subregions of the buffer in the original IR. Array match_buffer_regions; - match_buffer_regions.reserve(intrinsic->impl->params.size()); - for (size_t i = 0; i < intrinsic->impl->params.size(); ++i) { - const auto& param = intrinsic->impl->params[i]; - const auto& buffer = intrinsic->impl->buffer_map.at(param); - const auto& source = buffer_map.at(buffer); - // add the detected base indices to each buffer access region of the tensor intrinsic - Region old_region = buffer_region_map.at(buffer); - const auto& indices_base = comparator.buffer_indices_.at(source); + match_buffer_regions.reserve(intrin_impl->params.size()); + for (int i = 0, n = intrin_impl->params.size(); i < n; ++i) { + const Buffer& impl = intrin_impl->buffer_map.at(intrin_impl->params[i]); + const Buffer& cur = impl2cur.at(impl); + const Array& old_region = impl2region.at(impl); + const std::vector& indices_base = comparator.buffer_indices_.at(cur); int offset = static_cast(indices_base.size()) - static_cast(old_region.size()); ICHECK(offset >= 0); - Region new_region; - new_region.reserve(source->shape.size()); + Array new_region; + new_region.reserve(cur->shape.size()); for (int i = 0; i < offset; i++) { - new_region.push_back(Range::FromMinExtent(indices_base[i], 1)); + PrimExpr min = indices_base[i]; + PrimExpr extent = make_const(min.dtype(), 1); + new_region.push_back(Range::FromMinExtent(min, extent)); } for (int i = 0; i < static_cast(old_region.size()); i++) { - new_region.push_back(Range::FromMinExtent(indices_base[i + offset], old_region[i]->extent)); + PrimExpr min = indices_base[i + offset]; + PrimExpr extent = old_region[i]->extent; + new_region.push_back(Range::FromMinExtent(min, extent)); } - match_buffer_regions.push_back(MatchBufferRegion(buffer, BufferRegion(source, new_region))); + match_buffer_regions.push_back(MatchBufferRegion(impl, BufferRegion(cur, new_region))); } - // Step 5: Replace the subtree in the original IR with the tensor intrin impl. - ObjectPtr new_block_ptr = make_object(*block_realize->block.get()); - new_block_ptr->body = impl_block->body; - ICHECK(new_block_ptr->match_buffers.empty()); - new_block_ptr->match_buffers = std::move(match_buffer_regions); - Block new_block(new_block_ptr); - - self->Replace(block_sref, new_block, {{block_realize->block, new_block}}); - + { + BlockNode* block = block_realize.CopyOnWrite()->block.CopyOnWrite(); + block->body = impl_block->body; + block->match_buffers = std::move(match_buffer_regions); + } + if (old_block.defined()) { + self->Replace(sref, block_realize->block, {{old_block.value(), block_realize->block}}); + } else { + self->Replace(sref, block_realize, {}); + } // Step 6: Update the cached flags. - StmtSRef scope_root = tir::GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); - self->UpdateScopeBlockInfo(static_cast(scope_root->stmt)->body); + StmtSRef result = self->stmt2ref.at(block_realize->block.get()); + StmtSRef scope_root = tir::GetScopeRoot(self, result, /*require_stage_pipeline=*/false); + self->UpdateScopeBlockInfo(scope_root->StmtAs()->body); } /******** InstructionKind Registration ********/ diff --git a/tests/python/unittest/test_tir_schedule_blockize.py b/tests/python/unittest/test_tir_schedule_blockize.py index 481421cfdf78..6d13281320c0 100644 --- a/tests/python/unittest/test_tir_schedule_blockize.py +++ b/tests/python/unittest/test_tir_schedule_blockize.py @@ -15,12 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-function-docstring,missing-module-docstring -import sys -import pytest import tvm import tvm.testing -from tvm.script import tir as T from tvm import tir +from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip # fmt: off @@ -33,177 +31,219 @@ def single_elementwise(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128 vi, vj = T.axis.remap("SS", [i, j]) B[vi, vj] = A[vi, vj] * 2.0 - -@T.prim_func -def single_elementwise_blockized1( - A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"] -) -> None: - with T.block("blockized_B"): - vio = T.axis.spatial(1, 0) - vjo = T.axis.spatial(1, 0) - T.reads(A[0:128, 0:128]) - T.writes(B[0:128, 0:128]) - for i, j in T.grid(128, 128): - with T.block("B"): - vi, vj = T.axis.remap("SS", [i, j]) - T.reads(A[vi, vj]) - T.writes(B[vi, vj]) - B[vi, vj] = A[vi, vj] * T.float32(2) +# fmt: on +# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks -@T.prim_func -def single_elementwise_blockized2( - A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"] -) -> None: - for i in T.serial(128): +def test_blockize_outer(): + @T.prim_func + def after_blockize_outer( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + ) -> None: with T.block("blockized_B"): - vi = T.axis.spatial(128, i) + vio = T.axis.spatial(1, 0) vjo = T.axis.spatial(1, 0) - T.reads(A[vi, 0:128]) - T.writes(B[vi, 0:128]) - for j in T.serial(128): - with T.block("B"): - vj = T.axis.remap("S", [j]) - T.reads(A[vi, vj]) - T.writes(B[vi, vj]) - B[vi, vj] = A[vi, vj] * T.float32(2) - - -@T.prim_func -def two_elementwise(A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(128, 128), "float32"]) -> None: - B = T.alloc_buffer([128, 128], dtype="float32") - for i, j in T.grid(128, 128): - with T.block("B"): - vi, vj = T.axis.remap("SS", [i, j]) - T.reads(A[vi, vj]) - T.writes(B[vi, vj]) - B[vi, vj] = A[vi, vj] * T.float32(2) - for i, j in T.grid(128, 128): - with T.block("C"): - vi, vj = T.axis.remap("SS", [i, j]) - T.reads(B[vi, vj]) - T.writes(C[vi, vj]) - C[vi, vj] = B[vi, vj] + T.float32(1) - - -@T.prim_func -def two_elementwise_blockized( - A: T.Buffer[(128, 128), "float32"], - C: T.Buffer[(128, 128), "float32"] -) -> None: - B = T.alloc_buffer([128, 128], dtype="float32") - for i_0, j_0 in T.grid(8, 8): - with T.block("blockized_B"): - vio, vjo = T.axis.remap("SS", [i_0, j_0]) - T.reads(A[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) - T.writes(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) - for i_1, j_1 in T.grid(16, 16): + for i, j in T.grid(128, 128): with T.block("B"): - vi, vj = T.axis.remap("SS", [i_1, j_1]) - T.reads(A[vio * 16 + vi, vjo * 16 + vj]) - T.writes(B[vio * 16 + vi, vjo * 16 + vj]) - B[vio * 16 + vi, vjo * 16 + vj] = A[vio * 16 + vi, vjo * 16 + vj] * T.float32(2) - with T.block("blockized_C"): - vio, vjo = T.axis.remap("SS", [i_0, j_0]) - T.reads(B[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) - T.writes(C[vio * 16 : vio * 16 + 16, vjo * 16 : vjo * 16 + 16]) - for ax0, ax1 in T.grid(16, 16): - with T.block("C"): - vi, vj = T.axis.remap("SS", [ax0, ax1]) - T.reads(B[vio * 16 + vi, vjo * 16 + vj]) - T.writes(C[vio * 16 + vi, vjo * 16 + vj]) - C[vio * 16 + vi, vjo * 16 + vj] = B[vio * 16 + vi, vjo * 16 + vj] + T.float32(1) - - -@T.prim_func -def rowsum(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None: - for k, i in T.grid(128, 128): - with T.block("B"): - vk, vi = T.axis.remap("RS", [k, i]) - with T.init(): - B[vi] = 0.0 - B[vi] = B[vi] + A[vi, vk] - - -@T.prim_func -def rowsum_blockized(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None: - with T.block("blockized_B"): - vko = T.axis.R(1, 0) - vio = T.axis.S(1, 0) - with T.init(): - for i1 in T.serial(0, 128): - with T.block("B_init"): - vi_init = T.axis.S(128, i1) - B[vi_init] = T.float32(0) - for i0, i1_1 in T.grid(128, 128): - with T.block("B"): - vk, vi = T.axis.remap("RS", [i0, i1_1]) - B[vi] = B[vi] + A[vi, vk] + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 - -# fmt: off -# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks - -def test_blockize_outer(): func = single_elementwise - # schedule s = tir.Schedule(func, debug_mask="all") - B = s.get_block("B") - x, y = s.get_loops(B) + x, _ = s.get_loops(s.get_block("B")) s.blockize(x) - print(s.mod['main'].script()) - tvm.ir.assert_structural_equal(s.mod["main"], single_elementwise_blockized1) + tvm.ir.assert_structural_equal(s.mod["main"], after_blockize_outer) verify_trace_roundtrip(sch=s, mod=func) def test_blockize_inner(): + @T.prim_func + def after_blockize_inner( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128, 128), "float32"], + ) -> None: + for i in T.serial(128): + with T.block("blockized_B"): + vi = T.axis.spatial(128, i) + vjo = T.axis.spatial(1, 0) + for j in T.serial(128): + with T.block("B"): + vj = T.axis.remap("S", [j]) + B[vi, vj] = A[vi, vj] * 2.0 + func = single_elementwise - # schedule s = tir.Schedule(func, debug_mask="all") - B = s.get_block("B") - x, y = s.get_loops(B) + _, y = s.get_loops(s.get_block("B")) s.blockize(y) - tvm.ir.assert_structural_equal(s.mod["main"], single_elementwise_blockized2) + tvm.ir.assert_structural_equal(s.mod["main"], after_blockize_inner) verify_trace_roundtrip(sch=s, mod=func) def test_two_elementwise_blockize_reverse_compute_at(): - func = two_elementwise + @T.prim_func + def before_blockize_rca( + A: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], + ) -> None: + B = T.alloc_buffer([128, 128], dtype="float32") + for i, j in T.grid(8, 8): + with T.block("B_o"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + for i_1, j_1 in T.grid(16, 16): + with T.block("B"): + vi_i, vj_i = T.axis.remap("SS", [i_1, j_1]) + T.reads(A[vi * 16 + vi_i, vj * 16 + vj_i]) + T.writes(B[vi * 16 + vi_i, vj * 16 + vj_i]) + B[vi * 16 + vi_i, vj * 16 + vj_i] = A[vi * 16 + vi_i, vj * 16 + vj_i] * 2.0 + for ax0, ax1 in T.grid(16, 16): + with T.block("C"): + vi = T.axis.spatial(128, i * 16 + ax0) + vj = T.axis.spatial(128, j * 16 + ax1) + T.reads(B[vi, vj]) + T.writes(C[vi, vj]) + C[vi, vj] = B[vi, vj] + 1.0 + + @T.prim_func + def after_blockize_rca( + A: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], + ) -> None: + B = T.alloc_buffer([128, 128], dtype="float32") + for i, j in T.grid(8, 8): + with T.block("B_o"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + for i_1, j_1 in T.grid(16, 16): + with T.block("B"): + vi_i, vj_i = T.axis.remap("SS", [i_1, j_1]) + T.reads(A[vi * 16 + vi_i, vj * 16 + vj_i]) + T.writes(B[vi * 16 + vi_i, vj * 16 + vj_i]) + B[vi * 16 + vi_i, vj * 16 + vj_i] = A[vi * 16 + vi_i, vj * 16 + vj_i] * 2.0 + with T.block("C_o"): + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(B[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + T.writes(C[vi * 16 : vi * 16 + 16, vj * 16 : vj * 16 + 16]) + for ax0, ax1 in T.grid(16, 16): + with T.block("C"): + vi_i, vj_i = T.axis.remap("SS", [ax0, ax1]) + T.reads(B[vi * 16 + vi_i, vj * 16 + vj_i]) + T.writes(C[vi * 16 + vi_i, vj * 16 + vj_i]) + C[vi * 16 + vi_i, vj * 16 + vj_i] = B[vi * 16 + vi_i, vj * 16 + vj_i] + 1.0 + + func = before_blockize_rca s = tir.Schedule(func, debug_mask="all") - B = s.get_block("B") - C = s.get_block("C") - x, y = s.get_loops(B) - xo, xi = s.split(x, factors=[None, 16]) - yo, yi = s.split(y, factors=[None, 16]) - s.reorder(xo, yo, xi, yi) - s.blockize(xi) - s.reverse_compute_at(C, yo) - s.blockize(s.get_loops(C)[-2]) - tvm.ir.assert_structural_equal(s.mod["main"], two_elementwise_blockized) + _, _, x, _ = s.get_loops(s.get_block("C")) + s.blockize(x) + tvm.ir.assert_structural_equal(s.mod["main"], after_blockize_rca) verify_trace_roundtrip(sch=s, mod=func) def test_two_elementwise_blockize_compute_at(): - func = two_elementwise + @T.prim_func + def before_blockize_compute_at( + A: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], + ) -> None: + # body + # with T.block("root") + B = T.alloc_buffer([128, 128], dtype="float32") + for i_0, j_0 in T.grid(8, 8): + for ax0, ax1 in T.grid(16, 16): + with T.block("B"): + vi = T.axis.spatial(128, i_0 * 16 + ax0) + vj = T.axis.spatial(128, j_0 * 16 + ax1) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * 2.0 + with T.block("C_o"): + vi_o, vj_o = T.axis.remap("SS", [i_0, j_0]) + T.reads(B[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) + T.writes(C[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) + for i_1, j_1 in T.grid(16, 16): + with T.block("C"): + vi_i, vj_i = T.axis.remap("SS", [i_1, j_1]) + T.reads(B[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) + T.writes(C[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) + C[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = ( + B[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + 1.0 + ) + + @T.prim_func + def after_blockize_compute_at( + A: T.Buffer[(128, 128), "float32"], + C: T.Buffer[(128, 128), "float32"], + ) -> None: + B = T.alloc_buffer([128, 128], dtype="float32") + for i_0, j_0 in T.grid(8, 8): + with T.block("B_o"): + vi_o, vj_o = T.axis.remap("SS", [i_0, j_0]) + T.reads(A[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) + T.writes(B[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) + for ax0, ax1 in T.grid(16, 16): + with T.block("B"): + vi_i, vj_i = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) + T.writes(B[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) + B[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = ( + A[vi_o * 16 + vi_i, vj_o * 16 + vj_i] * 2.0 + ) + with T.block("C_o"): + vi_o, vj_o = T.axis.remap("SS", [i_0, j_0]) + T.reads(B[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) + T.writes(C[vi_o * 16 : vi_o * 16 + 16, vj_o * 16 : vj_o * 16 + 16]) + for i_1, j_1 in T.grid(16, 16): + with T.block("C"): + vi_i, vj_i = T.axis.remap("SS", [i_1, j_1]) + T.reads(B[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) + T.writes(C[vi_o * 16 + vi_i, vj_o * 16 + vj_i]) + C[vi_o * 16 + vi_i, vj_o * 16 + vj_i] = ( + B[vi_o * 16 + vi_i, vj_o * 16 + vj_i] + 1.0 + ) + + func = before_blockize_compute_at s = tir.Schedule(func, debug_mask="all") - B = s.get_block("B") - C = s.get_block("C") - x, y = s.get_loops(C) - xo, xi = s.split(x, factors=[None, 16]) - yo, yi = s.split(y, factors=[None, 16]) - s.reorder(xo, yo, xi, yi) - s.blockize(xi) - s.compute_at(B, yo) - s.blockize(s.get_loops(B)[-2]) - tvm.ir.assert_structural_equal(s.mod["main"], two_elementwise_blockized) + _, _, x, _ = s.get_loops(s.get_block("B")) + s.blockize(x) + tvm.ir.assert_structural_equal(s.mod["main"], after_blockize_compute_at) verify_trace_roundtrip(sch=s, mod=func) def test_blockize_init_loops(): + @T.prim_func + def rowsum(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128,), "float32"]) -> None: + for k, i in T.grid(128, 128): + with T.block("B"): + vk, vi = T.axis.remap("RS", [k, i]) + with T.init(): + B[vi] = 0.0 + B[vi] = B[vi] + A[vi, vk] + + @T.prim_func + def after_rowsum_blockize( + A: T.Buffer[(128, 128), "float32"], + B: T.Buffer[(128,), "float32"], + ) -> None: + with T.block("blockized_B"): + vko = T.axis.R(1, 0) + vio = T.axis.S(1, 0) + with T.init(): + for i1 in T.serial(0, 128): + with T.block("B_init"): + vi_init = T.axis.S(128, i1) + B[vi_init] = T.float32(0) + for i0, i1_1 in T.grid(128, 128): + with T.block("B"): + vk, vi = T.axis.remap("RS", [i0, i1_1]) + B[vi] = B[vi] + A[vi, vk] + s = tir.Schedule(rowsum, debug_mask="all") k, _ = s.get_loops(s.get_block("B")) s.blockize(k) - tvm.ir.assert_structural_equal(s.mod["main"], rowsum_blockized) + tvm.ir.assert_structural_equal(s.mod["main"], after_rowsum_blockize) verify_trace_roundtrip(sch=s, mod=rowsum)