From de6a3977adfee22504bfcd3a42aa70b16866ee64 Mon Sep 17 00:00:00 2001 From: Ramin Azarmehr Date: Fri, 3 Jun 2022 23:37:05 -0400 Subject: [PATCH 1/2] Remove view from graph cache when underlying buffer is freed (#78074) --- aten/src/ATen/mps/MPSAllocator.h | 18 +++++++ aten/src/ATen/mps/MPSAllocator.mm | 14 ++++++ aten/src/ATen/native/mps/OperationUtils.h | 34 +++++++++++-- aten/src/ATen/native/mps/OperationUtils.mm | 20 ++++++-- aten/src/ATen/native/mps/operations/Copy.mm | 55 ++++++++------------- 5 files changed, 100 insertions(+), 41 deletions(-) diff --git a/aten/src/ATen/mps/MPSAllocator.h b/aten/src/ATen/mps/MPSAllocator.h index fe40317189e45..b450bd29c7ebd 100644 --- a/aten/src/ATen/mps/MPSAllocator.h +++ b/aten/src/ATen/mps/MPSAllocator.h @@ -26,6 +26,23 @@ namespace at { namespace mps { +class IMpsAllocatorCallback { + public: + enum class EventType { + ALLOCATED, // buffer got allocated to be used immediately + RECYCLED, // buffer pulled from free list to be reused + FREED, // buffer put to free list for future recycling + RELEASED, // buffer memory released + }; + virtual ~IMpsAllocatorCallback() = default; + virtual bool executeMPSAllocatorCallback(void* ptr, EventType event) = 0; +}; + +// MPS allocator will execute every registered callback when a block of memory is freed. +C10_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback); +#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \ + C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__); + namespace HeapAllocator { #define MB(x) round_page(x * 1048576UL) @@ -209,6 +226,7 @@ class MPSHeapAllocatorImpl void release_buffers(BufferPool& pool); bool release_available_cached_buffers(const AllocParams& p); bool release_cached_buffers(); + bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event); BufferPool& get_pool(size_t Size, bool useShared) { return Size <= kMaxSmallAlloc ? (useShared ? m_small_pool_shared : m_small_pool_private) : diff --git a/aten/src/ATen/mps/MPSAllocator.mm b/aten/src/ATen/mps/MPSAllocator.mm index 3acf255bd274e..4dac9ede70802 100644 --- a/aten/src/ATen/mps/MPSAllocator.mm +++ b/aten/src/ATen/mps/MPSAllocator.mm @@ -8,6 +8,8 @@ namespace at { namespace mps { +C10_DEFINE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback); + namespace HeapAllocator { HeapBlock* MPSHeapAllocatorImpl::get_free_heap(AllocParams& p) @@ -125,6 +127,7 @@ void MPSHeapAllocatorImpl::free_buffer(BufferBlock* buffer_block) { TORCH_INTERNAL_ASSERT(buffer_block->in_use); + trigger_memory_callbacks(buffer_block, IMpsAllocatorCallback::EventType::FREED); buffer_block->in_use = false; BufferPool *pool = buffer_block->heap->pool; // Makes sure the BufferBlock* isn't already present in the pool we're freeing it back into. @@ -141,6 +144,14 @@ return it->second; } +bool MPSHeapAllocatorImpl::trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) { + bool result = false; + for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) { + result |= MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block->buffer, event); + } + return result; +} + bool MPSHeapAllocatorImpl::isSharedBuffer(void* ptr) { std::lock_guard lock(m_mutex); @@ -167,6 +178,8 @@ void MPSHeapAllocatorImpl::release_buffer(BufferBlock* buffer_block, bool remove_empty_heap) { + trigger_memory_callbacks(buffer_block, IMpsAllocatorCallback::EventType::RELEASED); + HeapBlock *heap = buffer_block->heap; BufferPool *pool = heap->pool; m_total_allocated_memory -= buffer_block->size; @@ -318,6 +331,7 @@ static bool isEnvVarEnabled(const char *envvar) { static MPSAllocator s_mps_shared_alloc(true); return s_mps_shared_alloc; } + MPSAllocator& _getPrivateAllocator() { static mps::MPSAllocator s_mps_private_alloc(false); return s_mps_private_alloc; diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 7860fcb2de35c..97c74bc6f583f 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -106,7 +106,7 @@ MPSGraphTensor* mpsGraphConstantPlaceHolder(MPSGraph *mpsGraph, const double val string get_mem_format_string(c10::MemoryFormat memory_format); -using MPSCacheKey = int64_t; +using MPSCacheKey = uint64_t; // derive this class to cache a graph and its inputs/ouputs // can be used to store any NSObject @@ -126,7 +126,6 @@ struct MPSCachedGraph // TODO: Improve the overall design of MPSGraphCache. // https://github.com/pytorch/pytorch/issues/77176 // Cache holding various keys mapped to graphs - struct MPSGraphCache { typedef MPSCachedGraph * (^CreateCachedGraphBlock)(); @@ -158,7 +157,7 @@ struct MPSGraphCache MPSGraphCache(const MPSGraphCache&) = delete; void operator=(const MPSGraphCache&) = delete; - MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) { + MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock, void* view_ptr = nullptr) { __block MPSCachedGraph * result = nil; @@ -176,6 +175,9 @@ struct MPSGraphCache result = createCacheBlock(); CacheEntry entry(key, result); cache_.emplace(hash, entry); + if (view_ptr) { + views_list.insert(std::make_pair(view_ptr, hash)); + } } }); return result; @@ -197,6 +199,29 @@ struct MPSGraphCache }); return result; } + + bool FindAndRemoveViewEntry(void* ptr) { + bool removed_entry = false; + + // this may find multiple view entries with the same buffer pointers + auto views_range = views_list.equal_range(ptr); + if (views_range.first != views_range.second) { + for (auto view_it = views_range.first; view_it != views_range.second; ++view_it) { + MPSCacheKey hash = view_it->second; + // find the cache entry associated with the hash + auto cache_it = cache_.find(hash); + if (cache_it != cache_.end()) { + cache_.erase(cache_it); + delete cache_it->second.cachedGraph_; + removed_entry = true; + } + } + // this erase-by-key will remove all pairs in the list with the same key + views_list.erase(ptr); + } + return removed_entry; + } + private: MPSGraphCache() { serialQueue_ = dispatch_queue_create("cache queue", DISPATCH_QUEUE_SERIAL); @@ -204,6 +229,9 @@ struct MPSGraphCache static MPSGraphCache* _instance_cache; std::unordered_map cache_; + // list of buffers associated with view entries in the cache + // note that multiple view cache entries could use the same buffer pointer + std::multimap views_list; dispatch_queue_t serialQueue_ = nullptr; }; diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 301ce65d89759..b65ccf1ec5fd0 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -1,6 +1,7 @@ // Copyright © 2022 Apple Inc. #include +#include namespace at { namespace native { @@ -287,9 +288,6 @@ void printTensorNDArray(const Tensor& t) { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* outputTensor_ = nil; - IntArrayRef size_; - IntArrayRef stride_; - int64_t storage_offset_; }; CachedGraph* cachedGraph = static_cast(mpsCachedGraph); @@ -499,6 +497,22 @@ string get_mem_format_string(c10::MemoryFormat memory_format) { MPSGraphCache* MPSGraphCache::_instance_cache = nullptr; +class MPSGraphCacheCallback : public IMpsAllocatorCallback { +public: + MPSGraphCacheCallback() : graph_cache(MPSGraphCache::getInstance()) { } + + bool executeMPSAllocatorCallback(void* ptr, EventType event) override { + if (event == EventType::FREED || event == EventType::RELEASED) { + return graph_cache->FindAndRemoveViewEntry(ptr); + } + return false; + } +private: + MPSGraphCache* graph_cache; +}; + +REGISTER_MPS_ALLOCATOR_CALLBACK("mps_graph_cache_callback", MPSGraphCacheCallback); + } // namespace mps } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index 0473cce56a2fb..c0e3bee8a6856 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -123,45 +123,30 @@ Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size, CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} MPSGraphTensor* inputTensor_ = nil; MPSGraphTensor* outputTensor_ = nil; - IntArrayRef size_; - IntArrayRef stride_; - int64_t storage_offset_; }; MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string lookup_key = mps::getStridedKey(self, size, stride, storage_offset); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(lookup_key)); - - if(!cachedGraph) { - string insert_key = mps::getStridedKey(self,size, stride, storage_offset); - CachedGraph* insertCachedGraph = static_cast(cache_->LookUp(insert_key)); - if (!insertCachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(insert_key, ^ MPSCachedGraph * () { - CachedGraph *newCachedGraph = nil; - @autoreleasepool { - MPSGraph* mpsGraph = make_mps_graph(); - newCachedGraph = new CachedGraph(mpsGraph); - - // Self is the input tensor we are creating view of - MPSGraphTensor* inputTensor = [mpsGraph placeholderWithShape : getMPSShape(self) - dataType : getMPSDataType(self.scalar_type()) - name : nil]; - newCachedGraph->inputTensor_ = inputTensor; - newCachedGraph->outputTensor_ = chainViewOperation(mpsGraph, size, - stride, - storage_offset, - inputTensor, - self); - newCachedGraph->size_ = size; - newCachedGraph->stride_ = stride; - newCachedGraph->storage_offset_ = storage_offset; - } - return newCachedGraph; - }); - cachedGraph = static_cast(tmpCachedGraph); - } + string key = mps::getStridedKey(self, size, stride, storage_offset); + CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + if (!cachedGraph) { + cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + CachedGraph *newCachedGraph = nil; + @autoreleasepool { + MPSGraph* mpsGraph = make_mps_graph(); + newCachedGraph = new CachedGraph(mpsGraph); + + // Self is the input tensor we are creating view of + MPSGraphTensor* inputTensor = [mpsGraph placeholderWithShape : getMPSShape(self) + dataType : getMPSDataType(self.scalar_type()) + name : nil]; + newCachedGraph->inputTensor_ = inputTensor; + newCachedGraph->outputTensor_ = chainViewOperation(mpsGraph, size, stride, + storage_offset, inputTensor, self); + } + return newCachedGraph; + }, self.storage().data()); } } } @@ -267,7 +252,7 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, auto storage_byte_offset = src_.storage_offset() * src_.itemsize(); id sourceBuffer = __builtin_bit_cast(id, src_.storage().data()); - if (!src_.is_contiguous()) { + if (src_.is_view()) { id gatherTensor = gatherViewTensor(src_, sourceBuffer); if (gatherTensor) { sourceBuffer = gatherTensor; From 22857dd391274658e8e8e3de8d3be1114bf540c5 Mon Sep 17 00:00:00 2001 From: Ramin Azarmehr Date: Mon, 6 Jun 2022 13:58:29 -0400 Subject: [PATCH 2/2] Remove is_view() and the return value from the allocator's callback --- aten/src/ATen/mps/MPSAllocator.h | 4 +-- aten/src/ATen/mps/MPSAllocator.mm | 10 +++---- aten/src/ATen/native/mps/OperationUtils.h | 31 +++++++++------------ aten/src/ATen/native/mps/OperationUtils.mm | 16 ++--------- aten/src/ATen/native/mps/operations/Copy.mm | 2 +- 5 files changed, 22 insertions(+), 41 deletions(-) diff --git a/aten/src/ATen/mps/MPSAllocator.h b/aten/src/ATen/mps/MPSAllocator.h index b450bd29c7ebd..72c0024807255 100644 --- a/aten/src/ATen/mps/MPSAllocator.h +++ b/aten/src/ATen/mps/MPSAllocator.h @@ -35,7 +35,7 @@ class IMpsAllocatorCallback { RELEASED, // buffer memory released }; virtual ~IMpsAllocatorCallback() = default; - virtual bool executeMPSAllocatorCallback(void* ptr, EventType event) = 0; + virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0; }; // MPS allocator will execute every registered callback when a block of memory is freed. @@ -226,7 +226,7 @@ class MPSHeapAllocatorImpl void release_buffers(BufferPool& pool); bool release_available_cached_buffers(const AllocParams& p); bool release_cached_buffers(); - bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event); + void trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event); BufferPool& get_pool(size_t Size, bool useShared) { return Size <= kMaxSmallAlloc ? (useShared ? m_small_pool_shared : m_small_pool_private) : diff --git a/aten/src/ATen/mps/MPSAllocator.mm b/aten/src/ATen/mps/MPSAllocator.mm index 4dac9ede70802..873a78fffce8e 100644 --- a/aten/src/ATen/mps/MPSAllocator.mm +++ b/aten/src/ATen/mps/MPSAllocator.mm @@ -144,12 +144,10 @@ return it->second; } -bool MPSHeapAllocatorImpl::trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) { - bool result = false; - for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) { - result |= MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block->buffer, event); - } - return result; +void MPSHeapAllocatorImpl::trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) { + for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) { + MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block->buffer, event); + } } bool MPSHeapAllocatorImpl::isSharedBuffer(void* ptr) diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 97c74bc6f583f..66bebaaf273f4 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -102,7 +102,6 @@ void printTensorNDArray(const Tensor& t); MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor); -MPSGraphTensor* mpsGraphConstantPlaceHolder(MPSGraph *mpsGraph, const double value, MPSShape* mpsShape, MPSDataType dataType); string get_mem_format_string(c10::MemoryFormat memory_format); @@ -200,26 +199,22 @@ struct MPSGraphCache return result; } - bool FindAndRemoveViewEntry(void* ptr) { - bool removed_entry = false; - + void FindAndRemoveViewEntry(void* ptr) { // this may find multiple view entries with the same buffer pointers auto views_range = views_list.equal_range(ptr); - if (views_range.first != views_range.second) { - for (auto view_it = views_range.first; view_it != views_range.second; ++view_it) { - MPSCacheKey hash = view_it->second; - // find the cache entry associated with the hash - auto cache_it = cache_.find(hash); - if (cache_it != cache_.end()) { - cache_.erase(cache_it); - delete cache_it->second.cachedGraph_; - removed_entry = true; - } + if (views_range.first == views_range.second) + return; + for (auto view_it = views_range.first; view_it != views_range.second; ++view_it) { + MPSCacheKey hash = view_it->second; + // find the cache entry associated with the hash + auto cache_it = cache_.find(hash); + if (cache_it != cache_.end()) { + cache_.erase(cache_it); + delete cache_it->second.cachedGraph_; } - // this erase-by-key will remove all pairs in the list with the same key - views_list.erase(ptr); } - return removed_entry; + // this erase-by-key will remove all pairs in the list with the same key + views_list.erase(ptr); } private: @@ -231,7 +226,7 @@ struct MPSGraphCache std::unordered_map cache_; // list of buffers associated with view entries in the cache // note that multiple view cache entries could use the same buffer pointer - std::multimap views_list; + std::unordered_multimap views_list; dispatch_queue_t serialQueue_ = nullptr; }; diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index b65ccf1ec5fd0..caee96d71acca 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -449,17 +449,6 @@ void resize_tensor(Tensor* output) { return mpsGraph; } -MPSGraphTensor* mpsGraphConstantPlaceHolder(MPSGraph *mpsGraph, const double value, MPSShape* mpsShape, MPSDataType dataType) { - // Bool is not handled by constantWithScalar - MPSGraphTensor* constPlaceHolder = [mpsGraph constantWithScalar:value - shape:mpsShape - dataType:(dataType == MPSDataTypeBool ? MPSDataTypeFloat32 : dataType)]; - if (dataType == MPSDataTypeBool) - return [mpsGraph castTensor:constPlaceHolder toType:MPSDataTypeBool name:@"ConstantBoolTensor"]; - - return constPlaceHolder; -} - MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType) { return [mpsGraph placeholderWithShape:nil dataType:dataType @@ -501,11 +490,10 @@ string get_mem_format_string(c10::MemoryFormat memory_format) { public: MPSGraphCacheCallback() : graph_cache(MPSGraphCache::getInstance()) { } - bool executeMPSAllocatorCallback(void* ptr, EventType event) override { + void executeMPSAllocatorCallback(void* ptr, EventType event) override { if (event == EventType::FREED || event == EventType::RELEASED) { - return graph_cache->FindAndRemoveViewEntry(ptr); + graph_cache->FindAndRemoveViewEntry(ptr); } - return false; } private: MPSGraphCache* graph_cache; diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index c0e3bee8a6856..677a179ef1045 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -252,7 +252,7 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, auto storage_byte_offset = src_.storage_offset() * src_.itemsize(); id sourceBuffer = __builtin_bit_cast(id, src_.storage().data()); - if (src_.is_view()) { + if (!src_.is_contiguous()) { id gatherTensor = gatherViewTensor(src_, sourceBuffer); if (gatherTensor) { sourceBuffer = gatherTensor;