diff --git a/aten/src/ATen/mps/MPSAllocator.h b/aten/src/ATen/mps/MPSAllocator.h index b450bd29c7ebd..72c0024807255 100644 --- a/aten/src/ATen/mps/MPSAllocator.h +++ b/aten/src/ATen/mps/MPSAllocator.h @@ -35,7 +35,7 @@ class IMpsAllocatorCallback { RELEASED, // buffer memory released }; virtual ~IMpsAllocatorCallback() = default; - virtual bool executeMPSAllocatorCallback(void* ptr, EventType event) = 0; + virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0; }; // MPS allocator will execute every registered callback when a block of memory is freed. @@ -226,7 +226,7 @@ class MPSHeapAllocatorImpl void release_buffers(BufferPool& pool); bool release_available_cached_buffers(const AllocParams& p); bool release_cached_buffers(); - bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event); + void trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event); BufferPool& get_pool(size_t Size, bool useShared) { return Size <= kMaxSmallAlloc ? (useShared ? m_small_pool_shared : m_small_pool_private) : diff --git a/aten/src/ATen/mps/MPSAllocator.mm b/aten/src/ATen/mps/MPSAllocator.mm index 4dac9ede70802..873a78fffce8e 100644 --- a/aten/src/ATen/mps/MPSAllocator.mm +++ b/aten/src/ATen/mps/MPSAllocator.mm @@ -144,12 +144,10 @@ return it->second; } -bool MPSHeapAllocatorImpl::trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) { - bool result = false; - for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) { - result |= MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block->buffer, event); - } - return result; +void MPSHeapAllocatorImpl::trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) { + for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) { + MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block->buffer, event); + } } bool MPSHeapAllocatorImpl::isSharedBuffer(void* ptr) diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 97c74bc6f583f..66bebaaf273f4 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -102,7 +102,6 @@ void printTensorNDArray(const Tensor& t); MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor); -MPSGraphTensor* mpsGraphConstantPlaceHolder(MPSGraph *mpsGraph, const double value, MPSShape* mpsShape, MPSDataType dataType); string get_mem_format_string(c10::MemoryFormat memory_format); @@ -200,26 +199,22 @@ struct MPSGraphCache return result; } - bool FindAndRemoveViewEntry(void* ptr) { - bool removed_entry = false; - + void FindAndRemoveViewEntry(void* ptr) { // this may find multiple view entries with the same buffer pointers auto views_range = views_list.equal_range(ptr); - if (views_range.first != views_range.second) { - for (auto view_it = views_range.first; view_it != views_range.second; ++view_it) { - MPSCacheKey hash = view_it->second; - // find the cache entry associated with the hash - auto cache_it = cache_.find(hash); - if (cache_it != cache_.end()) { - cache_.erase(cache_it); - delete cache_it->second.cachedGraph_; - removed_entry = true; - } + if (views_range.first == views_range.second) + return; + for (auto view_it = views_range.first; view_it != views_range.second; ++view_it) { + MPSCacheKey hash = view_it->second; + // find the cache entry associated with the hash + auto cache_it = cache_.find(hash); + if (cache_it != cache_.end()) { + cache_.erase(cache_it); + delete cache_it->second.cachedGraph_; } - // this erase-by-key will remove all pairs in the list with the same key - views_list.erase(ptr); } - return removed_entry; + // this erase-by-key will remove all pairs in the list with the same key + views_list.erase(ptr); } private: @@ -231,7 +226,7 @@ struct MPSGraphCache std::unordered_map cache_; // list of buffers associated with view entries in the cache // note that multiple view cache entries could use the same buffer pointer - std::multimap views_list; + std::unordered_multimap views_list; dispatch_queue_t serialQueue_ = nullptr; }; diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index b65ccf1ec5fd0..caee96d71acca 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -449,17 +449,6 @@ void resize_tensor(Tensor* output) { return mpsGraph; } -MPSGraphTensor* mpsGraphConstantPlaceHolder(MPSGraph *mpsGraph, const double value, MPSShape* mpsShape, MPSDataType dataType) { - // Bool is not handled by constantWithScalar - MPSGraphTensor* constPlaceHolder = [mpsGraph constantWithScalar:value - shape:mpsShape - dataType:(dataType == MPSDataTypeBool ? MPSDataTypeFloat32 : dataType)]; - if (dataType == MPSDataTypeBool) - return [mpsGraph castTensor:constPlaceHolder toType:MPSDataTypeBool name:@"ConstantBoolTensor"]; - - return constPlaceHolder; -} - MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType) { return [mpsGraph placeholderWithShape:nil dataType:dataType @@ -501,11 +490,10 @@ string get_mem_format_string(c10::MemoryFormat memory_format) { public: MPSGraphCacheCallback() : graph_cache(MPSGraphCache::getInstance()) { } - bool executeMPSAllocatorCallback(void* ptr, EventType event) override { + void executeMPSAllocatorCallback(void* ptr, EventType event) override { if (event == EventType::FREED || event == EventType::RELEASED) { - return graph_cache->FindAndRemoveViewEntry(ptr); + graph_cache->FindAndRemoveViewEntry(ptr); } - return false; } private: MPSGraphCache* graph_cache; diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm index c0e3bee8a6856..677a179ef1045 100644 --- a/aten/src/ATen/native/mps/operations/Copy.mm +++ b/aten/src/ATen/native/mps/operations/Copy.mm @@ -252,7 +252,7 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src, auto storage_byte_offset = src_.storage_offset() * src_.itemsize(); id sourceBuffer = __builtin_bit_cast(id, src_.storage().data()); - if (src_.is_view()) { + if (!src_.is_contiguous()) { id gatherTensor = gatherViewTensor(src_, sourceBuffer); if (gatherTensor) { sourceBuffer = gatherTensor;