diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 3911a5254290..1675bcce05ed 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -137,9 +137,8 @@ class ScheduleRule : public runtime::ObjectRef { * \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended: * - NullOpt on CPU * - [blockIdx.x, vthread.x, threadIdx.x] on GPU - * \param use_tensor_core Whether to apply tensor core wmma intrinsic for the computation * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit - * \param vector_load_max_len The length of vector lane in vectorized cooperative fetching. + * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. * NullOpt means disable vectorization * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. @@ -147,9 +146,8 @@ class ScheduleRule : public runtime::ObjectRef { */ TVM_DLL static ScheduleRule MultiLevelTiling(String structure, // Optional> tile_binds, // - bool use_tensor_core, // Optional max_innermost_factor, // - Optional vector_load_max_len, // + Optional> vector_load_lens, // Optional> reuse_read, // Optional> reuse_write); /*! diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 0a05439b2341..edb789b0bd7f 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1364,6 +1364,20 @@ constexpr const char* pragma_loop_partition_hint = "pragma_loop_partition_hint"; /*! \brief Mark the tiling structure of blocks that are applied by rule Multi-Level-Tiling */ constexpr const char* meta_schedule_tiling_structure = "meta_schedule.tiling_structure"; +/*! + * \brief Mark that the loop should be further skip and bound to environment threads to enable + * cooperative fetching. + */ +constexpr const char* meta_schedule_cooperative_fetch = "meta_schedule.cooperative_fetch"; + +/*! \brief The allowed range of thread extent in thread bindings */ +constexpr const char* meta_schedule_thread_extent_low_inclusive = + "meta_schedule.thread_extent_low_inclusive"; + +/*! \brief The allowed range of thread extent in thread bindings */ +constexpr const char* meta_schedule_thread_extent_high_inclusive = + "meta_schedule.thread_extent_high_inclusive"; + /*! \brief Mark the block whose producer needs to be applied by rule Random-Compute-Location */ constexpr const char* meta_schedule_random_compute_producer = "meta_schedule.random_compute_producer"; diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index ce66323fd15b..b0fe8c8bdd75 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -19,6 +19,7 @@ from .add_rfactor import AddRFactor from .auto_inline import AutoInline from .cross_thread_reduction import CrossThreadReduction +from .multi_level_tiling import MultiLevelTiling, ReuseType from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll from .random_compute_location import RandomComputeLocation from .schedule_rule import PyScheduleRule, ScheduleRule diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py new file mode 100644 index 000000000000..2ff49168d0c6 --- /dev/null +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Multi-level tiling with reuse.""" +from typing import Any, Dict, List, NamedTuple, Optional + +from tvm._ffi import register_object + +from .. import _ffi_api +from .schedule_rule import ScheduleRule + + +class ReuseType(NamedTuple): + """Reuse type.""" + + req: str + levels: List[int] + scope: str + + def as_dict(self) -> Dict[str, Any]: + """Return the dict representation of the reuse type.""" + return { + "req": self.req, + "levels": self.levels, + "scope": self.scope, + } + + +@register_object("meta_schedule.MultiLevelTiling") +class MultiLevelTiling(ScheduleRule): + """Multi-level tiling with reuse. + + Parameters + ---------- + structure : str + The tiling structure. Recommended: + - 'SSRSRS' on CPU + - 'SSSRRSRS' on GPU + tile_bind : Optional[List[str]] + For each level of tiles, which thread axis it is bound to. Recommended: + - None on CPU + - [blockIdx.x, vthread.x, threadIdx.x] on GPU + max_innermost_factor : Optional[int] + The maximum size of the innermost factor. None means no limit + vector_load_lens : Optional[List[int]] + The length of vector lane in vectorized cooperative fetching. + None means disable vectorization + reuse_read : Optional[ReuseType] + Data reuse configuration for reading. None means no reuse. + reuse_write : Optional[ReuseType] + Data reuse configuration for writing. None means no reuse. + """ + + def __init__( + self, + structure: str, + tile_binds: Optional[List[str]] = None, + max_innermost_factor: Optional[int] = None, + vector_load_lens: Optional[List[int]] = None, + reuse_read: Optional[ReuseType] = None, + reuse_write: Optional[ReuseType] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleMultiLevelTiling, # type: ignore # pylint: disable=no-member + structure, + tile_binds, + max_innermost_factor, + vector_load_lens, + reuse_read.as_dict() if reuse_read is not None else None, + reuse_write.as_dict() if reuse_write is not None else None, + ) diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py index 464d2496a603..b149f20c52e3 100644 --- a/python/tvm/meta_schedule/testing/schedule_rule.py +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -19,8 +19,10 @@ AddRFactor, AutoInline, CrossThreadReduction, + MultiLevelTiling, ParallelizeVectorizeUnroll, RandomComputeLocation, + ReuseType, ScheduleRule, ) from tvm.target import Target @@ -65,6 +67,41 @@ def cross_thread_reduction(target: Target) -> ScheduleRule: raise NotImplementedError(f"{target.kind.name} is not supported") +def multi_level_tiling(target: Target) -> ScheduleRule: + """Default schedule rules for with multi-level tiling and reuse""" + if target.kind.name == "llvm": + return MultiLevelTiling( + structure="SSRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_lens=None, + reuse_read=None, + reuse_write=ReuseType( + req="may", + levels=[1, 2], + scope="global", + ), + ) + if target.kind.name == "cuda": + return MultiLevelTiling( + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], + max_innermost_factor=64, + vector_load_lens=[1, 2, 3, 4], + reuse_read=ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=ReuseType( + req="must", + levels=[3], + scope="local", + ), + ) + raise NotImplementedError(f"{target.kind.name} is not supported") + + def random_compute_location(target: Target) -> ScheduleRule: """Default schedule rules for with random-compute-location""" if target.kind.name == "llvm": diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc new file mode 100644 index 000000000000..d0bfff40fcbe --- /dev/null +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -0,0 +1,416 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "../utils.h" + +namespace tvm { +namespace tir { +/*! + * \brief Get the buffer dimensions for all the read buffers of a block, but marks the reduction + * buffers' dimensions as -1 + * \param block_sref The block to be processed + * \return The buffer dimensions for all the read buffers of a block, except for reduction buffers + * \note The method is not designed for generic analysis and relies on assumptions in the scenario + * of multi-level tiling, so it's intentionally kept inside this file not in the analysis header + */ +std::vector GetReadBufferNDims(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + const BufferNode* write_buffer = block->writes[0]->buffer.get(); + int n = block->reads.size(); + std::vector results(n, -1); + for (int i = 0; i < n; ++i) { + const BufferNode* read_buffer = block->reads[i]->buffer.get(); + if (read_buffer != write_buffer) { + results[i] = read_buffer->shape.size(); + } + } + return results; +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::BlockRV; +using tir::ExprRV; +using tir::IterVarType; +using tir::LoopRV; +using tir::Schedule; + +/*! + * \brief Configuration of data reuse type: + * 0) kNoReuse: no reuse is allowed, then no cache_read/write is performed. + * 1) kMayReuse: reuse is allowed, but no reuse is explored. + * 2) kMustReuse: reuse is allowed and no reuse is not explored. + */ +enum class ReuseType : int32_t { + kNoReuse = 0, + kMayReuse = 1, + kMustReuse = 2, +}; + +/*! + * \brief Converts a string to ReuseType. + * \param str The string to be converted. + * \return The converted ReuseType. + */ +ReuseType Str2ReuseType(const String& str) { + if (str == "no") { + return ReuseType::kNoReuse; + } else if (str == "may") { + return ReuseType::kMayReuse; + } else if (str == "must") { + return ReuseType::kMustReuse; + } else { + LOG(FATAL) << "ValueError: Unknown ReuseType: " << str; + throw; + } +} + +/*! \brief Configuration of data reuse patterns */ +struct ReuseConfig { + /*! \brief Type of data reuse: no-reuse, may-reuse or must-reuse */ + ReuseType req; + /*! \brief Which levels are caching stage inserted at */ + std::vector levels; + /*! \brief The storage scope */ + String scope; + + /*! \brief Default constructor: no data reuse */ + ReuseConfig() : req(ReuseType::kNoReuse) {} + + /*! \brief Construct from a configuration dictionary */ + explicit ReuseConfig(const Map& config) + : req(Str2ReuseType(Downcast(config.at("req")))), + levels(support::AsVector(Downcast>(config.at("levels")))), + scope(Downcast(config.at("scope"))) { + ICHECK_EQ(config.size(), 3); + } +}; + +/*! \brief The state of auto scheduling for the multi-level tiling rule */ +struct State { + /*! \brief The schedule to date */ + Schedule sch; + /*! \brief The block to be tiled */ + BlockRV block_rv; + /*! \brief The loop tiles */ + Array> tiles; + + /*! \brief Default constructor */ + explicit State(Schedule sch, BlockRV block_rv, Optional write_cache = NullOpt, + bool write_cache_is_added = false, Array> tiles = {}) + : sch(sch), block_rv(block_rv), tiles(tiles) {} +}; + +/*! + * \brief Helper to apply a sub-rule to a list of auto scheduling states + * \tparam FLambda The type of the sub-rule functor + * \param states The list of states to be applied + * \return The list of states after applying the sub-rule + */ +template +std::vector SubRule(std::vector states, FLambda sub_rule) { + std::vector results; + for (auto&& state : states) { + std::vector next = sub_rule(std::move(state)); + results.insert(results.end(), // + std::make_move_iterator(next.begin()), // + std::make_move_iterator(next.end())); + } + return results; +} + +/*! + * \brief The mega rule: multi-level tiling with data reuse + */ +class MultiLevelTilingNode : public ScheduleRuleNode { + public: + // SubRule 1. add write cache + inline std::vector AddWriteReuse(State state) const; + // SubRule 2. tile the loop nest + inline std::vector TileLoopNest(State state) const; + // SubRule 3. add read cache + inline std::vector AddReadReuse(State state) const; + + // Do nothing; Inherited from ScheduleRuleNode + void InitializeWithTuneContext(const TuneContext& context) final { + if (Optional v = context->target.value()->GetAttr("max_threads_per_block")) { + this->max_threads_per_block_ = v.value()->value; + if (Optional v = context->target.value()->GetAttr("thread_warp_size")) { + this->thread_warp_size_ = v.value()->value; + } else { + LOG(INFO) << "'thread_warp_size' is not defined in the target"; + } + } + } + + // Entry of the mega rule; Inherited from ScheduleRuleNode + Array Apply(const Schedule& sch, const BlockRV& block_rv) final { + if (!NeedsMultiLevelTiling(sch->state(), sch->GetSRef(block_rv))) { + return {sch}; + } + sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure); + + std::vector states{State(sch, block_rv)}; + states = SubRule(std::move(states), [&](State state) { return TileLoopNest(state); }); + states = SubRule(std::move(states), [&](State state) { return AddWriteReuse(state); }); + states = SubRule(std::move(states), [&](State state) { return AddReadReuse(state); }); + Array results; + for (auto&& state : states) { + results.push_back(std::move(state.sch)); + } + return results; + } + + public: + /*! + * \brief The tiling structure. Recommended: + * - 'SSRSRS' on CPU + * - 'SSSRRSRS' on GPU + */ + String structure; + /*! \brief For each level of tiles, which thread axis it is bound to */ + Array tile_binds; + /*! \brief The maximum size of the innermost factor */ + int max_innermost_factor; + /*! \brief The length of vector lane in vectorized cooperative fetching */ + std::vector vector_load_lens; + /*! \brief Data reuse configuration for reading */ + ReuseConfig reuse_read_; + /*! \brief Data reuse configuration for writing */ + ReuseConfig reuse_write_; + /*! \brief The indices of spatial tiles in `structure` */ + std::vector s_indices_; + /*! \brief The indices of reduction tiles in `structure` */ + std::vector r_indices_; + /*! \brief The size of the thread warp */ + int thread_warp_size_; + /*! \brief The maximum number of threads to be used size of a thread warp */ + int max_threads_per_block_; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("structure", &structure); + v->Visit("tile_binds", &tile_binds); + v->Visit("max_innermost_factor", &max_innermost_factor); + // `vector_load_lens` is not visited + // `reuse_read_` is not visited + // `reuse_write_` is not visited + // `s_indices_` is not visited + // `r_indices_` is not visited + // `thread_warp_size_` is not visited + // `max_threads_per_block` is not visited + } + + static constexpr const char* _type_key = "meta_schedule.MultiLevelTiling"; + TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode); +}; + +inline std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { + const ReuseConfig& config = this->reuse_write_; + if (config.req == ReuseType::kNoReuse) { + return {std::move(state)}; + } + std::vector results; + if (config.req == ReuseType::kMayReuse) { + // Case 1. If the write cache is already there, we don't need to add another. + Array consumer_rvs = state.sch->GetConsumers(state.block_rv); + if (consumer_rvs.size() == 1 && IsWriteCache(state.sch->GetSRef(consumer_rvs[0]))) { + for (int level : config.levels) { + State new_state = state; + new_state.sch = state.sch->Copy(); + new_state.sch->Seed(state.sch->ForkSeed()); + const LoopRV& loop_rv = new_state.tiles[level - 1].back(); + new_state.sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true); + results.push_back(std::move(new_state)); + } + results.push_back(state); + return results; + } else { + // Case 2. No write cache is added + State new_state(/*sch=*/state.sch->Copy(), /*block_rv=*/state.block_rv); + new_state.sch->Seed(state.sch->ForkSeed()); + results.emplace_back(std::move(new_state)); + } + } + + // Case 3. Add one write cache + BlockRV write_cache = state.sch->CacheWrite(/*block_rv=*/state.block_rv, /*read_buffer_index=*/0, + /*storage_scope=*/config.scope); + for (int level : config.levels) { + State new_state = state; + new_state.sch = state.sch->Copy(); + new_state.sch->Seed(state.sch->ForkSeed()); + const LoopRV& loop_rv = new_state.tiles[level - 1].back(); + new_state.sch->ReverseComputeAt(write_cache, loop_rv, true); + results.push_back(std::move(new_state)); + } + return results; +} + +inline std::vector MultiLevelTilingNode::TileLoopNest(State state) const { + Schedule& sch = state.sch; + const BlockRV& block_rv = state.block_rv; + // Step 1. Assuming trivial binding, pair the loops and their iter-var-types + Array loops = sch->GetLoops(block_rv); + std::vector iter_types = GetBlockVarTypes(sch->GetSRef(state.block_rv)); + ICHECK_EQ(loops.size(), iter_types.size()); + // Step 2. For each loop axis, tile it + int64_t spatial_loop_product = 1; + std::vector> tiles(s_indices_.size() + r_indices_.size()); + for (int i = 0, n = loops.size(); i < n; ++i) { + LoopRV loop = loops[i]; + const std::vector* idx = nullptr; + if (iter_types[i] == IterVarType::kDataPar) { + idx = &s_indices_; + if (spatial_loop_product != -1) { + if (const int64_t* extent = tir::GetLoopIntExtent(sch->Get(loop).get())) { + spatial_loop_product *= *extent; + } else { + spatial_loop_product = -1; + } + } + } else if (iter_types[i] == IterVarType::kCommReduce) { + idx = &r_indices_; + } else { + continue; + } + // Do the split + int n_tiles = idx->size(); + Array factors = sch->SamplePerfectTile( + /*loop=*/loop, + /*n=*/n_tiles, + /*max_innermost_factor=*/max_innermost_factor); + Array splits = sch->Split(/*loop=*/loop, + /*factors=*/{factors.begin(), factors.end()}); + // Put every tile to its slot + for (int j = 0; j < n_tiles; ++j) { + tiles[idx->at(j)].push_back(splits[j]); + } + } + // Step 3. Reorder to organize the tiles + sch->Reorder(support::ConcatArrayList(tiles.begin(), tiles.end())); + // Step 4. Bind the tiles to threads + int n_binds = std::min(tile_binds.size(), tiles.size()); + for (int i = 0; i < n_binds; ++i) { + LoopRV fused = sch->Fuse(tiles[i]); + sch->Bind(fused, tile_binds[i]); + tiles[i] = {fused}; + } + state.tiles = Array>{tiles.begin(), tiles.end()}; + if (this->thread_warp_size_ != -1) { + int64_t low_inclusive = 1; + int64_t high_inclusive = this->max_threads_per_block_; + if (spatial_loop_product > 2 * this->thread_warp_size_) { + low_inclusive = this->thread_warp_size_; + } + sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_low_inclusive, + Integer(low_inclusive)); + sch->Annotate(block_rv, tir::attr::meta_schedule_thread_extent_high_inclusive, + Integer(high_inclusive)); + } + return {state}; +} + +inline std::vector MultiLevelTilingNode::AddReadReuse(State state) const { + const ReuseConfig& config = this->reuse_read_; + if (config.req == ReuseType::kNoReuse) { + return {std::move(state)}; + } + ICHECK(config.req != ReuseType::kMayReuse); + const BlockRV& block_rv = state.block_rv; + std::vector results; + results.reserve(config.levels.size()); + for (int level : config.levels) { + Schedule sch = state.sch->Copy(); + sch->Seed(state.sch->ForkSeed()); + const LoopRV& loop_rv = state.tiles[level - 1].back(); + // Enumerate all buffers that are read but not written + std::vector read_buffer_ndims = tir::GetReadBufferNDims(sch->GetSRef(block_rv)); + for (int i = 0, n_reads = read_buffer_ndims.size(); i < n_reads; ++i) { + int buffer_ndim = read_buffer_ndims[i]; + if (buffer_ndim == -1) { + continue; + } + // Do cache_read + BlockRV cache_read_block = sch->CacheRead(block_rv, i, config.scope); + // Insert cache_read block to the proper place + sch->ComputeAt(cache_read_block, loop_rv, true); + // Fuse the iterators of the cache_read + Array buffer_loops = sch->GetLoops(cache_read_block); + LoopRV fused = sch->Fuse(Array{buffer_loops.end() - buffer_ndim, // + buffer_loops.end()}); + // Annotate cooperative fetching + if (!vector_load_lens.empty()) { + int n = vector_load_lens.size(); + double prob = 1.0 / n; + ExprRV vector_load_len = + sch->SampleCategorical(support::AsArray(vector_load_lens), + Array(n, FloatImm(DataType::Float(64), prob))); + sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch, + vector_load_len); + } + } + State new_state = state; + new_state.sch = sch; + results.push_back(std::move(new_state)); + } + return results; +} + +// Constructor + +ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional> tile_binds, + Optional max_innermost_factor, + Optional> vector_load_lens, + Optional> reuse_read, + Optional> reuse_write) { + ObjectPtr n = make_object(); + n->structure = structure; + n->tile_binds = tile_binds.value_or({}); + n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; + n->vector_load_lens = vector_load_lens.defined() + ? support::AsVector(vector_load_lens.value()) + : std::vector(); + n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig(); + n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig(); + for (int i = 0, len = structure.size(); i < len; ++i) { + char c = structure.data()[i]; + if (c == 'S') { + n->s_indices_.push_back(i); + } else if (c == 'R') { + n->r_indices_.push_back(i); + } else { + LOG(FATAL) << "ValueError: Invalid tiling structure: " << structure; + } + } + n->thread_warp_size_ = -1; + n->max_threads_per_block_ = -1; + return ScheduleRule(n); +} + +TVM_REGISTER_NODE_TYPE(MultiLevelTilingNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTiling") + .set_body_typed(ScheduleRule::MultiLevelTiling); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/support/array.h b/src/support/array.h index 95b4f58a2e22..218150f9dba0 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -100,6 +100,29 @@ inline Array AsArray(const ShapeTuple& shape) { return result; } +/*! + * \brief Concatenate a list of arrays into a single array + * \tparam T The type of elements in the arrays + * \tparam Iterator The type of the iterator into the list of arrays + * \param begin The begin iterator to the array list + * \param end The end iterator to the array list + * \return The concatenated array + */ +template +inline Array ConcatArrayList(Iterator begin, Iterator end) { + int size = 0; + for (Iterator it = begin; it != end; ++it) { + size += (*it).size(); + } + Array result; + result.reserve(size); + for (Iterator it = begin; it != end; ++it) { + const auto& item = *it; + result.insert(result.end(), item.begin(), item.end()); + } + return result; +} + /********** Implementation details of AsVector **********/ namespace details { diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index 636cc7d0a5db..591201312cd2 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -175,6 +175,20 @@ bool IsOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& scope_root_sref); +/*! + * \brief Extracts the types of the block vars + * \param block_sref The block to be checked + * \return A vector of types of the block vars + */ +std::vector GetBlockVarTypes(const StmtSRef& block_sref); + +/*! + * \brief Checks if a block could be considered as a "write cache" + * \param block_sref The block to be checked + * \return A boolean flag indicating if the block is a write cache + */ +bool IsWriteCache(const StmtSRef& block_sref); + /******** Binding ********/ /*! * \brief Verifies if the block binding in a specific BlockRealize is an affine binding. diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index be5e55d4ec70..1579f9154fe6 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -408,6 +408,33 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, } } +std::vector GetBlockVarTypes(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + std::vector results; + results.reserve(block->iter_vars.size()); + for (const IterVar& iter_var : block->iter_vars) { + results.push_back(iter_var->iter_type); + } + return results; +} + +bool IsWriteCache(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + if (block->writes.size() != 1) { + return false; + } + const BufferRegion& write_region = block->writes[0]; + for (const BufferRegion& read_region : block->reads) { + bool exists, surjective, injective, ordered, no_const_read, no_shift_read; + std::tie(exists, surjective, injective, ordered, no_const_read, no_shift_read) = + AnalyzeReadWritePattern(read_region, write_region); + if (!(injective && ordered)) { + return false; + } + } + return true; +} + /******** Binding ********/ bool IsAffineBinding(const BlockRealize& realize, const Map& loop_var_ranges, diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py new file mode 100644 index 000000000000..c6a63aae7427 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -0,0 +1,280 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm.meta_schedule.testing.schedule_rule import ( + multi_level_tiling, +) +from tvm.meta_schedule.testing.space_generation import check_trace +from tvm.meta_schedule.tune_context import TuneContext +from tvm.te import create_prim_func +from tvm.meta_schedule.testing import te_workload +from tvm.target import Target + + +def _create_context(mod, target, rule) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + space_generator=PostOrderApply(), + sch_rules=[rule], + task_name="test", + ) + ctx.space_generator.initialize_with_tune_context(ctx) + for sch_rule in ctx.sch_rules: + sch_rule.initialize_with_tune_context(ctx) + return ctx + + +def test_cpu_matmul(): + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + 'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + 'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', + "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + ], + ] + target = Target("llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_cpu_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + "b24, = sch.get_consumers(block=b0)", + "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + "b24, = sch.get_consumers(block=b0)", + "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=1)", + ], + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", + "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7])", + "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", + "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15])", + "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", + "l22, l23 = sch.split(loop=l3, factors=[v20, v21])", + "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", + ], + ] + # pylint: enable=line-too-long + target = Target("llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 3 + check_trace(spaces, expected) + + +def test_cuda_matmul(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l1, n=5, max_innermost_factor=64)", + "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8])", + "v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", + "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18])", + "v24, v25, v26 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64)", + "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26])", + "sch.reorder(l9, l19, l10, l20, l11, l21, l27, l28, l12, l22, l29, l13, l23)", + "l30 = sch.fuse(l9, l19)", + 'sch.bind(loop=l30, thread_axis="blockIdx.x")', + "l31 = sch.fuse(l10, l20)", + 'sch.bind(loop=l31, thread_axis="vthread.x")', + "l32 = sch.fuse(l11, l21)", + 'sch.bind(loop=l32, thread_axis="threadIdx.x")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)', + 'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=1)", + 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=1)", + "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", + "l41 = sch.fuse(l39, l40)", + "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', + 'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=1)", + "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", + "l50 = sch.fuse(l48, l49)", + "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)', + ] + ] + # pylint: enable=line-too-long + target = Target("cuda --max_threads_per_block=1024 --thread_warp_size=32", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +def test_cuda_matmul_relu(): + # pylint: disable=line-too-long + expected = [ + [ + 'b0 = sch.get_block(name="C", func_name="main")', + 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', + "l1, l2, l3 = sch.get_loops(block=b0)", + "v4, v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l1, n=5, max_innermost_factor=64)", + "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8])", + "v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", + "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18])", + "v24, v25, v26 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64)", + "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26])", + "sch.reorder(l9, l19, l10, l20, l11, l21, l27, l28, l12, l22, l29, l13, l23)", + "l30 = sch.fuse(l9, l19)", + 'sch.bind(loop=l30, thread_axis="blockIdx.x")', + "l31 = sch.fuse(l10, l20)", + 'sch.bind(loop=l31, thread_axis="vthread.x")', + "l32 = sch.fuse(l11, l21)", + 'sch.bind(loop=l32, thread_axis="threadIdx.x")', + 'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', + "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=1)", + 'b34 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', + "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=1)", + "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", + "l41 = sch.fuse(l39, l40)", + "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', + 'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=1)", + "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", + "l50 = sch.fuse(l48, l49)", + "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)', + ] + ] + # pylint: enable=line-too-long + target = Target("cuda", host="llvm") + ctx = _create_context( + create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ), + target=target, + rule=multi_level_tiling(target=target), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + check_trace(spaces, expected) + + +if __name__ == "__main__": + test_cpu_matmul() + test_cpu_matmul_relu() + test_cuda_matmul() + test_cuda_matmul_relu()