diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 324eedafb98a1..bda088bc5f1e8 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -87,6 +87,21 @@ void MultiLevelTilingNode::InitializeWithTuneContext(const TuneContext& context) TVM_PY_LOG(INFO, context->logger) << "'thread_warp_size' is not defined in the target"; } } + if (Optional opt_sm = context->target.value()->GetAttr("arch")) { + std::string sm = opt_sm.value(); + if (support::StartsWith(sm, "sm_")) { + sm = sm.substr(3); + try { + // only sm_80 or higher supports async memcopy + if (std::stoi(sm) >= 80) { + this->stages.insert(this->stages.end(), {4, 5}); + } + } catch (const std::invalid_argument& e) { + LOG(WARNING) << "ValueError: Unable to parse `target.arch`: " << sm + << ". Details: " << e.what(); + } + } + } logger = context->logger; } @@ -115,6 +130,9 @@ std::vector MultiLevelTilingNode::ApplySubRules(std::vector states states = SubRule(std::move(states), [&](State state) { return TileLoopNest(std::move(state)); }); states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(std::move(state)); }); states = SubRule(std::move(states), [&](State state) { return AddReadReuse(std::move(state)); }); + states = SubRule(std::move(states), [&](State state) { + return AddAsyncPipeline(std::move(state)); + }); return states; } @@ -280,6 +298,43 @@ std::vector MultiLevelTilingNode::AddReadReuse(State state) const { return results; } +std::vector MultiLevelTilingNode::AddAsyncPipeline(State state) const { + // For arch that does not support async pipeline, this->stages will be an empty vector + if (r_indices_.size() < 1 || this->stages.empty()) { + return {state}; + } + // Current only support default config used by ScheduleRule::DefaultCUDA + // @see src/meta_schedule/schedule_rule/schedule_rule.cc + // check the reduce loop contains exactly 3 for loops + // therefore it matches the notation array size in the following code + tir::StmtSRef r_loop_sref = state->sch->GetSRef(state->tiles[r_indices_[0]].back()); + const tir::ForNode* r_for_loop = TVM_SREF_TO_FOR(r_loop_sref); + Array seq = Downcast(r_for_loop->body)->seq; + if (seq.size() != 3) { + return {state}; + } + for (auto& stmt : seq) { + if (!stmt.as()) { + return {state}; + } + } + + LoopRV r_loop_fused = state->sch->Fuse(state->tiles[r_indices_[0]]); + std::vector ret; + ret.push_back(state); + for (int stage : this->stages) { + State new_state = state->Copy(); + new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_stage, + Array{0, 0, stage - 2}); + new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_order, + Array{0, 1, 2}); + new_state->sch->Annotate(r_loop_fused, tir::attr::software_pipeline_async_stages, + Array{0}); + ret.push_back(std::move(new_state)); + } + return ret; +} + void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch, const tir::BlockRV& block) const { // Filter out invalid vector lanes according to the data type. diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index d8725a3060b1e..ff38756ff06be 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -148,6 +148,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode { std::vector TileLoopNest(State state) const; // SubRule 3. add read cache std::vector AddReadReuse(State state) const; + // SubRule 4. add async pipeline + std::vector AddAsyncPipeline(State state) const; // Do nothing; Inherited from ScheduleRuleNode void InitializeWithTuneContext(const TuneContext& context) final; @@ -192,6 +194,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode { int thread_warp_size_; /*! \brief The maximum number of threads to be used size of a thread warp */ int max_threads_per_block_; + /*! \brief All available async pipeline stages. */ + std::vector stages; /*! \brief The logging function */ PackedFunc logger; /*! \brief The function to overwrite the default condition for applying MultiLevelTiling. */