From d93b7f1ba5ffe8ed9ec92019eb5f616383031615 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Mon, 6 Jun 2022 22:53:26 +0800 Subject: [PATCH] [TIR][MetaSchedule] Support Tuple Reduction This PR improves our TIR scheduling primitives/transformations (rfactor & cross-thread reduction) designed for reduction operators, so that they can be applied to blocks of tuple-reduction. --- .../schedule_rule/cross_thread_reduction.cc | 7 + src/tir/schedule/analysis.h | 48 +- src/tir/schedule/analysis/analysis.cc | 524 +------------ src/tir/schedule/analysis/reducer.cc | 702 ++++++++++++++++++ src/tir/schedule/primitive/reduction.cc | 402 ++++++---- .../lower_cross_thread_reduction.cc | 323 ++++---- ...meta_schedule_schedule_rule_add_rfactor.py | 166 +++++ ...le_schedule_rule_cross_thread_reduction.py | 99 +++ .../unittest/test_tir_schedule_rfactor.py | 649 +++++++++++++++- ..._transform_lower_cross_thread_reduction.py | 244 +++++- 10 files changed, 2314 insertions(+), 850 deletions(-) create mode 100644 src/tir/schedule/analysis/reducer.cc diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index 0f0ab99e7259..35be33f72e21 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -184,6 +184,13 @@ class CrossThreadReductionNode : public ScheduleRuleNode { */ std::tuple GetComputeTargetLoopAndBlock( const tir::Schedule& sch, const tir::BlockRV& block_rv) { + // Step 0. Due to technical reason of some primitives (e.g., compute-at), if the block is doing + // a tuple reduction, fusion is temporarily not supported. + if (sch->Get(block_rv)->writes.size() != 1) { + return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr}, + tir::LoopRV{nullptr}); + } + // Step 1. Get all the consumers of the input block. Array consumers = sch->GetConsumers(block_rv); diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 52ef17df162c..489df8959d1b 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -455,15 +455,14 @@ std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_ /******** Reduction Block Related ********/ /*! - * \brief Convert the `init` and `body` of the input block to BufferStores - * \param self The schedule state - * \param block The block to be analyzed - * \return The BufferStores of the `init` and `body` of the input block - * \throw ScheduleError If the `init` or `body` is not BufferStore, or they don't write to the same - * buffer + * \brief Get the init values and the BufferStore updates from the input reduction block + * \param self The schedule state, used for error reporting + * \param block The block from which the init values and BufferStore updates are extracted from + * \return The extracted init values and BufferStore updates + * \throw ScheduleError If rfactor or cross-thread reduction cannot be applied to the block */ -std::pair GetBufferStoresFromReductionBlock( - const Optional& self, const Block& block); +std::pair, Array> GetInitValuesAndUpdatesFromReductionBlock( + const Optional& self, Block block); /*! * \brief Check whether the input array of IterVars only contains data-parallel and reduction block @@ -484,16 +483,17 @@ bool ContainsOnlyDataParAndReductionBlockIter(const Array& iters); bool ReductionIterNotIndexOutputBuffer(const Block& block); /*! - * \brief Given a reduction identity and a reduction combiner, detect the corresponding commutative - * reducer, and extract the combiner lhs and combiner rhs + * \brief Given a list of reduction identities and a list of reduction combiners, detect the + * corresponding commutative reducer, and extract the combiner LHS values and combiner RHS values * \param self The schedule state - * \param identity The reduction identity to be analyzed - * \param combiner The reduction combiner to be analyzed - * \return The corresponding CommReducer, the combiner lhs and the combiner rhs + * \param identities The reduction identities to be analyzed + * \param combiners The reduction combiners to be analyzed + * \return The corresponding CommReducer, combiner LHS values and combiner RHS values * \throw ScheduleError If no corresponding commutative reducer can be matched */ -std::tuple GetReducerAndCombinerLhsRhs( - const Optional& self, const PrimExpr& identity, const BufferStore& combiner); +std::tuple, Array> GetReducerAndCombinerLhsRhs( + const Optional& self, const Array& identities, + const Array& combiners); /******** Commutative Reducer ********/ @@ -502,20 +502,20 @@ std::tuple GetReducerAndCombinerLhsRhs( * \return The list of the registered reducer-getter functions * \sa ReducerRegistry */ -std::vector> GetReducerGetters(); +std::vector(Array)>> GetReducerGetters(); /*! - * \brief Given the input identity and the combiner BufferStore of a reduction, extract the - * corresponding commutative reducer and its lhs, rhs if possible. - * \param identity The identity of the reduction - * \param combiner The combiner of the reduction + * \brief Given the input identities and the combiner BufferStores of a reduction, extract the + * corresponding commutative reducer, LHS values and RHS values, if possible. + * \param identities The identities of the reduction + * \param combiners The combiners of the reduction * \param result_reducer The extracted CommReducer - * \param lhs The extracted lhs of the reducer - * \param rhs The extracted rhs of the reducer + * \param lhs The extracted LHS values of the reducer + * \param rhs The extracted RHS values of the reducer * \return A boolean indicating whether a corresponding commutative reducer is found */ -bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, - CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs); +bool FromIdentityCombiner(const Array& identities, const Array& combiners, + CommReducer* result_reducer, Array* lhs, Array* rhs); /******** Misc ********/ diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index fb09a3480a3a..7ed60876ab22 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -16,9 +16,6 @@ * specific language governing permissions and limitations * under the License. */ -#include -#include - #include "../ir_comparator.h" #include "../utils.h" @@ -1237,523 +1234,6 @@ std::pair, bool> GetBufferDefiningSite(const StmtSRef& block_ return {NullOpt, false}; } -/******** Pattern Matcher ********/ - -/*! - * \brief PrimExpr pattern matcher. - * - * It is different from the pattern matcher in arith/pattern_match.h, which is dedicated - * for compile-time constant patterns. This pattern matcher can work on dynamic user-specific - * patterns. - * - * The code below shows how to use the pattern matcher. - * - * \code - * - * Var x("x"), y("y"); - * // use PrimExpr to declare patterns, x, y are holes that can be filled with - * PatternMatcher pattern_matcher(x + y); - * // expr = C[i, j] + A[i, k] * B[k, j], which is the expr we want to match - * pattern_matcher.Match(expr); - * - * if (pattern_matcher.Success()) { - * pattern_matcher.Eval(x) // C[i, j] - * pattern_matcher.Eval(y) // A[i, k] * B[k, j] - * } - * - * \endcode - */ -class PatternMatcher : public ExprVisitor { - public: - explicit PatternMatcher(PrimExpr pattern) : pattern_(std::move(pattern)) {} - - void VisitExpr_(const VarNode* op) final { - auto it = filled_map_.find(op); - if (it == filled_map_.end()) { - filled_map_[op] = expr_to_match_; - } else { - ExprDeepEqual equal; - if (it->second.same_as(expr_to_match_) || equal(it->second, expr_to_match_)) return; - match_success_ = false; - } - } - - void VisitExpr_(const LoadNode* op) final { - const auto* ptr = expr_to_match_.as(); - if (ptr == nullptr) { - match_success_ = false; - } else { - if (!op->buffer_var.same_as(ptr->buffer_var)) { - match_success_ = false; - } else { - PrimExpr tmp = expr_to_match_; - expr_to_match_ = ptr->predicate; - VisitExpr(op->predicate); - expr_to_match_ = ptr->index; - VisitExpr(op->index); - std::swap(expr_to_match_, tmp); - } - } - } - - void VisitExpr_(const LetNode* op) final { - const auto* ptr = expr_to_match_.as(); - if (ptr == nullptr) { - match_success_ = false; - } else { - PrimExpr tmp = expr_to_match_; - expr_to_match_ = ptr->var; - VisitExpr(op->var); - expr_to_match_ = ptr->value; - VisitExpr(op->value); - expr_to_match_ = ptr->body; - VisitExpr(op->body); - std::swap(expr_to_match_, tmp); - } - } - - void VisitExpr_(const CallNode* op) final { - const auto* ptr = expr_to_match_.as(); - if (ptr == nullptr) { - match_success_ = false; - } else { - if (!op->op.same_as(ptr->op)) { - match_success_ = false; - } else { - PrimExpr tmp = expr_to_match_; - for (size_t i = 0; i < op->args.size(); ++i) { - expr_to_match_ = ptr->args[i]; - VisitExpr(op->args[i]); - } - std::swap(expr_to_match_, tmp); - } - } - } - -#define TVM_DECLARE_PATTERN_MATCHER_BIN_OP(OpName) \ - void VisitExpr_(const OpName* op) { \ - const auto* ptr = expr_to_match_.as(); \ - if (ptr == nullptr) { \ - match_success_ = false; \ - } else { \ - PrimExpr current = expr_to_match_; \ - expr_to_match_ = ptr->a; \ - VisitExpr(op->a); \ - expr_to_match_ = ptr->b; \ - VisitExpr(op->b); \ - std::swap(expr_to_match_, current); \ - } \ - } - - TVM_DECLARE_PATTERN_MATCHER_BIN_OP(AddNode); - TVM_DECLARE_PATTERN_MATCHER_BIN_OP(SubNode); - TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MulNode); - TVM_DECLARE_PATTERN_MATCHER_BIN_OP(DivNode); - TVM_DECLARE_PATTERN_MATCHER_BIN_OP(ModNode); - TVM_DECLARE_PATTERN_MATCHER_BIN_OP(FloorDivNode); - TVM_DECLARE_PATTERN_MATCHER_BIN_OP(FloorModNode); - TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MinNode); - TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MaxNode); - TVM_DECLARE_PATTERN_MATCHER_BIN_OP(EQNode); - TVM_DECLARE_PATTERN_MATCHER_BIN_OP(NENode); - TVM_DECLARE_PATTERN_MATCHER_BIN_OP(LTNode); - TVM_DECLARE_PATTERN_MATCHER_BIN_OP(LENode); - TVM_DECLARE_PATTERN_MATCHER_BIN_OP(GTNode); - TVM_DECLARE_PATTERN_MATCHER_BIN_OP(GENode); - TVM_DECLARE_PATTERN_MATCHER_BIN_OP(AndNode); - TVM_DECLARE_PATTERN_MATCHER_BIN_OP(OrNode); - - void VisitExpr_(const CastNode* op) final { - const auto* ptr = expr_to_match_.as(); - if (ptr == nullptr) { - match_success_ = false; - } else { - if (!runtime::TypeEqual(op->dtype, ptr->dtype)) { - match_success_ = false; - } else { - PrimExpr tmp = expr_to_match_; - expr_to_match_ = ptr->value; - VisitExpr(op->value); - std::swap(expr_to_match_, tmp); - } - } - } - - void VisitExpr_(const NotNode* op) final { - const auto* ptr = expr_to_match_.as(); - if (ptr == nullptr) { - match_success_ = false; - } else { - PrimExpr tmp = expr_to_match_; - expr_to_match_ = ptr->a; - VisitExpr(op->a); - std::swap(expr_to_match_, tmp); - } - } - - void VisitExpr_(const SelectNode* op) final { - const auto* ptr = expr_to_match_.as(); - if (ptr == nullptr) { - match_success_ = false; - } else { - PrimExpr tmp = expr_to_match_; - expr_to_match_ = ptr->condition; - VisitExpr(op->condition); - expr_to_match_ = ptr->true_value; - VisitExpr(op->true_value); - expr_to_match_ = ptr->false_value; - VisitExpr(op->false_value); - std::swap(expr_to_match_, tmp); - } - } - - void VisitExpr_(const RampNode* op) final { - const auto* ptr = expr_to_match_.as(); - if (ptr == nullptr) { - match_success_ = false; - } else { - if (op->lanes != ptr->lanes) { - match_success_ = false; - } else { - PrimExpr tmp = expr_to_match_; - expr_to_match_ = ptr->base; - VisitExpr(op->base); - expr_to_match_ = ptr->stride; - VisitExpr(op->stride); - std::swap(expr_to_match_, tmp); - } - } - } - - void VisitExpr_(const BroadcastNode* op) final { - const auto* ptr = expr_to_match_.as(); - if (ptr == nullptr) { - match_success_ = false; - } else { - if (op->lanes != ptr->lanes) { - match_success_ = false; - } else { - PrimExpr tmp = expr_to_match_; - expr_to_match_ = ptr->value; - VisitExpr(op->value); - std::swap(expr_to_match_, tmp); - } - } - } - - void VisitExpr_(const ShuffleNode* op) final { - const auto* ptr = expr_to_match_.as(); - if (ptr == nullptr) { - match_success_ = false; - } else { - if (op->vectors.size() != ptr->vectors.size() || op->indices.size() != ptr->indices.size()) { - match_success_ = false; - } else { - PrimExpr tmp = expr_to_match_; - for (size_t i = 0; i < op->indices.size(); ++i) { - expr_to_match_ = ptr->indices[i]; - VisitExpr(op->indices[i]); - } - for (size_t i = 0; i < op->vectors.size(); ++i) { - expr_to_match_ = ptr->vectors[i]; - VisitExpr(op->vectors[i]); - } - std::swap(expr_to_match_, tmp); - } - } - } - - void VisitExpr_(const IntImmNode* op) final { - const auto* ptr = expr_to_match_.as(); - match_success_ = ptr != nullptr && op->value == ptr->value; - } - - void VisitExpr_(const FloatImmNode* op) final { - const auto* ptr = expr_to_match_.as(); - match_success_ = ptr != nullptr && op->value == ptr->value; - } - - void VisitExpr_(const StringImmNode* op) final { - const auto* ptr = expr_to_match_.as(); - match_success_ = ptr != nullptr && op->value == ptr->value; - } - - void VisitExpr_(const BufferLoadNode* op) final { - const auto* ptr = expr_to_match_.as(); - if (ptr == nullptr) { - match_success_ = false; - } else { - if (!op->buffer.same_as(ptr->buffer) || op->indices.size() != ptr->indices.size()) { - match_success_ = false; - } else { - PrimExpr tmp = expr_to_match_; - for (size_t i = 0; i < op->indices.size(); ++i) { - expr_to_match_ = ptr->indices[i]; - VisitExpr(op->indices[i]); - } - std::swap(expr_to_match_, tmp); - } - } - } - - void Match(const PrimExpr& expr_to_match) { - this->match_success_ = true; - this->filled_map_.clear(); - this->expr_to_match_ = expr_to_match; - this->operator()(pattern_); - } - - PrimExpr Eval(const Var& var) { - auto it = filled_map_.find(var.operator->()); - ICHECK(it != filled_map_.end()) << "Unknown pattern variable"; - ICHECK(match_success_) << "Match failed"; - return it->second; - } - - bool Success() const { return match_success_; } - - private: - bool match_success_{true}; - PrimExpr pattern_, expr_to_match_; - std::unordered_map filled_map_; -}; - -/******** Reduction Block Related ********/ - -class InitBodyNotBufferStoreError : public ScheduleError { - public: - explicit InitBodyNotBufferStoreError(IRModule mod, Block block, bool init_is_bufferstore, - bool body_is_bufferstore) - : mod_(std::move(mod)), - block_(std::move(block)), - init_is_bufferstore_(init_is_bufferstore), - body_is_bufferstore_(body_is_bufferstore) {} - - String FastErrorString() const final { - return "ScheduleError: The `init` and `body` of reduction block are required to be both " - "BufferStore so that rfactor or cross-thread reduction can be applied"; - } - - String DetailRenderTemplate() const final { - if (!init_is_bufferstore_ && !body_is_bufferstore_) { - return "The `init` and `body` of block {0} are required to be BufferStore so that rfactor or " - "cross-thread reduction can be applied"; - } else if (!init_is_bufferstore_) { - return "The `init` of block {0} is required to be BufferStore so that rfactor or cross-thread" - " reduction can be applied"; - } else { - ICHECK(!body_is_bufferstore_); - return "The `body` of block {0} is required to be BufferStore so that rfactor or cross-thread" - " reduction can be applied"; - } - } - - IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } - - IRModule mod_; - Block block_; - bool init_is_bufferstore_; - bool body_is_bufferstore_; -}; - -class InitBodyNotSameBufferAccessError : public ScheduleError { - public: - explicit InitBodyNotSameBufferAccessError(IRModule mod, Block block) - : mod_(std::move(mod)), block_(std::move(block)) {} - - String FastErrorString() const final { - return "ScheduleError: The `init` and `body` of the reduction block are required to have the " - "same buffer access pattern"; - } - - String DetailRenderTemplate() const final { - std::ostringstream os; - const auto* init = block_->init.as(); - const auto* update = block_->body.as(); - os << "The `init` and `body` of the block {0} is required to have the same buffer access " - "pattern. However, in block {0} the `init` writes to " - << init->buffer->name << init->indices << ", and the `body` writes to " - << update->buffer->name << update->indices; - return os.str(); - } - - IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {block_}; } - - IRModule mod_; - Block block_; -}; - -std::pair GetBufferStoresFromReductionBlock( - const Optional& self, const Block& block) { - static constexpr const char* error_str1 = - "ValueError: The `init` and `body` of the reduction block are required to be both " - "BufferStore so that rfactor or cross-thread reduction can be applied. However, a reduction " - "block that doesn't meet this requirement is "; - static constexpr const char* error_str2 = - "ValueError: The `init` and `body` of the reduction block are required to have the same " - "buffer access pattern so that rfactor or cross-thread reduction can be applied. However, a " - "reduction block that doesn't meet this requirement is "; - - const auto* init = block->init.as(); - const auto* body = block->body.as(); - if (!(init && body)) { - if (self.defined()) { - throw InitBodyNotBufferStoreError(self.value()->mod, block, init != nullptr, body != nullptr); - } else { - LOG(FATAL) << error_str1 << block; - } - } - if (!init->buffer.same_as(body->buffer)) { - if (self.defined()) { - throw InitBodyNotSameBufferAccessError(self.value()->mod, block); - } else { - LOG(FATAL) << error_str2 << block; - } - } - int ndim = static_cast(init->buffer->shape.size()); - for (int i = 0; i < ndim; ++i) { - if (!ExprDeepEqual()(init->indices[i], body->indices[i])) { - if (self.defined()) { - throw InitBodyNotSameBufferAccessError(self.value()->mod, block); - } else { - LOG(FATAL) << error_str2 << block; - } - } - } - return std::make_pair(GetRef(init), GetRef(body)); -} - -bool ContainsOnlyDataParAndReductionBlockIter(const Array& iters) { - for (const IterVar& iter_var : iters) { - if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { - return false; - } - } - return true; -} - -bool ReductionIterNotIndexOutputBuffer(const Block& block) { - // Step 1. Collect the reduction block iters. - std::unordered_set reduction_block_iters; - reduction_block_iters.reserve(block->iter_vars.size()); - for (const IterVar& iter_var : block->iter_vars) { - if (iter_var->iter_type == kCommReduce) { - reduction_block_iters.insert(iter_var->var.get()); - } - } - // Step 2. Check if the reduction block iters are used to index the output buffer. - std::unordered_set buffer_written; - buffer_written.reserve(block->writes.size()); - for (const BufferRegion& write_region : block->writes) { - buffer_written.insert(write_region->buffer.get()); - } - auto f_uses_reduction_block_var = [&](const PrimExpr& expr) -> bool { - return UsesVar(expr, [&](const VarNode* var) { // - return reduction_block_iters.count(var); - }); - }; - bool affected = false; - PreOrderVisit(block->body, [&](const ObjectRef& obj) { - if (affected) { - return false; - } - const auto* store = obj.as(); - if (!store) { - return true; - } - ICHECK(buffer_written.count(store->buffer.get())) - << "ValueError: The buffer \"" << store->buffer - << "\" is written in the block but is not in the block's signature"; - for (const PrimExpr& index : store->indices) { - if (f_uses_reduction_block_var(index)) { - affected = true; - return false; - } - } - return false; - }); - return !affected; -} - -class NoMatchedReducerError : public ScheduleError { - public: - explicit NoMatchedReducerError(IRModule mod, PrimExpr identity, BufferStore combiner) - : mod_(std::move(mod)), identity_(std::move(identity)), combiner_(std::move(combiner)) {} - - String FastErrorString() const final { - return "ScheduleError: No matched reducer for the identity and the combiner of this reduction " - "block. So rfactor and cross-thread reduction cannot be applied."; - } - - String DetailRenderTemplate() const final { - std::ostringstream os; - os << "No matched reducer for identity " << identity_ << " and combiner " << combiner_ - << "In this case rfactor cannot be applied. You can check tvm::tir::ReducerRegistry for " - "default reducers or registering new reducers."; - return os.str(); - } - - IRModule mod() const final { return mod_; } - Array LocationsOfInterest() const final { return {}; } - - IRModule mod_; - PrimExpr identity_; - BufferStore combiner_; -}; - -std::tuple GetReducerAndCombinerLhsRhs( - const Optional& self, const PrimExpr& identity, const BufferStore& combiner) { - CommReducer reducer{nullptr}; - PrimExpr combiner_lhs{nullptr}, combiner_rhs{nullptr}; - bool matched = FromIdentityCombiner(identity, combiner, &reducer, &combiner_lhs, &combiner_rhs); - if (!matched) { - if (self.defined()) { - throw NoMatchedReducerError(self.value()->mod, identity, combiner); - } else { - LOG(FATAL) << "ValueError: No matched reducer for the identity and the combiner of the " - "reduction block. So rfactor and cross-thread reduction cannot be applied."; - } - } - return std::make_tuple(std::move(reducer), std::move(combiner_lhs), std::move(combiner_rhs)); -} - -/******** Commutative Reducer ********/ - -bool MatchReducer(const CommReducer& reducer, const PrimExpr& identity, const PrimExpr& combiner, - const BufferLoad& load, PrimExpr* lhs, PrimExpr* rhs) { - if (!ExprDeepEqual()(reducer->identity_element[0], identity)) { - return false; - } - PatternMatcher pattern_matcher(reducer->result[0]); - pattern_matcher.Match(combiner); - if (pattern_matcher.Success()) { - PrimExpr lhs_tmp = pattern_matcher.Eval(reducer->lhs[0]); - PrimExpr rhs_tmp = pattern_matcher.Eval(reducer->rhs[0]); - if (ExprDeepEqual()(load, lhs_tmp)) { - *lhs = std::move(lhs_tmp); - *rhs = std::move(rhs_tmp); - } - return true; - } - return false; -} - -bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner, - CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs) { - BufferLoad load(combiner->buffer, combiner->indices); - // Check reduction patterns. - for (const TypedPackedFunc& reducer_getter : GetReducerGetters()) { - CommReducer reducer = reducer_getter(identity.dtype()); - if (MatchReducer(reducer, identity, combiner->value, load, lhs, rhs)) { - *result_reducer = std::move(reducer); - return true; - } - } - return false; -} - /******** SRef Tree Related ********/ StmtSRef GetSRefTreeRoot(const StmtSRef& sref) { @@ -2072,8 +1552,8 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref); Array loops = tir::GetLoops(block_sref); - // Cond 1. The block has only one write buffer - if (block->writes.size() != 1) { + // Cond 1. The block must have at lease one write buffer + if (block->writes.size() == 0) { return false; } diff --git a/src/tir/schedule/analysis/reducer.cc b/src/tir/schedule/analysis/reducer.cc new file mode 100644 index 000000000000..50813ef3cae8 --- /dev/null +++ b/src/tir/schedule/analysis/reducer.cc @@ -0,0 +1,702 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/******** Pattern Matcher ********/ + +/*! + * \brief PrimExpr pattern matcher. + * + * It is different from the pattern matcher in arith/pattern_match.h, which is dedicated + * for compile-time constant patterns. This pattern matcher can work on dynamic user-specific + * patterns. + * + * The code below shows how to use the pattern matcher. + * + * \code + * + * Var x("x"), y("y"); + * // use PrimExpr to declare patterns, x, y are holes that can be filled with + * PatternMatcher pattern_matcher(x + y); + * // expr = C[i, j] + A[i, k] * B[k, j], which is the expr we want to match + * pattern_matcher.Match(expr); + * + * if (pattern_matcher.Success()) { + * pattern_matcher.Eval(x) // C[i, j] + * pattern_matcher.Eval(y) // A[i, k] * B[k, j] + * } + * + * \endcode + */ +class PatternMatcher : public ExprVisitor { + public: + explicit PatternMatcher(Array pattern) : pattern_(std::move(pattern)) {} + + void VisitExpr_(const VarNode* op) final { + auto it = filled_map_.find(op); + if (it == filled_map_.end()) { + filled_map_[op] = expr_to_match_; + } else { + ExprDeepEqual equal; + if (it->second.same_as(expr_to_match_) || equal(it->second, expr_to_match_)) return; + match_success_ = false; + } + } + + void VisitExpr_(const LoadNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (!op->buffer_var.same_as(ptr->buffer_var)) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->predicate; + VisitExpr(op->predicate); + expr_to_match_ = ptr->index; + VisitExpr(op->index); + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const LetNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->var; + VisitExpr(op->var); + expr_to_match_ = ptr->value; + VisitExpr(op->value); + expr_to_match_ = ptr->body; + VisitExpr(op->body); + std::swap(expr_to_match_, tmp); + } + } + + void VisitExpr_(const CallNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (!op->op.same_as(ptr->op)) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + for (size_t i = 0; i < op->args.size(); ++i) { + expr_to_match_ = ptr->args[i]; + VisitExpr(op->args[i]); + } + std::swap(expr_to_match_, tmp); + } + } + } + +#define TVM_DECLARE_PATTERN_MATCHER_BIN_OP(OpName) \ + void VisitExpr_(const OpName* op) { \ + const auto* ptr = expr_to_match_.as(); \ + if (ptr == nullptr) { \ + match_success_ = false; \ + } else { \ + PrimExpr current = expr_to_match_; \ + expr_to_match_ = ptr->a; \ + VisitExpr(op->a); \ + expr_to_match_ = ptr->b; \ + VisitExpr(op->b); \ + std::swap(expr_to_match_, current); \ + } \ + } + + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(AddNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(SubNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MulNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(DivNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(ModNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(FloorDivNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(FloorModNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MinNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(MaxNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(EQNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(NENode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(LTNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(LENode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(GTNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(GENode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(AndNode); + TVM_DECLARE_PATTERN_MATCHER_BIN_OP(OrNode); + + void VisitExpr_(const CastNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (!runtime::TypeEqual(op->dtype, ptr->dtype)) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->value; + VisitExpr(op->value); + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const NotNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->a; + VisitExpr(op->a); + std::swap(expr_to_match_, tmp); + } + } + + void VisitExpr_(const SelectNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->condition; + VisitExpr(op->condition); + expr_to_match_ = ptr->true_value; + VisitExpr(op->true_value); + expr_to_match_ = ptr->false_value; + VisitExpr(op->false_value); + std::swap(expr_to_match_, tmp); + } + } + + void VisitExpr_(const RampNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (op->lanes != ptr->lanes) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->base; + VisitExpr(op->base); + expr_to_match_ = ptr->stride; + VisitExpr(op->stride); + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const BroadcastNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (op->lanes != ptr->lanes) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + expr_to_match_ = ptr->value; + VisitExpr(op->value); + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const ShuffleNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (op->vectors.size() != ptr->vectors.size() || op->indices.size() != ptr->indices.size()) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + for (size_t i = 0; i < op->indices.size(); ++i) { + expr_to_match_ = ptr->indices[i]; + VisitExpr(op->indices[i]); + } + for (size_t i = 0; i < op->vectors.size(); ++i) { + expr_to_match_ = ptr->vectors[i]; + VisitExpr(op->vectors[i]); + } + std::swap(expr_to_match_, tmp); + } + } + } + + void VisitExpr_(const IntImmNode* op) final { + const auto* ptr = expr_to_match_.as(); + match_success_ = ptr != nullptr && op->value == ptr->value; + } + + void VisitExpr_(const FloatImmNode* op) final { + const auto* ptr = expr_to_match_.as(); + match_success_ = ptr != nullptr && op->value == ptr->value; + } + + void VisitExpr_(const StringImmNode* op) final { + const auto* ptr = expr_to_match_.as(); + match_success_ = ptr != nullptr && op->value == ptr->value; + } + + void VisitExpr_(const BufferLoadNode* op) final { + const auto* ptr = expr_to_match_.as(); + if (ptr == nullptr) { + match_success_ = false; + } else { + if (!op->buffer.same_as(ptr->buffer) || op->indices.size() != ptr->indices.size()) { + match_success_ = false; + } else { + PrimExpr tmp = expr_to_match_; + for (size_t i = 0; i < op->indices.size(); ++i) { + expr_to_match_ = ptr->indices[i]; + VisitExpr(op->indices[i]); + } + std::swap(expr_to_match_, tmp); + } + } + } + + void Match(const Array& exprs_to_match) { + this->match_success_ = true; + this->filled_map_.clear(); + + ICHECK_EQ(pattern_.size(), exprs_to_match.size()); + int n_buffers = pattern_.size(); + for (int i = 0; i < n_buffers; ++i) { + this->expr_to_match_ = exprs_to_match[i]; + this->operator()(pattern_[i]); + } + } + + PrimExpr Eval(const Var& var) { + auto it = filled_map_.find(var.operator->()); + ICHECK(it != filled_map_.end()) << "Unknown pattern variable"; + ICHECK(match_success_) << "Match failed"; + return it->second; + } + + bool Success() const { return match_success_; } + + private: + bool match_success_{true}; + Array pattern_; + PrimExpr expr_to_match_; + std::unordered_map filled_map_; +}; + +/******** Reduction Block Related ********/ + +static const char* kRFactorCrossThreadReductionApplicableBlockDef = + R"(Definition of a reduction block that is applicable by RFactor and Cross-Thread Reduction: +1) The block init should be a single BufferStore or a SeqStmt of BufferStores +2) The buffers initialized in the block init should be all different +3) The number of consecutive LetStmts in the block body (if any) should equal the number of BufferStores in the block init +4) The variables of the LetStmts in the block body should be all different +5) The body of the innermost LetStmt should be a single BufferStore or a SeqStmt of BufferStores +6) The number of BufferStores under the block body should equal the number of BufferStores in the block init, and thereby equal the number of LetStmts above +7) The variables bound by the LetStmts in the block body must all directly serve as values of the BufferStores inside, and the stored values of the BufferStores can only be those variables +8) The variables stored by the BufferStores in the block body should be all different +9) The buffers written by the BufferStores in the block body should be all different +10) The buffers initialized in the block init and written in the block body should match +11) The buffers written by the block should have same shape +12) The indices of all BufferStores in the reduction block should be the same)"; + +void ErrorRFactorCrossThreadReductionNotApplicable(const Optional& self, Block block, + int violated_cond) { + class RFactorNotApplicableError : public ScheduleError { + public: + explicit RFactorNotApplicableError(IRModule mod, Block block, int violated_cond) + : mod_(std::move(mod)), block_(std::move(block)), violated_cond_(violated_cond) {} + + String FastErrorString() const final { + return "ScheduleError: RFactor cannot be applied to the block since the block does not meet " + "the requirements"; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "RFactor cannot be applied to block {0}, because the block violates condition #" + << violated_cond_ << ".\n" + << kRFactorCrossThreadReductionApplicableBlockDef; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {block_}; } + + IRModule mod_; + Block block_; + int violated_cond_; + }; + + if (self.defined()) { + throw RFactorNotApplicableError(self.value()->mod, std::move(block), violated_cond); + } else { + LOG(FATAL) << "ValueError: Cross-thread reduction cannot be applied to the block " + << block->name_hint << " because the block violates the condition #" << violated_cond + << ".\n" + << kRFactorCrossThreadReductionApplicableBlockDef; + } +} + +/*! + * \brief Extract the BufferStores, which serve as the reduction updates, from the given LetStmt and + * the BufferStores inside. And meanwhile set the buffer order of the reduction + * \param self The schedule state, used for error reporting + * \param block The reduction block, used for error reporting + * \param let The LetStmt from which the reduction updates are extracted + * \param n_buffers The number of buffers participating in the reduction + * \param updates The extracted reduction updates + * \param buf2index A mapping from reduction buffers to their indices of the reduction order + * \throw ScheduleError If rfactor or cross-thread reduction cannot be applied to the block + */ +void ExtractReductionUpdates(const Optional& self, Block block, + const LetStmtNode* let, int n_buffers, Array* updates, + std::unordered_map* buf2index) { + std::unordered_map var2index; + Array let_values; + let_values.reserve(n_buffers); + updates->resize(n_buffers); + + // Step 1. + // - Extract the BufferStore values from the LetStmts. + // - Construct the mapping from let variables to the index. + for (int i = 0; i < n_buffers; ++i) { + if (let == nullptr) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/3); + } + + let_values.push_back(let->value); + auto insert_result = var2index.insert(std::make_pair(let->var.get(), i)); + if (!insert_result.second) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/4); + } + if (i != n_buffers - 1) { + let = let->body.as(); + } + } + + // There should be no more LetStmt. + if (let->body->IsInstance()) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/3); + } + + // Now `let` is expected to be the innermost LetStmt, whose body should either be a SeqStmt or + // a BufferStore + const auto* p_seq = let->body.as(); + const auto* p_buf_store = let->body.as(); + if (p_seq == nullptr && p_buf_store == nullptr) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/5); + } + SeqStmt seq = + p_seq != nullptr ? GetRef(p_seq) : SeqStmt({GetRef(p_buf_store)}); + if (static_cast(seq->seq.size()) != n_buffers) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/6); + } + + // Step 2. + // - Create BufferStores according to the variables being stored. + // - Construct the mapping from reduction buffers to the index. + for (const Stmt& stmt : seq->seq) { + const auto* buf_store = stmt.as(); + if (buf_store == nullptr) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/5); + } + const auto* var = buf_store->value.as(); + if (var == nullptr) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/7); + } + auto it = var2index.find(var); + if (it == var2index.end()) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/7); + } + int idx = it->second; + if ((*updates)[idx].defined()) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/8); + } + updates->Set(idx, BufferStore(buf_store->buffer, let_values[idx], buf_store->indices)); + auto insert_result = buf2index->insert(std::make_pair(buf_store->buffer.get(), idx)); + if (!insert_result.second) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/9); + } + } + for (int i = 0; i < n_buffers; ++i) { + ICHECK((*updates)[i].defined()); + } +} + +std::pair, Array> GetInitValuesAndUpdatesFromReductionBlock( + const Optional& self, Block block) { + Array inits; + Array updates; + + // Step 1. Extract the BufferStores serving as block inits. + if (const auto* init = block->init.as()) { + inits.push_back(GetRef(init)); + } else if (const auto* seq_init = block->init.as()) { + std::unordered_set init_buffers; + for (const Stmt& stmt : seq_init->seq) { + init = stmt.as(); + if (init == nullptr) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/1); + } + auto insert_result = init_buffers.insert(init->buffer.get()); + if (!insert_result.second) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/2); + } + inits.push_back(GetRef(init)); + } + } else { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/1); + } + + // Step 2. Extract the block updates, in the form of BufferStores. + int n_buffers = inits.size(); + std::unordered_map buf2index; + if (const auto* update = block->body.as()) { + updates.push_back(GetRef(update)); + buf2index[update->buffer.get()] = 0; + } else { + const auto* let = block->body.as(); + ExtractReductionUpdates(self, block, let, n_buffers, &updates, &buf2index); + } + ICHECK_EQ(updates.size(), n_buffers); + + // Step 3. Set the init values according to the buffer order in `updates`, with the help of the + // mapping `buf2index`. + Array init_values; + init_values.resize(n_buffers); + + // - Check all buffers have the same shape + // - Check all indices of the BufferStores are the same + // - Check buffers written in the block init and the block body can match + // - Check buffers do not duplicate + const Array& expected_shape = updates[0]->buffer->shape; + const Array& expected_indices = updates[0]->indices; + ICHECK_EQ(expected_shape.size(), expected_indices.size()); + int n_dim = expected_indices.size(); + arith::Analyzer ana; + for (int i = 0; i < n_buffers; ++i) { + if (static_cast(updates[i]->buffer->shape.size()) != n_dim) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/11); + } + if (static_cast(inits[i]->indices.size()) != n_dim || + static_cast(updates[i]->indices.size()) != n_dim) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/12); + } + for (int d = 0; d < n_dim; ++d) { + if (!ana.CanProveEqual(updates[i]->buffer->shape[d], expected_shape[d])) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/11); + } + if (!ana.CanProveEqual(inits[i]->indices[d], expected_indices[d]) || + !ana.CanProveEqual(updates[i]->indices[d], expected_indices[d])) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/12); + } + } + + auto it = buf2index.find(inits[i]->buffer.get()); + if (it == buf2index.end()) { + ErrorRFactorCrossThreadReductionNotApplicable(self, std::move(block), /*violated_cond=*/10); + } + int idx = it->second; + ICHECK(updates[idx]->buffer.same_as(inits[i]->buffer)); + ICHECK(!init_values[idx].defined()); + init_values.Set(idx, inits[i]->value); + } + for (int i = 0; i < n_buffers; ++i) { + ICHECK(init_values[i].defined()); + } + + return std::make_pair(init_values, updates); +} + +bool ContainsOnlyDataParAndReductionBlockIter(const Array& iters) { + for (const IterVar& iter_var : iters) { + if (iter_var->iter_type != kDataPar && iter_var->iter_type != kCommReduce) { + return false; + } + } + return true; +} + +bool ReductionIterNotIndexOutputBuffer(const Block& block) { + // Step 1. Collect the reduction block iters. + std::unordered_set reduction_block_iters; + reduction_block_iters.reserve(block->iter_vars.size()); + for (const IterVar& iter_var : block->iter_vars) { + if (iter_var->iter_type == kCommReduce) { + reduction_block_iters.insert(iter_var->var.get()); + } + } + // Step 2. Check if the reduction block iters are used to index the output buffer. + std::unordered_set buffer_written; + buffer_written.reserve(block->writes.size()); + for (const BufferRegion& write_region : block->writes) { + buffer_written.insert(write_region->buffer.get()); + } + auto f_uses_reduction_block_var = [&](const PrimExpr& expr) -> bool { + return UsesVar(expr, [&](const VarNode* var) { // + return reduction_block_iters.count(var); + }); + }; + bool affected = false; + PreOrderVisit(block->body, [&](const ObjectRef& obj) { + if (affected) { + return false; + } + const auto* store = obj.as(); + if (!store) { + return true; + } + ICHECK(buffer_written.count(store->buffer.get())) + << "ValueError: The buffer \"" << store->buffer + << "\" is written in the block but is not in the block's signature"; + for (const PrimExpr& index : store->indices) { + if (f_uses_reduction_block_var(index)) { + affected = true; + return false; + } + } + return false; + }); + return !affected; +} + +class NoMatchedReducerError : public ScheduleError { + public: + explicit NoMatchedReducerError(IRModule mod, Array identities, + Array combiners) + : mod_(std::move(mod)), + identities_(std::move(identities)), + combiners_(std::move(combiners)) {} + + String FastErrorString() const final { + return "ScheduleError: No matched reducer for the identity and the combiner of this reduction " + "block. So rfactor and cross-thread reduction cannot be applied."; + } + + String DetailRenderTemplate() const final { + std::ostringstream os; + os << "No matched reducer for identity " << identities_ << " and combiner " << combiners_ + << "In this case rfactor cannot be applied. You can check tvm::tir::ReducerRegistry for " + "default reducers or registering new reducers."; + return os.str(); + } + + IRModule mod() const final { return mod_; } + Array LocationsOfInterest() const final { return {}; } + + IRModule mod_; + Array identities_; + Array combiners_; +}; + +std::tuple, Array> GetReducerAndCombinerLhsRhs( + const Optional& self, const Array& identities, + const Array& combiners) { + CommReducer reducer{nullptr}; + Array combiner_lhs, combiner_rhs; + bool matched = + FromIdentityCombiner(identities, combiners, &reducer, &combiner_lhs, &combiner_rhs); + if (!matched) { + if (self.defined()) { + throw NoMatchedReducerError(self.value()->mod, identities, combiners); + } else { + LOG(FATAL) << "ValueError: No matched reducer for the identity and the combiner of the " + "reduction block. So rfactor and cross-thread reduction cannot be applied."; + } + } + return std::make_tuple(std::move(reducer), std::move(combiner_lhs), std::move(combiner_rhs)); +} + +/******** Commutative Reducer ********/ + +bool MatchReducer(const CommReducer& reducer, const Array& identities, + const Array& combined_values, const Array& buf_loads, + Array* lhs, Array* rhs) { + ExprDeepEqual equal; + ICHECK_EQ(identities.size(), combined_values.size()); + int n_buffers = identities.size(); + for (int i = 0; i < n_buffers; ++i) { + if (!equal(reducer->identity_element[i], identities[i])) { + return false; + } + } + + PatternMatcher pattern_matcher(reducer->result); + pattern_matcher.Match(combined_values); + Array lhs_tmp, rhs_tmp; + lhs_tmp.reserve(n_buffers); + rhs_tmp.reserve(n_buffers); + if (!pattern_matcher.Success()) { + return false; + } + + for (int i = 0; i < n_buffers; ++i) { + PrimExpr l = pattern_matcher.Eval(reducer->lhs[i]); + PrimExpr r = pattern_matcher.Eval(reducer->rhs[i]); + if (!equal(buf_loads[i], l)) { + return false; + } + lhs_tmp.push_back(l); + rhs_tmp.push_back(r); + } + *lhs = std::move(lhs_tmp); + *rhs = std::move(rhs_tmp); + return true; +} + +bool FromIdentityCombiner(const Array& identities, const Array& combiners, + CommReducer* result_reducer, Array* lhs, Array* rhs) { + int n = identities.size(); + Array buf_loads; + Array stored_values; + buf_loads.reserve(n); + stored_values.reserve(n); + + for (int i = 0; i < n; ++i) { + buf_loads.push_back(BufferLoad(combiners[i]->buffer, combiners[i]->indices)); + stored_values.push_back(combiners[i]->value); + } + + // Check reduction patterns. + for (const TypedPackedFunc(Array)>& reducer_getter : + GetReducerGetters()) { + Optional reducer = reducer_getter(identities); + if (!reducer.defined()) { + continue; + } + if (MatchReducer(reducer.value(), identities, stored_values, buf_loads, lhs, rhs)) { + *result_reducer = reducer.value(); + return true; + } + } + return false; +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 1198e67d710a..2dc47fa15bea 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -297,29 +297,85 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, */ struct ReducerRegistry { ReducerRegistry() - : reducer_getters{CreateReducerGetter([](const Var& x, const Var& y) { return x + y; }, - [](DataType dtype) { return make_const(dtype, 0); }), - CreateReducerGetter([](const Var& x, const Var& y) { return x * y; }, - [](DataType dtype) { return make_const(dtype, 1); }), - CreateReducerGetter([](const Var& x, const Var& y) { return min(x, y); }, - [](DataType dtype) { return max_value(dtype); }), - CreateReducerGetter([](const Var& x, const Var& y) { return max(x, y); }, - [](DataType dtype) { return min_value(dtype); })} {} - - static void RegisterReducer(TypedPackedFunc combiner_getter, - TypedPackedFunc identity_getter) { + : reducer_getters{CreateReducerGetter( + /*n_buffers=*/1, + [](const Array& x, const Array& y) { + return Array{x[0] + y[0]}; + }, + [](const Array& values) { + return Array{make_const(values[0]->dtype, 0)}; + }), + CreateReducerGetter( + /*n_buffers=*/1, + [](const Array& x, const Array& y) { + return Array{x[0] * y[0]}; + }, + [](const Array& values) { + return Array{make_const(values[0]->dtype, 1)}; + }), + CreateReducerGetter( + /*n_buffers=*/1, + [](const Array& x, const Array& y) { + return Array{min(x[0], y[0])}; + }, + [](const Array& values) { + return Array{max_value(values[0]->dtype)}; + }), + CreateReducerGetter( + /*n_buffers=*/1, + [](const Array& x, const Array& y) { + return Array{max(x[0], y[0])}; + }, + [](const Array& values) { + return Array{min_value(values[0]->dtype)}; + }), + CreateReducerGetter( + /*n_buffers=*/2, + [](const Array& x, const Array& y) { + PrimExpr idx = Select(x[1] >= y[1], x[0], y[0]); + PrimExpr val = Select(x[1] >= y[1], x[1], y[1]); + return Array{idx, val}; + }, + [](const Array& values) { + return Array{make_const(values[0]->dtype, -1), + min_value(values[1]->dtype)}; + }), + CreateReducerGetter( + /*n_buffers=*/2, + [](const Array& x, const Array& y) { + PrimExpr idx = Select(x[1] <= y[1], x[0], y[0]); + PrimExpr val = Select(x[1] <= y[1], x[1], y[1]); + return Array{idx, val}; + }, + [](const Array& values) { + return Array{make_const(values[0]->dtype, -1), + max_value(values[1]->dtype)}; + })} {} + + static void RegisterReducer( + int n_buffers, TypedPackedFunc(Array, Array)> combiner_getter, + TypedPackedFunc(Array)> identity_getter) { ReducerRegistry::Global()->reducer_getters.push_back(ReducerRegistry::CreateReducerGetter( - std::move(combiner_getter), std::move(identity_getter))); + n_buffers, std::move(combiner_getter), std::move(identity_getter))); } - static TypedPackedFunc CreateReducerGetter( - TypedPackedFunc combiner_getter, - TypedPackedFunc identity_getter) { - return [combiner_getter = std::move(combiner_getter), - identity_getter = std::move(identity_getter)](DataType dtype) -> CommReducer { - Var lhs("x", dtype); - Var rhs("y", dtype); - return CommReducer({lhs}, {rhs}, {combiner_getter(lhs, rhs)}, {identity_getter(dtype)}); + static TypedPackedFunc(Array)> CreateReducerGetter( + int n_buffers, TypedPackedFunc(Array, Array)> combiner_getter, + TypedPackedFunc(Array)> identity_getter) { + return [n_buffers, // + combiner_getter = std::move(combiner_getter), // + identity_getter = std::move(identity_getter) // + ](Array values) -> Optional { + if (static_cast(values.size()) != n_buffers) { + return NullOpt; + } + Array lhs; + Array rhs; + for (int i = 0; i < n_buffers; ++i) { + lhs.push_back(Var("x" + std::to_string(i), values[i]->dtype)); + rhs.push_back(Var("y" + std::to_string(i), values[i]->dtype)); + } + return CommReducer(lhs, rhs, combiner_getter(lhs, rhs), identity_getter(values)); }; } @@ -328,10 +384,10 @@ struct ReducerRegistry { return &instance; } - std::vector> reducer_getters; + std::vector(Array)>> reducer_getters; }; -std::vector> GetReducerGetters() { +std::vector(Array)>> GetReducerGetters() { return ReducerRegistry::Global()->reducer_getters; } @@ -508,44 +564,57 @@ std::unordered_map GetLoopVar2LoopMap(const Array& loo } /*! - * \brief Create the intermediate rfactor buffer, which the rfactor block writes to and the + * \brief Create the intermediate rfactor buffers, which the rfactor block writes to and the * write-back block reads from - * \param buffer The buffer written by the reduction block + * \param buf_stores The BufferStores of the original block, where the rfactor buffers will be + * created from * \param factor_axis The `factor_axis` parameter of rfactor * \param rf_loop The rfactor loop * \return The new created intermediate rfactor buffer */ -Buffer CreateRFactorBuffer(const Buffer& buffer, int factor_axis, const ForNode* rf_loop) { - Array rf_shape = buffer->shape; - rf_shape.insert(rf_shape.begin() + factor_axis, rf_loop->extent); - - ObjectPtr n = make_object(*buffer.get()); - n->shape = rf_shape; - n->name = buffer->name + ".rf"; - n->data = buffer->data.copy_with_suffix(".rf"); - return Buffer(n); +Array CreateRFactorBuffers(const Array& buf_stores, int factor_axis, + const ForNode* rf_loop) { + Array rf_buffers; + rf_buffers.reserve(buf_stores.size()); + for (const BufferStore& buf_store : buf_stores) { + Buffer buffer = buf_store->buffer; + Array rf_shape = buffer->shape; + rf_shape.insert(rf_shape.begin() + factor_axis, rf_loop->extent); + + ObjectPtr n = make_object(*buffer.get()); + n->shape = rf_shape; + n->name = buffer->name + ".rf"; + n->data = buffer->data.copy_with_suffix(".rf"); + rf_buffers.push_back(Buffer(n)); + } + return rf_buffers; } /*! * \brief The base class of the rfactor/write-back block creator, which creates the blocks in four * steps: * 1) Create the new block iters and the their iter bindings - * 2) Create the reduction update of the new block + * 2) Create the body and init of the new block * 3) Create the read/write regions of the new block * 4) Create the new block and the new block-realize */ class BaseBlockCreator { public: explicit BaseBlockCreator(BlockRealize old_block_realize, For rf_loop, - BufferStore old_reduction_update, CommReducer reducer, Buffer rf_buffer, - bool is_rf_block) + Array old_reduction_updates, CommReducer reducer, + Array rf_buffers, bool is_rf_block) : old_block_realize_(std::move(old_block_realize)), rf_loop_(std::move(rf_loop)), - old_reduction_update_(std::move(old_reduction_update)), + old_reduction_updates_(std::move(old_reduction_updates)), reducer_(std::move(reducer)), - rf_buffer_(std::move(rf_buffer)), + rf_buffers_(std::move(rf_buffers)), + n_buffers_(static_cast(rf_buffers_.size())), is_rf_block_(is_rf_block) { n_block_iters_ = static_cast(old_block_realize_->iter_values.size()); + update_buffers_.reserve(n_buffers_); + update_indices_.reserve(n_buffers_); + update_lhs_.reserve(n_buffers_); + update_rhs_.reserve(n_buffers_); } void CreateBlock() { @@ -560,7 +629,15 @@ class BaseBlockCreator { break; } } - CreateReductionUpdate(has_reduce_iter); + + // The pre-processing finds out the buffers written in the block, the indices of the buffer + // accesses, and the reduction LHS and RHS of the stored values. + PreProcess(); + Stmt block_body = Substitute(CreateBlockBody(has_reduce_iter), var_map_); + Optional block_init = CreateBlockInit(has_reduce_iter); + if (block_init.defined()) { + block_init = Substitute(block_init.value(), var_map_); + } CreateReadWriteRegions(); String new_block_name = old_block_realize_->block->name_hint; @@ -569,17 +646,13 @@ class BaseBlockCreator { new_block_name = new_block_name + "_rf"; predicate = old_block_realize_->predicate; } - Optional init_block = - has_reduce_iter ? BufferStore(new_reduction_update_->buffer, reducer_->identity_element[0], - new_reduction_update_->indices) - : Optional(NullOpt); new_block_ = Block( /*iter_vars=*/iter_vars_, /*reads=*/read_regions_, /*writes=*/write_regions_, /*name_hint=*/new_block_name, - /*body=*/new_reduction_update_, - /*init=*/init_block, + /*body=*/std::move(block_body), + /*init=*/std::move(block_init), /*alloc_buffers=*/{}, /*match_buffers=*/{}, /*annotations=*/old_block_realize_->block->annotations); @@ -589,9 +662,58 @@ class BaseBlockCreator { private: virtual void CreateAdditionalIter() = 0; virtual void CreateNormalIters(int idx) = 0; - virtual void CreateReductionUpdate(bool has_reduce_iter) = 0; + virtual void PreProcess() = 0; virtual void CreateReadWriteRegions() = 0; + Stmt CreateBlockBody(bool has_reduce_iter) { + Array buf_stores; + buf_stores.reserve(n_buffers_); + + // Case 1. If the block has no reduction iterator, we just store the RHS values into the + // buffers. + if (!has_reduce_iter) { + for (int i = 0; i < n_buffers_; ++i) { + buf_stores.push_back(BufferStore(update_buffers_[i], update_rhs_[i], update_indices_[i])); + } + return n_buffers_ > 1 ? SeqStmt(buf_stores) : buf_stores[0]; + } + + // Case 2. If the reduction is for single buffer, the block body is a single BufferStore. + Array stored_values = (*reducer_.get())(update_lhs_, update_rhs_); + if (n_buffers_ == 1) { + return BufferStore(update_buffers_[0], stored_values[0], update_indices_[0]); + } + + // Case 3. In case the reduction is for multiple buffers, we should create the reduction with + // LetStmt so that the reduction execution generates correct results. + Array let_vars; + let_vars.reserve(n_buffers_); + for (int i = 0; i < n_buffers_; ++i) { + Var var("v_" + update_buffers_[i]->name, PrimType(stored_values[i]->dtype)); + let_vars.push_back(var); + buf_stores.push_back(BufferStore(update_buffers_[i], var, update_indices_[i])); + } + Stmt body = SeqStmt(buf_stores); + for (int i = n_buffers_ - 1; i >= 0; --i) { + body = LetStmt(let_vars[i], stored_values[i], std::move(body)); + } + return body; + } + + Optional CreateBlockInit(bool has_reduce_iter) { + if (!has_reduce_iter) { + return NullOpt; + } + + Array inits; + inits.reserve(n_buffers_); + for (int i = 0; i < n_buffers_; ++i) { + inits.push_back( + BufferStore(update_buffers_[i], reducer_->identity_element[i], update_indices_[i])); + } + return n_buffers_ > 1 ? SeqStmt(inits) : inits[0]; + } + public: /*! \brief The new created block */ Block new_block_; @@ -607,12 +729,19 @@ class BaseBlockCreator { int n_block_iters_; /*! \brief The rfactor loop */ For rf_loop_; - /*! \brief The update BufferStore of the old block */ - BufferStore old_reduction_update_; + /*! \brief The update BufferStores of the old block */ + Array old_reduction_updates_; /*! \brief The matched commutative reducer */ CommReducer reducer_; - /*! \brief The intermediate rfactor buffer */ - Buffer rf_buffer_; + /*! \brief The intermediate rfactor buffers */ + Array rf_buffers_; + /*! \brief The number of rfactor buffers. */ + const int n_buffers_; + /*! + * \brief A mapping which maps old block iters to new expressions. The old iters will be replaced + * by the expressions in future substitution for the two blocks + */ + Map var_map_; /*! \brief Whether we are creating the rfactor block or the write-back block */ bool is_rf_block_; @@ -620,13 +749,14 @@ class BaseBlockCreator { std::vector iter_vars_; /*! \brief The new block iter bindings of the new created block-realize */ std::vector iter_values_; - /*! - * \brief A mapping which maps old block iters to new expressions. The old iters will be replaced - * by the expressions in future substitution for the two blocks - */ - Map var_map_; - /*! \brief The update BufferStore of the new created block */ - BufferStore new_reduction_update_; + /*! \brief The buffers updated in this block */ + Array update_buffers_; + /*! \brief The indices of the buffers updated in this block, respectively */ + Array> update_indices_; + /*! \brief The LHS values of the reduction in this block */ + Array update_lhs_; + /*! \brief THe RHS values of the reduction in this block */ + Array update_rhs_; /*! \brief The read regions of the new created block */ Array read_regions_; /*! \brief The write regions of the new created block */ @@ -658,13 +788,13 @@ class BaseBlockCreator { class RFactorBlockCreator : public BaseBlockCreator { public: explicit RFactorBlockCreator(BlockRealize old_block_realize, For rf_loop, - BufferStore old_reduction_update, CommReducer reducer, - Buffer rf_buffer, + Array old_reduction_updates, CommReducer reducer, + Array rf_buffers, std::unordered_map loop_vars2loop, - int factor_axis, PrimExpr combiner_rhs) + int factor_axis, Array combiner_rhs) : BaseBlockCreator(std::move(old_block_realize), std::move(rf_loop), - std::move(old_reduction_update), std::move(reducer), std::move(rf_buffer), - true), + std::move(old_reduction_updates), std::move(reducer), + std::move(rf_buffers), true), loop_vars2loop_(std::move(loop_vars2loop)), factor_axis_(factor_axis), combiner_rhs_(std::move(combiner_rhs)) {} @@ -718,41 +848,38 @@ class RFactorBlockCreator : public BaseBlockCreator { var_map_.Set(old_iter->var, Substitute(old_binding, loop_var2block_binding_)); } - void CreateReductionUpdate(bool has_reduce_iter) final { - rf_buf_access_indices_ = old_reduction_update_->indices; + void PreProcess() final { + // The accessed indices for all reduction buffers are the same. + rf_buf_access_indices_ = old_reduction_updates_[0]->indices; rf_buf_access_indices_.insert(rf_buf_access_indices_.begin() + factor_axis_, additional_iter_->var); - PrimExpr rhs{nullptr}; - if (has_reduce_iter) { - rhs = (*reducer_.get())({BufferLoad(rf_buffer_, rf_buf_access_indices_)}, {combiner_rhs_})[0]; - } else { - rhs = combiner_rhs_; + for (int i = 0; i < n_buffers_; ++i) { + update_buffers_.push_back(rf_buffers_[i]); + update_indices_.push_back(rf_buf_access_indices_); + update_lhs_.push_back(BufferLoad(update_buffers_[i], rf_buf_access_indices_)); + update_rhs_.push_back(combiner_rhs_[i]); } - new_reduction_update_ = BufferStore(rf_buffer_, rhs, rf_buf_access_indices_); - new_reduction_update_ = Downcast(Substitute(new_reduction_update_, var_map_)); } void CreateReadWriteRegions() final { + Map buffer_map; + for (int i = 0; i < n_buffers_; ++i) { + buffer_map.Set(old_reduction_updates_[i]->buffer, rf_buffers_[i]); + } const Block& old_block = old_block_realize_->block; - read_regions_ = CreateRegions(old_block->reads); - write_regions_ = CreateRegions(old_block->writes); - } - - Array CreateRegions(const Array& old_regions) { - Array new_regions; - new_regions.reserve(old_regions.size()); - for (const BufferRegion& buffer_region : old_regions) { - if (buffer_region->buffer.same_as(old_reduction_update_->buffer)) { - Array region = buffer_region->region; - region.insert(region.begin() + factor_axis_, - Range::FromMinExtent(additional_iter_->var, 1)); - new_regions.push_back(BufferRegion(rf_buffer_, Substitute(region, var_map_))); - } else { - new_regions.push_back( - BufferRegion(buffer_region->buffer, Substitute(buffer_region->region, var_map_))); - } + read_regions_.reserve(old_block->reads.size()); + for (const BufferRegion& read_region : old_block->reads) { + read_regions_.push_back( + BufferRegion(read_region->buffer, Substitute(read_region->region, var_map_))); + } + write_regions_.reserve(old_block->writes.size()); + for (const BufferRegion& write_region : old_block->writes) { + Array region = write_region->region; + region.insert(region.begin() + factor_axis_, Range::FromMinExtent(additional_iter_->var, 1)); + Optional rf_buffer = buffer_map.Get(write_region->buffer); + ICHECK(rf_buffer.defined()); + write_regions_.push_back(BufferRegion(rf_buffer.value(), Substitute(region, var_map_))); } - return new_regions; } public: @@ -767,8 +894,8 @@ class RFactorBlockCreator : public BaseBlockCreator { std::unordered_map loop_vars2loop_; /*! \brief The factor_axis specified for rfactor */ int factor_axis_; - /*! \brief The rhs of the combiner in the reduction update of the old block */ - PrimExpr combiner_rhs_; + /*! \brief The RHS values of the reduction in the old block */ + Array combiner_rhs_; /*! * \brief A mapping which maps loop vars to new created block iters. This map is used to * substitute the loop vars which appear in the bindings of some old block iters with the new @@ -784,12 +911,13 @@ class RFactorBlockCreator : public BaseBlockCreator { class WriteBackBlockCreator : public BaseBlockCreator { public: explicit WriteBackBlockCreator(BlockRealize old_block_realize, For rf_loop, - BufferStore old_reduction_update, CommReducer reducer, - Buffer rf_buffer, IterVar rf_additional_iter, - PrimExpr combiner_lhs, Array rf_buf_access_indices) + Array old_reduction_updates, CommReducer reducer, + Array rf_buffers, IterVar rf_additional_iter, + Array combiner_lhs, + Array rf_buf_access_indices) : BaseBlockCreator(std::move(old_block_realize), std::move(rf_loop), - std::move(old_reduction_update), std::move(reducer), std::move(rf_buffer), - false), + std::move(old_reduction_updates), std::move(reducer), + std::move(rf_buffers), false), rf_additional_iter_(std::move(rf_additional_iter)), combiner_lhs_(std::move(combiner_lhs)) { iter_vars_.reserve(n_block_iters_); @@ -817,39 +945,40 @@ class WriteBackBlockCreator : public BaseBlockCreator { } } - void CreateReductionUpdate(bool has_reduce_iter) final { - wb_lhs_ = Downcast(Substitute(combiner_lhs_, var_map_)); - wb_rhs_ = - Downcast(Substitute(BufferLoad(rf_buffer_, rf_buf_access_indices_), var_map_)); - new_reduction_update_ = - BufferStore(old_reduction_update_->buffer, (*reducer_.get())({wb_lhs_}, {wb_rhs_})[0], - old_reduction_update_->indices); - new_reduction_update_ = Downcast(Substitute(new_reduction_update_, var_map_)); + void PreProcess() final { + for (int i = 0; i < n_buffers_; ++i) { + PrimExpr rhs = BufferLoad(rf_buffers_[i], rf_buf_access_indices_); + update_buffers_.push_back(old_reduction_updates_[i]->buffer); + update_indices_.push_back(old_reduction_updates_[i]->indices); + update_lhs_.push_back(Substitute(combiner_lhs_[i], var_map_)); + update_rhs_.push_back(Substitute(std::move(rhs), var_map_)); + } } void CreateReadWriteRegions() final { - read_regions_.push_back(CreateRegion(wb_rhs_)); - write_regions_.push_back(CreateRegion(wb_lhs_)); + CreateRegion(update_rhs_, true); + CreateRegion(update_lhs_, false); } - static BufferRegion CreateRegion(const BufferLoad& load) { - Array region; - region.reserve(load->indices.size()); - for (const PrimExpr& index : load->indices) { - region.push_back(Range::FromMinExtent(index, 1)); + void CreateRegion(const Array& buf_loads, bool is_read) { + Array& buf_regions = is_read ? read_regions_ : write_regions_; + for (const PrimExpr& expr : buf_loads) { + const auto* buf_load = expr.as(); + ICHECK(buf_load != nullptr); + Array region; + region.reserve(buf_load->indices.size()); + for (const PrimExpr& index : buf_load->indices) { + region.push_back(Range::FromMinExtent(index, 1)); + } + buf_regions.push_back(BufferRegion(buf_load->buffer, std::move(region))); } - return BufferRegion(load->buffer, std::move(region)); } private: /*! \brief The new created additional block iter of the rfactor block */ IterVar rf_additional_iter_; - /*! \brief The lhs of the combiner in the reduction update of the old block */ - PrimExpr combiner_lhs_; - /*! \brief The lhs of the combiner of the write-back block */ - BufferLoad wb_lhs_; - /*! \brief The rhs of the combiner of the write-back block */ - BufferLoad wb_rhs_; + /*! \brief The LHS values of the reduction in the old block */ + Array combiner_lhs_; }; /*! @@ -924,14 +1053,16 @@ class BlockReplacer : public StmtMutator { BlockRealize wb_block_realize, BlockRealize old_block_realize, For rf_loop, std::unordered_set reduce_loop_vars, std::unordered_map loop_vars2loop, - const Buffer& rf_buffer) { + const Array& rf_buffers) { BlockReplacer replacer(std::move(rf_body), std::move(outermost_loop), std::move(wb_block_realize), std::move(old_block_realize), std::move(rf_loop), std::move(reduce_loop_vars), std::move(loop_vars2loop)); Block new_scope_root = Downcast(replacer(std::move(scope_root_block))); BlockNode* p = new_scope_root.CopyOnWrite(); - p->alloc_buffers.push_back(rf_buffer); + for (const Buffer& rf_buffer : rf_buffers) { + p->alloc_buffers.push_back(rf_buffer); + } return new_scope_root; } @@ -1040,13 +1171,19 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax // commutative reducer, combiner lhs and combiner rhs from the reduction identity and the // reduction combiner. The lhs will be used when constructing the write-back block, and the rhs // will be used when constructing the rfactor block. - auto [init, update] = GetBufferStoresFromReductionBlock(self, block); - auto [reducer, combiner_lhs, combiner_rhs] = - GetReducerAndCombinerLhsRhs(self, init->value, update); + Array init_values{nullptr}; + Array updates{nullptr}; + CommReducer reducer{nullptr}; + Array combiner_lhs{nullptr}; + Array combiner_rhs{nullptr}; + std::tie(init_values, updates) = GetInitValuesAndUpdatesFromReductionBlock(self, block); + std::tie(reducer, combiner_lhs, combiner_rhs) = + GetReducerAndCombinerLhsRhs(self, init_values, updates); // Step 6. Check whether `factor_axis` is in a correct range, and convert it to non-negative if it // is negative. - factor_axis = FactorAxisOutOfRangeError::CheckAndUpdate(self->mod, update->buffer, factor_axis); + factor_axis = + FactorAxisOutOfRangeError::CheckAndUpdate(self->mod, updates[0]->buffer, factor_axis); // ***************************************************** // * IR Manipulation * @@ -1056,17 +1193,17 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax // Step 1. Create the intermediate buffer (a.k.a. rfactor buffer), which has an additional // dimension that specified by `factor_axis` and `rf_loop`. - Buffer rf_buffer = CreateRFactorBuffer(update->buffer, factor_axis, rf_loop); + Array rf_buffers = CreateRFactorBuffers(updates, factor_axis, rf_loop); // Step 2. Create the rfactor block. - RFactorBlockCreator rf_block_creator(block_realize, GetRef(rf_loop), update, reducer, - rf_buffer, loop_vars2loop, factor_axis, + RFactorBlockCreator rf_block_creator(block_realize, GetRef(rf_loop), updates, reducer, + rf_buffers, loop_vars2loop, factor_axis, std::move(combiner_rhs)); rf_block_creator.CreateBlock(); // Step 3. Create the write-back block. - WriteBackBlockCreator wb_block_creator(block_realize, GetRef(rf_loop), update, reducer, - rf_buffer, std::move(rf_block_creator.additional_iter_), + WriteBackBlockCreator wb_block_creator(block_realize, GetRef(rf_loop), updates, reducer, + rf_buffers, std::move(rf_block_creator.additional_iter_), std::move(combiner_lhs), std::move(rf_block_creator.rf_buf_access_indices_)); wb_block_creator.CreateBlock(); @@ -1082,7 +1219,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax Block old_scope_root_block = GetRef(scope_root->StmtAs()); Block new_scope_root_block = BlockReplacer::Replace( old_scope_root_block, rf_body, loops[0], wb_block_creator.new_block_realize_, block_realize, - GetRef(rf_loop), reduce_loop_vars, loop_vars2loop, rf_buffer); + GetRef(rf_loop), reduce_loop_vars, loop_vars2loop, rf_buffers); self->Replace( scope_root, new_scope_root_block, {{old_scope_root_block, new_scope_root_block}, {block, wb_block_creator.new_block_}}); @@ -1157,8 +1294,9 @@ TVM_REGISTER_INST_KIND_TRAITS(DecomposeReductionTraits); /******** FFI ********/ TVM_REGISTER_GLOBAL("tir.schedule.RegisterReducer") - .set_body_typed([](PackedFunc combiner_getter, PackedFunc identity_getter) { - ReducerRegistry::RegisterReducer(std::move(combiner_getter), std::move(identity_getter)); + .set_body_typed([](int n_buffers, PackedFunc combiner_getter, PackedFunc identity_getter) { + ReducerRegistry::RegisterReducer(n_buffers, std::move(combiner_getter), + std::move(identity_getter)); }); } // namespace tir diff --git a/src/tir/transforms/lower_cross_thread_reduction.cc b/src/tir/transforms/lower_cross_thread_reduction.cc index 04b025b5f9ae..c10555e74d07 100644 --- a/src/tir/transforms/lower_cross_thread_reduction.cc +++ b/src/tir/transforms/lower_cross_thread_reduction.cc @@ -111,70 +111,66 @@ bool IsReductionBlock(const BlockRealize& realize, const Map& loop_r } /*! - * \brief Create an intermediate buffer with specified name and data type - * \param name The specified name - * \param dtype The specified data type - * \return The created buffer + * \brief Create intermediate buffers according to the input buffers and buffer kind + * \param reduction_buffers The old reduction buffers which provide the buffer names and data types + * \param is_cross_thread_buffer A boolean indicating whether to create buffers for the cross-thread + * computation results or not, which is used for determine the buffer name prefix + * \return The created buffers */ -Buffer MakeScratchpad(String name, const DataType& dtype) { - return Buffer(/*ptr=*/Var(name, PointerType(PrimType(dtype), "local")), - /*dtype=*/dtype, - /*shape=*/{Integer(1)}, - /*strides=*/{Integer(1)}, - /*elem_offset=*/PrimExpr{nullptr}, - /*name=*/name, - /*data_alignment=*/0, - /*offset_factor=*/0, - /*buffer_type=*/kDefault); -} - -/*! - * \brief Remove the BufferRegions whose buffer is the input buffer - * \param buffer_regions The array of BufferRegions to be - * \param buffer_to_remove The specified buffer - * \return The mutated array of BufferRegions, no longer containing BufferRegion of the input buffer - */ -Array RemoveBufferFromBufferRegions(const Array& buffer_regions, - const Buffer& buffer_to_remove) { - Array res; - res.reserve(buffer_regions.size()); - for (const BufferRegion& buffer_region : buffer_regions) { - if (!buffer_region->buffer.same_as(buffer_to_remove)) { - res.push_back(buffer_region); - } +Array MakeScratchpads(const Array& reduction_buffers, bool is_cross_thread_buffer) { + Array new_buffers; + new_buffers.reserve(reduction_buffers.size()); + for (const Buffer& buffer : reduction_buffers) { + String name = is_cross_thread_buffer ? "cross" : "in"; + name = name + "_thread_" + buffer->name; + new_buffers.push_back(Buffer(/*ptr=*/Var(name, PointerType(PrimType(buffer->dtype), "local")), + /*dtype=*/buffer->dtype, + /*shape=*/{Integer(1)}, + /*strides=*/{Integer(1)}, + /*elem_offset=*/PrimExpr{nullptr}, + /*name=*/name, + /*data_alignment=*/0, + /*offset_factor=*/0, + /*buffer_type=*/kDefault)); } - return res; + return new_buffers; } /*! - * \brief Substitute a given source buffer with a given target buffer in statements or expressions + * \brief Substitute given source buffers with given target buffers respectively in the input + * statement */ class BufferReplacer : private StmtExprMutator { public: - static Stmt Run(Buffer src_buffer, Buffer tgt_buffer, Stmt stmt) { - return BufferReplacer(src_buffer, tgt_buffer)(std::move(stmt)); + static Stmt Run(Array src_buffers, Array tgt_buffers, Stmt stmt) { + Map buffer_map; + ICHECK_EQ(src_buffers.size(), tgt_buffers.size()); + int n_buffers = src_buffers.size(); + for (int i = 0; i < n_buffers; ++i) { + buffer_map.Set(src_buffers[i], tgt_buffers[i]); + } + return BufferReplacer(buffer_map)(std::move(stmt)); } private: - explicit BufferReplacer(Buffer src_buffer, Buffer tgt_buffer) - : src_buffer_(std::move(src_buffer)), tgt_buffer_(std::move(tgt_buffer)) {} + explicit BufferReplacer(Map buffer_map) : buffer_map_(std::move(buffer_map)) {} PrimExpr VisitExpr_(const BufferLoadNode* load) final { - return load->buffer.same_as(src_buffer_) ? BufferLoad(tgt_buffer_, {0}) - : GetRef(load); + auto it = buffer_map_.find(load->buffer); + return it != buffer_map_.end() ? BufferLoad((*it).second, {0}) : GetRef(load); } Stmt VisitStmt_(const BufferStoreNode* store) final { - if (store->buffer.same_as(src_buffer_)) { + auto it = buffer_map_.find(store->buffer); + if (it != buffer_map_.end()) { PrimExpr value = StmtExprMutator::VisitExpr(store->value); - return BufferStore(tgt_buffer_, value, {0}); + return BufferStore((*it).second, std::move(value), {0}); } else { return StmtMutator::VisitStmt_(store); } } - Buffer src_buffer_; - Buffer tgt_buffer_; + Map buffer_map_; }; /*! @@ -231,25 +227,40 @@ class InThreadReducerMaker : private StmtMutator { /*! * \brief Create the lowered allreduce block transformed from the input reduction block - * \param reduction_block The input reduction block - * \param it_buffer The buffer to store in-thread reduction results - * \param ct_buffer The buffer to store cross-thread reduction results + * \param realize The block-realize which contains the old reduction block + * \param it_buffers The buffers to store in-thread reduction results + * \param ct_buffers The buffers to store cross-thread reduction results + * \param wb_buffers The buffers to store the final reduction results + * \param old_wb_indices The indices used to access the write-back buffers when storing the final + * reduction results into the write-back buffers * \param reducer The reduction function - * \param combiner_rhs The RHS of the combiner + * \param combiner_rhs The RHS values of the combiner * \param reduction_loops The reduction loops */ -Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optional& it_buffer, - const Buffer& ct_buffer, const CommReducer& reducer, - const PrimExpr& combiner_rhs, +Stmt TransformReductionBlock(const BlockRealizeNode* realize, // + const Optional>& it_buffers, // + const Array& ct_buffers, // + const Array& wb_buffers, // + const Array& old_wb_indices, // + const CommReducer& reducer, // + const Array& combiner_rhs, // const std::vector& reduction_loops) { + int n_buffers = wb_buffers.size(); const BlockNode* block = realize->block.get(); - Buffer wb_buffer = block->writes[0]->buffer; - Array wb_region = block->writes[0]->region; - BufferRegion ct_buffer_region(ct_buffer, {Range::FromMinExtent(0, 1)}); - Optional it_buffer_region = NullOpt; - if (it_buffer.defined()) { - it_buffer_region = BufferRegion(it_buffer.value(), {Range::FromMinExtent(0, 1)}); + auto f_create_buffer_regions = [](Array buffers) { + Array regions; + regions.reserve(buffers.size()); + for (const Buffer& buffer : buffers) { + regions.push_back(BufferRegion(buffer, {Range::FromMinExtent(0, 1)})); + } + return regions; + }; + + Array ct_buffer_regions = f_create_buffer_regions(ct_buffers); + Optional> it_buffer_regions = NullOpt; + if (it_buffers.defined()) { + it_buffer_regions = f_create_buffer_regions(it_buffers.value()); } // In total, the block is transformed into at most 4 statements // - Stmt 1: initialize the buffer for in-thread reduction @@ -259,35 +270,35 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optional stmts; stmts.reserve(4); // Stmt 1: initialize the buffer for in-thread reduction - if (it_buffer.defined()) { - BufferStore init = Downcast(block->init); - stmts.push_back(BlockRealize( - /*iter_values=*/{}, - /*predicate=*/const_true(), - /*block=*/ - Block(/*iter_vars=*/{}, - /*reads=*/{}, - /*writes=*/{it_buffer_region.value()}, - /*name_hint=*/block->name_hint + "_in_thread_init", - /*body=*/ - BufferStore(/*buffer=*/it_buffer.value(), - /*value=*/init->value, - /*indices=*/{Integer(0)})))); + if (it_buffers.defined()) { + Array inits; + inits.reserve(n_buffers); + for (int i = 0; i < n_buffers; ++i) { + inits.push_back( + BufferStore(it_buffers.value()[i], reducer->identity_element[i], {Integer(0)})); + } + stmts.push_back(BlockRealize(/*iter_values=*/{}, + /*predicate=*/const_true(), + /*block=*/ + Block(/*iter_vars=*/{}, + /*reads=*/{}, + /*writes=*/it_buffer_regions.value(), + /*name_hint=*/block->name_hint + "_in_thread_init", + /*body=*/n_buffers > 1 ? SeqStmt(inits) : inits[0]))); } // Stmt 2: do in-thread reduction { Optional new_realize = NullOpt; // If need to generate in-thread reduction, - // then replace `wb_buffer` with `it_buffer` accordingly in given BlockRealize + // then replace `wb_buffers` with `it_buffers` accordingly in given BlockRealize // otherwise, directly remove given BlockRealize - if (it_buffer.defined()) { + if (it_buffers.defined()) { ObjectPtr new_block = make_object(*block); - new_block->reads = RemoveBufferFromBufferRegions(std::move(new_block->reads), wb_buffer); - new_block->reads.push_back(it_buffer_region.value()); - new_block->writes = {it_buffer_region.value()}; + new_block->reads = std::move(new_block->reads); + new_block->writes = it_buffer_regions.value(); new_block->name_hint = new_block->name_hint + "_in_thread"; new_block->body = - BufferReplacer::Run(wb_buffer, it_buffer.value(), std::move(new_block->body)); + BufferReplacer::Run(wb_buffers, it_buffers.value(), std::move(new_block->body)); new_block->init = NullOpt; ObjectPtr n = make_object(*realize); n->block = Block(new_block); @@ -303,19 +314,23 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optional parameters; parameters.reserve(reduction_loops.size() + 4); - // 1-st argument: size - parameters.push_back(make_const(DataType::UInt(32), 1)); - // 2-nd argument: source - if (it_buffer.defined()) { - parameters.push_back(BufferLoad(it_buffer.value(), {Integer(0)})); + // 1-st argument: number of buffers + parameters.push_back(make_const(DataType::UInt(32), n_buffers)); + // Next `n_buffers` arguments: sources + if (it_buffers.defined()) { + for (int i = 0; i < n_buffers; ++i) { + parameters.push_back(BufferLoad(it_buffers.value()[i], {Integer(0)})); + } } else { - parameters.push_back(combiner_rhs); + parameters.insert(parameters.end(), combiner_rhs.begin(), combiner_rhs.end()); } - // 3-rd argument: predicate + // Next argument: predicate parameters.push_back(const_true()); - // 4-th argument: destination - parameters.push_back(BufferLoad(ct_buffer, {0})); - // next arguments: all the reduction threads + // Next `n_buffers` arguments: destinations + for (int i = 0; i < n_buffers; ++i) { + parameters.push_back(BufferLoad(ct_buffers[i], {0})); + } + // Next arguments: all the reduction threads for (const ForNode* reduction_loop : reduction_loops) { if (reduction_loop->thread_binding.defined()) { parameters.push_back(reduction_loop->loop_var); @@ -325,14 +340,14 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optional iter_vars{nullptr}; Array bindings{nullptr}; Array reads{nullptr}; - if (it_buffer.defined()) { + if (it_buffers.defined()) { iter_vars = Array{}; bindings = Array{}; - reads = {it_buffer_region.value()}; + reads = it_buffer_regions.value(); } else { iter_vars = block->iter_vars; bindings = realize->iter_values; - reads = {RemoveBufferFromBufferRegions(block->reads, wb_buffer)}; + reads = block->reads; } stmts.push_back(BlockRealize( /*iter_values=*/std::move(bindings), @@ -340,7 +355,7 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optionalname_hint + "_cross_thread", /*body=*/ AttrStmt(/*node=*/reducer, @@ -376,21 +391,31 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, const Optionalvar, new_iter_var->var); } } - BufferStore update = Downcast(block->body); - update = Downcast(Substitute(std::move(update), var_map)); + Array wb_updates; + Array wb_regions; + wb_updates.reserve(n_buffers); + wb_regions.reserve(n_buffers); + int n_dim = static_cast(old_wb_indices.size()); + Array region = Substitute(block->writes[0]->region, var_map); + Array wb_indices; + wb_indices.reserve(n_dim); + for (int d = 0; d < n_dim; ++d) { + wb_indices.push_back(Substitute(old_wb_indices[d], var_map)); + } + for (int i = 0; i < n_buffers; ++i) { + wb_updates.push_back( + BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {Integer(0)}), wb_indices)); + wb_regions.push_back(BufferRegion(wb_buffers[i], region)); + } stmts.push_back(BlockRealize( /*iter_values=*/std::move(bindings), /*predicate=*/const_true(), /*block=*/ - Block( - /*iter_vars=*/std::move(iter_vars), - /*reads=*/{std::move(ct_buffer_region)}, - /*writes=*/{BufferRegion(wb_buffer, Substitute(wb_region, var_map))}, - /*name_hint=*/block->name_hint + "_write_back", - /*body=*/ - BufferStore(/*buffer=*/wb_buffer, - /*value=*/BufferLoad(ct_buffer, {Integer(0)}), - /*indices=*/update->indices)))); + Block(/*iter_vars=*/std::move(iter_vars), + /*reads=*/std::move(ct_buffer_regions), + /*writes=*/std::move(wb_regions), + /*name_hint=*/block->name_hint + "_write_back", + /*body=*/n_buffers > 1 ? SeqStmt(wb_updates) : wb_updates[0]))); } // Final step: Wrap all the above four statements with the reduction loops bound to threadIdx Stmt new_stmt = SeqStmt::Flatten(std::move(stmts)); @@ -447,18 +472,23 @@ class CrossThreadReductionTransformer : public StmtMutator { return need ? reduction_loops : std::vector{}; } - // Given that the input block needs cross-thread reduction, check if cross-thread reduction can - // be applied to the block (i.e., the block satisfies all necessary conditions of cross-thread - // reduction). - std::tuple CheckCanApplyCrossThreadReduction( - const BlockNode* block, const std::vector& reduction_loops) const { - // Condition 1. The block being applied cross-thread reduction should write to single buffer. - CHECK_EQ(block->writes.size(), 1) - << "ValueError: Cross-thread reduction requires the block to only " - "write to single buffer. However, the block " - << block->name_hint << " writes to " << block->writes.size() << " buffer(s)."; - - // Condition 2. All the reduction-related loops should be the deepest among all statements + /*! + * \brief Given that the input block needs cross-thread reduction, check if cross-thread reduction + * can be applied to the block (i.e., the block satisfies all necessary conditions of cross-thread + * reduction) + * \param block The block to be checked + * \param reduction_loops The reduction loops above the block + * \return A tuple consisting of five elements: + * - an integer which indicates the number of reduction loops that are bound to thread axes, + * - the detected commutative reducer of the reduction, + * - the reduction buffers which store the reduction results, + * - the RHS values of the reduction updates, + * - the indices which is used to access the reduction buffers when storing the reduction results + */ + std::tuple, Array, Array> + CheckCanApplyCrossThreadReduction(const BlockNode* block, + const std::vector& reduction_loops) const { + // Condition 1. All the reduction-related loops should be the deepest among all statements // outside the block (ignoring SeqStmt here). int n_deepest_reduction_loops = 0; for (auto rit = statement_stack_.rbegin() + 1; rit != statement_stack_.rend(); ++rit) { @@ -480,7 +510,7 @@ class CrossThreadReductionTransformer : public StmtMutator { << " needs cross-thread reduction, while the reduction-related loops outside of it are not " "the deepest statements, which violates the condition."; - // Condition 3. All the reduction-related loops that are bound to thread axes should only be + // Condition 2. All the reduction-related loops that are bound to thread axes should only be // bound to `threadIdx.x/y/z`. int n_bound_reduction_loops = 0; for (const ForNode* reduction_loop : reduction_loops) { @@ -493,16 +523,26 @@ class CrossThreadReductionTransformer : public StmtMutator { } } - // Condition 4. Get the `init` identity and the `update` combiner of the reduction. They should - // both be BufferStores with the same buffer and indices; - // Extract the commutative reducer, combiner lhs and combiner rhs from the reduction identity - // and the reduction combiner. - auto [init, update] = GetBufferStoresFromReductionBlock(NullOpt, GetRef(block)); - auto [reducer, combiner_lhs, combiner_rhs] = - GetReducerAndCombinerLhsRhs(NullOpt, init->value, update); - (void)combiner_lhs; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 + // Condition 3. Get the identity values of the block init and the BufferStore block combiner + // updates of the reduction. Extract the commutative reducer, combiner lhs and combiner rhs from + // the reduction identities and the reduction combiner. + Array init_values{nullptr}; + Array updates{nullptr}; + CommReducer reducer{nullptr}; + Array combiner_lhs{nullptr}; + Array combiner_rhs{nullptr}; + std::tie(init_values, updates) = + GetInitValuesAndUpdatesFromReductionBlock(NullOpt, GetRef(block)); + std::tie(reducer, combiner_lhs, combiner_rhs) = + GetReducerAndCombinerLhsRhs(NullOpt, init_values, updates); + + Array reduction_buffers; + reduction_buffers.reserve(updates.size()); + for (const BufferStore& buf_store : updates) { + reduction_buffers.push_back(buf_store->buffer); + } - // Condition 5. The block should be the last block under the first reduction-related loop. + // Condition 4. The block should be the last block under the first reduction-related loop. bool visit = false; PreOrderVisit(GetRef(reduction_loops[0]), [block, &visit](const ObjectRef& obj) { if (const auto* realize = obj.as()) { @@ -515,7 +555,11 @@ class CrossThreadReductionTransformer : public StmtMutator { } return true; }); - return std::make_tuple(n_bound_reduction_loops, reducer, combiner_rhs); + return std::make_tuple(n_bound_reduction_loops, // + std::move(reducer), // + std::move(reduction_buffers), // + std::move(combiner_rhs), // + updates[0]->indices); } Stmt VisitStmt(const Stmt& stmt) final { @@ -570,10 +614,14 @@ class CrossThreadReductionTransformer : public StmtMutator { if (reduction_loops.empty()) { return StmtMutator::VisitStmt_(realize); } - ++reduction_id_; // Step 2. Check whether cross-thread reduction can be applied. If no, throw an exception on // which condition the block violates. - auto [n_bound_reduction_loops, reducer, combiner_rhs] = + int n_bound_reduction_loops = 0; + CommReducer reducer{nullptr}; + Array reduction_buffers{nullptr}; + Array combiner_rhs{nullptr}; + Array wb_indices{nullptr}; + std::tie(n_bound_reduction_loops, reducer, reduction_buffers, combiner_rhs, wb_indices) = CheckCanApplyCrossThreadReduction(block, reduction_loops); // Step 3. Before doing the cross-thread reduction, in-thread reduction is needed when // - not all the reduction-related loops are bound to thread axes, or @@ -581,31 +629,30 @@ class CrossThreadReductionTransformer : public StmtMutator { bool need_in_thread_reduction = n_bound_reduction_loops < static_cast(reduction_loops.size()) || !is_one(realize->predicate); - // Step 4. Create intermediate buffers, storing them in `ct_buffer` and - // `it_buffer`. Let the scope block allocate these new buffers. - std::vector& new_buffers = block2new_buffers_[block_stack_.back()]; - DataType dtype = block->writes[0]->buffer->dtype; - Buffer ct_buffer = MakeScratchpad("cross_thread_" + std::to_string(reduction_id_), dtype); - new_buffers.push_back(ct_buffer); - Optional it_buffer = NullOpt; + // Step 4. Create intermediate buffers, storing them in `ct_buffers` and + // `it_buffers`. Let the scope block allocate these new buffers. + Array& new_buffers = block2new_buffers_[block_stack_.back()]; + Array ct_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/true); + new_buffers.insert(new_buffers.end(), ct_buffers.begin(), ct_buffers.end()); + Optional> it_buffers = NullOpt; if (need_in_thread_reduction) { - it_buffer = MakeScratchpad("in_thread_" + std::to_string(reduction_id_), dtype); - new_buffers.push_back(it_buffer.value()); + it_buffers = MakeScratchpads(reduction_buffers, /*is_cross_thread_buffer=*/false); + new_buffers.insert(new_buffers.end(), it_buffers.value().begin(), it_buffers.value().end()); } // Step 5. Transform. - loop2new_stmt_[reduction_loops[0]] = TransformReductionBlock( - realize, it_buffer, ct_buffer, reducer, combiner_rhs, reduction_loops); + loop2new_stmt_[reduction_loops[0]] = + TransformReductionBlock(realize, it_buffers, ct_buffers, reduction_buffers, wb_indices, + reducer, combiner_rhs, reduction_loops); // Step 6. Return an empty statement, because the transformation result will be inserted when // returning to the first reduction-related loop. return Stmt{nullptr}; } private: - int reduction_id_ = -1; std::vector statement_stack_; std::vector loop_stack_; std::vector block_stack_; - std::unordered_map> block2new_buffers_; + std::unordered_map> block2new_buffers_; std::unordered_map loop2new_stmt_; Map loop_range_map_; arith::Analyzer analyzer_; diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py index 17f42654fcf7..70b49944ba0f 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py @@ -119,5 +119,171 @@ def cpu_matmul_2( ) +def test_cpu_argmax(): + @T.prim_func + def argmax( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], + ) -> None: + for i0, i1 in T.grid(128, 128): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.min_value("float32") + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select( + argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k] + ) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + + @T.prim_func + def argmax_0( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[128, "int32"], + argmax_v1: T.Buffer[128, "float32"], + ) -> None: + for i0, i1 in T.grid(128, 128): + with T.block("argmax"): + i, k = T.axis.remap("SR", [i0, i1]) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.float32(-3.4028234663852886e38) + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select( + argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k] + ) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + + @T.prim_func + def argmax_1( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[128, "int32"], + argmax_v1: T.Buffer[128, "float32"], + ) -> None: + argmax_v0_rf = T.alloc_buffer([128, 16], dtype="int32") + argmax_v1_rf = T.alloc_buffer([128, 16], dtype="float32") + for i0, i1_0, i1_1 in T.grid(128, 8, 16): + with T.block("argmax_rf"): + vi1_1, i, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0]) + T.reads(idx[i, vi1_0 * 16 + vi1_1], val[i, vi1_0 * 16 + vi1_1]) + T.writes(argmax_v0_rf[i, vi1_1], argmax_v1_rf[i, vi1_1]) + with T.init(): + argmax_v0_rf[i, vi1_1] = -1 + argmax_v1_rf[i, vi1_1] = T.float32(-3.4028234663852886e38) + v_argmax_v0_rf: T.int32 = T.Select( + argmax_v1_rf[i, vi1_1] >= val[i, vi1_0 * 16 + vi1_1], + argmax_v0_rf[i, vi1_1], + idx[i, vi1_0 * 16 + vi1_1], + ) + v_argmax_v1_rf: T.float32 = T.Select( + argmax_v1_rf[i, vi1_1] >= val[i, vi1_0 * 16 + vi1_1], + argmax_v1_rf[i, vi1_1], + val[i, vi1_0 * 16 + vi1_1], + ) + argmax_v0_rf[i, vi1_1] = v_argmax_v0_rf + argmax_v1_rf[i, vi1_1] = v_argmax_v1_rf + for i0, i1_1 in T.grid(128, 16): + with T.block("argmax"): + vi1_1, i = T.axis.remap("RS", [i1_1, i0]) + T.reads(argmax_v0_rf[i, vi1_1], argmax_v1_rf[i, vi1_1]) + T.writes(argmax_v0[i], argmax_v1[i]) + T.block_attr({"meta_schedule.random_compute_producer": 1}) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.float32(-3.4028234663852886e38) + v_argmax_v0: T.int32 = T.Select( + argmax_v1[i] >= argmax_v1_rf[i, vi1_1], argmax_v0[i], argmax_v0_rf[i, vi1_1] + ) + v_argmax_v1: T.float32 = T.Select( + argmax_v1[i] >= argmax_v1_rf[i, vi1_1], argmax_v1[i], argmax_v1_rf[i, vi1_1] + ) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + + @T.prim_func + def argmax_2( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[128, "int32"], + argmax_v1: T.Buffer[128, "float32"], + ) -> None: + # body + # with T.block("root") + argmax_v0_rf = T.alloc_buffer([128, 8], dtype="int32") + argmax_v1_rf = T.alloc_buffer([128, 8], dtype="float32") + for i0, i1_0, i1_1 in T.grid(128, 8, 16): + with T.block("argmax_rf"): + vi1_0, i, vi1_1 = T.axis.remap("SSR", [i1_0, i0, i1_1]) + T.reads(idx[i, vi1_0 * 16 + vi1_1], val[i, vi1_0 * 16 + vi1_1]) + T.writes(argmax_v0_rf[i, vi1_0], argmax_v1_rf[i, vi1_0]) + with T.init(): + argmax_v0_rf[i, vi1_0] = -1 + argmax_v1_rf[i, vi1_0] = T.float32(-3.4028234663852886e38) + v_argmax_v0_rf: T.int32 = T.Select( + argmax_v1_rf[i, vi1_0] >= val[i, vi1_0 * 16 + vi1_1], + argmax_v0_rf[i, vi1_0], + idx[i, vi1_0 * 16 + vi1_1], + ) + v_argmax_v1_rf: T.float32 = T.Select( + argmax_v1_rf[i, vi1_0] >= val[i, vi1_0 * 16 + vi1_1], + argmax_v1_rf[i, vi1_0], + val[i, vi1_0 * 16 + vi1_1], + ) + argmax_v0_rf[i, vi1_0] = v_argmax_v0_rf + argmax_v1_rf[i, vi1_0] = v_argmax_v1_rf + for i0, i1_0 in T.grid(128, 8): + with T.block("argmax"): + vi1_0, i = T.axis.remap("RS", [i1_0, i0]) + T.reads(argmax_v0_rf[i, vi1_0], argmax_v1_rf[i, vi1_0]) + T.writes(argmax_v0[i], argmax_v1[i]) + T.block_attr({"meta_schedule.random_compute_producer": 1}) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.float32(-3.4028234663852886e38) + v_argmax_v0: T.int32 = T.Select( + argmax_v1[i] >= argmax_v1_rf[i, vi1_0], argmax_v0[i], argmax_v0_rf[i, vi1_0] + ) + v_argmax_v1: T.float32 = T.Select( + argmax_v1[i] >= argmax_v1_rf[i, vi1_0], argmax_v1[i], argmax_v1_rf[i, vi1_0] + ) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + + decision_0 = [] # type: ignore + decision_1 = [ + ("SamplePerfectTile", [8, 16]), + ] + decision_2 = [ + ("SamplePerfectTile", [8, 16]), + ] + mod = argmax + actual = ms.TuneContext( + mod=mod, + target=Target("llvm --num-cores=32"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ms.schedule_rule.AddRFactor()], + task_name="test", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[argmax_0, argmax_1, argmax_2], + expected_decisions=[decision_0, decision_1, decision_2], + ) + + if __name__ == "__main__": test_cpu_matmul() + test_cpu_argmax() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py index a0ca47c09a34..ab8df6678b0b 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py @@ -572,7 +572,106 @@ def batch_norm_bmn_1(A: T.Buffer[(1, 512, 512), "float32"], D: T.Buffer[1, "floa ) +@T.prim_func +def argmax( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], +) -> None: + for i0, i1 in T.grid(128, 128): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.min_value("float32") + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + + +def test_gpu_argmax(): + @T.prim_func + def argmax_0( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[128, "int32"], + argmax_v1: T.Buffer[128, "float32"], + ) -> None: + # body + # with T.block("root") + for i0, i1 in T.grid(128, 128): + with T.block("argmax"): + i, k = T.axis.remap("SR", [i0, i1]) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.float32(-3.4028234663852886e38) + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select( + argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k] + ) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + + @T.prim_func + def argmax_1( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[128, "int32"], + argmax_v1: T.Buffer[128, "float32"], + ) -> None: + # body + # with T.block("root") + for i0, i1_0 in T.grid(128, 2): + for i1_1 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 64 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.float32(-3.4028234663852886e38) + v_argmax_v0: T.int32 = T.Select( + argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k] + ) + v_argmax_v1: T.float32 = T.Select( + argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k] + ) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + + decision_0 = [] # type: ignore + decision_1 = [ + ("SampleCategorical", 4), + ] + + mod = argmax + actual = ms.TuneContext( + mod=mod, + target=Target("nvidia/geforce-rtx-3090", host="llvm"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ + ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]) + ], + task_name="test", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[argmax_0, argmax_1], + expected_decisions=[decision_0, decision_1], + ) + + if __name__ == "__main__": test_gpu_softmax_mn() test_gpu_softmax_mn_after_inline() test_gpu_batch_norm_bmn() + test_gpu_argmax() diff --git a/tests/python/unittest/test_tir_schedule_rfactor.py b/tests/python/unittest/test_tir_schedule_rfactor.py index 4078b1e89682..f6db79f3ed23 100644 --- a/tests/python/unittest/test_tir_schedule_rfactor.py +++ b/tests/python/unittest/test_tir_schedule_rfactor.py @@ -29,9 +29,9 @@ @T.prim_func def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128, 128]) - C = T.match_buffer(c, [128, 128]) + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128, 128], dtype="float32") + C = T.match_buffer(c, [128, 128], dtype="float32") for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): with T.block("update"): @@ -44,12 +44,30 @@ def transformed_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + (A[vi, vk] * B[vj, vk]) +@T.prim_func +def transformed_matmul_with_let(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128, 128], dtype="float32") + C = T.match_buffer(c, [128, 128], dtype="float32") + + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): + with T.block("update"): + vi, vj = T.axis.remap("SS", [i0, i1]) + vk = T.axis.R(128, i2_outer * 32 + i2_inner_outer * 4 + i2_inner_inner) + T.reads([A[vi, vk], B[vj, vk]]) + T.writes([C[vi, vj]]) + with T.init(): + C[vi, vj] = 0.0 + v_C: T.float32 = C[vi, vj] + (A[vi, vk] * B[vj, vk]) + C[vi, vj] = v_C + + @T.prim_func def matmul_rfactor(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128, 128]) - C = T.match_buffer(c, [128, 128]) - C_rf = T.alloc_buffer([4, 128, 128]) + A = T.match_buffer(a, [128, 128], dtype="float32") + B = T.match_buffer(b, [128, 128], dtype="float32") + C = T.match_buffer(c, [128, 128], dtype="float32") + C_rf = T.alloc_buffer([4, 128, 128], dtype="float32") for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in T.grid(128, 128, 4, 8, 4): with T.block("update_rf"): @@ -436,6 +454,20 @@ def rowsum_wrong_reduce_pattern2(a: T.handle, b: T.handle) -> None: B[vi] = B[vi] - A[vi, vk] +@T.prim_func +def rowsum_init_not_bufferstore(a: T.handle, b: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128,)) + + for i, k in T.grid(128, 128): + with T.block("B"): + vi, vk = T.axis.remap("SR", [i, k]) + with T.init(): + v_init: T.float32 = T.float32(0) + B[vi] = v_init + B[vi] = B[vi] + A[vi, vk] + + @T.prim_func def rowsum_transformed(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, (128, 128)) @@ -654,6 +686,453 @@ def rfactor_spatial_only_after( B[ax0, ax1, ax2, ax3] = B[ax0, ax1, ax2, ax3] + B_rf[ax0, ax1, ax2, ax3, vi4] +@T.prim_func +def argmax_split( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], +) -> None: + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.min_value("float32") + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + + +@T.prim_func +def argmin_split_init_update_reordered( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmin_v0: T.Buffer[(128,), "int32"], + argmin_v1: T.Buffer[(128,), "float32"], +) -> None: + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmin"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmin_v0[i], argmin_v1[i]) + with T.init(): + argmin_v1[i] = T.max_value("float32") + argmin_v0[i] = -1 + v_argmin_v0: T.int32 = T.Select(argmin_v1[i] <= val[i, k], argmin_v0[i], idx[i, k]) + v_argmin_v1: T.float32 = T.Select(argmin_v1[i] <= val[i, k], argmin_v1[i], val[i, k]) + argmin_v1[i] = v_argmin_v1 + argmin_v0[i] = v_argmin_v0 + + +@T.prim_func +def argmax_split_different_shape( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(256,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], +) -> None: + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.min_value("float32") + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + + +@T.prim_func +def argmax_split_different_indices( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], +) -> None: + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i + 1] = T.min_value("float32") + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i + 1] = v_argmax_v1 + + +@T.prim_func +def argmax_split_init_not_bufferstore( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], +) -> None: + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + v1_init: T.float32 = T.min_value("float32") + argmax_v1[i] = v1_init + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + + +@T.prim_func +def argmax_split_init_buffer_duplicate( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], +) -> None: + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v0[i] = -1 + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + + +@T.prim_func +def argmax_split_letstmt_fewer_than_init( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], +) -> None: + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.min_value("float32") + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + + +@T.prim_func +def argmax_split_letstmt_more_than_init( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], +) -> None: + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + + +@T.prim_func +def argmax_split_let_body_neither_seqstmt_nor_bufferstore( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], +) -> None: + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.min_value("float32") + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + T.evaluate(0) + + +@T.prim_func +def argmax_split_init_update_inconsistent_bufferstore_number( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], +) -> None: + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.min_value("float32") + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + argmax_v1[i] = v_argmax_v1 + + +@T.prim_func +def argmax_split_body_seq_not_bufferstore( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], +) -> None: + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.min_value("float32") + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + argmax_v0[i] = v_argmax_v0 + T.evaluate(0) + + +@T.prim_func +def argmax_split_body_bufferstore_value_not_var( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], +) -> None: + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.min_value("float32") + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + argmax_v0[i] = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + argmax_v1[i] = v_argmax_v1 + + +@T.prim_func +def argmax_split_body_bufferstore_value_unbound_var( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], +) -> None: + v_unbound = T.var("int32") + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.min_value("float32") + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + argmax_v0[i] = v_unbound + argmax_v1[i] = v_argmax_v1 + + +@T.prim_func +def argmax_split_one_let_var_used_multi_times( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "int32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "int32"], +) -> None: + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.min_value("int32") + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v0 + + +@T.prim_func +def argmax_split_body_one_buffer_updated_multi_times( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "int32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "int32"], +) -> None: + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.min_value("int32") + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + argmax_v0[i] = v_argmax_v0 + argmax_v0[i] = v_argmax_v1 + + +@T.prim_func +def argmax_split_init_buffer_not_match( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v0_1: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], +) -> None: + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v0_1[i], argmax_v1[i]) + with T.init(): + argmax_v0_1[i] = -1 + argmax_v1[i] = T.min_value("float32") + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k]) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + + +@T.prim_func +def argmax_split_rfactor( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], +) -> None: + argmax_v0_rf = T.alloc_buffer([128, 32], dtype="int32") + argmax_v1_rf = T.alloc_buffer([128, 32], dtype="float32") + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmax_rf"): + vi1_1, i, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0]) + T.reads(idx[i, vi1_0 * 32 + vi1_1], val[i, vi1_0 * 32 + vi1_1]) + T.writes(argmax_v0_rf[i, vi1_1], argmax_v1_rf[i, vi1_1]) + with T.init(): + argmax_v0_rf[i, vi1_1] = -1 + argmax_v1_rf[i, vi1_1] = T.min_value("float32") + v_argmax_v0_rf: T.int32 = T.Select( + argmax_v1_rf[i, vi1_1] >= val[i, vi1_0 * 32 + vi1_1], + argmax_v0_rf[i, vi1_1], + idx[i, vi1_0 * 32 + vi1_1], + ) + v_argmax_v1_rf: T.float32 = T.Select( + argmax_v1_rf[i, vi1_1] >= val[i, vi1_0 * 32 + vi1_1], + argmax_v1_rf[i, vi1_1], + val[i, vi1_0 * 32 + vi1_1], + ) + argmax_v0_rf[i, vi1_1] = v_argmax_v0_rf + argmax_v1_rf[i, vi1_1] = v_argmax_v1_rf + for i0, i1_1 in T.grid(128, 32): + with T.block("argmax"): + vi1_1, i = T.axis.remap("RS", [i1_1, i0]) + T.reads(argmax_v0_rf[i, vi1_1], argmax_v1_rf[i, vi1_1]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.min_value("float32") + v_argmax_v0: T.int32 = T.Select( + argmax_v1[i] >= argmax_v1_rf[i, vi1_1], argmax_v0[i], argmax_v0_rf[i, vi1_1] + ) + v_argmax_v1: T.float32 = T.Select( + argmax_v1[i] >= argmax_v1_rf[i, vi1_1], argmax_v1[i], argmax_v1_rf[i, vi1_1] + ) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + + +@T.prim_func +def argmin_split_rfactor( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmin_v0: T.Buffer[(128,), "int32"], + argmin_v1: T.Buffer[(128,), "float32"], +) -> None: + argmin_v0_rf = T.alloc_buffer([128, 32], dtype="int32") + argmin_v1_rf = T.alloc_buffer([128, 32], dtype="float32") + for i0, i1_0, i1_1 in T.grid(128, 4, 32): + with T.block("argmin_rf"): + vi1_1, i, vi1_0 = T.axis.remap("SSR", [i1_1, i0, i1_0]) + T.reads(idx[i, vi1_0 * 32 + vi1_1], val[i, vi1_0 * 32 + vi1_1]) + T.writes(argmin_v0_rf[i, vi1_1], argmin_v1_rf[i, vi1_1]) + with T.init(): + argmin_v0_rf[i, vi1_1] = -1 + argmin_v1_rf[i, vi1_1] = T.max_value("float32") + v_argmin_v0_rf: T.int32 = T.Select( + argmin_v1_rf[i, vi1_1] <= val[i, vi1_0 * 32 + vi1_1], + argmin_v0_rf[i, vi1_1], + idx[i, vi1_0 * 32 + vi1_1], + ) + v_argmin_v1_rf: T.float32 = T.Select( + argmin_v1_rf[i, vi1_1] <= val[i, vi1_0 * 32 + vi1_1], + argmin_v1_rf[i, vi1_1], + val[i, vi1_0 * 32 + vi1_1], + ) + argmin_v0_rf[i, vi1_1] = v_argmin_v0_rf + argmin_v1_rf[i, vi1_1] = v_argmin_v1_rf + for i0, i1_1 in T.grid(128, 32): + with T.block("argmin"): + vi1_1, i = T.axis.remap("RS", [i1_1, i0]) + T.reads(argmin_v0_rf[i, vi1_1], argmin_v1_rf[i, vi1_1]) + T.writes(argmin_v0[i], argmin_v1[i]) + with T.init(): + argmin_v0[i] = -1 + argmin_v1[i] = T.max_value("float32") + v_argmin_v0: T.int32 = T.Select( + argmin_v1[i] <= argmin_v1_rf[i, vi1_1], argmin_v0[i], argmin_v0_rf[i, vi1_1] + ) + v_argmin_v1: T.float32 = T.Select( + argmin_v1[i] <= argmin_v1_rf[i, vi1_1], argmin_v1[i], argmin_v1_rf[i, vi1_1] + ) + argmin_v0[i] = v_argmin_v0 + argmin_v1[i] = v_argmin_v1 + + # pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg @@ -668,6 +1147,17 @@ def test_reduction_rfactor_matmul(): verify_trace_roundtrip(s, mod=transformed_matmul) +def test_reduction_rfactor_matmul_with_let(): + s = tir.Schedule(transformed_matmul_with_let, debug_mask="all") + update = s.get_block("update") + _, _, _, _, kii = s.get_loops(update) + rf_block = s.rfactor(kii, 0) + tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) + assert s.get(update).same_as(s.get(s.get_block("update"))) + verify_trace_roundtrip(s, mod=transformed_matmul_with_let) + + def test_reduction_rfactor_square_sum(): s = tir.Schedule(square_sum, debug_mask="all") C = s.get_block("C") @@ -773,6 +1263,13 @@ def test_reduction_rfactor_wrong_reduce_pattern2(): s.rfactor(k, 0) +def test_reduction_rfactor_init_not_bufferstore(): + s = tir.Schedule(rowsum_init_not_bufferstore, debug_mask="all") + _, k = s.get_loops(s.get_block("B")) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(k, 0) + + def test_reduction_rfactor_wrong_loops1(): s = tir.Schedule(rowsum, debug_mask="all") i, _ = s.get_loops(s.get_block("B")) @@ -852,10 +1349,146 @@ def test_reduction_rfactor_spatial_only(): s = tir.Schedule(rfactor_spatial_only, debug_mask="all") block = s.get_block(name="acc", func_name="main") _, _, _, _, loop, _ = s.get_loops(block) - s.rfactor(loop=loop, factor_axis=4) + rf_block = s.rfactor(loop=loop, factor_axis=4) tvm.ir.assert_structural_equal(s.mod["main"], rfactor_spatial_only_after) + assert s.get(rf_block).same_as(s.get(s.get_block("acc_rf"))) + assert s.get(block).same_as(s.get(s.get_block("acc"))) verify_trace_roundtrip(s, mod=rfactor_spatial_only) +def test_reduction_rfactor_argmax(): + s = tir.Schedule(argmax_split, debug_mask="all") + argmax = s.get_block("argmax") + _, _, ki = s.get_loops(argmax) + rf_block = s.rfactor(ki, 1) + tvm.ir.assert_structural_equal(s.mod["main"], argmax_split_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("argmax_rf"))) + assert s.get(argmax).same_as(s.get(s.get_block("argmax"))) + verify_trace_roundtrip(s, mod=argmax_split) + + +def test_reduction_rfactor_argmin_init_update_reordeded(): + s = tir.Schedule(argmin_split_init_update_reordered, debug_mask="all") + argmin = s.get_block("argmin") + _, _, ki = s.get_loops(argmin) + rf_block = s.rfactor(ki, 1) + tvm.ir.assert_structural_equal(s.mod["main"], argmin_split_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("argmin_rf"))) + assert s.get(argmin).same_as(s.get(s.get_block("argmin"))) + verify_trace_roundtrip(s, mod=argmin_split_init_update_reordered) + + +def test_reduction_rfactor_argmax_reduction_buffer_different_shape(): + s = tir.Schedule(argmax_split_different_shape, debug_mask="all") + argmax = s.get_block("argmax") + _, _, ki = s.get_loops(argmax) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(ki, 1) + + +def test_reduction_rfactor_argmax_different_access_indices(): + s = tir.Schedule(argmax_split_different_indices, debug_mask="all") + argmax = s.get_block("argmax") + _, _, ki = s.get_loops(argmax) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(ki, 1) + + +def test_reduction_rfactor_argmax_init_not_bufferstore(): + s = tir.Schedule(argmax_split_init_not_bufferstore, debug_mask="all") + argmax = s.get_block("argmax") + _, _, ki = s.get_loops(argmax) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(ki, 1) + + +def test_reduction_rfactor_argmax_init_buffer_duplicate(): + s = tir.Schedule(argmax_split_init_buffer_duplicate, debug_mask="all") + argmax = s.get_block("argmax") + _, _, ki = s.get_loops(argmax) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(ki, 1) + + +def test_reduction_rfactor_argmax_letstmt_fewer_than_init(): + s = tir.Schedule(argmax_split_letstmt_fewer_than_init, debug_mask="all") + argmax = s.get_block("argmax") + _, _, ki = s.get_loops(argmax) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(ki, 1) + + +def test_reduction_rfactor_argmax_letstmt_more_than_init(): + s = tir.Schedule(argmax_split_letstmt_more_than_init, debug_mask="all") + argmax = s.get_block("argmax") + _, _, ki = s.get_loops(argmax) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(ki, 1) + + +def test_reduction_rfactor_argmax_let_body_neither_seqstmt_nor_bufferstore(): + s = tir.Schedule(argmax_split_let_body_neither_seqstmt_nor_bufferstore, debug_mask="all") + argmax = s.get_block("argmax") + _, _, ki = s.get_loops(argmax) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(ki, 1) + + +def test_reduction_rfactor_argmax_init_update_inconsistent_bufferstore_number(): + s = tir.Schedule(argmax_split_init_update_inconsistent_bufferstore_number, debug_mask="all") + argmax = s.get_block("argmax") + _, _, ki = s.get_loops(argmax) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(ki, 1) + + +def test_reduction_rfactor_argmax_body_seq_not_bufferstore(): + s = tir.Schedule(argmax_split_body_seq_not_bufferstore, debug_mask="all") + argmax = s.get_block("argmax") + _, _, ki = s.get_loops(argmax) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(ki, 1) + + +def test_reduction_rfactor_argmax_body_bufferstore_value_not_var(): + s = tir.Schedule(argmax_split_body_bufferstore_value_not_var, debug_mask="all") + argmax = s.get_block("argmax") + _, _, ki = s.get_loops(argmax) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(ki, 1) + + +def test_reduction_rfactor_argmax_body_bufferstore_value_unbound_var(): + s = tir.Schedule(argmax_split_body_bufferstore_value_unbound_var, debug_mask="all") + argmax = s.get_block("argmax") + _, _, ki = s.get_loops(argmax) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(ki, 1) + + +def test_reduction_rfactor_argmax_one_let_var_used_multi_times(): + s = tir.Schedule(argmax_split_one_let_var_used_multi_times, debug_mask="all") + argmax = s.get_block("argmax") + _, _, ki = s.get_loops(argmax) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(ki, 1) + + +def test_reduction_rfactor_argmax_body_one_buffer_updated_multi_times(): + s = tir.Schedule(argmax_split_body_one_buffer_updated_multi_times, debug_mask="all") + argmax = s.get_block("argmax") + _, _, ki = s.get_loops(argmax) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(ki, 1) + + +def test_reduction_rfactor_argmax_init_buffer_not_match(): + s = tir.Schedule(argmax_split_init_buffer_not_match, debug_mask="all") + argmax = s.get_block("argmax") + _, _, ki = s.get_loops(argmax) + with pytest.raises(tvm.tir.ScheduleError): + s.rfactor(ki, 1) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py index 9b5937ac6efd..ff1353d2265e 100644 --- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring import sys import pytest @@ -22,6 +23,8 @@ from tvm import te from tvm.script import tir as T +# pylint: disable=no-member,invalid-name,unused-variable,unexpected-keyword-arg + def _check(original, transformed): mod = tvm.IRModule.from_expr(original) @@ -44,7 +47,7 @@ def loop_split(a: T.handle, b: T.handle) -> None: with T.block("B"): vi = T.axis.S(128, i) vk = T.axis.R(128, ko * 32 + ki) - T.reads([B[vi], A[vi, vk]]) + T.reads([A[vi, vk]]) T.writes([B[vi]]) with T.init(): B[vi] = T.float32(0) @@ -67,7 +70,7 @@ def lowered_loop_split(a: T.handle, b: T.handle) -> None: with T.block("B_normal_reduction"): vi = T.axis.S(128, i) vk = T.axis.R(128, ko * 32 + ki) - T.reads([A[vi, vk], normal_reduce_temp0[0]]) + T.reads([A[vi, vk]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk] with T.block("B_cross_thread_reduction"): @@ -103,7 +106,7 @@ def no_normal_reduction(a: T.handle, b: T.handle) -> None: for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B"): vi, vk = T.axis.remap("SR", [i, k]) - T.reads([B[vi], A[vi, vk]]) + T.reads([A[vi, vk]]) T.writes([B[vi]]) with T.init(): B[vi] = T.float32(0) @@ -148,7 +151,7 @@ def two_bound_loops(a: T.handle, b: T.handle) -> None: with T.block("B"): vi = T.axis.spatial(128, i) vk = T.axis.reduce(128, ko * 32 + ki) - T.reads([B[vi], A[vi, vk]]) + T.reads([A[vi, vk]]) T.writes([B[vi]]) with T.init(): B[vi] = T.float32(0) @@ -196,7 +199,7 @@ def multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> None: with T.block("B_rf"): vk0 = T.axis.spatial(16, k0o * 4 + k0i0) vi, vk1 = T.axis.remap("SR", [i, k1]) - T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]]) + T.reads([A[vi, vk0, vk1]]) T.writes([B_rf_local[vk0, vi]]) with T.init(): B_rf_local[vk0, vi] = T.float32(0) @@ -205,7 +208,7 @@ def multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> None: with T.block("B"): vk0 = T.axis.reduce(16, k0o * 4 + k0i1) vi = T.axis.spatial(16, i) - T.reads([B[vi], B_rf_local[vk0, vi]]) + T.reads([B_rf_local[vk0, vi]]) T.writes([B[vi]]) with T.init(): B[vi] = T.float32(0) @@ -229,7 +232,7 @@ def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> No with T.block("B_rf"): vk0 = T.axis.spatial(16, k0o * 4 + k0i0) vi, vk1 = T.axis.remap("SR", [i, k1]) - T.reads([B_rf_local[vk0, vi], A[vi, vk0, vk1]]) + T.reads([A[vi, vk0, vk1]]) T.writes([B_rf_local[vk0, vi]]) with T.init(): B_rf_local[vk0, vi] = T.float32(0) @@ -238,7 +241,7 @@ def lowered_multiple_blocks_under_reduction_loop(a: T.handle, b: T.handle) -> No with T.block("B_normal_reduction"): vk0 = T.axis.reduce(16, k0o * 4 + k0i1) vi = T.axis.spatial(16, i) - T.reads([B_rf_local[vk0, vi], normal_reduce_temp0[0]]) + T.reads([B_rf_local[vk0, vi]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = normal_reduce_temp0[0] + B_rf_local[vk0, vi] with T.block("B_cross_thread_reduction"): @@ -276,7 +279,7 @@ def with_block_predicate(a: T.handle, b: T.handle) -> None: vi = T.axis.spatial(128, i) vk = T.axis.reduce(120, ko * 32 + ki) T.where(ko * 32 + ki < 120) - T.reads([B[vi], A[vi, vk]]) + T.reads([A[vi, vk]]) T.writes([B[vi]]) with T.init(): B[vi] = T.float32(0) @@ -300,7 +303,7 @@ def lowered_with_block_predicate(a: T.handle, b: T.handle) -> None: vi = T.axis.spatial(128, i) vk = T.axis.reduce(120, ko * 32 + ki) T.where(ko * 32 + ki < 120) - T.reads([A[vi, vk], normal_reduce_temp0[0]]) + T.reads([A[vi, vk]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = normal_reduce_temp0[0] + A[vi, vk] with T.block("B_cross_thread_reduction"): @@ -341,7 +344,7 @@ def single_reduction_loop_with_block_predicate( i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) - T.reads(T_softmax_maxelem_shared[i0_1], A[i0_1, k]) + T.reads(A[i0_1, k]) T.writes(T_softmax_maxelem_shared[i0_1]) with T.init(): T_softmax_maxelem_shared[i0_1] = T.float32(-3.4028234663852886e38) @@ -354,9 +357,7 @@ def single_reduction_loop_with_block_predicate( i0_2 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) - T.reads( - T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2] - ) + T.reads(A[i0_2, k], T_softmax_maxelem_shared[i0_2]) T.writes(T_softmax_expsum_shared[i0_2]) with T.init(): T_softmax_expsum_shared[i0_2] = T.float32(0) @@ -401,7 +402,7 @@ def lowered_single_reduction_loop_with_block_predicate( i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) - T.reads(A[i0_1, k], in_thread_0[0]) + T.reads(A[i0_1, k]) T.writes(in_thread_0[0]) in_thread_0[0] = T.max(in_thread_0[0], A[i0_1, k]) with T.block("T_softmax_maxelem_cross_thread"): @@ -439,7 +440,7 @@ def lowered_single_reduction_loop_with_block_predicate( i0_3 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax1_1) T.where(ax1_0 * 512 + ax1_1 < 256) - T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3], in_thread_1[0]) + T.reads(A[i0_3, k], T_softmax_maxelem_shared[i0_3]) T.writes(in_thread_1[0]) in_thread_1[0] = in_thread_1[0] + T.exp( A[i0_3, k] - T_softmax_maxelem_shared[i0_3], dtype="float32" @@ -492,7 +493,7 @@ def reducer_max(a: T.handle, b: T.handle) -> None: for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B"): vi, vk = T.axis.remap("SR", [i, k]) - T.reads([B[vi], A[vi, vk]]) + T.reads([A[vi, vk]]) T.writes([B[vi]]) with T.init(): B[vi] = T.min_value("float32") @@ -534,7 +535,7 @@ def zero_rank_buffer(a: T.handle, b: T.handle) -> None: for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B"): vk = T.axis.reduce(128, k) - T.reads([B[()], A[vk]]) + T.reads([A[vk]]) T.writes([B[()]]) with T.init(): B[()] = T.float32(0) @@ -590,7 +591,7 @@ def reduction_loop_not_deepest(a: T.handle, b: T.handle) -> None: for i in T.serial(0, 128): with T.block("B"): vi, vk = T.axis.remap("SR", [i, k]) - T.reads([B[vi], A[vi, vk]]) + T.reads([A[vi, vk]]) T.writes([B[vi]]) with T.init(): B[vi] = T.float32(0) @@ -605,7 +606,7 @@ def reduction_loop_bound_to_blockidx(a: T.handle, b: T.handle) -> None: for k in T.thread_binding(0, 128, thread="blockIdx.x"): with T.block("B"): vi, vk = T.axis.remap("SR", [i, k]) - T.reads([B[vi], A[vi, vk]]) + T.reads([A[vi, vk]]) T.writes([B[vi]]) with T.init(): B[vi] = T.float32(0) @@ -620,7 +621,7 @@ def different_access_indices(a: T.handle, b: T.handle) -> None: for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B"): vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - T.reads([B[vi, vj], A[vi, vj, vk]]) + T.reads([A[vi, vj, vk]]) T.writes( [ B[ @@ -642,7 +643,7 @@ def invalid_reducer(a: T.handle, b: T.handle) -> None: for k in T.thread_binding(0, 128, thread="threadIdx.x"): with T.block("B"): vi, vk = T.axis.remap("SR", [i, k]) - T.reads([B[vi], A[vi, vk]]) + T.reads([A[vi, vk]]) T.writes([B[vi]]) with T.init(): B[vi] = T.float32(0) @@ -661,7 +662,7 @@ def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: with T.block("T_softmax_maxelem"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) - T.reads([T_softmax_maxelem_shared[i0_1], A[i0_1, k]]) + T.reads([A[i0_1, k]]) T.writes([T_softmax_maxelem_shared[i0_1]]) with T.init(): T_softmax_maxelem_shared[i0_1] = T.min_value("float32") @@ -675,7 +676,6 @@ def softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) T.reads( [ - T_softmax_expsum_shared[i0_2], A[i0_2, k], T_softmax_maxelem_shared[i0_2], ] @@ -729,7 +729,7 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: with T.block("T_softmax_maxelem_normal_reduction"): i0_1 = T.axis.spatial(256, i0) k = T.axis.reduce(256, ax0_0 * 32 + ax0_1) - T.reads([A[i0_1, k], normal_reduce_temp0[0]]) + T.reads([A[i0_1, k]]) T.writes([normal_reduce_temp0[0]]) normal_reduce_temp0[0] = T.max(normal_reduce_temp0[0], A[i0_1, k]) with T.block("T_softmax_maxelem_cross_thread_reduction"): @@ -768,7 +768,6 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: [ A[i0_3, k], T_softmax_maxelem_shared[i0_3], - normal_reduce_temp1[0], ] ) T.writes([normal_reduce_temp1[0]]) @@ -821,6 +820,191 @@ def lowered_softmax(var_A: T.handle, var_T_softmax_norm: T.handle) -> None: ) +@T.prim_func +def argmax_split( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], +) -> None: + for i0, i1_0 in T.grid(128, 4): + for i1_1 in T.thread_binding(32, thread="threadIdx.x"): + with T.block("argmax"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmax_v0[i], argmax_v1[i]) + with T.init(): + argmax_v0[i] = -1 + argmax_v1[i] = T.float32(-3.4028234663852886e38) + v_argmax_v0: T.int32 = T.Select(argmax_v1[i] >= val[i, k], argmax_v0[i], idx[i, k]) + v_argmax_v1: T.float32 = T.Select( + argmax_v1[i] >= val[i, k], argmax_v1[i], val[i, k] + ) + argmax_v0[i] = v_argmax_v0 + argmax_v1[i] = v_argmax_v1 + + +@T.prim_func +def lowered_argmax_split( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmax_v0: T.Buffer[(128,), "int32"], + argmax_v1: T.Buffer[(128,), "float32"], +) -> None: + cross_thread_argmax_v0 = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local") + cross_thread_argmax_v1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + in_thread_argmax_v0 = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local") + in_thread_argmax_v1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i0 in T.serial(128): + for i1_1 in T.thread_binding(32, thread="threadIdx.x"): + with T.block("argmax_in_thread_init"): + T.reads() + T.writes(in_thread_argmax_v0[0], in_thread_argmax_v1[0]) + in_thread_argmax_v0[0] = -1 + in_thread_argmax_v1[0] = T.float32(-3.4028234663852886e38) + for i1_0 in T.serial(4): + with T.block("argmax_in_thread"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(in_thread_argmax_v0[0], in_thread_argmax_v1[0]) + v_argmax_v0: T.int32 = T.Select( + in_thread_argmax_v1[0] >= val[i, k], in_thread_argmax_v0[0], idx[i, k] + ) + v_argmax_v1: T.float32 = T.Select( + in_thread_argmax_v1[0] >= val[i, k], in_thread_argmax_v1[0], val[i, k] + ) + in_thread_argmax_v0[0] = v_argmax_v0 + in_thread_argmax_v1[0] = v_argmax_v1 + with T.block("argmax_cross_thread"): + T.reads(in_thread_argmax_v0[0], in_thread_argmax_v1[0]) + T.writes(cross_thread_argmax_v0[0], cross_thread_argmax_v1[0]) + T.attr( + T.comm_reducer( + lambda x0, x1, y0, y1: ( + T.Select(x1 >= y1, x0, y0), + T.Select(x1 >= y1, x1, y1), + ), + [-1, T.float32(-3.4028234663852886e38)], + ), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(2), + in_thread_argmax_v0[0], + in_thread_argmax_v1[0], + True, + cross_thread_argmax_v0[0], + cross_thread_argmax_v1[0], + i1_1, + dtype="handle", + ) + ) + with T.block("argmax_write_back"): + i = T.axis.spatial(128, i0) + T.reads(cross_thread_argmax_v0[0], cross_thread_argmax_v1[0]) + T.writes(argmax_v0[i], argmax_v1[i]) + argmax_v0[i] = cross_thread_argmax_v0[0] + argmax_v1[i] = cross_thread_argmax_v1[0] + + +@T.prim_func +def argmin_split_init_update_reordered( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmin_v0: T.Buffer[(128,), "int32"], + argmin_v1: T.Buffer[(128,), "float32"], +) -> None: + for i0, i1_0 in T.grid(128, 4): + for i1_1 in T.thread_binding(32, thread="threadIdx.x"): + with T.block("argmin"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(argmin_v0[i], argmin_v1[i]) + with T.init(): + argmin_v1[i] = T.float32(3.4028234663852886e38) + argmin_v0[i] = -1 + v_argmin_v0: T.int32 = T.Select(argmin_v1[i] <= val[i, k], argmin_v0[i], idx[i, k]) + v_argmin_v1: T.float32 = T.Select( + argmin_v1[i] <= val[i, k], argmin_v1[i], val[i, k] + ) + argmin_v1[i] = v_argmin_v1 + argmin_v0[i] = v_argmin_v0 + + +@T.prim_func +def lowered_argmin_split_init_update_reordered( + idx: T.Buffer[(128, 128), "int32"], + val: T.Buffer[(128, 128), "float32"], + argmin_v0: T.Buffer[(128,), "int32"], + argmin_v1: T.Buffer[(128,), "float32"], +) -> None: + cross_thread_argmin_v0 = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local") + cross_thread_argmin_v1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + in_thread_argmin_v0 = T.alloc_buffer([1], dtype="int32", strides=[1], scope="local") + in_thread_argmin_v1 = T.alloc_buffer([1], dtype="float32", strides=[1], scope="local") + for i0 in T.serial(128): + for i1_1 in T.thread_binding(32, thread="threadIdx.x"): + with T.block("argmin_in_thread_init"): + T.reads() + T.writes(in_thread_argmin_v0[0], in_thread_argmin_v1[0]) + in_thread_argmin_v0[0] = -1 + in_thread_argmin_v1[0] = T.float32(3.4028234663852886e38) + for i1_0 in T.serial(4): + with T.block("argmin_in_thread"): + i = T.axis.spatial(128, i0) + k = T.axis.reduce(128, i1_0 * 32 + i1_1) + T.reads(idx[i, k], val[i, k]) + T.writes(in_thread_argmin_v0[0], in_thread_argmin_v1[0]) + v_argmin_v0: T.int32 = T.Select( + in_thread_argmin_v1[0] <= val[i, k], in_thread_argmin_v0[0], idx[i, k] + ) + v_argmin_v1: T.float32 = T.Select( + in_thread_argmin_v1[0] <= val[i, k], in_thread_argmin_v1[0], val[i, k] + ) + in_thread_argmin_v1[0] = v_argmin_v1 + in_thread_argmin_v0[0] = v_argmin_v0 + with T.block("argmin_cross_thread"): + T.reads(in_thread_argmin_v0[0], in_thread_argmin_v1[0]) + T.writes(cross_thread_argmin_v0[0], cross_thread_argmin_v1[0]) + T.attr( + T.comm_reducer( + lambda x0, x1, y0, y1: ( + T.Select(x1 <= y1, x0, y0), + T.Select(x1 <= y1, x1, y1), + ), + [-1, T.float32(3.4028234663852886e38)], + ), + "reduce_scope", + T.reinterpret(T.uint64(0), dtype="handle"), + ) + T.evaluate( + T.tvm_thread_allreduce( + T.uint32(2), + in_thread_argmin_v0[0], + in_thread_argmin_v1[0], + True, + cross_thread_argmin_v0[0], + cross_thread_argmin_v1[0], + i1_1, + dtype="handle", + ) + ) + with T.block("argmin_write_back"): + i = T.axis.spatial(128, i0) + T.reads(cross_thread_argmin_v0[0], cross_thread_argmin_v1[0]) + T.writes(argmin_v0[i], argmin_v1[i]) + argmin_v0[i] = cross_thread_argmin_v0[0] + argmin_v1[i] = cross_thread_argmin_v1[0] + + +# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg + + def test_loop_split(): _check(loop_split, lowered_loop_split) @@ -880,6 +1064,14 @@ def test_softmax(): _check(softmax, lowered_softmax) +def test_argmax_split(): + _check(argmax_split, lowered_argmax_split) + + +def test_argmin_split_init_update_reordered(): + _check(argmin_split_init_update_reordered, lowered_argmin_split_init_update_reordered) + + def test_lower_te(): a = te.placeholder((32, 2, 2)) k1 = te.reduce_axis((0, 2), "k1")