diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp index 401a8a16de71..a591a01e718c 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.cpp +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.cpp @@ -449,12 +449,14 @@ Allocate::Allocate( Val* buffer, MemoryType memory_type, std::vector shape, - bool zero_init) + bool zero_init, + const Allocate* alias) : Expr(passkey, ExprType::Allocate), buffer_(buffer), memory_type_(memory_type), shape_(std::move(shape)), - zero_init_(zero_init) { + zero_init_(zero_init), + alias_(alias) { TORCH_INTERNAL_ASSERT( passkey.ir_container_->isA(), "IR type only valid for Kernel container."); @@ -484,6 +486,12 @@ Allocate::Allocate( size_ = FusionGuard::getCurFusion()->oneVal(); } + if (alias_ != nullptr) { + TORCH_INTERNAL_ASSERT(alias_ != this, "Invalid alias"); + TORCH_INTERNAL_ASSERT( + alias_->memoryType() == memory_type_, "Invalid alias"); + } + addInput(size_); } diff --git a/torch/csrc/jit/codegen/cuda/kernel_ir.h b/torch/csrc/jit/codegen/cuda/kernel_ir.h index 8032d6692ea1..f6f21a149508 100644 --- a/torch/csrc/jit/codegen/cuda/kernel_ir.h +++ b/torch/csrc/jit/codegen/cuda/kernel_ir.h @@ -187,7 +187,8 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { Val* buffer, MemoryType memory_type, std::vector shape = {}, - bool zero_init = false); + bool zero_init = false, + const Allocate* alias = nullptr); //! Allocation of a non-dimensional buffer //! @@ -225,12 +226,6 @@ class TORCH_CUDA_CU_API Allocate final : public Expr { return alias_; } - void setAlias(const Allocate* alias) { - TORCH_INTERNAL_ASSERT(alias != this); - TORCH_INTERNAL_ASSERT(alias->memoryType() == memory_type_); - alias_ = alias; - } - private: Val* buffer_ = nullptr; MemoryType memory_type_ = MemoryType::Local; diff --git a/torch/csrc/jit/codegen/cuda/lower2device.cpp b/torch/csrc/jit/codegen/cuda/lower2device.cpp index bc7e2b88251a..92cecccbfe77 100644 --- a/torch/csrc/jit/codegen/cuda/lower2device.cpp +++ b/torch/csrc/jit/codegen/cuda/lower2device.cpp @@ -396,6 +396,16 @@ bool GpuLower::hasCurrent() { void GpuLower::propagateExprInfo(const Expr* old_expr, const Expr* new_expr) { pred_elimination_.propagateRemovalInfo(old_expr, new_expr); + if (old_expr->isA()) { + auto alloc_info_it = + localAllocationInfoMap().find(old_expr->as()); + if (alloc_info_it != localAllocationInfoMap().end()) { + auto alloc_info = + std::make_unique(*(alloc_info_it->second)); + localAllocationInfoMap().emplace( + new_expr->as(), std::move(alloc_info)); + } + } } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp index 7cd4a19ac4d3..8013a526a5ff 100644 --- a/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_alias_memory.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -12,6 +13,16 @@ #include #include +// The goal of this pass is to change allocations to use other +// allocations when possible. To do so, there are 3 main stages and +// corresponding classes. +// +// - Analyze live ranges of tensors (class AllocationInfoMap) +// - Find allocations of tensors that can reuse other allocations +// (class ReusableAllocationFinder) +// - Replace those allocation expressions with their alias fields +// pointing to reused allocations (class AllocationReuseModifier) + namespace torch { namespace jit { namespace fuser { @@ -150,7 +161,8 @@ class SymbolicSizePrinter : private OptOutConstDispatch { std::stringstream os_; }; -class BufferUseDefInfo; +class AllocationInfoMap; + //! A debug printer internal to this pass to support //! future expansion and inline annotation of pass info. class BufferReuseDebugPrinter { @@ -167,7 +179,8 @@ class BufferReuseDebugPrinter { public: BufferReuseDebugPrinter() : ir_printer_(os_){}; - std::string dumpDebugInfo() { + std::string dumpDebugInfo(const AllocationInfoMap* allocation_info_map) { + allocation_info_map_ = allocation_info_map; os_.clear(); for (auto& debug_entry : debug_info_) { switch (debug_entry->first.line_type) { @@ -189,9 +202,6 @@ class BufferReuseDebugPrinter { return os_.str(); } - private: - friend class BufferUseDefInfo; - void pushBack(int lineno, Expr* expr) { makeExprEntry(lineno, expr); } @@ -204,6 +214,7 @@ class BufferReuseDebugPrinter { makeScopeEntry(DebugLineType::END_BLOCK); } + private: void makeExprEntry(int lineno, Expr* expr) { auto debug_entry_ptr = std::make_unique(); debug_entry_ptr->first.lineno = lineno; @@ -277,7 +288,8 @@ class BufferReuseDebugPrinter { int indent_level_ = 0; std::vector debug_info_; - BufferUseDefInfo* buffer_info_ = nullptr; + + const AllocationInfoMap* allocation_info_map_ = nullptr; }; //! Utility class for modeling the liveness interval. @@ -356,8 +368,125 @@ struct ScopeInfo { kir::ForLoop* loop = nullptr; }; -using ScopeInfoOwningPtr = std::unique_ptr; -using ScopeInfoOwningPtrList = std::vector; +//! Assign an integer position to each expression to help representing +//! scope ranges. The position starts from 1. +class ExprPosMap { + public: + //! Get the position of an expr + int get(const Expr* expr) const { + return expr_pos_map_.at(expr); + } + + //! Get the current position + int getCurrentPos() const { + return current_pos_; + } + + //! Advance the position counter + void moveToNext() { + ++current_pos_; + } + + //! Record the current position as the position of an expr + void setPosAtCurrent(const Expr* expr) { + expr_pos_map_[expr] = current_pos_; + } + + private: + //! Position counter. The first expression is assigned position 1 + int current_pos_ = 0; + + //! Keep track of the positions of expressions + std::unordered_map expr_pos_map_; +}; + +// Create ScopeInfo for each loop +class ScopeMap : private kir::IrVisitor { + public: + ScopeMap(const std::vector& exprs) { + global_scope_info_ = makeAndRegisterScopeInfo(nullptr); + handle(exprs); + global_scope_info_->end_pos = expr_pos_map_.getCurrentPos() + 1; + + // Make sure all loops have end_pos filled + for (const auto& info : all_scope_info_) { + TORCH_INTERNAL_ASSERT(info->end_pos != -1); + } + } + + using kir::IrVisitor::handle; + + void handle(Expr* expr) final { + expr_pos_map_.moveToNext(); + expr_pos_map_.setPosAtCurrent(expr); + kir::IrVisitor::handle(expr); + } + + void handle(kir::ForLoop* for_loop) final { + auto loop_info = makeAndRegisterScopeInfo(for_loop); + kir::IrVisitor::handle(for_loop); + loop_info->end_pos = expr_pos_map_.getCurrentPos() + 1; + } + + void handle(kir::IfThenElse* ite) final { + TORCH_INTERNAL_ASSERT( + false, "lower_alias_memory: no support for IfThenElse at this phase."); + } + + //! Factory function for internal loop information data + ScopeInfo* makeAndRegisterScopeInfo(kir::ForLoop* loop) { + auto loop_info_ptr = std::make_unique(); + auto loop_info = loop_info_ptr.get(); + + // When loop is null, it corresponds to the global scope + loop_info->start_pos = loop == nullptr ? 0 : getExprPos(loop); + loop_info->end_pos = -1; // This will be filled later + loop_info->loop = loop; + all_scope_info_.emplace_back(std::move(loop_info_ptr)); + + if (loop != nullptr) { + TORCH_INTERNAL_ASSERT( + loop_to_scope_info_map_.emplace(loop, loop_info).second, + "Duplicated scope info created for loop: ", + loop->toString()); + } + + return loop_info; + } + + ScopeInfo* getGlobalScopeInfo() const { + return global_scope_info_; + } + + std::vector>&& getAllScopeInfo() { + return std::move(all_scope_info_); + } + + ScopeInfo* getLoopScopeInfo(const kir::ForLoop* loop) const { + auto it = loop_to_scope_info_map_.find(loop); + TORCH_INTERNAL_ASSERT( + it != loop_to_scope_info_map_.end(), + "No scope info found for loop: ", + loop->toString()); + return it->second; + } + + int getExprPos(const Expr* expr) const { + return expr_pos_map_.get(expr); + } + + private: + //! Owning list of collected scope info + std::vector> all_scope_info_; + + //! Contains start and end position of the global scope + ScopeInfo* global_scope_info_ = nullptr; + + //! map loop to scope info + std::unordered_map loop_to_scope_info_map_; + + ExprPosMap expr_pos_map_; +}; //! Utility class to record the read and write of each //! allocated buffer. @@ -369,9 +498,9 @@ using ScopeInfoOwningPtrList = std::vector; //! //! Will probably at some point need dataflow and index analysis to precisely //! handle loop carried dependency. -struct AllocationUseDefInfo { +struct AllocationInfo { kir::Allocate* alloc_expr = nullptr; - kir::Allocate* alias_to = nullptr; + const kir::Allocate* alias_to = nullptr; bool is_inner_alias = false; bool should_try_alias = true; MemoryType mem_type = MemoryType::Local; @@ -380,8 +509,7 @@ struct AllocationUseDefInfo { ScopeInfo* loop_info = nullptr; bool can_use_inner_alias = true; int alloc_pos = -1; - std::unique_ptr> inner_alias_list_ = - nullptr; + std::unique_ptr> inner_alias_list_ = nullptr; std::unique_ptr inner_live_interval = nullptr; std::unique_ptr inner_subscribed_intevals = nullptr; @@ -390,11 +518,6 @@ struct AllocationUseDefInfo { nullptr; }; -using AllocationInfoOwningPtr = std::unique_ptr; -using AllocationInfoOwningList = std::vector; -using AllocationInfoPtr = AllocationUseDefInfo*; -using AllocationInfoList = std::vector; - //! Analysis pass to collect the liveness info of local and shared buffers: //! The liveness info is illustrated as follows: //! @@ -425,46 +548,54 @@ using AllocationInfoList = std::vector; //! Outer interval marks the beginning of the loop of first write and end of //! the loop of last read, both at the same loop level as the buffer //! allocation. -class BufferUseDefInfo { +class AllocationInfoMap : private kir::IrVisitor { public: // Alias local memory if it exceeds this threshold static constexpr long kRegisterSizeThreshold = 1; - BufferUseDefInfo( - const std::vector& exprs, - BufferReuseDebugPrinter* debug_printer = nullptr) - : debug_printer_(debug_printer) { - if (debug_printer) { - debug_printer->buffer_info_ = this; + AllocationInfoMap(const std::vector& exprs, bool debug_print) + : scope_map_(exprs), + debug_printer_( + debug_print ? std::make_unique() + : nullptr) { + current_stack_.push_back(scope_map_.getGlobalScopeInfo()); + if (debug_printer_) { + debug_printer_->pushScope(); + } + handle(exprs); + if (debug_printer_) { + debug_printer_->popScope(); + std::cout << debug_printer_->dumpDebugInfo(this); } - collectScopeInfo(exprs); - collectScopeUseDefInfo(exprs); + current_stack_.pop_back(); } - //! Returns live interval info of buffer if previously - //! computed. - c10::optional getMaybeReuseInfoFor( - kir::Allocate* allocate) const { - auto alloc_it = map_allocate_to_info_.find(allocate); - if (alloc_it == map_allocate_to_info_.end()) { + c10::optional getMaybeAllocationInfo( + const kir::Allocate* alloc) const { + auto it = allocation_info_map_.find(alloc); + if (it == allocation_info_map_.end()) { return c10::nullopt; } - auto alloc = alloc_it->second; - return alloc; + return it->second; + } + + const std::unordered_map& + getAllocationInfoMap() const { + return allocation_info_map_; } - //! Realize alias of two buffers through inner alias analysis and - //! keep track of the re-use. - void useInnerAlias(AllocationInfoPtr from, AllocationInfoPtr to) { + //! Mark the tensor of "from" be an alias of the tensor of "to" + //! through inner alias analysis and keep track of the re-use. + void useInnerAlias(AllocationInfo* from, AllocationInfo* to) { to->inner_alias_list_->push_back(from); to->inner_subscribed_intevals->push_back(from->inner_live_interval.get()); setAlias(from, to); from->is_inner_alias = true; } - //! Realize alias of two buffers through outer alias analysis and - //! keep track of the re-use. - void useOuterAlias(AllocationInfoPtr from, AllocationInfoPtr to) { + //! Mark the tensor of "from" be an alias of the tensor of "to" + //! through outer alias analysis and keep track of the re-use. + void useOuterAlias(AllocationInfo* from, AllocationInfo* to) { to->outer_subscribed_intevals->push_back(from->outer_live_interval.get()); setAlias(from, to); } @@ -473,7 +604,7 @@ class BufferUseDefInfo { //! Initializes the inner live intervals with each //! allocation's inner live interval. void prepareInnerSharingAnalysis() { - for (auto it : map_allocate_to_info_) { + for (auto it : getAllocationInfoMap()) { auto alloc_info = it.second; // At beginning only use interval for each // allocate is their corresponding live interval @@ -486,7 +617,7 @@ class BufferUseDefInfo { //! Initializes the outer live intervals with the outer live interval //! of each allocation and copy inner sharing information. void prepareOuterSharingAnalysis() { - for (auto it : map_allocate_to_info_) { + for (auto it : getAllocationInfoMap()) { auto alloc_info = it.second; if (!alias_map_.count(alloc_info)) { alloc_info->outer_subscribed_intevals->push_back( @@ -500,50 +631,45 @@ class BufferUseDefInfo { } } + const std::unordered_map& getAliasMap() + const { + return alias_map_; + } + private: - void handle(Expr* expr) { - current_pos_++; + using kir::IrVisitor::handle; + + void handle(Expr* expr) final { if (debug_printer_) { - debug_printer_->pushBack(current_pos_, expr); + debug_printer_->pushBack(scope_map_.getExprPos(expr), expr); } - if (auto alloc = dynamic_cast(expr)) { - handle(alloc); - } else if (auto for_loop = dynamic_cast(expr)) { - handle(for_loop); - } else if (auto ite = dynamic_cast(expr)) { - handle(ite); - } else { - collectLivenessInfo(expr); + kir::IrVisitor::handle(expr); + if (ir_utils::isTvOp(expr)) { + collectLivenessInfoOfExpr(expr); } } - void handleScope(const std::vector& exprs) { + void handle(kir::ForLoop* for_loop) final { + auto loop_info = scope_map_.getLoopScopeInfo(for_loop); + current_stack_.push_back(loop_info); if (debug_printer_) { debug_printer_->pushScope(); } - for (auto expr : exprs) { - handle(expr); - } + kir::IrVisitor::handle(for_loop); if (debug_printer_) { debug_printer_->popScope(); } - } - - void handle(kir::ForLoop* for_loop) { - auto loop_info = map_loop_pos_to_loop_info_.at(current_pos_); - current_stack_.push_back(loop_info); - handleScope(for_loop->body().exprs()); current_stack_.pop_back(); } - void handle(kir::IfThenElse* ite) { + void handle(kir::IfThenElse* ite) final { TORCH_INTERNAL_ASSERT( false, "lower_alias_memory: no support for IfThenElse at this phase."); } // Generate allocation info for allocation after some pre-filtering // conditions. - void handle(kir::Allocate* alloc) { + void handle(kir::Allocate* alloc) final { if (alloc->alias()) { // We shouldn't really see a case like this in general, but // some Fusion outputs could have been aliased to inputs. @@ -581,11 +707,12 @@ class BufferUseDefInfo { auto size_print = SymbolicSizePrinter::printSize(alloc); // Make sure we don't have conflicting information on record - TORCH_INTERNAL_ASSERT(!map_allocate_to_info_.count(alloc)); - TORCH_INTERNAL_ASSERT(!map_tv_to_allocations_.count(tv->name())); + TORCH_INTERNAL_ASSERT(!allocation_info_map_.count(alloc)); + TORCH_INTERNAL_ASSERT(!tv_to_allocation_map_.count(tv->name())); // make AllocationUseDefInfo: - auto alloc_info = makeUseDefInfo(); + auto alloc_info = makeAllocationInfo(); + alloc_info->alloc_pos = scope_map_.getExprPos(alloc); alloc_info->alloc_expr = alloc; alloc_info->mem_type = mem_type; alloc_info->data_type = data_type; @@ -594,50 +721,38 @@ class BufferUseDefInfo { alloc_info->should_try_alias = should_try_alias; // record short cuts - map_allocate_to_info_[alloc] = alloc_info; - map_tv_to_allocations_[tv->name()] = alloc_info; - } - - void collectScopeUseDefInfo(const std::vector& exprs) { - // Reset position pointer - resetExprCounter(); - TORCH_INTERNAL_ASSERT(global_scope_info_ != nullptr); - current_stack_.push_back(global_scope_info_); - handleScope(exprs); + allocation_info_map_[alloc] = alloc_info; + tv_to_allocation_map_[tv->name()] = alloc_info; } - void collectScopeInfo(const std::vector& exprs) { - // Reset position pointer - resetExprCounter(); - collectScopeInfoWithinLoop(exprs, nullptr); - } - - void collectScopeInfoWithinLoop( - const std::vector& exprs, - kir::ForLoop* current_loop) { - auto loop_info = makeScopeInfo(current_loop); - for (auto expr : exprs) { - current_pos_++; - if (auto for_loop = dynamic_cast(expr)) { - collectScopeInfoWithinLoop(for_loop->body().exprs(), for_loop); - } - } - loop_info->end_pos = current_pos_ + 1; - } + //! Factory function for internal use-def information data + AllocationInfo* makeAllocationInfo() { + auto alloc_info_ptr = std::make_unique(); + auto alloc_info = alloc_info_ptr.get(); - void resetExprCounter() { - current_pos_ = -1; + alloc_info->inner_alias_list_ = + std::make_unique>(); + alloc_info->inner_live_interval = std::make_unique(); + alloc_info->inner_subscribed_intevals = + std::make_unique(); + alloc_info->outer_live_interval = std::make_unique(); + alloc_info->outer_subscribed_intevals = + std::make_unique(); + all_allocations_.emplace_back(std::move(alloc_info_ptr)); + return alloc_info; } // Iterate over the inputs and outputs of exprs and update // the liveness info of local buffers if applicaable. - void collectLivenessInfo(const Expr* expr) { + void collectLivenessInfoOfExpr(Expr* expr) { if (!ir_utils::isTvOp(expr)) { return; } auto out_tv = expr->outputs()[0]->as(); + const auto expr_pos = scope_map_.getExprPos(expr); + // Collect all tv's that resolves broadcast in this // expr. The current analysis isn't enough to capture // their liveness range. @@ -645,7 +760,7 @@ class BufferUseDefInfo { auto maybe_alloc_info = getMaybeAllocInfoFromTV(input_tv); if (maybe_alloc_info.has_value()) { if (!isSerialBroadcastResolution(input_tv, out_tv)) { - maybe_alloc_info.value()->inner_live_interval->markRead(current_pos_); + maybe_alloc_info.value()->inner_live_interval->markRead(expr_pos); } else { // Disable inner alias info for this buffer, since line number based // analysis is no longer precise enough for inplace sharing @@ -662,7 +777,7 @@ class BufferUseDefInfo { } else { // Allocate is inlined in the innermost loop, // so outer live interval is the same as inner. - maybe_alloc_info.value()->outer_live_interval->markRead(current_pos_); + maybe_alloc_info.value()->outer_live_interval->markRead(expr_pos); } } } @@ -671,28 +786,32 @@ class BufferUseDefInfo { if (maybe_alloc_info.has_value()) { // Reductions use outputs as read-write parameters, so their // outputs need to be marked as read as well - const bool is_read_write = ir_utils::isReductionOp(expr) && - std::any_of(output_tv->getMaybeRFactorDomain().begin(), - output_tv->getMaybeRFactorDomain().end(), - [](IterDomain* id) { return id->isReduction(); }); - maybe_alloc_info.value()->inner_live_interval->markWrite(current_pos_); + const bool is_read_write = ir_utils::isReductionOp(expr); + maybe_alloc_info.value()->inner_live_interval->markWrite(expr_pos); if (is_read_write) { - maybe_alloc_info.value()->inner_live_interval->markRead(current_pos_); + maybe_alloc_info.value()->inner_live_interval->markRead(expr_pos); } auto outer_loop_info = ascendLoopNestToSameLevelAs(maybe_alloc_info.value()); auto write_pos = - outer_loop_info ? outer_loop_info->start_pos : current_pos_; + outer_loop_info ? outer_loop_info->start_pos : expr_pos; maybe_alloc_info.value()->outer_live_interval->markWrite(write_pos); if (is_read_write) { - auto read_pos = - outer_loop_info ? outer_loop_info->end_pos : current_pos_; + auto read_pos = outer_loop_info ? outer_loop_info->end_pos : expr_pos; maybe_alloc_info.value()->outer_live_interval->markRead(read_pos); } } } } + c10::optional getMaybeAllocInfoFromTV(TensorView* tv) const { + auto alloc_it = tv_to_allocation_map_.find(tv->name()); + if (alloc_it == tv_to_allocation_map_.end()) { + return c10::nullopt; + } + return alloc_it->second; + } + //! Find the loop level of expr that apears in the same scope as //! the reference allocate. Eg. //! @@ -706,7 +825,7 @@ class BufferUseDefInfo { //! expr <---- current expr (implied in current_stack_ and //! current_pos_ ) //! Assumes that expr either writes to or reads from the reference allocate. - ScopeInfo* ascendLoopNestToSameLevelAs(AllocationUseDefInfo* reference) { + ScopeInfo* ascendLoopNestToSameLevelAs(AllocationInfo* reference) { auto allocate_loop_info = reference->loop_info; if (allocate_loop_info->loop == nullptr) { if (current_stack_.size() > 1) { @@ -729,106 +848,49 @@ class BufferUseDefInfo { return nullptr; } - c10::optional getMaybeAllocInfoFromTV(TensorView* tv) { - auto alloc_it = map_tv_to_allocations_.find(tv->name()); - if (alloc_it == map_tv_to_allocations_.end()) { - return c10::nullopt; - } - return alloc_it->second; - } - - //! Factory function for internal loop information data - ScopeInfo* makeScopeInfo(kir::ForLoop* loop) { - auto loop_info_ptr = std::make_unique(); - auto loop_info = loop_info_ptr.get(); - loop_info->start_pos = current_pos_; - loop_info->end_pos = -1; - loop_info->loop = loop; - all_loop_infos_.emplace_back(std::move(loop_info_ptr)); - - if (loop == nullptr) { - TORCH_INTERNAL_ASSERT( - !global_scope_info_, "Should only create global scope info once!"); - global_scope_info_ = loop_info; - } else { - map_loop_pos_to_loop_info_[current_pos_] = loop_info; - } - return loop_info; - } - - //! Factory function for internal use-def information data - AllocationUseDefInfo* makeUseDefInfo() { - auto alloc_info_ptr = std::make_unique(); - auto alloc_info = alloc_info_ptr.get(); - - alloc_info->alloc_pos = current_pos_; - alloc_info->inner_alias_list_ = - std::make_unique>(); - alloc_info->inner_live_interval = std::make_unique(); - alloc_info->inner_subscribed_intevals = - std::make_unique(); - alloc_info->outer_live_interval = std::make_unique(); - alloc_info->outer_subscribed_intevals = - std::make_unique(); - all_allocations_.emplace_back(std::move(alloc_info_ptr)); - return alloc_info; - } - - // Realize buffer alias and keep track of the alias info. - void setAlias(AllocationInfoPtr from, AllocationInfoPtr to) { + //! Mark the tensor of "from" be an alias of the tensor of "to". + void setAlias(AllocationInfo* from, AllocationInfo* to) { alias_map_[from] = to; - from->alloc_expr->setAlias(to->alloc_expr); from->alias_to = to->alloc_expr; } private: friend BufferReuseDebugPrinter; - friend class SerialBroadcastIntervalExpansion; - //! Allocation sites that will participate in this analysis - std::unordered_map - map_allocate_to_info_; + const ScopeMap scope_map_; //! Map TensorView name to Allocate node. //! Note: this assumes that each tensor view is only allocated once. - std::unordered_map map_tv_to_allocations_; - - //! Keeps track of all the allocations that have been set to alias - std::unordered_map alias_map_; - - //! Keep track of stack: - std::vector current_stack_; - - //! Contains start and end position of the global scope - ScopeInfo* global_scope_info_ = nullptr; + std::unordered_map tv_to_allocation_map_; - //! map loop start position to loop info - std::unordered_map map_loop_pos_to_loop_info_; + //! Allocation sites that will participate in this analysis + std::unordered_map + allocation_info_map_; //! Owning list of collected allocation info - AllocationInfoOwningList all_allocations_; + std::vector> all_allocations_; - //! Owning list of collected allocation info - ScopeInfoOwningPtrList all_loop_infos_; + //! Keep track of stack + std::vector current_stack_; - //! Position counter when iterating through the exprs list - int current_pos_ = -1; + //! Keeps track of all the allocations that have been set to alias + std::unordered_map alias_map_; //! Debug info: - BufferReuseDebugPrinter* debug_printer_ = nullptr; + std::unique_ptr debug_printer_ = nullptr; }; void BufferReuseDebugPrinter::printAllocInfo(const kir::Allocate* alloc) { - TORCH_INTERNAL_ASSERT(buffer_info_ != nullptr); + TORCH_INTERNAL_ASSERT(allocation_info_map_ != nullptr); std::string message_header(" \033[1;32m^^^^^ ---Buffer Reuse Info--- "); std::string message_end(" \033[0m\n"); - if (!buffer_info_->map_allocate_to_info_.count(alloc)) { + if (!allocation_info_map_->getMaybeAllocationInfo(alloc).has_value()) { // This buffer is not considered for any sharing, either // because of un-supported op or size below threshold. return; } - auto alloc_info = buffer_info_->map_allocate_to_info_.at(alloc); + auto alloc_info = allocation_info_map_->getMaybeAllocationInfo(alloc).value(); indent() << message_header; if (alloc_info->alias_to) { @@ -838,7 +900,7 @@ void BufferReuseDebugPrinter::printAllocInfo(const kir::Allocate* alloc) { os_ << "(outer) "; } os_ << " alias to alloc at pos " - << buffer_info_->getMaybeReuseInfoFor(alloc_info->alias_to) + << allocation_info_map_->getMaybeAllocationInfo(alloc_info->alias_to) .value() ->alloc_pos << " "; @@ -851,6 +913,8 @@ void BufferReuseDebugPrinter::printAllocInfo(const kir::Allocate* alloc) { if (alloc_info->can_use_inner_alias) { os_ << "inner live interval: "; os_ << alloc_info->inner_live_interval->toString() << " , "; + } else { + os_ << "cannot use inner alias, "; } os_ << "size expr : " << alloc_info->size_expr << " , " << "outer live interval: " << alloc_info->outer_live_interval->toString(); @@ -858,43 +922,52 @@ void BufferReuseDebugPrinter::printAllocInfo(const kir::Allocate* alloc) { } //! Reuse Allocation nodes via pointer aliasing -class AllocateReuseModifier { +class ReusableAllocationFinder : private kir::IrVisitor { public: - static void modify(const std::vector& exprs) { - AllocateReuseModifier modifier(exprs); - } - - static void debugPrint(const std::vector& exprs) { - BufferReuseDebugPrinter debug_printer; - AllocateReuseModifier modifier(exprs, &debug_printer); - std::cout << debug_printer.dumpDebugInfo(); - } - - private: - AllocateReuseModifier( + static void find( const std::vector& exprs, - BufferReuseDebugPrinter* debug_printer_ = nullptr) - : buffer_info_(exprs, debug_printer_) { + AllocationInfoMap& allocation_info_map) { // Perform in-place sharing first and then outer liveness // based sharing. Since outer liveness info can still // be used with some buffers already aliasing through // in-place re-use but wouldn't be the case if we did // outer liveness based sharing first. - buffer_info_.prepareInnerSharingAnalysis(); - handleScope(exprs); + ReusableAllocationFinder finder_inner_alias( + exprs, allocation_info_map, true); + ReusableAllocationFinder finder_outer_alias( + exprs, allocation_info_map, false); + return; + } + + private: + ReusableAllocationFinder( + const std::vector& exprs, + AllocationInfoMap& allocation_info_map, + bool inner_aliasing_pass) + : allocation_info_map_(allocation_info_map), + inner_aliasing_pass_(inner_aliasing_pass) { + if (inner_aliasing_pass_) { + allocation_info_map_.prepareInnerSharingAnalysis(); + } else { + allocation_info_map_.prepareOuterSharingAnalysis(); + } + + current_visible_buffer_stack_.emplace_back( + std::make_unique>()); - inner_aliasing_pass_ = false; + handle(exprs); - buffer_info_.prepareOuterSharingAnalysis(); - handleScope(exprs); + current_visible_buffer_stack_.pop_back(); } - // Second visit of an allocate op - void handle(kir::Allocate* allocate) { + using kir::IrVisitor::handle; + + void handle(kir::Allocate* allocate) final { // Check that if this allocation site is one that // we want to re-use or replace with an alias - auto maybe_alloc_info = buffer_info_.getMaybeReuseInfoFor(allocate); + auto maybe_alloc_info = + allocation_info_map_.getMaybeAllocationInfo(allocate); if (maybe_alloc_info.has_value() && maybe_alloc_info.value()->alias_to == nullptr) { // Try to re-use existing allocates @@ -908,7 +981,7 @@ class AllocateReuseModifier { } } - bool tryReuseOtherAllocate(AllocationInfoPtr alloc_info) { + bool tryReuseOtherAllocate(AllocationInfo* alloc_info) { if (!alloc_info->should_try_alias) { return false; } @@ -973,11 +1046,12 @@ class AllocateReuseModifier { } if (alloc_info->alloc_expr->buffer()->isA()) { - if (!alloc_info->alloc_expr->buffer()->isA()) { + if (!alloc_to_reuse->alloc_expr->buffer()->isA()) { continue; } auto this_tv = alloc_info->alloc_expr->buffer()->as(); - auto reuse_tv = alloc_info->alloc_expr->buffer()->as(); + auto reuse_tv = + alloc_to_reuse->alloc_expr->buffer()->as(); // Check that either both tv's are vectorized acceses, or neither are. // Vectorized allocations require correct alignment so they can only // alias with other allocations with the right alignment @@ -1009,39 +1083,17 @@ class AllocateReuseModifier { } // Now re-use the alloc here and be sure to update - reUseAllocation(alloc_info, alloc_to_reuse); + reuseAllocation(alloc_info, alloc_to_reuse); return true; } } return false; } - void handle(Expr* expr) { - if (auto ite = dynamic_cast(expr)) { - handle(ite); - } else if (auto for_loop = dynamic_cast(expr)) { - handle(for_loop); - } else if (auto allocate = dynamic_cast(expr)) { - handle(allocate); - } - } - - void handle(const kir::ForLoop* for_loop) { - handleScope(for_loop->body().exprs()); - } - - void handle(const kir::IfThenElse* for_loop) { - TORCH_INTERNAL_ASSERT( - false, - "lower_alias_memory: IfThenElse before unrolling is not yet supported"); - } - - void handleScope(const std::vector& exprs) { + void handle(kir::ForLoop* for_loop) final { current_visible_buffer_stack_.emplace_back( - std::make_unique()); - for (auto expr : exprs) { - handle(expr); - } + std::make_unique>()); + kir::IrVisitor::handle(for_loop); current_visible_buffer_stack_.pop_back(); } @@ -1057,8 +1109,8 @@ class AllocateReuseModifier { //! 2. No halo in the allocated iter domains //! 3. Require index equivalence when sharing across broadcast bool isValidInnerSharing( - AllocationUseDefInfo* alloc_info, - AllocationUseDefInfo* to_reuse) { + AllocationInfo* alloc_info, + AllocationInfo* to_reuse) { // Disable if either of the buffers do not support inner sharing if (!alloc_info->can_use_inner_alias || !to_reuse->can_use_inner_alias) { return false; @@ -1095,7 +1147,8 @@ class AllocateReuseModifier { // Get information on the allocated domains of the // two buffers - auto& local_alloc_map = GpuLower::current()->localAllocationInfoMap(); + const auto& local_alloc_map = + GpuLower::current()->localAllocationInfoMap(); auto alloc_it = local_alloc_map.find(alloc_info->alloc_expr); auto to_reuse_it = local_alloc_map.find(to_reuse->alloc_expr); if (alloc_it == local_alloc_map.end() || @@ -1170,14 +1223,12 @@ class AllocateReuseModifier { return true; } - void reUseAllocation( - AllocationUseDefInfo* alloc_info, - AllocationUseDefInfo* to_reuse) { + void reuseAllocation(AllocationInfo* alloc_info, AllocationInfo* to_reuse) { // Update analysis result if (inner_aliasing_pass_) { - buffer_info_.useInnerAlias(alloc_info, to_reuse); + allocation_info_map_.useInnerAlias(alloc_info, to_reuse); } else { - buffer_info_.useOuterAlias(alloc_info, to_reuse); + allocation_info_map_.useOuterAlias(alloc_info, to_reuse); } } @@ -1201,28 +1252,110 @@ class AllocateReuseModifier { private: // Analysis result from the first pass collecting the use-defs - BufferUseDefInfo buffer_info_; + AllocationInfoMap& allocation_info_map_; // Internal data keeping track of currently visible allocations as // the pass iterate through the expr list, grouped by the stack // layer of alloc ops. - std::vector> + std::vector>> current_visible_buffer_stack_; // Marks state of current pass bool inner_aliasing_pass_ = true; }; +// Replace Allocate exprs as determined by the alias analysis +class AllocationReuseModifier : private kir::ExprMutator { + public: + static std::vector modify( + const std::vector& exprs, + const AllocationInfoMap& allocation_info_map) { + AllocationReuseModifier modifier(exprs, allocation_info_map); + return modifier.exprs_; + } + + private: + AllocationReuseModifier( + const std::vector& exprs, + const AllocationInfoMap& allocation_info_map) + : allocation_info_map_(allocation_info_map) { + traverseAndInsert(exprs); + } + + using kir::ExprMutator::handle; + + //! Replace an kir::Allocate with a new aliased Allocate + void handle(kir::Allocate* allocate) final { + auto maybe_alloc_info = + allocation_info_map_.getMaybeAllocationInfo(allocate); + if (!maybe_alloc_info.has_value()) { + return; + } + + AllocationInfo* alloc_info_from = maybe_alloc_info.value(); + + auto alias_it = allocation_info_map_.getAliasMap().find(alloc_info_from); + if (alias_it == allocation_info_map_.getAliasMap().end()) { + return; + } + + kir::Allocate* alloc_expr_to = alias_it->second->alloc_expr; + + // Currently, we don't allow 2-hop alias, ie., aliasing of an + // aliased tensor, so alloc_expr_to should be still the allocation + // expression of the aliased allocation. This assertion should be + // removed if 2-hop aliasing is enabled. + TORCH_INTERNAL_ASSERT( + alloc_expr_to == getMaybeNewAllocate(alloc_expr_to), + "Invalid updated allocation found. Original: ", + alloc_expr_to->toString(), + ". Updated: ", + getMaybeNewAllocate(alloc_expr_to)->toString()); + + kir::Allocate* old_alloc = alloc_info_from->alloc_expr; + kir::Allocate* new_alloc = IrBuilder::create( + old_alloc->buffer(), + old_alloc->memoryType(), + old_alloc->shape(), + old_alloc->zeroInit(), + alloc_expr_to); + + registerReplace(old_alloc, new_alloc); + + TORCH_INTERNAL_ASSERT(old2new_.emplace(old_alloc, new_alloc).second); + + // TODO: Consider more robust way to keep the information map up-to-date + GpuLower::current()->propagateExprInfo(old_alloc, new_alloc); + } + + kir::Allocate* getMaybeNewAllocate(kir::Allocate* allocate) const { + auto it = old2new_.find(allocate); + if (it == old2new_.end()) { + return allocate; + } else { + return it->second; + } + } + + private: + const AllocationInfoMap& allocation_info_map_; + + //! Keep track of new Allocate exprs + std::unordered_map old2new_; +}; + } // namespace std::vector reuseMemoryAllocations(const std::vector& exprs) { FUSER_PERF_SCOPE("reuseMemoryAllocations"); + bool debug_print = isDebugDumpEnabled(DebugDumpOption::BufferReuseInfo); - if (debug_print) { - AllocateReuseModifier::debugPrint(exprs); - } - AllocateReuseModifier::modify(exprs); - return exprs; + + AllocationInfoMap allocation_info_map(exprs, debug_print); + + ReusableAllocationFinder::find(exprs, allocation_info_map); + + return AllocationReuseModifier::modify(exprs, allocation_info_map); } } // namespace cuda diff --git a/torch/csrc/jit/codegen/cuda/lower_allocation.h b/torch/csrc/jit/codegen/cuda/lower_allocation.h index 45ebeac03f77..2815b1c61ce7 100644 --- a/torch/csrc/jit/codegen/cuda/lower_allocation.h +++ b/torch/csrc/jit/codegen/cuda/lower_allocation.h @@ -20,8 +20,8 @@ struct LocalAllocationInfo { bool has_halo = false; }; -using LocalAllocationInfoMap = - std::unordered_map>; +using LocalAllocationInfoMap = std:: + unordered_map>; //! Insert buffer allocations std::vector insertAllocations(const std::vector& exprs); diff --git a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp index ff8bb43f9efa..4c02220f01ac 100644 --- a/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp +++ b/torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp @@ -131,10 +131,6 @@ class WarSyncInserter : private kir::ExprMutator { //! Insert Sync nodes at the end of a given for-loop when a WAR //! hazard may happen. WarSyncInserter(const std::vector& exprs) { - auto& lower_alloc_info_map = GpuLower::current()->localAllocationInfoMap(); - for (const auto& entry : lower_alloc_info_map) { - alloc_map_.insert(entry.first); - } kir::ExprMutator::traverseAndInsert(exprs); } @@ -196,6 +192,10 @@ class WarSyncInserter : private kir::ExprMutator { return false; } + void handle(kir::Allocate* allocate) final { + alloc_map_.insert(allocate); + } + void handle(Expr* expr) final { // If not a tensor view expression continue with dispatch if (!ir_utils::isTvOp(expr)) { diff --git a/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu b/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu index bec24b486b46..1a6d7437d925 100644 --- a/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu +++ b/torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu @@ -108,9 +108,7 @@ __device__ void sync( index_utils::maskedIsLast(blockIdx, gridDim); if (last_block) { int64_t finished_val = - ((int64_t)( - index_utils::maskedSize(gridDim) - - 1)) * + ((int64_t)(index_utils::maskedSize(gridDim) - 1)) * ((int64_t)n_entrances); unsigned int ns = 8; diff --git a/torch/csrc/jit/codegen/cuda/test/test_gpu2.cpp b/torch/csrc/jit/codegen/cuda/test/test_gpu2.cpp index ed1a1769b821..affd8ddb1db1 100644 --- a/torch/csrc/jit/codegen/cuda/test/test_gpu2.cpp +++ b/torch/csrc/jit/codegen/cuda/test/test_gpu2.cpp @@ -8210,7 +8210,10 @@ TEST_F(NVFuserTest, FusionWARSyncAliasedSmem_CUDA) { const auto& body = loop->body().exprs(); TORCH_CHECK(!body.empty()); auto last_expr = dynamic_cast(body.back()); - TORCH_CHECK(last_expr != nullptr, "Invalid expr found"); + TORCH_CHECK( + last_expr != nullptr, + "Invalid expr found: ", + body.back()->toString()); TORCH_CHECK(last_expr->isWarHazardSync(), "Not a sync for WAR hazard"); } }