diff --git a/apps/README.md b/apps/README.md deleted file mode 100644 index 01630a3ee8c1..000000000000 --- a/apps/README.md +++ /dev/null @@ -1,28 +0,0 @@ - - - - - - - - - - - - - - - - - -# TVM Application Extensions and Examples -This folder contains various extension projects using TVM, -they also serve as examples on how to use TVM in your own project. - - -- [extension](extension) How to extend TVM C++ api along with python API. -- [ios_rpc](ios_rpc) iOS RPC server. -- [android_rpc](android_rpc) Android RPC server. -- [benchmark](benchmark) Example end to end compilation benchmarks -- [howto_deploy](howto_deploy) Tutorial on how to deploy TVM with minimum code dependency. -- [wasm_standalone](wasm-standalone) WebAssembly standalone for deep learning framework with TVM runtime. diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 1d218c6a7c61..e90c49fc4f6d 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -89,8 +89,7 @@ class PrimFuncNode : public BaseFuncNode { * normal statements, making buffer_map as first class citizen of PrimFunc * will make program analysis much easier. * - * Prior to buffer flattening, which is performed either in - * StorageFlatten for TE-based schedules or in FlattenBuffer for + * Prior to buffer flattening, which is performed FlattenBuffer for * TIR-based schedules, these buffer objects are used directly in * the body of the function. After buffer flattening, these buffer * objects remain unflattened for use in argument validation, but diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index b03b4d3a12a1..ea39bca444c7 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -58,55 +58,6 @@ TVM_DLL Pass CreatePrimFuncPass( const runtime::TypedPackedFunc& pass_func, int opt_level, String name, tvm::Array required, bool traceable = false); -/*! - * \brief Inject prefetch instructions into stmt. - * - * \return The pass. - */ -TVM_DLL Pass InjectPrefetch(); - -// TODO(tvm-team): consolidate configs to the PassContext -/*! - * \brief Flatten the multi-dimensional read/write - * to single dimensional Load/Store - * - * \param cache_line_size The size of CPU cache line. - * \param create_bound_attribute Whether to create bound attributes. - * - * \return The Pass - */ -TVM_DLL Pass StorageFlatten(int cache_line_size, bool create_bound_attribute = false); - -/*! - * \brief Inject copy intrinsics with optional pad. - * - * \param pragma_key The pragma key for hint of copy. - * \param fintrin The function with signature - * - * Stmt fintrin(Buffer src, - * Buffer dst, - * Array pad_before, - * Array pad_after, - * Expr pad_value) - * \return The pass. - */ -TVM_DLL Pass InjectCopyIntrin(String pragma_key, runtime::PackedFunc fintrin); - -/*! - * \brief Detect and insert sync points to co-processor. - * - * \return The pass. - */ -TVM_DLL Pass CoProcSync(); - -/*! - * \brief Lift common attrs with attr_key to outer scope. - * - * \param attr_key The attribute key to be checked. - * \return The pass. - */ -TVM_DLL Pass LiftAttrScope(String attr_key); - /*! * \brief partition loops in the stmt. * @@ -573,15 +524,6 @@ TVM_DLL Pass LowerOpaqueBlock(); */ TVM_DLL Pass FlattenBuffer(); -/* - * \brief Flatten the multi-dimensional read/write - * to two dimensional texture Load/Store and realize - * texture buffer allocations. - * - * \return The Pass - */ -TVM_DLL Pass TextureFlatten(); - /* * \brief Lower VTCM allocations * diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py index b7141bae30de..c8019c922981 100644 --- a/python/tvm/tir/pipeline.py +++ b/python/tvm/tir/pipeline.py @@ -31,11 +31,6 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I pass_ctx = tvm.transform.PassContext.current() config = pass_ctx.config passes = [ - tir.transform.InjectPrefetch(), - tir.transform.TextureFlatten(), - tir.transform.StorageFlatten( - 64, bool(config.get("tir.instrument_bound_checkers", False)) - ), tir.transform.LowerCrossThreadReduction(), tir.transform.LowerInitBlock(), tir.transform.PlanAndUpdateBufferAllocationLocation(), diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 99a2e1e66485..eb38c5f77507 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -48,112 +48,6 @@ def _transform(func, mod, ctx): return _fpass.prim_func_pass(_transform, opt_level=0, name="Apply") # type: ignore -def InjectPrefetch(): - """Inject prefetch instructions into stmt. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.InjectPrefetch() # type: ignore - - -def ApplyLayoutTransforms(): - """Reshape buffers that appear in the "layout_transform_map" - fucntion attribute. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - - """ - return _ffi_api.ApplyLayoutTransforms() # type: ignore - - -def StorageFlatten(cache_line_size, create_bound_attribute: bool = False): - """Flatten the multi-dimensional read/write to 1D. - - - Parameters - ---------- - cache_line_size: int - The size of CPU cache line. - - create_bound_attribute: - Whether to create bound attributes. - - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.StorageFlatten(cache_line_size, create_bound_attribute) # type: ignore - - -def TextureFlatten(): - """Flatten the multi-dimensional read/write to 2D. - - - Parameters - ---------- - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.TextureFlatten() # type: ignore - - -def InjectCopyIntrin(pragma_key: str, fintrin): - """Inject virtual thread loops. - - Parameters - ---------- - pragma_key : str - The pragma key for hint of copy. - - fintrin : function - The function with signature copyintrin(src, dst, pad_before, pad_after, pad_value) - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.InjectCopyIntrin(pragma_key, fintrin) # type: ignore - - -def CoProcSync(): - """Detect and insert sync points to co-processor. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.CoProcSync() # type: ignore - - -def LiftAttrScope(attr_key: str): - """Lift common attrs with attr_key to outer scope. - - Parameters - ---------- - attr_key : str - The attribute key to be checked. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.LiftAttrScope(attr_key) # type: ignore - - def LoopPartition(): """Inject virtual thread loops. @@ -682,7 +576,7 @@ def NarrowDataType(target_bits: int): Note ---- - Run this pass after StorageFlatten. + Run this pass after FlattenBuffer. """ return _ffi_api.NarrowDataType(target_bits) # type: ignore diff --git a/src/README.md b/src/README.md deleted file mode 100644 index 8a5368f03c65..000000000000 --- a/src/README.md +++ /dev/null @@ -1,37 +0,0 @@ - - - - - - - - - - - - - - - - - -# Code Organization - -Header files in include are public APIs that share across modules. -There can be internal header files within each module that sit in src. - -## Modules -- arith: Arithmetic expression and set simplification. -- auto\_scheduler: The template-free auto-tuning module. -- autotvm: The template-based auto-tuning module. -- contrib: Contrib extension libraries. -- driver: Compilation driver APIs. -- ir: Common IR infrastructure. -- node: The base infra for IR/AST nodes that is dialect independent. -- relay: Relay IR, high-level optimizations. -- runtime: Minimum runtime related codes. -- support: Internal support utilities. -- target: Hardware targets. -- tir: Tensor IR, low-level optimizations. -- te: Tensor expression DSL. -- topi: Tensor Operator Inventory. diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc index 17283063543d..95af6d4dfa30 100644 --- a/src/meta_schedule/postproc/verify_gpu_code.cc +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -153,10 +153,6 @@ class VerifyGPUCodeNode : public PostprocNode { try { auto pass_list = Array(); // Phase 1 - // First three passes are not needed in TIR schedule. - // pass_list.push_back(tir::transform::InjectPrefetch()); - // pass_list.push_back(tir::transform::TextureFlatten()); - // pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); pass_list.push_back(tir::transform::LowerCrossThreadReduction()); pass_list.push_back(tir::transform::LowerInitBlock()); pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index fa2cd6b09d13..36638576d387 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -141,8 +141,7 @@ int CodeGenStackVM::GetVarID(const VarNode* v) const { void CodeGenStackVM::VisitExpr_(const BufferLoadNode* op) { ICHECK_EQ(op->indices.size(), 1) << "StackVM expects flat 1-d buffers. " - << "Has StorageFlatten (TE-based schedules) or " - << "FlattenBuffer (TIR-based schedules) been run?"; + << "Has FlattenBuffer been run?"; auto index = op->indices[0]; this->Push(op->buffer->data); @@ -160,8 +159,7 @@ void CodeGenStackVM::VisitExpr_(const BufferLoadNode* op) { void CodeGenStackVM::VisitStmt_(const BufferStoreNode* op) { ICHECK_EQ(op->indices.size(), 1) << "StackVM expects flat 1-d buffers. " - << "Has StorageFlatten (TE-based schedules) or " - << "FlattenBuffer (TIR-based schedules) been run?"; + << "Has FlattenBuffer been run?"; auto index = op->indices[0]; this->Push(op->buffer->data); diff --git a/src/tir/transforms/bound_checker.cc b/src/tir/transforms/bound_checker.cc index 358f864d3a24..616b47f29403 100644 --- a/src/tir/transforms/bound_checker.cc +++ b/src/tir/transforms/bound_checker.cc @@ -40,7 +40,7 @@ namespace tvm { namespace tir { // TODO(Lunderberg): Move this pass to be before -// StorageFlatten/FlattenBuffer. That will simplify this pass, +// FlattenBuffer. That will simplify this pass, // because it can check directly against the buffer limits. class BoundCollector : public StmtVisitor { public: diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc deleted file mode 100644 index 65ee33d2dad6..000000000000 --- a/src/tir/transforms/coproc_sync.cc +++ /dev/null @@ -1,661 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file coproc_sync.cc - */ -#include -#include -#include -#include -#include - -#include -#include - -#include "ir_utils.h" -#include "storage_access.h" - -namespace tvm { -namespace tir { - -// Visitor to find touched set by co-processor scope. -class CoProcTouchedBuffer : public StmtExprVisitor { - public: - void VisitExpr_(const BufferLoadNode* op) final { - if (in_scope_) { - touched_[op->buffer->data.get()].coproc = true; - } else { - touched_[op->buffer->data.get()].normal = true; - } - StmtExprVisitor::VisitExpr_(op); - } - void VisitStmt_(const BufferStoreNode* op) final { - if (in_scope_) { - touched_[op->buffer->data.get()].coproc = true; - } else { - touched_[op->buffer->data.get()].normal = true; - } - StmtExprVisitor::VisitStmt_(op); - } - void VisitExpr_(const CallNode* op) final { - if (op->op.same_as(builtin::tvm_access_ptr())) { - const VarNode* buffer = op->args[1].as(); - if (in_scope_) { - touched_[buffer].coproc = true; - } else { - touched_[buffer].normal = true; - } - } - StmtExprVisitor::VisitExpr_(op); - } - void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::coproc_scope && !in_scope_) { - in_scope_ = true; - IterVar iv = Downcast(op->node); - coproc_.insert(iv); - StmtExprVisitor::VisitStmt_(op); - in_scope_ = false; - } else { - StmtExprVisitor::VisitStmt_(op); - } - } - - // Touch Entry - struct TouchEntry { - bool normal{false}; - bool coproc{false}; - }; - std::unordered_map touched_; - std::unordered_set coproc_; - - private: - bool in_scope_{false}; -}; - -// Synchronization planning with co-processor. -class CoProcSyncPlanner : public StorageAccessVisitor { - public: - explicit CoProcSyncPlanner(const std::unordered_set& touched, - const std::string& coproc_name) - : touched_(touched), coproc_name_(coproc_name) {} - - void Plan(const Stmt& stmt) { - this->VisitStmt(stmt); - PlanSync(scope_.back(), nullptr, true); - if (sync_.size() == 0) { - sync_[stmt.get()] = GetSync(coproc_name_ + ".coproc_sync"); - } - } - - // Write synchronization to be inserted before or after stmt. - std::unordered_map> sync_; - - protected: - bool Enabled(const VarNode* buf, const StorageScope& scope) const final { - return touched_.count(buf); - } - - // Plan the sync - std::vector Summarize(std::vector seq, const ForNode* loop) final { - return PlanSync(seq, loop, false); - } - - private: - // Plan write synchronization if write is not coherent - std::vector PlanSync(std::vector seq, const ForNode* loop, - bool force_sync_at_end) { - // detect write barriers - // access by the co-processor. - std::vector co_access; - bool contain_sync = false; - - auto find_conflict = [&](const AccessEntry& acc) { - for (const AccessEntry& x : co_access) { - if (x.buffer.same_as(acc.buffer) && - ((acc.type == kRead && x.type == kWrite) || acc.type == kWrite)) { - return true; - } - } - return false; - }; - for (size_t i = 0; i < seq.size(); ++i) { - const StmtEntry& s = seq[i]; - bool sync_write = false; - for (const AccessEntry& acc : s.access) { - if (acc.threads.size() == 0 && find_conflict(acc)) { - sync_write = true; - break; - } - if (acc.type == kSync) { - co_access.clear(); - contain_sync = true; - } - } - if (sync_write) { - ICHECK_NE(i, 0U); - sync_[seq[i - 1].stmt] = GetSync(co_access); - co_access.clear(); - contain_sync = true; - } - for (const AccessEntry& acc : s.access) { - if (acc.threads.size() != 0) { - co_access.push_back(acc); - } - } - } - bool sync_at_end = force_sync_at_end; - if (loop != nullptr && !sync_at_end) { - // loop carray dependency - for (size_t i = 0; i < seq.size(); ++i) { - const StmtEntry& s = seq[i]; - for (const AccessEntry& acc : s.access) { - if (acc.threads.size() == 0 && find_conflict(acc)) { - sync_at_end = true; - break; - } - } - if (sync_.count(s.stmt) || sync_at_end) break; - } - } - if (sync_at_end && co_access.size() != 0) { - ICHECK_NE(seq.size(), 0); - contain_sync = true; - sync_[seq.back().stmt] = GetSync(co_access); - co_access.clear(); - } - if (contain_sync) { - AccessEntry e; - e.type = kSync; - co_access.insert(co_access.begin(), e); - } - return co_access; - } - // Add write Synchronization - std::vector GetSync(const std::vector& co_access) { - // Does not consider memory coherence, need runtime. - ICHECK_NE(co_access.size(), 0U); - ICHECK_EQ(co_access[0].threads.size(), 1U); - return GetSync(coproc_name_ + ".coproc_sync"); - } - - std::vector GetSync(std::string sync_name) { - return {Evaluate(Call(DataType::Int(32), Op::Get("tir." + sync_name), {}))}; - } - - const std::unordered_set& touched_; - std::string coproc_name_; -}; - -// Detect memory barriers when coproc read/write memory -class CoProcBarrierDetector : public StorageAccessVisitor { - public: - explicit CoProcBarrierDetector(const std::unordered_set& touched, - const std::string& coproc_name) - : touched_(touched) { - read_barrier_name_ = "tir." + coproc_name + ".coproc_read_barrier"; - write_barrier_name_ = "tir." + coproc_name + ".coproc_write_barrier"; - } - - void PlanReadBarrier(const Stmt& stmt) { - read_barrier_ = true; - this->VisitStmt(stmt); - PlanReadBarrier(scope_.back(), nullptr); - } - void PlanWriteBarrier(const Stmt& stmt) { - read_barrier_ = false; - this->VisitStmt(stmt); - PlanWriteBarrier(scope_.back(), nullptr); - } - - std::unordered_map> barrier_before_; - std::unordered_map> barrier_after_; - - protected: - bool Enabled(const VarNode* buf, const StorageScope& scope) const final { - return touched_.count(buf); - } - - // Plan the sync - std::vector Summarize(std::vector seq, const ForNode* loop) final { - if (read_barrier_) { - return PlanReadBarrier(seq, loop); - } else { - return PlanWriteBarrier(seq, loop); - } - } - - private: - // Plan write barrier at Read after write point. - std::vector PlanWriteBarrier(std::vector seq, const ForNode* loop) { - std::vector read_seq; - std::unordered_map> write_set; - - auto fupdate = [&](size_t i, const AccessEntry& acc) { - auto it = write_set.find(acc.buffer.get()); - if (it != write_set.end()) { - ICHECK_NE(i, 0U); - barrier_after_[seq[i - 1].stmt].push_back(MakeBarrier(write_barrier_name_, it->second)); - write_set.erase(it); - } - }; - for (size_t i = 0; i < seq.size(); ++i) { - const StmtEntry& s = seq[i]; - for (const AccessEntry& acc : s.access) { - if (acc.threads.size() == 0 && acc.type == kRead) { - fupdate(i, acc); - read_seq.push_back(acc); - } - } - for (const AccessEntry& acc : s.access) { - if (acc.threads.size() != 0 && acc.type == kWrite) { - write_set[acc.buffer.get()].push_back(acc); - } - } - } - // loop carry - if (loop != nullptr) { - for (const AccessEntry& acc : read_seq) { - fupdate(seq.size(), acc); - } - } - for (const auto& kv : write_set) { - read_seq.insert(read_seq.end(), kv.second.begin(), kv.second.end()); - } - return read_seq; - } - - std::vector PlanReadBarrier(std::vector seq, const ForNode* loop) { - std::vector write_seq; - std::unordered_map> read_set; - - auto fupdate = [&](size_t i, const AccessEntry& acc) { - auto it = read_set.find(acc.buffer.get()); - if (it != read_set.end()) { - ICHECK_NE(i, seq.size()); - barrier_before_[seq[i].stmt].push_back(MakeBarrier(read_barrier_name_, it->second)); - read_set.erase(it); - } - }; - - for (size_t i = seq.size(); i != 0; --i) { - const StmtEntry& s = seq[i - 1]; - for (const AccessEntry& acc : s.access) { - if (acc.threads.size() == 0 && acc.type == kWrite) { - fupdate(i, acc); - write_seq.push_back(acc); - } - } - for (const AccessEntry& acc : s.access) { - if (acc.threads.size() != 0 && acc.type == kRead) { - read_set[acc.buffer.get()].push_back(acc); - } - } - } - // loop carry - if (loop != nullptr) { - for (const AccessEntry& acc : write_seq) { - fupdate(0, acc); - } - } - for (const auto& kv : read_set) { - write_seq.insert(write_seq.end(), kv.second.begin(), kv.second.end()); - } - return write_seq; - } - - Stmt MakeBarrier(const std::string& func, const std::vector& wvec) { - // insert write point - Array wset; - for (const AccessEntry& acc : wvec) { - ICHECK(acc.dtype == wvec[0].dtype); - ICHECK_EQ(acc.touched.size(), 1) << "CoProcBarrierDetector expects flat memory"; - wset.push_back(acc.touched[0]); - } - Range none; - Range r = arith::Union(wset).CoverRange(none); - ICHECK(r.defined()) << "Cannot deduce write range of " << wvec[0].buffer; - PrimExpr min = r->min; - PrimExpr extent = r->extent; - return Evaluate(Call(DataType::Int(32), Op::Get(func), - {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent})); - } - // Write barrier name - bool read_barrier_{false}; - std::string read_barrier_name_; - std::string write_barrier_name_; - const std::unordered_set& touched_; -}; - -class CoProcInstDepDetector : public StmtVisitor { - public: - explicit CoProcInstDepDetector(const IterVar& coproc_axis, const std::string& coproc_name) - : coproc_axis_(coproc_axis) { - sync_push_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_push"); - sync_pop_op_ = Op::Get("tir." + coproc_name + ".coproc_dep_pop"); - } - - void Plan(const Stmt& stmt) { - this->VisitStmt(stmt); - if (last_state_.node != nullptr) { - MatchFixEnterPop(first_state_); - MatchFixExitPush(last_state_); - } - } - - void VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::coproc_scope && op->node.same_as(coproc_axis_)) { - const IntImmNode* ctx_id = op->value.as(); - ICHECK(ctx_id != nullptr); - curr_state_.clear(); - curr_state_.node = op->body.get(); - curr_state_.enter_ctx.insert(ctx_id->value); - curr_state_.exit_ctx.insert(ctx_id->value); - UpdateState(); - } else { - StmtVisitor::VisitStmt_(op); - } - } - - void VisitStmt_(const ForNode* op) final { - SyncState temp_first, temp_last; - std::swap(first_state_, temp_first); - std::swap(last_state_, temp_last); - this->VisitStmt(op->body); - curr_state_.clear(); - if (last_state_.node != nullptr) { - curr_state_.node = op; - ICHECK(first_state_.node != nullptr); - // loop carry dependency - InjectSync(last_state_, first_state_, &(curr_state_.exit_push), &(curr_state_.enter_pop)); - curr_state_.enter_ctx = first_state_.enter_ctx; - curr_state_.exit_ctx = last_state_.exit_ctx; - } - std::swap(first_state_, temp_first); - std::swap(last_state_, temp_last); - if (curr_state_.node != nullptr) { - UpdateState(); - } - } - - void VisitStmt_(const IfThenElseNode* op) final { - SyncState temp_first, temp_last, curr_state; - std::swap(first_state_, temp_first); - std::swap(last_state_, temp_last); - { - // then stmt - this->VisitStmt(op->then_case); - if (last_state_.node != nullptr) { - curr_state.node = op; - MatchFixEnterPop(first_state_); - MatchFixExitPush(last_state_); - curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end()); - curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end()); - } - first_state_.clear(); - last_state_.clear(); - } - if (op->else_case) { - this->VisitStmt(op->else_case.value()); - if (last_state_.node != nullptr) { - curr_state.node = op; - MatchFixEnterPop(first_state_); - MatchFixExitPush(last_state_); - curr_state.enter_ctx.insert(first_state_.enter_ctx.begin(), first_state_.enter_ctx.end()); - curr_state.exit_ctx.insert(last_state_.exit_ctx.begin(), last_state_.exit_ctx.end()); - } - } - // update in the trace. - std::swap(first_state_, temp_first); - std::swap(last_state_, temp_last); - std::swap(curr_state_, curr_state); - if (curr_state_.node != nullptr) { - UpdateState(); - } - } - - void VisitStmt_(const WhileNode* op) final { - // TODO(masahi): Do we need a special handling for While nodes? - LOG(FATAL) << "WhileNode not supported in CoProcSync."; - } - - // insert before is stored in reverse order - // the first element is closest to the node. - std::unordered_map> insert_before_; - std::unordered_map> insert_after_; - - private: - // state in the sync entry - struct SyncState { - // The statement of the state. - const Object* node{nullptr}; - // Set of all possible contexts in the entering moment. - std::unordered_set enter_ctx; - // Set of all possible contexts in the exit moment. - std::unordered_set exit_ctx; - // existing pop performed at enter - std::vector> enter_pop; - // existing push performed at exit - std::vector> exit_push; - // clear the state - void clear() { - node = nullptr; - enter_ctx.clear(); - exit_ctx.clear(); - enter_pop.clear(); - exit_push.clear(); - } - }; - // inject proper sync into the pair - // record the push/pop sequence that could be possibly un-matched. - // return the push/pop message at enter/exit of the Block - // after considering the existing unmatcheded events and added events - void InjectSync(const SyncState& prev, const SyncState& next, - std::vector>* prev_exit_push, - std::vector>* next_enter_pop) { - prev_exit_push->clear(); - next_enter_pop->clear(); - // quick path - if (prev.exit_push.size() == 0 && next.enter_pop.size() == 0 && prev.exit_ctx.size() == 1 && - next.enter_ctx.size() == 1) { - int from = *prev.exit_ctx.begin(); - int to = *next.enter_ctx.begin(); - if (from != to) { - insert_after_[prev.node].emplace_back(MakePush(from, to)); - insert_before_[next.node].emplace_back(MakePop(from, to)); - prev_exit_push->emplace_back(std::make_pair(from, to)); - next_enter_pop->emplace_back(std::make_pair(from, to)); - } - return; - } - // complicate path. - std::vector> vpush = prev.exit_push; - std::vector> vpop = next.enter_pop; - std::vector> pending; - for (int from : prev.exit_ctx) { - for (int to : next.enter_ctx) { - if (from != to) { - pending.emplace_back(std::make_pair(from, to)); - } - } - } - // policy 1 - std::vector prev_after, next_before; - for (const std::pair& p : pending) { - if (std::find(prev.exit_push.begin(), prev.exit_push.end(), p) == prev.exit_push.end()) { - vpush.push_back(p); - prev_after.emplace_back(MakePush(p.first, p.second)); - } - if (std::find(next.enter_pop.begin(), next.enter_pop.end(), p) == next.enter_pop.end()) { - vpop.push_back(p); - next_before.emplace_back(MakePop(p.first, p.second)); - } - } - // fix pending - for (const std::pair& p : vpush) { - if (std::find(vpop.begin(), vpop.end(), p) == vpop.end()) { - prev_after.emplace_back(MakePop(p.first, p.second)); - } else { - prev_exit_push->push_back(p); - } - } - for (const std::pair& p : vpop) { - if (std::find(vpush.begin(), vpush.end(), p) == vpush.end()) { - next_before.emplace_back(MakePush(p.first, p.second)); - } else { - next_enter_pop->push_back(p); - } - } - if (prev_after.size() != 0) { - auto& v1 = insert_after_[prev.node]; - v1.insert(v1.end(), prev_after.begin(), prev_after.end()); - } - if (next_before.size() != 0) { - auto& v2 = insert_before_[next.node]; - v2.insert(v2.end(), next_before.begin(), next_before.end()); - } - } - - void MatchFixEnterPop(const SyncState& state) { - if (state.enter_pop.size() == 0) return; - auto& vec = insert_before_[state.node]; - for (const std::pair& p : state.enter_pop) { - vec.push_back(MakePush(p.first, p.second)); - } - } - - void MatchFixExitPush(const SyncState& state) { - if (state.exit_push.size() == 0) return; - auto& vec = insert_after_[state.node]; - for (const std::pair& p : state.exit_push) { - vec.push_back(MakePop(p.first, p.second)); - } - } - - void UpdateState() { - if (last_state_.node != nullptr) { - std::vector> t1, t2; - InjectSync(last_state_, curr_state_, &t1, &t2); - std::swap(last_state_, curr_state_); - } else { - ICHECK(first_state_.node == nullptr); - first_state_ = curr_state_; - last_state_ = curr_state_; - } - } - - Stmt MakePush(int from, int to) { - return Evaluate(Call(DataType::Int(32), sync_push_op_, - {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)})); - } - Stmt MakePop(int from, int to) { - return Evaluate(Call(DataType::Int(32), sync_pop_op_, - {make_const(DataType::Int(32), from), make_const(DataType::Int(32), to)})); - } - // sync states. - SyncState first_state_, last_state_, curr_state_; - // Variables - IterVar coproc_axis_; - Op sync_push_op_, sync_pop_op_; -}; - -class CoProcSyncInserter : public StmtMutator { - public: - Stmt Insert(Stmt stmt) { - CoProcTouchedBuffer visitor; - visitor(stmt); - if (visitor.coproc_.size() == 0) return stmt; - std::unordered_set touched; - - for (const auto& kv : visitor.touched_) { - if (kv.second.normal && kv.second.coproc) { - touched.insert(kv.first); - } - } - ICHECK_EQ(visitor.coproc_.size(), 1U); - std::string coproc_name = (*visitor.coproc_.begin())->var->name_hint; - // plan sync. - CoProcSyncPlanner sync_planner(touched, coproc_name); - sync_planner.Plan(stmt); - for (const auto& kv : sync_planner.sync_) { - auto& vec = insert_after_[kv.first]; - vec.insert(vec.end(), kv.second.begin(), kv.second.end()); - } - // Detect barrier - CoProcBarrierDetector barrier_detector(touched, coproc_name); - barrier_detector.PlanReadBarrier(stmt); - barrier_detector.PlanWriteBarrier(stmt); - for (const auto& kv : barrier_detector.barrier_before_) { - auto& vec = insert_before_[kv.first]; - vec.insert(vec.end(), kv.second.begin(), kv.second.end()); - } - for (const auto& kv : barrier_detector.barrier_after_) { - auto& vec = insert_after_[kv.first]; - vec.insert(vec.end(), kv.second.begin(), kv.second.end()); - } - // Detect barrier - CoProcInstDepDetector sync_detector(*visitor.coproc_.begin(), coproc_name); - sync_detector.Plan(stmt); - for (const auto& kv : sync_detector.insert_before_) { - auto& vec = insert_before_[kv.first]; - vec.insert(vec.end(), kv.second.begin(), kv.second.end()); - } - for (const auto& kv : sync_detector.insert_after_) { - auto& vec = insert_after_[kv.first]; - vec.insert(vec.end(), kv.second.begin(), kv.second.end()); - } - return operator()(std::move(stmt)); - } - - Stmt VisitStmt(const Stmt& stmt) final { - auto it_before = insert_before_.find(stmt.get()); - auto it_after = insert_after_.find(stmt.get()); - Stmt new_stmt = StmtMutator::VisitStmt(stmt); - - return SeqStmt::Flatten( - it_before != insert_before_.end() ? it_before->second : std::vector(), new_stmt, - it_after != insert_after_.end() ? it_after->second : std::vector()); - } - - private: - // insert before is stored in reverse order - // the first element is closest to the node. - std::unordered_map> insert_before_; - std::unordered_map> insert_after_; -}; - -Stmt CoProcSync(Stmt stmt) { return CoProcSyncInserter().Insert(std::move(stmt)); } - -namespace transform { - -Pass CoProcSync() { - auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - auto* n = f.CopyOnWrite(); - n->body = CoProcSyncInserter().Insert(std::move(n->body)); - return f; - }; - return CreatePrimFuncPass(pass_func, 0, "tir.CoProcSync", {}); -} - -TVM_REGISTER_GLOBAL("tir.transform.CoProcSync").set_body_typed(CoProcSync); - -} // namespace transform - -} // namespace tir -} // namespace tvm diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc deleted file mode 100644 index f7b14f49977f..000000000000 --- a/src/tir/transforms/inject_copy_intrin.cc +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \brief Replace certain copy with copy intrinsics. - * \file copy_intrin_rewrite.cc - */ -#include -#include -#include -#include -#include -#include - -#include "../../arith/pattern_match.h" -#include "ir_utils.h" - -namespace tvm { -namespace tir { - -using runtime::PackedFunc; - -class CopyIntrinInjector : public StmtMutator { - public: - CopyIntrinInjector(const std::string& pragma_key, const PackedFunc& flower_copy_fromto) - : pragma_key_(attr::pragma_scope_prefix + pragma_key), - flower_copy_fromto_(flower_copy_fromto) {} - - Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == pragma_key_) { - Stmt ret; - std::string error_info; - ICHECK(MatchCopyPattern(op->body, &ret, &error_info)) - << "Cannot match copy pattern. The error is " << error_info << " The body is " - << op->body; - return ret; - } - return StmtMutator::VisitStmt_(op); - } - - private: - bool MatchCopyPattern(Stmt stmt, Stmt* out, std::string* error_info) { - using namespace arith; - Stmt body = stmt; - - // strip the loops - std::vector loops; - while (const ForNode* op = body.as()) { - if (!is_zero(op->min)) { - *error_info = "the 'min' value of body 'Fonode' is 0."; - return false; - } - loops.push_back(op); - body = op->body; - } - auto store = body.as(); - if (store == nullptr) { - *error_info = "the body is not a 'BufferStoreNode'"; - return false; - } - // Expr sel_cond, sel_true_value, sel_false_value; - // match select or if - PVar sel_cond, sel_true_value, sel_false_value; - bool has_cond = if_then_else(sel_cond, sel_true_value, sel_false_value).Match(store->value) || - select(sel_cond, sel_true_value, sel_false_value).Match(store->value); - - const CastNode* cast = store->value.as(); - auto load = store->value.as(); - if (0 == loops.size()) { - ICHECK(!has_cond); - } - // for now only support true condition matching - if (has_cond) { - load = sel_true_value.Eval().as(); - } - // cast can be part of the pattern - if (cast != nullptr) { - load = cast->value.as(); - } - if (load == nullptr) { - *error_info = "the 'BufferLoadNode' of body is a nullptr."; - return false; - } - if (load->dtype.lanes() != 1) return false; - Array loop_vars; - for (const ForNode* op : loops) { - loop_vars.push_back(op->loop_var); - } - // TODO(Lunderberg): Move this pass to be before - // StorageFlatten/FlattenBuffer. That will simplify the - // implementation, since the pre-flattened indices/strides can be - // used directly. - ICHECK((store->indices.size() == 1) && (load->indices.size() == 1)) - << "InjectDoubleBuffer expects flat 1-d buffers. " - << "Has StorageFlatten (TE-based schedules) or " - << "FlattenBuffer (TIR-based schedules) been run?"; - - Array store_strides = arith::DetectLinearEquation(store->indices[0], loop_vars); - Array load_strides = arith::DetectLinearEquation(load->indices[0], loop_vars); - if (load_strides.size() == 0 || store_strides.size() == 0) return false; - Array dst_shape; - const size_t loop_var_size = loop_vars.size(); - if (loop_var_size == 0) { - dst_shape.push_back(make_const(DataType::Int(32), 1)); - } else { - for (const ForNode* op : loops) { - dst_shape.push_back(op->extent); - } - } - Array src_shape = dst_shape; - Array pad_before, pad_after; - PrimExpr pad_value; - PrimExpr src_elem_offset = load_strides[loop_var_size]; - if (has_cond) { - Array clip_bound = arith::DetectClipBound(sel_cond.Eval(), loop_vars); - pad_value = sel_false_value.Eval(); - if (clip_bound.size() == 0) { - *error_info = "the size of clip bound is 0."; - return false; - } - ICHECK_EQ(src_shape.size(), loop_vars.size()); - ICHECK_EQ(clip_bound.size(), loop_vars.size() * 2); - for (size_t i = 0; i < src_shape.size(); ++i) { - PrimExpr min_value = clip_bound[2 * i]; - PrimExpr max_value = clip_bound[2 * i + 1]; - DataType t = loop_vars[i].dtype(); - PrimExpr svalue = src_shape[i]; - if (min_value.defined()) { - PrimExpr pbefore = analyzer_.Simplify(Max(min_value, make_zero(t))); - src_elem_offset = src_elem_offset + pbefore * load_strides[i]; - svalue = svalue - pbefore; - pad_before.push_back(pbefore); - } else { - pad_before.push_back(make_zero(t)); - } - if (max_value.defined()) { - PrimExpr pafter = analyzer_.Simplify( - max(loops[i]->extent - max_value - make_const(t, 1), make_zero(t))); - svalue = svalue - pafter; - pad_after.push_back(pafter); - } else { - pad_after.push_back(make_zero(t)); - } - src_shape.Set(i, analyzer_.Simplify(svalue)); - } - src_elem_offset = analyzer_.Simplify(src_elem_offset); - } - ICHECK_EQ(load_strides.size(), store_strides.size()); - ICHECK_EQ(load_strides.size(), loop_var_size + 1); - Array src_strides(load_strides.begin(), load_strides.begin() + loop_var_size); - Array dst_strides(store_strides.begin(), store_strides.begin() + loop_var_size); - if (loop_var_size == 0) { - src_strides.push_back(make_const(DataType::Int(32), 1)); - dst_strides.push_back(make_const(DataType::Int(32), 1)); - } - Buffer dst = store->buffer; - { - auto writer = dst.CopyOnWrite(); - writer->shape = dst_shape; - writer->strides = dst_strides; - writer->elem_offset = store_strides[loop_var_size]; - } - - Buffer src = load->buffer; - { - auto writer = src.CopyOnWrite(); - writer->shape = src_shape; - writer->strides = src_strides; - writer->elem_offset = src_elem_offset; - } - *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); - if (!out->defined()) { - *error_info = "flower function did not return correct stmt"; - return false; - } - return true; - } - - // pragma key - std::string pragma_key_; - // function to lower copy intrinsics. - const PackedFunc& flower_copy_fromto_; - // arith analyzer - arith::Analyzer analyzer_; -}; - -Stmt InjectCopyIntrin(Stmt stmt, const std::string& pragma_key, - const PackedFunc& flower_copy_fromto) { - return CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(stmt)); -} - -namespace transform { - -Pass InjectCopyIntrin(String pragma_key, PackedFunc flower_copy_fromto) { - auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - auto* n = f.CopyOnWrite(); - n->body = CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(n->body)); - return f; - }; - return CreatePrimFuncPass(pass_func, 0, "tir.InjectCopyIntrin", {}); -} - -TVM_REGISTER_GLOBAL("tir.transform.InjectCopyIntrin").set_body_typed(InjectCopyIntrin); - -} // namespace transform - -} // namespace tir -} // namespace tvm diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 4e2e79db26da..52e4d44b615a 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -110,8 +110,7 @@ class DoubleBufferInjector : public StmtExprMutator { entry.scope = GetPtrStorageScope(op->buffer_var); ICHECK_EQ(op->extents.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " - << "Has StorageFlatten (TE-based schedules) or " - << "FlattenBuffer (TIR-based schedules) been run?"; + << "Has FlattenBuffer been run?"; entry.stride = op->extents[0]; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); @@ -188,8 +187,7 @@ class DoubleBufferInjector : public StmtExprMutator { ICHECK(e.switch_write_var.defined()); ICHECK_EQ(node->indices.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " - << "Has StorageFlatten (TE-based schedules) or " - << "FlattenBuffer (TIR-based schedules) been run?"; + << "Has FlattenBuffer been run?"; auto writer = node.CopyOnWrite(); writer->buffer = GetRemappedBuffer(node->buffer, e.stride); @@ -208,8 +206,7 @@ class DoubleBufferInjector : public StmtExprMutator { ICHECK(e.switch_read_var.defined()); ICHECK_EQ(node->indices.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " - << "Has StorageFlatten (TE-based schedules) or " - << "FlattenBuffer (TIR-based schedules) been run?"; + << "Has FlattenBuffer been run?"; auto writer = node.CopyOnWrite(); writer->buffer = GetRemappedBuffer(node->buffer, e.stride); @@ -228,12 +225,11 @@ class DoubleBufferInjector : public StmtExprMutator { ICHECK(stride.defined()); // TODO(Lunderberg): Move this pass to before - // StorageFlatten/FlattenBuffer. That will simplify the + // FlattenBuffer. That will simplify the // implementation, to be the insertion of a new dimension for the // buffer, rather than adjusting the other indices. ICHECK_EQ(buf->shape.size(), 1) << "InjectDoubleBuffer expects flat 1-d buffers. " - << "Has StorageFlatten (TE-based schedules) or " - << "FlattenBuffer (TIR-based schedules) been run?"; + << "Has FlattenBuffer been run?"; // Stride gives the distance between the two halves of the // double-buffer, not the stride of the buffer's index. diff --git a/src/tir/transforms/inject_prefetch.cc b/src/tir/transforms/inject_prefetch.cc deleted file mode 100644 index f20577e3a01b..000000000000 --- a/src/tir/transforms/inject_prefetch.cc +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file inject_prefetch.cc - */ -// Inject prefetch op in HalideIR -#include -#include -#include -#include -#include -#include -#include - -#include - -#include "ir_utils.h" - -namespace tvm { -namespace tir { - -using arith::DomainTouched; -using arith::IntSet; - -class PrefetchInjector : public StmtMutator { - public: - Stmt VisitStmt_(const AttrStmtNode* op) final { - Stmt ret = StmtMutator::VisitStmt_(op); - op = ret.as(); - if (op && op->attr_key == attr::prefetch_scope) { - Buffer buffer = Downcast(op->node); - ICHECK_NE(loop_nest_.size(), 0U); - Region domain = DomainTouched(op->body, buffer, true, false); - Region region; - - auto iter_var = loop_nest_.back().get(); - vectorized_[iter_var] = IntSet::SinglePoint(loop_nest_.back() + op->value); - - for (Range r : domain) { - if (!r.defined()) { - LOG(WARNING) << "Cannot decide prefetch region for " << buffer; - return op->body; - } - Range res(EvalSet(r, vectorized_).CoverRange(none)); - region.push_back(Range::FromMinExtent(res->min, res->extent)); - } - - vectorized_.erase(iter_var); - - Stmt prefetch = Prefetch(buffer, region); - return SeqStmt({prefetch, op->body}); - } - return ret; - } - - Stmt VisitStmt_(const ForNode* op) final { - auto& var = op->loop_var; - loop_nest_.push_back(var); - if (op->kind == ForKind::kVectorized) { - vectorized_[var.get()] = IntSet::Interval(op->min, (op->min + op->extent) - 1); - } - Stmt ret = StmtMutator::VisitStmt_(op); - if (op->kind == ForKind::kVectorized) { - vectorized_.erase(var.get()); - } - loop_nest_.pop_back(); - return ret; - } - - private: - std::vector loop_nest_; - std::unordered_map vectorized_; - static const Range none; -}; - -const Range PrefetchInjector::none; - -Stmt InjectPrefetch(Stmt stmt) { return PrefetchInjector()(std::move(stmt)); } - -namespace transform { - -Pass InjectPrefetch() { - auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - // Only apply this pass to TIR from TE schedules - if (IsFromLegacyTESchedule(f)) { - auto* n = f.CopyOnWrite(); - n->body = PrefetchInjector()(std::move(n->body)); - return f; - } else { - return f; - } - }; - return CreatePrimFuncPass(pass_func, 0, "tir.InjectPrefetch", {}); -} - -TVM_REGISTER_GLOBAL("tir.transform.InjectPrefetch").set_body_typed(InjectPrefetch); - -} // namespace transform - -} // namespace tir -} // namespace tvm diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 6fad8b378cd8..d9fc74f8ad18 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -399,7 +399,7 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { // place v on highest dimension. // TODO(Lunderberg): Move pass to apply before - // StorageFlatten/FlattenBuffer. Would rewrite the Buffer to + // FlattenBuffer. Would rewrite the Buffer to // add the injected virtual thread as the first index. ICHECK_EQ(extents.size(), 1) << "InjectVirtualThread expects rewritten allocations to be flat memory."; @@ -507,9 +507,7 @@ class VirtualThreadInjector : public arith::IRMutatorWithAnalyzer { } } - Stmt VisitStmt_(const ProducerStoreNode* op) final { - LOG(FATAL) << "Need to call StorageFlatten first"; - } + Stmt VisitStmt_(const ProducerStoreNode* op) final { LOG(FATAL) << "Should not appear in TIR"; } }; namespace transform { diff --git a/src/tir/transforms/lift_attr_scope.cc b/src/tir/transforms/lift_attr_scope.cc deleted file mode 100644 index b340a94937f3..000000000000 --- a/src/tir/transforms/lift_attr_scope.cc +++ /dev/null @@ -1,203 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * - * \brief Lift specified AttrStmt scope to outer if - * the body contains the same scope. - * \file lift_attr_scope.cc - */ -#include -#include -#include - -#include "ir_utils.h" - -namespace tvm { -namespace tir { - -// NOTE: this optimization can only be applied -// to a few specified attr keys -class AttrScopeLifter : public StmtMutator { - public: - explicit AttrScopeLifter(std::string attr_key) : attr_key_(attr_key) {} - - Stmt Lift(Stmt stmt) { - stmt = operator()(std::move(stmt)); - if (attr_node_.defined()) { - stmt = AttrStmt(attr_node_, attr_key_, attr_value_, stmt); - } - return stmt; - } - - // do not go beyond - Stmt VisitStmt_(const AllocateNode* op) final { - Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); - if (attr_node_.defined()) { - Stmt body = AttrStmt(attr_node_, attr_key_, attr_value_, op->body); - // undefine them - attr_node_ = ObjectRef(); - attr_value_ = PrimExpr(); - return Allocate(op->buffer_var, op->dtype, op->extents, op->condition, body); - } else { - return stmt; - } - } - - Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr_key_) { - attr_node_ = op->node; - attr_value_ = op->value; - return op->body; - } else { - return StmtMutator::VisitStmt_(op); - } - } - - Stmt VisitStmt_(const SeqStmtNode* op) final { - // remember the decorations. - std::vector attr_node; - std::vector attr_value; - - auto fmutate = [&](const Stmt& s) { - attr_node_ = ObjectRef(); - attr_value_ = PrimExpr(); - Stmt ret = this->VisitStmt(s); - attr_node.push_back(attr_node_); - attr_value.push_back(attr_value_); - return ret; - }; - Stmt ret = StmtMutator::VisitSeqStmt_(op, true, fmutate); - if (attr_node.size() == 0) return ret; - - op = ret.as(); - ICHECK(op != nullptr); - Array reorg; - // check if all decorations are common. - for (size_t begin = 0; begin < attr_node.size();) { - size_t end = begin + 1; - while (end < attr_node.size() && attr_node[end].same_as(attr_node[begin]) && - ValueSame(attr_value[end], attr_value[begin])) { - ++end; - } - // covers everything - // lift attr to parent. - if (begin == 0 && end == attr_node.size()) { - attr_node_ = attr_node[0]; - attr_value_ = attr_value[0]; - return ret; - } - // construct subsegments. - Array seq; - for (size_t i = begin; i < end; ++i) { - seq.push_back(op->seq[i]); - } - Stmt stmt = SeqStmt::Flatten(seq); - if (attr_node[begin].defined()) { - stmt = AttrStmt(attr_node[begin], attr_key_, attr_value[begin], stmt); - } - reorg.push_back(stmt); - begin = end; - } - attr_node_ = ObjectRef(); - attr_value_ = PrimExpr(); - return SeqStmt::Flatten(reorg); - } - - Stmt VisitStmt_(const IfThenElseNode* op) final { - if (!op->else_case) { - return StmtMutator::VisitStmt_(op); - } - Stmt then_case = this->VisitStmt(op->then_case); - ObjectRef first_node; - PrimExpr first_value; - std::swap(first_node, attr_node_); - std::swap(first_value, attr_value_); - Stmt else_case = this->VisitStmt(op->else_case.value()); - if (attr_node_.defined() && attr_value_.defined() && first_node.defined() && - first_value.defined() && attr_node_.same_as(first_node) && - ValueSame(attr_value_, first_value)) { - if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { - return GetRef(op); - } else { - return IfThenElse(op->condition, then_case, else_case); - } - } else { - if (first_node.defined()) { - then_case = AttrStmt(first_node, attr_key_, first_value, then_case); - } - if (attr_node_.defined()) { - else_case = AttrStmt(attr_node_, attr_key_, attr_value_, else_case); - // undefine them - attr_node_ = ObjectRef(); - attr_value_ = PrimExpr(); - } - if (then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { - return GetRef(op); - } else { - return IfThenElse(op->condition, then_case, else_case); - } - } - } - - Stmt VisitStmt_(const WhileNode* op) final { - // TODO(masahi): Do we need a special handling for While nodes? - LOG(FATAL) << "WhileNode not supported in LiftAttrScope."; - } - - private: - // value comparison that also compares content of int constant - static bool ValueSame(const PrimExpr& a, const PrimExpr& b) { - if (a.same_as(b)) return true; - if (!a.defined() || !b.defined()) return false; - if (a->type_index() != b->type_index()) return false; - if (a.dtype() != b.dtype()) return false; - if (const IntImmNode* op = a.as()) { - return op->value == b.as()->value; - } - return false; - } - - std::string attr_key_; - ObjectRef attr_node_; - PrimExpr attr_value_; -}; - -Stmt LiftAttrScope(Stmt stmt, std::string attr_key) { - return AttrScopeLifter(attr_key).Lift(std::move(stmt)); -} - -namespace transform { - -Pass LiftAttrScope(String attr_key) { - auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - auto* n = f.CopyOnWrite(); - n->body = AttrScopeLifter(attr_key).Lift(std::move(n->body)); - return f; - }; - return CreatePrimFuncPass(pass_func, 0, "tir.LiftAttrScope", {}); -} - -TVM_REGISTER_GLOBAL("tir.transform.LiftAttrScope").set_body_typed(LiftAttrScope); - -} // namespace transform - -} // namespace tir -} // namespace tvm diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 870235954689..4a364c0ecb8b 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -132,8 +132,7 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { } ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory. " - << "Has StorageFlatten (TE-based schedule) or " - << "FlattenBuffer (TIR-based schedules) been run?"; + << "Has FlattenBuffer been run?"; PrimExpr index = op->indices[0]; if (op->value.dtype().lanes() != 1) { @@ -294,8 +293,7 @@ class WarpAccessRewriter : protected StmtExprMutator { if (store->buffer->data.get() == buffer_) { ICHECK_EQ(store->indices.size(), 1) << "Expected flat memory to use as warp memory. " - << "Has StorageFlatten (TE-based schedule) or " - << "FlattenBuffer (TIR-based schedules) been run?"; + << "Has FlattenBuffer been run?"; auto [local_index, group] = SplitIndexByGroup(store->indices[0]); (void)group; // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81767 @@ -315,8 +313,7 @@ class WarpAccessRewriter : protected StmtExprMutator { } ICHECK_EQ(op->indices.size(), 1) << "Expected flat memory to use as warp memory. " - << "Has StorageFlatten (TE-based schedule) or " - << "FlattenBuffer (TIR-based schedules) been run?"; + << "Has FlattenBuffer been run?"; auto [local_index, group] = SplitIndexByGroup(op->indices[0]); // invariance: local index must do not contain warp id diff --git a/src/tir/transforms/merge_shared_memory_allocations.cc b/src/tir/transforms/merge_shared_memory_allocations.cc index bd9ff371517f..85f102cb4177 100644 --- a/src/tir/transforms/merge_shared_memory_allocations.cc +++ b/src/tir/transforms/merge_shared_memory_allocations.cc @@ -350,7 +350,7 @@ class SharedMemoryRewriter : public StmtExprMutator { ICHECK_EQ(node->indices.size(), 1) << "MergeSharedMemoryAllocations expects flat memory buffers, " << "and is to be run after " - << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)"; + << "FlattenBuffer"; Array indices = {node->indices[0] + this->GetBufferOffset(node->buffer->data, node->buffer->dtype)}; @@ -374,7 +374,7 @@ class SharedMemoryRewriter : public StmtExprMutator { << "Buffer " << buffer << " has shape " << buffer->shape << ". " << "MergeSharedMemoryAllocations expects flat memory buffers, " << "and is to be run after " - << "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)"; + << "FlattenBuffer"; auto writer = buffer.CopyOnWrite(); writer->data = merged_buf_var_; } diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc deleted file mode 100644 index 87c6d48639f0..000000000000 --- a/src/tir/transforms/storage_flatten.cc +++ /dev/null @@ -1,1947 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file storage_flatten.cc - * \brief Flattens storage from multi-dimensional array to 1D buffer access - */ -// The pass definition originates from Halide pipeline. - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "../../arith/ir_visitor_with_analyzer.h" -#include "../../runtime/thread_storage_scope.h" -#include "arg_binder.h" -#include "ir_utils.h" - -namespace tvm { -namespace tir { - -using arith::IRVisitorWithAnalyzer; -using runtime::StorageRank; -using runtime::StorageScope; -using runtime::ThreadScope; - -/* Make buffer realize extents and buffer shapes consistent - * - * For external buffers, verify that the extents of BufferRealize - * nodes match the shape of the external buffer. For internal - * buffers, rewrite the shape of the Buffer objects to match the - * extent of the BufferRealize, and rewrite indices of - * BufferLoad/BufferStore nodes to match. - */ -class BufferShapeLegalize : public StmtExprMutator { - public: - static transform::Pass Pass() { - auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { - IRVisitorWithAnalyzer bound_analyzer; - - bound_analyzer(func->body); - - auto pass = BufferShapeLegalize(func->buffer_map, &bound_analyzer); - - auto fptr = func.CopyOnWrite(); - fptr->body = pass(std::move(fptr->body)); - if (auto map = func->attrs.GetAttr>>("layout_transform_map")) { - func = WithAttr(std::move(func), "layout_transform_map", pass.UpdateIndexMap(map.value())); - } - return func; - }; - return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferShapeLegalize", {}); - } - - explicit BufferShapeLegalize(const Map& extern_buffer_map, - IRVisitorWithAnalyzer* bound_analyzer) - : bound_analyzer_(bound_analyzer) { - for (auto kv : extern_buffer_map) { - Buffer buf = kv.second; - extern_buffers_.insert(buf); - - BufferEntry remap; - remap.remap_to = buf; - remap.index_offsets = Array(buf->shape.size(), 0); - remap.in_scope = true; - buf_map_[buf] = remap; - } - } - - Map> UpdateIndexMap(const Map>& orig) { - Map> output; - for (const auto& kv : orig) { - auto it = buf_map_.find(kv.first); - if (it != buf_map_.end()) { - output.Set(it->second.remap_to, kv.second); - } else { - output.Set(kv.first, kv.second); - } - } - return output; - } - - PrimExpr VisitExpr_(const VarNode* op) final { - auto it = var_remap_.find(op); - if (it != var_remap_.end()) { - return it->second; - } else { - return GetRef(op); - } - } - - Stmt VisitStmt_(const BufferRealizeNode* op) final { - // BufferRealizeNode for an external buffer serves as an - // annotation of the external buffers, and should not be changed. - // Instead, verify that the bounds match the external - // buffer. - if (extern_buffers_.count(op->buffer)) { - CHECK_EQ(op->buffer->shape.size(), op->bounds.size()) - << "External buffer realize has mismatched dimension"; - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - ICHECK(op); - - for (size_t i = 0; i < op->bounds.size(); i++) { - PrimExpr eq = bound_analyzer_->Simplify(op->buffer->shape[i] == op->bounds[i]->extent); - std::ostringstream ss; - ss << "Dim " << i << " of external buffer " << op->buffer->name << " has shape " - << op->buffer->shape[i] << ", but is only realized for extent " << op->bounds[i]->extent; - if (auto eq_int = eq.as()) { - ICHECK(eq_int->value) << ss.str(); - } else { - stmt = AssertStmt(eq, tvm::tir::StringImm(ss.str()), stmt); - } - } - return stmt; - } - - // Compute the new buffer shape, new realization bounds, and the - // offsets to be applied to buffer access. - Array realized_shape; - Array index_offsets; - Array new_bounds; - for (size_t i = 0; i < op->bounds.size(); i++) { - const Range& bound = op->bounds[i]; - realized_shape.push_back(bound->extent); - index_offsets.push_back(bound->min); - new_bounds.push_back({0, bound->extent}); - } - - if (op->buffer->shape.size()) { - ICHECK_EQ(op->buffer->shape.size(), realized_shape.size()) - << "Inconsistency between dimension of buffer " << op->buffer - << " and dimension of its realized bounds."; - } - - Buffer key = op->buffer; - - Buffer buf = op->buffer; - auto write_ptr = buf.CopyOnWrite(); - write_ptr->shape = realized_shape; - - { - BufferEntry remap; - remap.remap_to = buf; - remap.index_offsets = index_offsets; - remap.in_scope = true; - buf_map_[key] = remap; - } - - Stmt stmt = BufferRealize(buf, new_bounds, op->condition, this->VisitStmt(op->body), op->span); - - buf_map_.at(key).in_scope = false; - - return stmt; - } - - Stmt VisitStmt_(const BufferStoreNode* op) final { - auto node = Downcast(StmtExprMutator::VisitStmt_(op)); - return VisitBufferAccess(std::move(node)); - } - - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - auto node = Downcast(StmtExprMutator::VisitExpr_(op)); - return VisitBufferAccess(std::move(node)); - } - - template - Node VisitBufferAccess(Node node) { - auto it = buf_map_.find(node->buffer); - if (it != buf_map_.end()) { - const BufferEntry& entry = it->second; - ICHECK(entry.in_scope) << "Cannot access an out-of-scope buffer"; - - Array indices = node->indices; - if (entry.index_offsets.size()) { - ICHECK_GE(entry.index_offsets.size(), indices.size()) - << "Cannot bind buffer to a shape of lower dimension."; - - Array new_indices; - - // Pad leading indices with zero, matching the "fuzzy_match" - // behavior from ArgBinder::BindBuffer. - size_t diff = entry.index_offsets.size() - indices.size(); - for (size_t i = 0; i < diff; i++) { - new_indices.push_back(0); - } - - // Offset indices used to access buffers of a reduced size. - for (size_t i = 0; i < indices.size(); i++) { - PrimExpr offset = entry.index_offsets[i + diff]; - new_indices.push_back(indices[i] - offset); - } - indices = new_indices; - } - - auto write_ptr = node.CopyOnWrite(); - write_ptr->indices = indices; - write_ptr->buffer = entry.remap_to; - } - return node; - } - - Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->node->IsInstance()) { - // Visit body before checking internal_buf_map_, because we - // don't know if the BufferNode needs to be changed until we - // look in the body for a BufferRealizeNode with different - // extents. - Stmt body = this->VisitStmt(op->body); - - Buffer buffer = Downcast(op->node); - auto it = buf_map_.find(buffer); - if (it != buf_map_.end()) { - buffer = it->second.remap_to; - return AttrStmt(it->second.remap_to, op->attr_key, op->value, body); - } - return AttrStmt(buffer, op->attr_key, op->value, body); - - } else if (op->attr_key == attr::buffer_bind_scope) { - return HandleBufferBindScope(op); - } - - return StmtExprMutator::VisitStmt_(op); - } - - private: - // Any buffers that give views into a resized buffer should be - // updated, both to refer to the resized buffer and to have the view - // window updated. For example, suppose B1 is a 1-D buffer of size - // 100 which is only realized on the range (10,50), and buffer V1 is - // a view into B1[25:35]. When B1 is replaced with B2, a buffer of - // size 40 realized on the range (0,40), V1 must be replaced to be a - // view into B2[15:25]. - Stmt HandleBufferBindScope(const AttrStmtNode* op) { - Array arr = Downcast>(op->node); - ICHECK_EQ(arr.size(), 2U); - Buffer buffer = Downcast(arr[0]); - ICHECK(buffer.defined()); - Buffer target = Downcast(arr[1]); - ICHECK(target.defined()); - - auto it = buf_map_.find(target); - ICHECK(it != buf_map_.end()) << "attr::buffer_bind_scope target " << target << " not in scope."; - const BufferEntry& target_remap = it->second; - - ICHECK(target_remap.in_scope) << "Cannot bind " << buffer->name - << " to the out-of-scope buffer " << target_remap.remap_to->name; - - Call tuple = Downcast(op->value); - ICHECK(tuple.defined() && tuple->op.same_as(builtin::tvm_tuple())); - - Array new_tuple_args; - Array realized_begins; - Array view_shape; - ICHECK_EQ(tuple->args.size(), target_remap.index_offsets.size() * 2) - << "attr::buffer_bind_scope to define " << buffer << " as a view into " << target - << " does match dimensionality of " << target; - for (size_t i = 0; i < target_remap.index_offsets.size(); i++) { - PrimExpr parent_begin = tuple->args[2 * i]; - PrimExpr view_extent = tuple->args[2 * i + 1]; - // Offset the begin of the buffer view by the offset of the target buffer. - new_tuple_args.push_back(parent_begin - target_remap.index_offsets[i]); - // Keep the extent of the buffer view the same. - new_tuple_args.push_back(view_extent); - // Use the extent of the buffer view to define the buffer view's shape. - view_shape.push_back(view_extent); - // Within the buffer view, indices start at 0. - realized_begins.push_back(0); - } - - // If a view is binding to a buffer of a higher dimensionality, - // then the leading dimensions should be padded out with shape of - // 1. - ICHECK_GE(view_shape.size(), buffer->shape.size()) - << "Cannot bind " << buffer << " to a shape of lower dimension."; - if (view_shape.size() > buffer->shape.size()) { - size_t diff = view_shape.size() - buffer->shape.size(); - Array padded_shape; - for (size_t i = 0; i < diff; i++) { - padded_shape.push_back(1); - } - for (auto dim : buffer->shape) { - padded_shape.push_back(dim); - } - view_shape = std::move(padded_shape); - } - - // If a buffer has strides defined, and is being remapped into a - // shape with additional dimensions, then define dummy values for - // the strides. - Array realized_strides = buffer->strides; - if ((realized_strides.size() > 0) && (realized_strides.size() != view_shape.size())) { - ICHECK_GE(view_shape.size(), realized_strides.size()) - << "Cannot bind the strides of " << buffer << " to a shape of lower dimension"; - size_t diff = view_shape.size() - buffer->strides.size(); - - Array updated_strides; - for (size_t i = 0; i < diff; i++) { - updated_strides.push_back(Var("stride", buffer->shape[0].dtype())); - } - for (auto stride : buffer->strides) { - updated_strides.push_back(stride); - } - realized_strides = updated_strides; - } - - Buffer key = buffer; - - auto write_ptr = buffer.CopyOnWrite(); - write_ptr->shape = view_shape; - write_ptr->strides = realized_strides; - - { - BufferEntry remap; - remap.index_offsets = realized_begins; - remap.remap_to = buffer; - remap.in_scope = true; - buf_map_[key] = remap; - } - - // Define remappings of any Variables referencing Buffer internals - // (e.g. Store/Load nodes). Passing fuzzy_match=true allows the - // remapped buffer to have a number of dimensions. - ArgBinder binder(&var_remap_); - binder.BindBuffer(key, buffer, key->name, true); - - Stmt body = this->VisitStmt(op->body); - body = MergeNest(binder.asserts(), body); - body = MergeNest(binder.init_nest(), body); - - Stmt stmt = AttrStmt(Array{buffer, target_remap.remap_to}, op->attr_key, - Call(tuple->dtype, tuple->op, new_tuple_args, tuple->span), body); - - for (const Var& v : binder.defs()) { - var_remap_.erase(v.get()); - } - - buf_map_.at(key).in_scope = false; - return stmt; - } - - std::unordered_map var_remap_; - - std::unordered_set extern_buffers_; - - struct BufferEntry { - Buffer remap_to; - Array index_offsets; - bool in_scope; - }; - - std::unordered_map buf_map_; - - IRVisitorWithAnalyzer* bound_analyzer_; -}; - -/* Apply dimension alignment restrictions - * - * Buffers annotated with attr::buffer_dim_align may need to have - * strides defined such that they are no longer in a compact shape. - * After this pass, buffers have stride definitions to include these - * alignment restrictions, and attr::buffer_dim_align annotations have - * been removed. - */ -class BufferStrideLegalize : public StmtExprMutator { - public: - static transform::Pass Pass() { - auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { - IRVisitorWithAnalyzer bound_analyzer; - - bound_analyzer(func->body); - - auto pass = BufferStrideLegalize(func->buffer_map, &bound_analyzer); - - auto fptr = func.CopyOnWrite(); - fptr->body = pass(std::move(fptr->body)); - fptr->buffer_map = pass.UpdatedExternBufferMap(); - if (auto map = func->attrs.GetAttr>>("layout_transform_map")) { - func = WithAttr(std::move(func), "layout_transform_map", pass.UpdateIndexMap(map.value())); - } - return func; - }; - return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferStrideLegalize", {}); - } - - explicit BufferStrideLegalize(const Map& extern_buffer_map, - IRVisitorWithAnalyzer* bound_analyzer) - : bound_analyzer_(bound_analyzer) { - for (auto kv : extern_buffer_map) { - Buffer buf = kv.second; - Buffer with_strides = WithStrides(buf); - { - BufferEntry entry; - entry.remap_to = with_strides; - entry.in_scope = true; - buf_map_[buf] = entry; - } - updated_extern_buffer_map_.Set(kv.first, with_strides); - } - } - - Map> UpdateIndexMap(const Map>& orig) { - Map> output; - for (const auto& kv : orig) { - auto it = buf_map_.find(kv.first); - if (it != buf_map_.end()) { - output.Set(it->second.remap_to, kv.second); - } else { - output.Set(kv.first, kv.second); - } - } - return output; - } - - Map UpdatedExternBufferMap() const { return updated_extern_buffer_map_; } - - Buffer WithStrides(Buffer buf) { - auto cache_key = buf; - - auto it = buf_map_.find(cache_key); - if (it != buf_map_.end()) { - const BufferEntry& entry = it->second; - ICHECK(entry.in_scope) << "Cannot annotate an out-of-scope buffer"; - return entry.remap_to; - } - - Array shape = buf->shape; - - if (buf->strides.size()) { - ICHECK_EQ(buf->strides.size(), buf->shape.size()) - << "Buffer " << buf << " has inconsistent strides/shape."; - } else if (dim_align_.count(buf) == 0) { - // Keeping this to have matched behavior to previous version. - // There are many parts of the codebase that assume that a - // strided array cannot be compact. For example, - // ArgBinder::BindBuffer and tir.Specialize. To avoid breaking - // these, do not define the strides unless required for a - // non-compact array. - } else if (shape.size() == 0) { - // Can't define the strides for a buffer without a known shape. - } else { - // With everything checked, can now define the updated strides - std::vector rstrides; - const std::vector& avec = dim_align_[buf]; - int first_dim = 0; - PrimExpr stride = make_const(shape[first_dim].dtype(), 1); - for (size_t i = shape.size(); i != 0; --i) { - size_t dim = i - 1; - if (dim < avec.size() && avec[dim].align_factor != 0) { - PrimExpr factor = make_const(stride.dtype(), avec[dim].align_factor); - PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); - stride = stride + indexmod(factor + offset - indexmod(stride, factor), factor); - stride = bound_analyzer_->Simplify(stride); - } - rstrides.push_back(stride); - stride = stride * shape[dim]; - } - - buf.CopyOnWrite()->strides = Array(rstrides.rbegin(), rstrides.rend()); - } - - BufferEntry entry; - entry.remap_to = buf; - entry.in_scope = true; - buf_map_[cache_key] = entry; - - return buf; - } - - Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::buffer_dim_align) { - auto buffer = Downcast(op->node); - const CallNode* tuple = op->value.as(); - ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); - auto& vinfo = dim_align_[buffer]; - int dim = tuple->args[0].as()->value; - if (static_cast(dim) >= vinfo.size()) { - vinfo.resize(dim + 1); - } - vinfo[dim].align_factor = tuple->args[1].as()->value; - vinfo[dim].align_offset = tuple->args[2].as()->value; - - return this->VisitStmt(op->body); - } else if (op->attr_key == attr::buffer_bind_scope) { - Array arr = Downcast>(op->node); - ICHECK_EQ(arr.size(), 2U); - Buffer source = Downcast(arr[0]); - Buffer target_with_strides = WithStrides(Downcast(arr[1])); - Buffer source_with_strides = WithStrides(source); - - Stmt body = this->VisitStmt(op->body); - - buf_map_[source].in_scope = false; - - return AttrStmt(Array{source_with_strides, target_with_strides}, op->attr_key, - op->value, body, op->span); - } else { - return StmtExprMutator::VisitStmt_(op); - } - } - - // AllocateNodes may be present from tvm.tir.ir_builder. This can - // be simplified in the future by having AllocateNode hold a buffer, - // rather than a buffer_var. - Stmt VisitStmt_(const AllocateNode* op) final { - buffer_var_defines_.insert(op->buffer_var.get()); - return StmtExprMutator::VisitStmt_(op); - } - - Stmt VisitStmt_(const AllocateConstNode* op) final { - buffer_var_defines_.insert(op->buffer_var.get()); - return StmtExprMutator::VisitStmt_(op); - } - - Stmt VisitStmt_(const LetStmtNode* op) final { - if (op->var.dtype().is_handle()) { - buffer_var_defines_.insert(op->var.get()); - } - return StmtExprMutator::VisitStmt_(op); - } - - PrimExpr VisitExpr_(const LetNode* op) final { - if (op->var.dtype().is_handle()) { - buffer_var_defines_.insert(op->var.get()); - } - return StmtExprMutator::VisitExpr_(op); - } - - Stmt VisitStmt_(const BufferRealizeNode* op) final { - Buffer key = op->buffer; - Buffer with_strides = WithStrides(op->buffer); - - Stmt stmt = StmtExprMutator::VisitStmt_(op); - - buf_map_[key].in_scope = false; - op = stmt.as(); - ICHECK(op); - - return BufferRealize(with_strides, op->bounds, op->condition, op->body, op->span); - } - - Stmt VisitStmt_(const BufferStoreNode* op) final { - auto node = Downcast(StmtExprMutator::VisitStmt_(op)); - return VisitBufferAccess(std::move(node)); - } - - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - auto node = Downcast(StmtExprMutator::VisitExpr_(op)); - return VisitBufferAccess(std::move(node)); - } - - template - Node VisitBufferAccess(Node node) { - auto it = buf_map_.find(node->buffer); - ICHECK(it == buf_map_.end() || it->second.in_scope) - << "Cannot access a buffer " << node->buffer->name << ", out of scope"; - - auto with_strides = WithStrides(node->buffer); - if (!with_strides.same_as(node->buffer)) { - node.CopyOnWrite()->buffer = with_strides; - } - - return node; - } - - private: - Map updated_extern_buffer_map_; - - struct DimAlignInfo { - int align_factor{0}; - int align_offset{0}; - }; - - // Dimension alignment - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dim_align_; - - struct BufferEntry { - Buffer remap_to; - bool in_scope; - }; - - std::unordered_map buf_map_; - - // Set of vars that have occurred in an AllocateNode, but haven't - // yet occurred in a BufferLoad/BufferStore. - std::unordered_set buffer_var_defines_; - - IRVisitorWithAnalyzer* bound_analyzer_; -}; - -/* Use the scope of IterVar to determine storage scope. - * - * For buffers that do not have an explicit storage scope defined, a - * reasonable storage scope may be defined based on the thread scope - * that contains the buffer's allocation. All other buffers without a - * scope are assigned to global scope. - */ -class ThreadScopePropagate : public StmtExprMutator { - public: - static transform::Pass Pass() { - auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { - auto pass = ThreadScopePropagate(func->buffer_map); - - auto fptr = func.CopyOnWrite(); - fptr->body = pass(std::move(fptr->body)); - if (auto map = func->attrs.GetAttr>>("layout_transform_map")) { - func = WithAttr(std::move(func), "layout_transform_map", pass.UpdateIndexMap(map.value())); - } - return func; - }; - return transform::CreatePrimFuncPass(pass_func, 0, "tir.ThreadScopePropagate", {}); - } - - explicit ThreadScopePropagate(const Map& extern_buffer_map) { - // External buffers shouldn't be overwritten, even if they have a - // BufferRealizeNode. - for (auto kv : extern_buffer_map) { - external_buffers_.insert(kv.second); - } - } - - Map> UpdateIndexMap(const Map>& orig) { - Map> output; - for (const auto& kv : orig) { - auto it = buf_remap_.find(kv.first->data); - if (it != buf_remap_.end()) { - output.Set(it->second, kv.second); - } else { - output.Set(kv.first, kv.second); - } - } - return output; - } - - PrimExpr VisitExpr_(const VarNode* op) final { - auto it = buf_remap_.find(GetRef(op)); - if (it != buf_remap_.end()) { - return it->second->data; - } else { - return GetRef(op); - } - } - - Stmt VisitStmt_(const AttrStmtNode* op) final { - ICHECK_NE(op->attr_key, attr::buffer_dim_align) - << "StorageFlattener assumes that all buffers have accurate strides, " - << "and all buffer_dim_align annotations are removed. " - << "Please run BufferStrideLegalize first."; - - if (op->attr_key == attr::thread_extent) { - IterVar iv = Downcast(op->node); - ThreadScope ts = ThreadScope::Create(iv->thread_tag); - curr_thread_scope_.push_back(ts); - Stmt stmt = StmtExprMutator::VisitStmt_(op); - curr_thread_scope_.pop_back(); - return stmt; - } else if (op->attr_key == attr::buffer_bind_scope) { - return HandleBufferBindScope(op); - } else { - return StmtExprMutator::VisitStmt_(op); - } - } - - Stmt VisitStmt_(const BufferRealizeNode* op) final { - Var old_var = op->buffer->data; - - // Don't remap buffers that already have an explicit scope, - // or external buffers. - std::string str_scope = GetPtrStorageScope(old_var); - if ((str_scope.length() > 0) || external_buffers_.count(op->buffer)) { - return StmtExprMutator::VisitStmt_(op); - } - - ICHECK_EQ(buf_remap_.count(old_var), 0) - << "Buffer var " << op->buffer->data << " appears in multiple BufferRealize nodes"; - - StorageScope skey; - if (curr_thread_scope_.size() == 0) { - skey.rank = StorageRank::kGlobal; - } else { - skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); - } - - auto ptr_type = old_var->type_annotation.as(); - ICHECK(ptr_type); - Var new_var(old_var->name_hint, PointerType(ptr_type->element_type, skey.to_string()), - old_var->span); - - Buffer buf = op->buffer; - buf.CopyOnWrite()->data = new_var; - - buf_remap_[old_var] = buf; - - Stmt body = this->VisitStmt(op->body); - return BufferRealize(buf, op->bounds, op->condition, body, op->span); - } - - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - ICHECK(op); - - auto it = buf_remap_.find(op->buffer->data); - if (it != buf_remap_.end()) { - return BufferLoad(it->second, op->indices, op->predicate, op->span); - } else { - return expr; - } - } - - Stmt VisitStmt_(const BufferStoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - ICHECK(op); - - auto it = buf_remap_.find(op->buffer->data); - if (it != buf_remap_.end()) { - return BufferStore(it->second, op->value, op->indices, op->predicate, op->span); - } else { - return stmt; - } - } - - private: - // If the rewritten buffers are part of a buffer_bind_scope, either - // as the buffer view or as the buffer being viewed, then the - // buffer_bind_scope must be rewritten to refer to the updated - // buffers. - Stmt HandleBufferBindScope(const AttrStmtNode* op) { - Array arr = Downcast>(op->node); - ICHECK_EQ(arr.size(), 2U); - Buffer buffer = Downcast(arr[0]); - ICHECK(buffer.defined()); - Buffer target = Downcast(arr[1]); - ICHECK(target.defined()); - - bool needs_rewrite = false; - - { - auto it = buf_remap_.find(buffer->data); - if (it != buf_remap_.end()) { - needs_rewrite = true; - buffer = it->second; - } - } - - { - auto it = buf_remap_.find(target->data); - if (it != buf_remap_.end()) { - needs_rewrite = true; - target = it->second; - } - } - - if (needs_rewrite) { - Stmt body = this->VisitStmt(op->body); - return AttrStmt(Array{buffer, target}, op->attr_key, op->value, body); - } else { - return StmtExprMutator::VisitStmt_(op); - } - } - - std::unordered_map buf_remap_; - std::unordered_set external_buffers_; - - // The current thread scope. - std::vector curr_thread_scope_; -}; - -/* Map buffer binds to their source buffer - * - * Buffers defined using an attr::buffer_bind_scope annotation are - * views into some linked buffer, potentially into some restricted - * subregion of that buffer. This pass identifies such buffers, then - * rewrites all access of the bound buffers to be access into the - * linked buffer. - */ -class BufferBindUnwrapper : public StmtExprMutator { - public: - static transform::Pass Pass() { - auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { - IRVisitorWithAnalyzer bound_analyzer; - - bound_analyzer(func->body); - - auto pass = BufferBindUnwrapper(func->buffer_map, &bound_analyzer); - - auto fptr = func.CopyOnWrite(); - fptr->body = pass(std::move(fptr->body)); - return func; - }; - return transform::CreatePrimFuncPass(pass_func, 0, "tir.BufferBindUnwrapper", {}); - } - - explicit BufferBindUnwrapper(const Map& extern_buffer_map, - IRVisitorWithAnalyzer* bound_analyzer) - : bound_analyzer_(bound_analyzer) { - for (auto kv : extern_buffer_map) { - BufferEntry e; - e.buffer = kv.second; - e.external = true; - var_to_buffer_[kv.second->data.get()] = kv.second; - buf_map_[kv.second.get()] = std::move(e); - } - } - - Map> UpdateIndexMap(const Map>& orig) { - Map> output; - for (const auto& kv : orig) { - const BufferEntry& e = GetBufferEntry(kv.first); - - if (e.remap) { - output.Set(e.remap->target, kv.second); - } else { - output.Set(kv.first, kv.second); - } - } - return output; - } - - Stmt VisitStmt_(const AttrStmtNode* op) final { - ICHECK_NE(op->attr_key, attr::buffer_dim_align) - << "BufferBindUnwrapper assumes that all buffers have accurate strides, " - << "and all buffer_dim_align annotations are removed. " - << "Please run BufferStrideLegalize first."; - - if (op->attr_key == attr::buffer_bind_scope) { - return HandleBufferBindScope(op); - } else { - return StmtExprMutator::VisitStmt_(op); - } - } - - PrimExpr VisitExpr_(const VarNode* op) final { - ICHECK(!illegal_vars_.count(op)) << "Variable " << op->name_hint << " is not well defined. " - << "(e.g. use of buffer.elem_offset for a non-flat buffer)"; - - auto it = var_remap_.find(op); - if (it != var_remap_.end()) { - return it->second; - } else { - return GetRef(op); - } - } - - Array remap_indices(Array indices, Array begins, - Array extents) { - ICHECK_EQ(begins.size(), extents.size()); - - if (begins.size() == 0) { - return indices; - } - - ICHECK_EQ(begins.size(), indices.size()); - - Array out; - for (size_t i = 0; i < begins.size(); i++) { - out.push_back(begins[i] + indices[i]); - } - return out; - } - - Array remap_bounds(Array bounds, Array begins, Array extents) { - ICHECK_EQ(begins.size(), extents.size()); - - if (begins.size() == 0) { - return bounds; - } - - ICHECK_EQ(begins.size(), bounds.size()); - - Array out; - for (size_t i = 0; i < begins.size(); i++) { - out.push_back(Range::FromMinExtent(bounds[i]->min + begins[i], bounds[i]->extent)); - } - return out; - } - - // AllocateNodes may be present from tvm.tir.ir_builder. This can - // be simplified in the future by having AllocateNode hold a buffer, - // rather than a buffer_var. - Stmt VisitStmt_(const AllocateNode* op) final { - buffer_var_defines_.insert(op->buffer_var.get()); - return StmtExprMutator::VisitStmt_(op); - } - - Stmt VisitStmt_(const AllocateConstNode* op) final { - buffer_var_defines_.insert(op->buffer_var.get()); - return StmtExprMutator::VisitStmt_(op); - } - - Stmt VisitStmt_(const LetStmtNode* op) final { - if (op->var.dtype().is_handle()) { - buffer_var_defines_.insert(op->var.get()); - } - return StmtExprMutator::VisitStmt_(op); - } - - PrimExpr VisitExpr_(const LetNode* op) final { - if (op->var.dtype().is_handle()) { - buffer_var_defines_.insert(op->var.get()); - } - return StmtExprMutator::VisitExpr_(op); - } - - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - - const BufferEntry& e = GetBufferEntry(op->buffer); - - if (e.remap) { - ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in " - "storage flatten pass."; - return BufferLoad(e.remap->target, - remap_indices(op->indices, e.remap->begins, e.remap->extents), - op->predicate, op->span); - } else { - return expr; - } - } - - Stmt VisitStmt_(const BufferStoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - - const BufferEntry& e = GetBufferEntry(op->buffer); - - if (e.remap) { - ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " - "storage flatten pass."; - return BufferStore(e.remap->target, op->value, - remap_indices(op->indices, e.remap->begins, e.remap->extents), - op->predicate, op->span); - } else { - return stmt; - } - } - - Stmt VisitStmt_(const BufferRealizeNode* op) final { - const auto& key = op->buffer.get(); - - bool is_external = false; - - if (buf_map_.count(key)) { - ICHECK(buf_map_.at(key).external) - << "BufferRealize node for internal buffer " << op->buffer << " occurred multiple times."; - - is_external = true; - } else { - BufferEntry e; - e.bounds = op->bounds; - e.buffer = op->buffer; - var_to_buffer_[op->buffer->data.get()] = op->buffer; - buf_map_[key] = std::move(e); - } - - Stmt stmt = StmtExprMutator::VisitStmt_(op); - - if (is_external) { - buf_map_[key].in_scope = false; - } - - return stmt; - } - - Stmt VisitStmt_(const PrefetchNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - ICHECK(op != nullptr); - - const BufferEntry& e = GetBufferEntry(op->buffer); - - ICHECK(e.in_scope) << "Read a buffer that is already out of scope"; - ICHECK_EQ(e.buffer->shape.size(), op->bounds.size()) - << "Prefetch dim should be the same as buffer dim"; - - if (e.remap) { - return Prefetch(e.remap->target, remap_bounds(op->bounds, e.remap->begins, e.remap->extents), - op->span); - } else { - return stmt; - } - } - - private: - // Read the mapping from a buffer view to the actual buffer. This - // allows all later BufferStore/BufferLoad nodes to reference the - // actual buffer, rather than the buffer view. - Stmt HandleBufferBindScope(const AttrStmtNode* op) { - // Unpack information from Attribute node - RemapInfo remap; - - Array arr = Downcast>(op->node); - ICHECK_EQ(arr.size(), 2U); - const Buffer source = Downcast(arr[0]); - ICHECK(source.defined()); - remap.target = Downcast(arr[1]); - ICHECK(remap.target.defined()); - const CallNode* tuple = op->value.as(); - ICHECK(tuple && tuple->op.same_as(builtin::tvm_tuple())); - - for (size_t i = 0; i < tuple->args.size(); i += 2) { - remap.begins.push_back(tuple->args[i]); - remap.extents.push_back(tuple->args[i + 1]); - } - - // Determine bounds in the target buffer - auto it = buf_map_.find(remap.target.get()); - ICHECK(it != buf_map_.end()) << "Cannot define " << source << " as a view into " << remap.target - << ", " << remap.target << " was not defined."; - const BufferEntry& target_info = it->second; - ICHECK(target_info.in_scope) << "Cannot define " << source << " as a view into " << remap.target - << ", " << remap.target << " is out of scope."; - ICHECK_EQ(remap.begins.size(), target_info.buffer->shape.size()) - << "Incorrect number of arguments in buffer_bind_scope attribute. " - << "Expected (min_0, extent_0, min_1, extent_0, ..., min_N, extent_N)."; - - if (target_info.bounds.size() > 0) { - Array mapped_begins; - for (size_t i = 0; i < target_info.buffer->shape.size(); ++i) { - mapped_begins.push_back(remap.begins[i] - target_info.bounds[i]->min); - } - remap.begins = std::move(mapped_begins); - } - - ICHECK(target_info.remap == nullptr) - << "buffer_bind_scope defines " << source << " as a view into " << remap.target - << ", which is itself a buffer view. " - << "Indirect remapping not currently supported."; - - for (size_t i = 0; i < remap.begins.size(); i++) { - remap.begins.Set(i, bound_analyzer_->Simplify(remap.begins[i])); - remap.extents.Set(i, bound_analyzer_->Simplify(remap.extents[i])); - } - - // Add a buffer remap entry - { - BufferEntry source_info; - source_info.buffer = source; - source_info.remap = std::make_unique(remap); - - var_to_buffer_[source->data.get()] = source; - buf_map_[source.get()] = std::move(source_info); - } - - // Define remappings of any remaining Variables (e.g. Store/Load nodes). - ArgBinder binder(&var_remap_); - - // Define a view that represents the source's view into the target - // buffer. This Buffer object is only used to define the mapping - // to the target buffer, and never actually appears in the TIR - // graph. - Buffer view = remap.target.MakeSlice(remap.begins, remap.extents); - if (source->strides.size() == 0) { - ICHECK_EQ(view->strides.size(), 0U) - << "Cannot bind a compact buffer " << source << " to a strided buffer " << view - << " with strides " << view->strides; - } else { - // Add explicit strides to the view, in order to bind to source.strides[i]. - view = view.MakeStrideView(); - } - - // Match integer bits of source->elem_offset and view->elem_offset - // as is required by ArgBinder::Bind_ - if (view->elem_offset.defined() && source->elem_offset.dtype() != view->elem_offset.dtype()) { - view.CopyOnWrite()->elem_offset = cast(source->elem_offset.dtype(), view->elem_offset); - } - - // Bind any variables that reference the view (e.g. elem_offset, - // strides, shape). Pass fuzzy_match=false, because all shape - // transformations should have been handled in - // BufferShapeLegalize. - binder.BindBuffer(source, view, source->name, false); - if (auto* elem_offset_var = source->elem_offset.as()) { - if (!view->elem_offset.defined()) { - illegal_vars_.insert(elem_offset_var); - } - } - - // Apply the remaps - Stmt body = op->body; - body = MergeNest(binder.asserts(), body); - body = MergeNest(binder.init_nest(), body); - body = this->VisitStmt(body); - // remove the binds - for (const Var& v : binder.defs()) { - var_remap_.erase(v.get()); - } - return body; - } - - struct RemapInfo { - Buffer target; - Array begins; - Array extents; - }; - - // The buffer entry in the flatten map - struct BufferEntry { - // The storage buffer - Buffer buffer; - // the bounds of realization, can be null, means everything - Region bounds; - // Whether the buffer is external - bool external{false}; - // Whether we are within the allocation scope of the buffer. - bool in_scope{true}; - - // The buffer to which the storage buffer should be remapped. - std::unique_ptr remap{nullptr}; - }; - - const BufferEntry& GetBufferEntry(Buffer buffer) { - if (buf_map_.count(buffer.get())) { - const BufferEntry& e = buf_map_[buffer.get()]; - ICHECK(e.in_scope) << "Cannot access a buffer " << buffer->name << ", out of scope"; - return e; - } else if (buffer_var_defines_.count(buffer->data.get())) { - // The buffer var was defined, but the buffer hasn't been seen - // before. - BufferEntry entry; - entry.buffer = buffer; - var_to_buffer_[buffer->data.get()] = buffer; - buf_map_[buffer.get()] = std::move(entry); - return buf_map_[buffer.get()]; - } else if (var_remap_.count(buffer->data.get())) { - // The buffer var is an alias of a bound buffer. Only - // supported if the bound buffer has no offsets. In this - // case, we just need to make a new aliasing buffer that - // shares the remapped data variable. - Var old_var = buffer->data; - Var new_var = Downcast(var_remap_[old_var.get()]); - - { - ICHECK(var_to_buffer_.count(old_var.get())) - << "Cannot find remap information for aliased buffer var " << old_var->name_hint - << ", required to verify this alias is legal."; - const Buffer& aliased_buffer = var_to_buffer_[old_var.get()]; - const BufferEntry& entry = buf_map_[aliased_buffer.get()]; - if (entry.remap) { - for (const auto& begin : entry.remap->begins) { - ICHECK(is_zero(begin)) << "Aliasing of buffer with offset is not supported"; - } - } - } - - { - Buffer new_buf = buffer; - new_buf.CopyOnWrite()->data = new_var; - - RemapInfo remap_info; - remap_info.target = new_buf; - remap_info.begins = Array(buffer->shape.size(), 0); - remap_info.extents = buffer->shape; - - BufferEntry entry; - entry.buffer = buffer; - entry.remap = std::make_unique(remap_info); - entry.in_scope = true; - var_to_buffer_[buffer->data.get()] = buffer; - buf_map_[buffer.get()] = std::move(entry); - } - return buf_map_[buffer.get()]; - } else if (var_to_buffer_.count(buffer->data.get())) { - // This buffer is an alias of a known buffer, with no remaps. A - // buffer entry should be generated and returned. - BufferEntry entry; - entry.buffer = buffer; - entry.in_scope = true; - var_to_buffer_[buffer->data.get()] = buffer; - buf_map_[buffer.get()] = std::move(entry); - - return buf_map_[buffer.get()]; - } else { - LOG(FATAL) << "Can't work around the undefined buffer"; - } - } - - // The buffer assignment map - // Variable remap - std::unordered_map var_remap_; - // Variables that may not occur within the body. - std::unordered_set illegal_vars_; - // Buffer map - std::unordered_map buf_map_; - // Map from Var to the Buffer they occurred in. In case of aliased - // buffers, contains the first buffer. - std::unordered_map var_to_buffer_; - // Set of vars that have occurred in an AllocateNode, but haven't - // yet occurred in a BufferLoad/BufferStore. - std::unordered_set buffer_var_defines_; - // Analyzer for the variable bounds, used to simplify the bounds populator. We really need the - // analyzer from it. However - IRVisitorWithAnalyzer* bound_analyzer_; -}; - -class ApplyLayoutTransforms : public StmtExprMutator { - public: - static transform::Pass Pass() { - auto pass_func = [](PrimFunc func, IRModule m, transform::PassContext ctx) { - auto lookup = func->attrs.GetAttr>>("layout_transform_map"); - - if (!lookup) { - return func; - } - - Map> layout_transforms = lookup.value(); - - auto fptr = func.CopyOnWrite(); - - auto mutator = ApplyLayoutTransforms(layout_transforms); - fptr->buffer_map = mutator.UpdateExternBufferMap(fptr->buffer_map); - fptr->body = mutator(std::move(fptr->body)); - - return WithoutAttr(std::move(func), "layout_transform_map"); - }; - return transform::CreatePrimFuncPass(pass_func, 0, "tir.ApplyLayoutTransforms", {}); - } - - explicit ApplyLayoutTransforms(Map> layout_transforms) - : layout_transforms_(layout_transforms) {} - - Map UpdateExternBufferMap(const Map& buffer_map) { - Map output; - for (const auto& kv : buffer_map) { - output.Set(kv.first, GetBufferRemap(kv.second, true)); - } - return output; - } - - Stmt VisitStmt_(const BufferRealizeNode* op) final { - // Call once so that load/store nodes can read from the cached - // value. - GetBufferRemap(op->buffer, true); - - auto realize = Downcast(StmtExprMutator::VisitStmt_(op)); - - auto lookup = layout_transforms_.Get(op->buffer); - if (lookup) { - auto write_ptr = realize.CopyOnWrite(); - write_ptr->buffer = GetBufferRemap(op->buffer, true); - - Array transforms = lookup.value(); - for (const auto& transform : transforms) { - write_ptr->bounds = transform->MapRanges(realize->bounds, &analyzer); - } - } - - return std::move(realize); - } - - Stmt VisitStmt_(const BufferStoreNode* op) final { - auto node = Downcast(StmtExprMutator::VisitStmt_(op)); - return VisitBufferAccess(std::move(node)); - } - - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - auto node = Downcast(StmtExprMutator::VisitExpr_(op)); - return VisitBufferAccess(std::move(node)); - } - - template - Node VisitBufferAccess(Node node) { - auto lookup = layout_transforms_.Get(node->buffer); - if (lookup) { - auto write_ptr = node.CopyOnWrite(); - - write_ptr->buffer = GetBufferRemap(node->buffer); - - Array transforms = lookup.value(); - for (const auto& transform : transforms) { - write_ptr->indices = transform->MapIndices(node->indices, &analyzer); - } - } - return node; - } - - private: - //! \brief Given a buffer, return the buffer it should be remapped into. - Buffer GetBufferRemap(Buffer buf, bool allow_alloc = false) { - auto key = buf.get(); - auto it = buf_map_.find(key); - if (it != buf_map_.end()) { - return it->second; - } - - ICHECK(allow_alloc) << "Buffer " << buf << " accessed before declaration."; - - auto lookup = layout_transforms_.Get(buf); - if (lookup) { - Array transforms = lookup.value(); - - auto write_ptr = buf.CopyOnWrite(); - for (const auto& transform : transforms) { - write_ptr->shape = transform->MapShape(buf->shape, &analyzer); - } - } - - buf_map_[key] = buf; - return buf; - } - - std::unordered_map buf_map_; - - Map> layout_transforms_; - arith::Analyzer analyzer; -}; - -class StorageFlattener : public StmtExprMutator { - public: - static transform::Pass Pass(int cache_line_size, bool create_bound_attributes) { - auto pass_func = [=](PrimFunc func, IRModule m, transform::PassContext ctx) { - IRVisitorWithAnalyzer bound_analyzer; - - bound_analyzer(func->body); - - auto pass = StorageFlattener(func->buffer_map, cache_line_size, create_bound_attributes, - &bound_analyzer); - - auto fptr = func.CopyOnWrite(); - fptr->body = pass(std::move(fptr->body)); - // The buffers in func->buffer_map are deliberately left - // unflattened, as they are used for validation of user-provided - // arguments. The flattened buffers used in the updated - // function body alias the argument buffers. - return func; - }; - return transform::CreatePrimFuncPass(pass_func, 0, "tir.StorageFlattener", {}); - } - - explicit StorageFlattener(const Map& extern_buffer_map, int cache_line_size, - bool create_bound_attributes, IRVisitorWithAnalyzer* bound_analyzer) - : bound_analyzer_(bound_analyzer), create_bound_attributes_(create_bound_attributes) { - for (auto kv : extern_buffer_map) { - BufferEntry e; - e.buffer = kv.second; - e.flattened_buffer = e.buffer.GetFlattenedBuffer(); - // TODO(Lunderberg): Move the handling of boolean into a - // dedicated pass. - - // Boolean tensors are backed by a Int8 array. - if (e.buffer->dtype == DataType::Bool()) { - { - auto writer = e.buffer.CopyOnWrite(); - writer->dtype = DataType::Int(8); - } - { - auto writer = e.flattened_buffer.CopyOnWrite(); - writer->dtype = DataType::Int(8); - } - } - e.external = true; - buffer_var_defines_.insert(kv.second->data.get()); - buf_map_[kv.second] = e; - } - cache_line_size_ = cache_line_size; - } - - Stmt VisitStmt_(const AttrStmtNode* op) final { - ICHECK_NE(op->attr_key, attr::buffer_dim_align) - << "StorageFlattener assumes that all buffers have accurate strides, " - << "and all buffer_dim_align annotations are removed. " - << "Please run BufferStrideLegalize first."; - - ICHECK_NE(op->attr_key, attr::buffer_bind_scope) - << "StorageFlattener assumes that all buffer binds have already been applied. " - << "Please run BufferBindUnwrapper first."; - - if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance()) { - auto buffer = Downcast(op->node); - Stmt body = this->VisitStmt(op->body); - const auto& entry = GetBufferEntry(buffer); - body = AttrStmt(entry.flattened_buffer->data, op->attr_key, op->value, std::move(body)); - return body; - } - return StmtExprMutator::VisitStmt_(op); - } - - Stmt VisitStmt_(const BufferStoreNode* op) final { - if (create_bound_attributes_) shape_collector_.clear(); - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - - const BufferEntry& e = GetBufferEntry(op->buffer); - - // Handle casts from the value's dtype to the dtype of the backing - // array. - PrimExpr value = op->value; - if (value.dtype() == DataType::Bool()) { - ICHECK_EQ(e.flattened_buffer->dtype, DataType::Int(8)) - << "Expected int8 backing array for boolean tensor, but received " - << e.flattened_buffer->dtype; - value = tir::Cast(DataType::Int(8), value); - } - - auto flattened_indices = e.buffer->ElemOffset(op->indices); - - ICHECK(!op->predicate.defined()) << "Predicated buffer store is not currently supported in " - "storage flatten pass."; - Stmt body = BufferStore(e.flattened_buffer, value, flattened_indices, op->predicate, op->span); - if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { - shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); - } - // To create bound attribute collector should has at least one item. - if (create_bound_attributes_ && shape_collector_.size()) { - for (size_t i = 0; i < shape_collector_.size(); ++i) { - body = AttrStmt(shape_collector_[i].first, tir::attr::buffer_bound, - MakeBound(e.buffer->dtype, shape_collector_[i].second), body); - } - } - return body; - } - - Stmt VisitStmt_(const DeclBufferNode* op) final { - auto node = Downcast(StmtExprMutator::VisitStmt_(op)); - const BufferEntry& entry = GetBufferEntry(node->buffer); - if (!entry.flattened_buffer.same_as(node->buffer)) { - node.CopyOnWrite()->buffer = entry.flattened_buffer; - } - return std::move(node); - } - - // AllocateNodes may be present from tvm.tir.ir_builder. This can - // be simplified in the future by having AllocateNode hold a buffer, - // rather than a buffer_var. - Stmt VisitStmt_(const AllocateNode* op) final { - buffer_var_defines_.insert(op->buffer_var.get()); - auto stmt = Downcast(StmtExprMutator::VisitStmt_(op)); - return Allocate(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt), stmt->condition, - stmt->body, stmt->annotations, stmt->span); - } - - Stmt VisitStmt_(const AllocateConstNode* op) final { - buffer_var_defines_.insert(op->buffer_var.get()); - auto stmt = Downcast(StmtExprMutator::VisitStmt_(op)); - ObjectRef data_or_idx; - if (stmt->data) { - data_or_idx = stmt->data.value(); - } else if (stmt->irmod_storage_idx) { - data_or_idx = stmt->irmod_storage_idx.value(); - } else { - LOG(FATAL) << "Neither data array nor data index specified for allocation of const " - << op->buffer_var->name_hint; - } - return AllocateConst(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt), data_or_idx, - stmt->body, stmt->annotations, stmt->span); - } - - Stmt VisitStmt_(const LetStmtNode* op) final { - if (op->var.dtype().is_handle()) { - buffer_var_defines_.insert(op->var.get()); - } - return StmtExprMutator::VisitStmt_(op); - } - - PrimExpr VisitExpr_(const LetNode* op) final { - if (op->var.dtype().is_handle()) { - buffer_var_defines_.insert(op->var.get()); - } - return StmtExprMutator::VisitExpr_(op); - } - - Stmt VisitStmt_(const BufferRealizeNode* op) final { - const auto& key = op->buffer; - - if (buf_map_.count(key)) { - ICHECK(buf_map_.at(key).external) - << "BufferRealize for internal buffer " << op->buffer << " appears multiple times."; - return this->VisitStmt(op->body); - } else { - // create a buffer entry - BufferEntry e; - - ICHECK_EQ(op->buffer->shape.size(), op->bounds.size()) - << "Inconsistent buffer shape and realization shape for " << op->buffer; - - for (size_t i = 0; i < op->bounds.size(); i++) { - const auto& bound = op->bounds[i]; - const auto& dim_size = op->buffer->shape[i]; - ICHECK(is_zero(bound_analyzer_->Simplify(bound->min))) - << "Buffer " << op->buffer << " has realization bounds that do not start at zero. " - << "Please run BufferShapeLegalize first."; - ICHECK(is_one(bound_analyzer_->Simplify(bound->extent == dim_size))) - << "Buffer " << op->buffer - << " has realization extent that does not match its size. " - "Please run BufferShapeLegalize first."; - } - - StorageScope skey = StorageScope::Create(GetPtrStorageScope(op->buffer->data)); - - // use small alignment for small arrays - auto dtype = op->buffer->dtype; - size_t const_size = AllocateNode::ConstantAllocationSize(op->buffer->shape); - int align = GetTempAllocaAlignment(dtype, const_size); - if (skey.tag.length() != 0) { - MemoryInfo info = GetMemoryInfo(skey.to_string()); - if (info.defined()) { - align = (info->max_simd_bits + dtype.bits() - 1) / dtype.bits(); - ICHECK_LE(const_size * dtype.bits(), info->max_num_bits) - << "Allocation exceed bound of memory tag " << skey.to_string(); - } - } - - e.buffer = Buffer(op->buffer->data, op->buffer->dtype, op->buffer->shape, op->buffer->strides, - PrimExpr(), op->buffer->name, align, 0, kDefault, - op->buffer->axis_separators, op->buffer->span); - e.flattened_buffer = e.buffer.GetFlattenedBuffer(); - - // TODO(Lunderberg): Move the handling of boolean into a - // dedicated pass. - - // Boolean tensors are backed by a Int8 array. - if (e.flattened_buffer->dtype == DataType::Bool()) { - auto writer = e.flattened_buffer.CopyOnWrite(); - writer->dtype = DataType::Int(8); - } - - buffer_var_defines_.insert(op->buffer->data.get()); - buf_map_[key] = e; - Stmt body = this->VisitStmt(op->body); - buffer_var_defines_.erase(op->buffer->data.get()); - buf_map_[key].in_scope = false; - - Stmt ret = - Allocate(e.flattened_buffer->data, e.flattened_buffer->dtype, e.flattened_buffer->shape, - make_const(DataType::Bool(e.flattened_buffer->dtype.lanes()), true), body); - - if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { - ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound, - MakeBound(e.buffer->dtype, e.buffer->shape), ret); - } - return ret; - } - } - - PrimExpr VisitExpr_(const VarNode* op) final { - auto it = var_remap_.find(op); - if (it != var_remap_.end()) { - return it->second; - } else { - return GetRef(op); - } - } - - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - - const BufferEntry& e = GetBufferEntry(op->buffer); - - if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { - shape_collector_.push_back(std::make_pair(e.buffer->data, e.buffer->shape)); - } - - ICHECK(!op->predicate.defined()) << "Predicated buffer load is not currently supported in " - "storage flatten pass."; - auto flattened_indices = e.buffer->ElemOffset(op->indices); - PrimExpr val = BufferLoad(e.flattened_buffer, flattened_indices, op->predicate, op->span); - - if (op->dtype == DataType::Bool()) { - ICHECK_EQ(e.flattened_buffer->dtype, DataType::Int(8)) - << "Expected int8 backing array for boolean tensor, but received " - << e.flattened_buffer->dtype; - val = tir::Cast(DataType::Bool(), val); - } - - return val; - } - - Stmt VisitStmt_(const PrefetchNode* op) final { - const BufferEntry& e = GetBufferEntry(op->buffer); - - ICHECK(e.in_scope) << "Cannot prefetch " << op->buffer << ", out of scope."; - ICHECK_EQ(e.buffer->shape.size(), op->bounds.size()) - << "Prefetch dim should be the same as buffer dim"; - - int block_size = 1, elem_cnt = cache_line_size_ / e.buffer->dtype.bytes(); - - int starts = op->bounds.size() - 1; - - while (starts > 0) { - auto* shape_as_int = e.buffer->shape[starts].as(); - if (shape_as_int == nullptr || block_size * shape_as_int->value > elem_cnt) break; - block_size *= static_cast(shape_as_int->value); - starts--; - } - PrimExpr stride(elem_cnt / block_size); - - Array args; - std::vector vars; - - for (int i = op->bounds.size() - 1; i > starts; --i) { - args.push_back(op->bounds[i]->min); - } - auto& func_name = op->buffer->name; - vars.push_back(Var("prefetch." + func_name + "." + std::to_string(starts), DataType::Int(32))); - args.push_back(op->bounds[starts]->min + stride * vars.back()); - for (int i = starts - 1; i >= 0; --i) { - vars.push_back(Var("prefetch." + func_name + "." + std::to_string(i), DataType::Int(32))); - args.push_back(vars.back() + op->bounds[i]->min); - } - - Stmt stmt = GetRef(op); - for (int i = starts; i >= 0; --i) { - if (i < starts) { - stmt = For(vars[i], 0, op->bounds[i]->extent, ForKind::kSerial, stmt); - } else { - PrimExpr load = e.buffer.vload(args, e.buffer->dtype); - PrimExpr address = Call(DataType::Handle(), builtin::address_of(), {load}); - PrimExpr prefetch = Call(op->buffer->dtype, builtin::prefetch(), {address, 0, 3, 1}); - stmt = Evaluate(prefetch); - PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; - stmt = For(vars[i], 0, extent, ForKind::kSerial, stmt); - } - } - return this->VisitStmt(stmt); - } - - PrimExpr VisitExpr_(const ProducerLoadNode* op) final { - LOG(FATAL) << "ProducerLoad cannot appear in a valid TIR PrimFunc. " - << "Please run SchedulePostProcToPrimFunc first."; - return PrimExpr(); - } - - Stmt VisitStmt_(const ProducerStoreNode* op) final { - LOG(FATAL) << "ProducerStore cannot appear in a valid TIR PrimFunc. " - << "Please run SchedulePostProcToPrimFunc first."; - return Stmt(); - } - - Stmt VisitStmt_(const ProducerRealizeNode* op) final { - LOG(FATAL) << "ProducerRealize cannot appear in a valid TIR PrimFunc. " - << "Please run SchedulePostProcToPrimFunc first."; - return Stmt(); - } - - private: - // Helper function for visiting Allocate and AllocateConst. If, in - // the future, these are updated to hold a buffer (Buffer) object - // rather than a buffer_var (Var), this function can be replaced - // with a call to GetBufferEntry. - template - Array FlattenExtents(const Node& node) { - arith::Analyzer analyzer; - - // If an allocation has extents that match the buffer - auto is_compatible_buffer = [&](const Buffer& buffer) { - if (buffer->shape.size() != node->extents.size()) { - return false; - } - for (size_t i = 0; i < buffer->shape.size(); i++) { - if (!analyzer.CanProveEqual(buffer->shape[i], node->extents[i])) { - return false; - } - } - - return true; - }; - - auto int_array_equal = [](const Array& a, const Array& b) { - if (a.size() != b.size()) { - return false; - } - - for (size_t i = 0; i < a.size(); i++) { - if (a[i]->value != b[i]->value) { - return false; - } - } - - return true; - }; - - Array axis_separators; - auto it = buffer_var_map_.find(node->buffer_var.get()); - if (it != buffer_var_map_.end()) { - const auto& buffers = it->second; - if (buffers.size() == 0) { - // No buffers use this allocation, treat as flat and optimize - // out later. - } else if (buffers.size() == 1) { - // Only one buffer uses this allocation, so use its axis - // separators. - axis_separators = buffers[0]->axis_separators; - } else { - // Try to find a buffer using this allocation with a matching - // shape. - Buffer compatible_buffer; - for (const auto& buffer : buffers) { - if (is_compatible_buffer(buffer)) { - ICHECK(!compatible_buffer.defined() || - int_array_equal(compatible_buffer->axis_separators, buffer->axis_separators)) - << "Cannot determine axis separators to use when flattening " - << node->buffer_var->name_hint - << ", multiple buffer objects found with conflicting axis separators"; - compatible_buffer = buffer; - } - } - ICHECK(compatible_buffer.defined()) - << "Cannot determine axis separators to use when flattening " - << node->buffer_var->name_hint << ", no buffers found with matching shape"; - axis_separators = compatible_buffer->axis_separators; - } - } - - // Use GetFlattenedBuffer to determine the flattened shape of the - // output. We only need the shape and axis separators defined, - // everything else can be dummy values. - Buffer dummy_buffer = - decl_buffer(node->extents, DataType::Float(32), "buffer", "", axis_separators); - return dummy_buffer.GetFlattenedBuffer()->shape; - } - - // The buffer entry in the flatten map - struct DimAlignInfo { - int align_factor{0}; - int align_offset{0}; - }; - // The buffer entry in the flatten map - struct BufferEntry { - // The buffer object - Buffer buffer; - // The updated buffer object, after flattening has been applied. - Buffer flattened_buffer; - // Whether the buffer is external - bool external{false}; - // Whether the buffer is currently in scope. - bool in_scope{true}; - }; - - bool ShapeIsValid(const Array& shape) { - // Zero-dimensional tensor does not need boundary check. - if (!shape.size()) return false; - - for (size_t i = 0; i < shape.size(); ++i) { - if (!shape[i].defined() || !shape[i].dtype().is_scalar() || is_negative_const(shape[i])) { - return false; - } - } - return true; - } - - PrimExpr MakeBound(const DataType& type, const Array& shape) { - // We have already checked the shape size to be greater then 0. - PrimExpr bound = Mul(make_const(shape[0].dtype(), type.lanes()), shape[0]); - for (size_t i = 1; i < shape.size(); ++i) { - bound = Mul(bound, Mul(make_const(bound.dtype(), type.lanes()), shape[i])); - } - Array bounds{bound}; - - return Call(DataType::Handle(), builtin::tvm_tuple(), bounds); - } - - const BufferEntry& GetBufferEntry(Buffer buffer) { - auto alloc_key = buffer->data.get(); - if (!buf_map_.count(buffer) && buffer_var_defines_.count(alloc_key)) { - BufferEntry entry; - entry.buffer = buffer; - entry.flattened_buffer = buffer.GetFlattenedBuffer(); - // Boolean tensors are backed by a Int8 array. - if (entry.flattened_buffer->dtype == DataType::Bool()) { - auto writer = entry.flattened_buffer.CopyOnWrite(); - writer->dtype = DataType::Int(8); - } - buf_map_[buffer] = std::move(entry); - } - - auto it = buf_map_.find(buffer); - ICHECK(it != buf_map_.end()) << "Cannot find allocated buffer for " << buffer; - const BufferEntry& e = it->second; - ICHECK(e.in_scope) << "Cannot access a buffer " << buffer->name << ", out of scope"; - return it->second; - } - - // The buffer assignment map - // Variable remap - std::unordered_map var_remap_; - // Set of vars that have occurred in an AllocateNode, but haven't - // yet occurred in a BufferLoad/BufferStore. - std::unordered_set buffer_var_defines_; - // Map from an allocation variable to the buffer(s) that it backs. - // Used to track the determine the axis_separators that should be - // used for flattening the extents of an AllocateNode. - std::unordered_map> buffer_var_map_; - // Buffer map - std::unordered_map buf_map_; - // Collects shapes. - std::vector>> shape_collector_; - // bounds populator. We really need the analyzer from it. - // However - IRVisitorWithAnalyzer* bound_analyzer_; - // The size of cacheline - int cache_line_size_; - // Whether to mark load/store with theirs bounds. - bool create_bound_attributes_{false}; -}; - -/*! - * \brief Simplify assert statements. - * - * If an assert statement can be statically verified to be true, - * remove the assert statement. Otherwise, keep the assert statement - * unmodified. - */ -class AssertSimplifier : public StmtMutator { - public: - static transform::Pass Pass() { - auto pass_func = [=](PrimFunc func, IRModule m, transform::PassContext ctx) { - IRVisitorWithAnalyzer bound_analyzer; - - bound_analyzer(func->body); - - auto fptr = func.CopyOnWrite(); - fptr->body = AssertSimplifier(&bound_analyzer)(std::move(fptr->body)); - return func; - }; - return transform::CreatePrimFuncPass(pass_func, 0, "tir.AssertSimplifier", {}); - } - - explicit AssertSimplifier(IRVisitorWithAnalyzer* bound_analyzer) - : bound_analyzer_(bound_analyzer) {} - - Stmt VisitStmt_(const AssertStmtNode* op) final { - Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); - - PrimExpr condition = bound_analyzer_->Simplify(op->condition); - if (is_one(condition)) { - return op->body; - } - - return stmt; - } - - private: - IRVisitorWithAnalyzer* bound_analyzer_; -}; - -// The specific tensor data layout is not determined before -// StorageFlatten pass. We use buffer_bind_scope -// to specify before hand we want to bind a subregion -// of tensor to a symbolic buffer, which get used in extern. -// -// Example: -// -// realize A in range [i*4, extent=10) { -// bind Ab to A in [i*4+1, extent=4) { -// call_func(Ab.ptr, Ab.shape[0]) -// } -// } -// -// After StorageFlatten -// -// alloc A[10] -// call(A + 1, 4) -// -// Buffer is a protocol to declare specific -// data layout and shape we expect. -// So this function need to check: -// - If the bind range is within the realize range -// - If we can match the requirement of buffer -// - Remap variables such as Ab.ptr to the actual value. -// -// Here are a few possible failure cases: -// - Buffer is declared to have constant shape, -// but we try to bind it to a different one. -// - Buffer is declared to be compact(no strides) -// but this binded region is a subregion of -// a matrix(tensor), which means it requires strides. -// -// We do support a few relaxed case, such as binding a -// region with shape [1, 1, n, m] to buffer with shape [n, m] -PrimFunc StorageFlatten(PrimFunc func, int cache_line_size, bool create_bound_attributes) { - // Only apply this pass to TIR from TE schedules. Because this is a - // per-function attribute, we can't just check it once for the - // entire module and apply the Sequential transform. - Optional from_legacy_te_schedule = func->GetAttr("from_legacy_te_schedule", Bool(false)); - if (from_legacy_te_schedule.value()) { - auto seq = transform::Sequential( - { - BufferShapeLegalize::Pass(), - BufferStrideLegalize::Pass(), - ThreadScopePropagate::Pass(), - BufferBindUnwrapper::Pass(), - ApplyLayoutTransforms::Pass(), - StorageFlattener::Pass(cache_line_size, create_bound_attributes), - AssertSimplifier::Pass(), - }, - "tir.StorageFlatten_impl"); - GlobalVar dummy_func_name("dummy_func"); - IRModule mod(Map({{dummy_func_name, func}})); - mod = seq(mod); - return Downcast(mod->Lookup(dummy_func_name)); - } else { - return func; - } -} - -namespace transform { - -TVM_REGISTER_GLOBAL("tir.transform.ApplyLayoutTransforms") - .set_body_typed(ApplyLayoutTransforms::Pass); - -// TODO(tvm-team): consolidate configs to the PassContext -Pass StorageFlatten(int cache_line_size, bool create_bound_attributes) { - auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return StorageFlatten(std::move(f), cache_line_size, create_bound_attributes); - }; - return CreatePrimFuncPass(pass_func, 0, "tir.StorageFlatten", {}); -} - -TVM_REGISTER_GLOBAL("tir.transform.StorageFlatten").set_body_typed(StorageFlatten); - -} // namespace transform - -} // namespace tir -} // namespace tvm diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index 1c3f916a445d..22c347066789 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -92,7 +92,7 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { AllocEntry entry; entry.alloc = op; entry.level = level; - // Since StorageRewrite occurs after StorageFlatten/FlattenBuffer, + // Since StorageRewrite occurs after FlattenBuffer, // all allocations specify the extent of physical dimensions, and // is 1 for flat memory spaces. entry.num_physical_dimensions = op->extents.size(); @@ -542,7 +542,7 @@ class StoragePlanRewriter : public StmtExprMutator { // The storage scope. StorageScope scope; // The physical dimensionality of the allocations. Since - // StorageRewrite is applied after StorageFlatten/FlattenBuffer, + // StorageRewrite is applied after FlattenBuffer, // this is size of `AllocateNode::extents`. If moved size_t ndim; // Allocs that shares this entry. diff --git a/src/tir/transforms/texture_flatten.cc b/src/tir/transforms/texture_flatten.cc deleted file mode 100644 index 91e1121ea130..000000000000 --- a/src/tir/transforms/texture_flatten.cc +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file texture_flatten.cc - * \brief Flattens texture storage from multi-dimensional array - * to 2D (width, height) buffer access - */ - -#include -#include -#include -#include -#include -#include - -#include - -#include "../../arith/ir_visitor_with_analyzer.h" -#include "../../runtime/texture.h" -#include "../../runtime/thread_storage_scope.h" - -namespace tvm { -namespace tir { -using arith::IRVisitorWithAnalyzer; -using runtime::ApplyTexture2DFlattening; -using runtime::DefaultTextureLayoutSeparator; -using runtime::IsTextureStorage; - -class TextureLoweringBase : public StmtExprMutator { - public: - explicit TextureLoweringBase(const Map& extern_buffer_map, - IRVisitorWithAnalyzer* bound_analyzer) - : bound_analyzer_{bound_analyzer} { - for (auto kv : extern_buffer_map) { - extern_buf_.insert(kv.second); - } - } - - inline PrimExpr SimplifyOffset(const Array& shape, const Array& index) const { - PrimExpr base = make_const(DataType::Int(32), 0); - ICHECK_EQ(shape.size(), index.size()); - if (index.size() > 0) { - PrimExpr offset = index[0]; - for (size_t i = 1; i < index.size(); ++i) { - offset = bound_analyzer_->Simplify(offset * shape[i] + index[i]); - } - base = base + offset; - } - return base; - } - - protected: - std::string GetStorageScope(const Buffer& buffer) { - auto* ptr = buffer->data->type_annotation.as(); - ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; - return ptr->storage_scope; - } - - // Set of all external input and output buffers - std::unordered_set extern_buf_; - // Bound analzer - IRVisitorWithAnalyzer* bound_analyzer_; -}; - -// Lower Nd storage access to 2d texture access using lowering convention -// specified by the buffers storage scope. -class TextureFlattener : public TextureLoweringBase { - public: - using StmtExprMutator::VisitStmt_; - explicit TextureFlattener(const Map& extern_buffer_map, - IRVisitorWithAnalyzer* bound_analyzer) - : TextureLoweringBase(extern_buffer_map, bound_analyzer) {} - - Stmt VisitStmt_(const BufferRealizeNode* op) final { - if (extern_buf_.count(op->buffer)) { - return this->VisitStmt(op->body); - } - - std::string storage_scope = GetStorageScope(op->buffer); - Var buffer_var(op->buffer->data->name_hint, - PointerType(PrimType(op->buffer->dtype), String(storage_scope))); - let_binding_.insert({op->buffer->data, buffer_var}); - - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - - // Rewrite any buffer realizations with storage scope to 2d texture allocations - if (IsTextureStorage(storage_scope)) { - Stmt body = this->VisitStmt(op->body); - ICHECK(op->bounds.size() >= 3) << "Only 2d RGBA texture is currently supported"; - int vec_length = static_cast(op->bounds.back()->extent.as()->value); - ICHECK(vec_length == 4 || vec_length == 1) - << "Inner dimension of texture must be vector of length 1 or 4 (RGBA), was: " - << vec_length; - - struct ShapeFromRange { - const Array& bounds; - PrimExpr operator[](size_t i) const { return bounds[i]->extent; } - }; - size_t axis = DefaultTextureLayoutSeparator(op->bounds.size(), storage_scope); - auto texture = - ApplyTexture2DFlattening(ShapeFromRange{op->bounds}, op->bounds.size(), axis); - Array args; - args.push_back(StringImm(storage_scope)); - args.push_back(IntImm(DataType::Int(64), 2)); // 2d - args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), - {texture.width, texture.height})); - stmt = LetStmt(buffer_var, Call(buffer_var.dtype(), builtin::nd_mem_alloc_with_scope(), args), - body); - } - - return stmt; - } - - Stmt VisitStmt_(const BufferStoreNode* op) final { - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - std::string storage_scope = GetStorageScope(op->buffer); - // Lower to two dimensional access - if (IsTextureStorage(storage_scope)) { - Array args = GetTextureAccessArgs(op, op->buffer); - args.push_back(op->value); - stmt = Evaluate(Call(args[0]->dtype, builtin::texture2d_store(), args)); - } - - return stmt; - } - - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - op = expr.as(); - // Lower to two dimensional access - std::string storage_scope = GetStorageScope(op->buffer); - if (IsTextureStorage(storage_scope)) { - Array args = GetTextureAccessArgs(op, op->buffer); - args.push_back(op->indices.back()); - expr = Call(op->buffer->dtype, builtin::texture2d_load(), args); - } - - return expr; - } - - protected: - template - Array GetTextureAccessArgs(const T* op, const Buffer& buffer) { - Array args; - if (let_binding_.count(op->buffer->data)) { - args.push_back(let_binding_[op->buffer->data]); - } else { - args.push_back(buffer->data); - } - Array row_dims, row_indices, col_dims, col_indices; - for (size_t i = 0; i < op->buffer->shape.size() - 1; i++) { - if (i < DefaultTextureLayoutSeparator(op->buffer->shape.size(), GetStorageScope(buffer))) { - col_dims.push_back(op->buffer->shape[i]); - col_indices.push_back(op->indices[i]); - } else { - row_dims.push_back(op->buffer->shape[i]); - row_indices.push_back(op->indices[i]); - } - } - PrimExpr row_offset = SimplifyOffset(row_dims, row_indices); - PrimExpr col_offset = SimplifyOffset(col_dims, col_indices); - args.push_back(row_offset); - args.push_back(col_offset); - return args; - } - - // Bindings to new texture vars with texture pointer scope - std::unordered_map let_binding_; -}; - -PrimFunc TextureFlatten(PrimFunc func) { - auto fptr = func.CopyOnWrite(); - IRVisitorWithAnalyzer bound_analyzer; - bound_analyzer(fptr->body); - fptr->body = TextureFlattener(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); - return func; -} - -namespace transform { - -Pass TextureFlatten() { - auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - return TextureFlatten(std::move(f)); - }; - return CreatePrimFuncPass(pass_func, 0, "tir.TextureFlatten", {}); -} - -TVM_REGISTER_GLOBAL("tir.transform.TextureFlatten").set_body_typed(TextureFlatten); - -} // namespace transform - -} // namespace tir -} // namespace tvm diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index b4e3d67e500e..ec290e48d457 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -236,7 +236,7 @@ class VecAllocAccess : public StmtExprMutator { shape.Set(shape.size() - 1, analyzer_.Simplify(shape[shape.size() - 1] * var_lanes_)); // TODO(Lunderberg): Move this pass to be prior to - // StorageFlatten/FlattenBuffer, implement by appending a + // FlattenBuffer, implement by appending a // dimension to the buffer. Since it is currently after the // flattening, the strides are not technically necessary, but // are updated for consistency. @@ -780,7 +780,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor None: - A = T.match_buffer(a, [2, 2]) - B = T.match_buffer(b, [2, 2]) - A[0, 1] = B[1, 1] - - -def test_flatten_tir(): - orig_mod = tvm.IRModule({"main": tir_func}) - mod = tvm.tir.transform.StorageFlatten(64)(orig_mod) - tvm.ir.assert_structural_equal( - orig_mod, mod - ) # StorageFlatten should do nothing to TIR functions - - -class TestPreserveDeclBuffer(tvm.testing.CompareBeforeAfter): - transform = tvm.tir.transform.StorageFlatten(64) - - def before(): - T.func_attr({"from_legacy_te_schedule": True}) - A = T.decl_buffer([16, 16], "float32") - for i, j in T.grid(16, 16): - A[i, j] = 0.0 - - def expected(): - T.func_attr({"from_legacy_te_schedule": True}) - A = T.decl_buffer([256], "float32") - for i, j in T.grid(16, 16): - A[i * 16 + j] = 0.0 - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_thread_sync.py b/tests/python/tir-transform/test_tir_transform_thread_sync.py index 4ca33424c1d5..48de01a629c7 100644 --- a/tests/python/tir-transform/test_tir_transform_thread_sync.py +++ b/tests/python/tir-transform/test_tir_transform_thread_sync.py @@ -22,7 +22,6 @@ def run_passes(func: tvm.tir.PrimFunc): mod = tvm.IRModule.from_expr(func) - mod = tvm.tir.transform.StorageFlatten(64)(mod) cuda_target = tvm.target.Target("cuda", host="llvm")