Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][Meta-Schedule] Tuple-reduction scheduling support #11639

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/meta_schedule/schedule_rule/cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,13 @@ class CrossThreadReductionNode : public ScheduleRuleNode {
*/
std::tuple<bool, tir::LoopRV, tir::BlockRV, tir::LoopRV> GetComputeTargetLoopAndBlock(
const tir::Schedule& sch, const tir::BlockRV& block_rv) {
// Step 0. Due to technical reason of some primitives (e.g., compute-at), if the block is doing
// a tuple reduction, fusion is temporarily not supported.
junrushao marked this conversation as resolved.
Show resolved Hide resolved
if (sch->Get(block_rv)->writes.size() != 1) {
return std::make_tuple(false, tir::LoopRV{nullptr}, tir::BlockRV{nullptr},
tir::LoopRV{nullptr});
}

// Step 1. Get all the consumers of the input block.
Array<tir::BlockRV> consumers = sch->GetConsumers(block_rv);

Expand Down
48 changes: 24 additions & 24 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,15 +455,14 @@ std::pair<Optional<StmtSRef>, bool> GetBufferDefiningSite(const StmtSRef& block_
/******** Reduction Block Related ********/

/*!
* \brief Convert the `init` and `body` of the input block to BufferStores
* \param self The schedule state
* \param block The block to be analyzed
* \return The BufferStores of the `init` and `body` of the input block
* \throw ScheduleError If the `init` or `body` is not BufferStore, or they don't write to the same
* buffer
* \brief Get the init values and the BufferStore updates from the input reduction block
* \param self The schedule state, used for error reporting
* \param block The block from which the init values and BufferStore updates are extracted from
* \return The extracted init values and BufferStore updates
* \throw ScheduleError If rfactor or cross-thread reduction cannot be applied to the block
*/
std::pair<BufferStore, BufferStore> GetBufferStoresFromReductionBlock(
const Optional<ScheduleState>& self, const Block& block);
std::pair<Array<PrimExpr>, Array<BufferStore>> GetInitValuesAndUpdatesFromReductionBlock(
const Optional<ScheduleState>& self, Block block);

/*!
* \brief Check whether the input array of IterVars only contains data-parallel and reduction block
Expand All @@ -484,16 +483,17 @@ bool ContainsOnlyDataParAndReductionBlockIter(const Array<IterVar>& iters);
bool ReductionIterNotIndexOutputBuffer(const Block& block);

/*!
* \brief Given a reduction identity and a reduction combiner, detect the corresponding commutative
* reducer, and extract the combiner lhs and combiner rhs
* \brief Given a list of reduction identities and a list of reduction combiners, detect the
* corresponding commutative reducer, and extract the combiner LHS values and combiner RHS values
* \param self The schedule state
* \param identity The reduction identity to be analyzed
* \param combiner The reduction combiner to be analyzed
* \return The corresponding CommReducer, the combiner lhs and the combiner rhs
* \param identities The reduction identities to be analyzed
* \param combiners The reduction combiners to be analyzed
* \return The corresponding CommReducer, combiner LHS values and combiner RHS values
* \throw ScheduleError If no corresponding commutative reducer can be matched
*/
std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
const Optional<ScheduleState>& self, const PrimExpr& identity, const BufferStore& combiner);
std::tuple<CommReducer, Array<PrimExpr>, Array<PrimExpr>> GetReducerAndCombinerLhsRhs(
const Optional<ScheduleState>& self, const Array<PrimExpr>& identities,
const Array<BufferStore>& combiners);

/******** Commutative Reducer ********/

Expand All @@ -502,20 +502,20 @@ std::tuple<CommReducer, PrimExpr, PrimExpr> GetReducerAndCombinerLhsRhs(
* \return The list of the registered reducer-getter functions
* \sa ReducerRegistry
*/
std::vector<runtime::TypedPackedFunc<CommReducer(DataType)>> GetReducerGetters();
std::vector<runtime::TypedPackedFunc<Optional<CommReducer>(Array<PrimExpr>)>> GetReducerGetters();

/*!
* \brief Given the input identity and the combiner BufferStore of a reduction, extract the
* corresponding commutative reducer and its lhs, rhs if possible.
* \param identity The identity of the reduction
* \param combiner The combiner of the reduction
* \brief Given the input identities and the combiner BufferStores of a reduction, extract the
* corresponding commutative reducer, LHS values and RHS values, if possible.
* \param identities The identities of the reduction
* \param combiners The combiners of the reduction
* \param result_reducer The extracted CommReducer
* \param lhs The extracted lhs of the reducer
* \param rhs The extracted rhs of the reducer
* \param lhs The extracted LHS values of the reducer
* \param rhs The extracted RHS values of the reducer
* \return A boolean indicating whether a corresponding commutative reducer is found
*/
bool FromIdentityCombiner(const PrimExpr& identity, const BufferStore& combiner,
CommReducer* result_reducer, PrimExpr* lhs, PrimExpr* rhs);
bool FromIdentityCombiner(const Array<PrimExpr>& identities, const Array<BufferStore>& combiners,
CommReducer* result_reducer, Array<PrimExpr>* lhs, Array<PrimExpr>* rhs);

/******** Misc ********/

Expand Down
Loading