Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove view from graph cache when underlying buffer is freed #26

Merged
merged 2 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions aten/src/ATen/mps/MPSAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,23 @@
namespace at {
namespace mps {

class IMpsAllocatorCallback {
public:
enum class EventType {
ALLOCATED, // buffer got allocated to be used immediately
RECYCLED, // buffer pulled from free list to be reused
FREED, // buffer put to free list for future recycling
RELEASED, // buffer memory released
};
virtual ~IMpsAllocatorCallback() = default;
virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
};

// MPS allocator will execute every registered callback when a block of memory is freed.
C10_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__);

namespace HeapAllocator {

#define MB(x) round_page(x * 1048576UL)
Expand Down Expand Up @@ -209,6 +226,7 @@ class MPSHeapAllocatorImpl
void release_buffers(BufferPool& pool);
bool release_available_cached_buffers(const AllocParams& p);
bool release_cached_buffers();
void trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event);

BufferPool& get_pool(size_t Size, bool useShared) {
return Size <= kMaxSmallAlloc ? (useShared ? m_small_pool_shared : m_small_pool_private) :
Expand Down
12 changes: 12 additions & 0 deletions aten/src/ATen/mps/MPSAllocator.mm
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
namespace at {
namespace mps {

C10_DEFINE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);

namespace HeapAllocator {

HeapBlock* MPSHeapAllocatorImpl::get_free_heap(AllocParams& p)
Expand Down Expand Up @@ -125,6 +127,7 @@
void MPSHeapAllocatorImpl::free_buffer(BufferBlock* buffer_block)
{
TORCH_INTERNAL_ASSERT(buffer_block->in_use);
trigger_memory_callbacks(buffer_block, IMpsAllocatorCallback::EventType::FREED);
buffer_block->in_use = false;
BufferPool *pool = buffer_block->heap->pool;
// Makes sure the BufferBlock* isn't already present in the pool we're freeing it back into.
Expand All @@ -141,6 +144,12 @@
return it->second;
}

void MPSHeapAllocatorImpl::trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) {
for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) {
MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block->buffer, event);
}
}

bool MPSHeapAllocatorImpl::isSharedBuffer(void* ptr)
{
std::lock_guard<std::mutex> lock(m_mutex);
Expand All @@ -167,6 +176,8 @@

void MPSHeapAllocatorImpl::release_buffer(BufferBlock* buffer_block, bool remove_empty_heap)
{
trigger_memory_callbacks(buffer_block, IMpsAllocatorCallback::EventType::RELEASED);

HeapBlock *heap = buffer_block->heap;
BufferPool *pool = heap->pool;
m_total_allocated_memory -= buffer_block->size;
Expand Down Expand Up @@ -318,6 +329,7 @@ static bool isEnvVarEnabled(const char *envvar) {
static MPSAllocator s_mps_shared_alloc(true);
return s_mps_shared_alloc;
}

MPSAllocator& _getPrivateAllocator() {
static mps::MPSAllocator s_mps_private_alloc(false);
return s_mps_private_alloc;
Expand Down
31 changes: 27 additions & 4 deletions aten/src/ATen/native/mps/OperationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,10 @@ void printTensorNDArray(const Tensor& t);
MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape);
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor);
MPSGraphTensor* mpsGraphConstantPlaceHolder(MPSGraph *mpsGraph, const double value, MPSShape* mpsShape, MPSDataType dataType);

string get_mem_format_string(c10::MemoryFormat memory_format);

using MPSCacheKey = int64_t;
using MPSCacheKey = uint64_t;

// derive this class to cache a graph and its inputs/ouputs
// can be used to store any NSObject
Expand All @@ -126,7 +125,6 @@ struct MPSCachedGraph
// TODO: Improve the overall design of MPSGraphCache.
// https://github.com/pytorch/pytorch/issues/77176
// Cache holding various keys mapped to graphs

struct MPSGraphCache
{
typedef MPSCachedGraph * (^CreateCachedGraphBlock)();
Expand Down Expand Up @@ -158,7 +156,7 @@ struct MPSGraphCache
MPSGraphCache(const MPSGraphCache&) = delete;
void operator=(const MPSGraphCache&) = delete;

MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock, void* view_ptr = nullptr) {

__block MPSCachedGraph * result = nil;

Expand All @@ -176,6 +174,9 @@ struct MPSGraphCache
result = createCacheBlock();
CacheEntry entry(key, result);
cache_.emplace(hash, entry);
if (view_ptr) {
views_list.insert(std::make_pair(view_ptr, hash));
}
}
});
return result;
Expand All @@ -197,13 +198,35 @@ struct MPSGraphCache
});
return result;
}

void FindAndRemoveViewEntry(void* ptr) {
// this may find multiple view entries with the same buffer pointers
auto views_range = views_list.equal_range(ptr);
if (views_range.first == views_range.second)
return;
for (auto view_it = views_range.first; view_it != views_range.second; ++view_it) {
MPSCacheKey hash = view_it->second;
// find the cache entry associated with the hash
auto cache_it = cache_.find(hash);
if (cache_it != cache_.end()) {
cache_.erase(cache_it);
delete cache_it->second.cachedGraph_;
}
}
// this erase-by-key will remove all pairs in the list with the same key
views_list.erase(ptr);
}

private:
MPSGraphCache() {
serialQueue_ = dispatch_queue_create("cache queue", DISPATCH_QUEUE_SERIAL);
}

static MPSGraphCache* _instance_cache;
std::unordered_map<MPSCacheKey, CacheEntry> cache_;
// list of buffers associated with view entries in the cache
// note that multiple view cache entries could use the same buffer pointer
std::unordered_multimap<void*, MPSCacheKey> views_list;
dispatch_queue_t serialQueue_ = nullptr;

};
Expand Down
30 changes: 16 additions & 14 deletions aten/src/ATen/native/mps/OperationUtils.mm
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright © 2022 Apple Inc.

#include <ATen/native/mps/OperationUtils.h>
#include <ATen/mps/MPSAllocator.h>

namespace at {
namespace native {
Expand Down Expand Up @@ -287,9 +288,6 @@ void printTensorNDArray(const Tensor& t) {
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
IntArrayRef size_;
IntArrayRef stride_;
int64_t storage_offset_;
};

CachedGraph* cachedGraph = static_cast<CachedGraph *>(mpsCachedGraph);
Expand Down Expand Up @@ -451,17 +449,6 @@ void resize_tensor(Tensor* output) {
return mpsGraph;
}

MPSGraphTensor* mpsGraphConstantPlaceHolder(MPSGraph *mpsGraph, const double value, MPSShape* mpsShape, MPSDataType dataType) {
// Bool is not handled by constantWithScalar
MPSGraphTensor* constPlaceHolder = [mpsGraph constantWithScalar:value
shape:mpsShape
dataType:(dataType == MPSDataTypeBool ? MPSDataTypeFloat32 : dataType)];
if (dataType == MPSDataTypeBool)
return [mpsGraph castTensor:constPlaceHolder toType:MPSDataTypeBool name:@"ConstantBoolTensor"];

return constPlaceHolder;
}

MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType) {
return [mpsGraph placeholderWithShape:nil
dataType:dataType
Expand Down Expand Up @@ -499,6 +486,21 @@ string get_mem_format_string(c10::MemoryFormat memory_format) {

MPSGraphCache* MPSGraphCache::_instance_cache = nullptr;

class MPSGraphCacheCallback : public IMpsAllocatorCallback {
public:
MPSGraphCacheCallback() : graph_cache(MPSGraphCache::getInstance()) { }

void executeMPSAllocatorCallback(void* ptr, EventType event) override {
if (event == EventType::FREED || event == EventType::RELEASED) {
graph_cache->FindAndRemoveViewEntry(ptr);
}
}
private:
MPSGraphCache* graph_cache;
};

REGISTER_MPS_ALLOCATOR_CALLBACK("mps_graph_cache_callback", MPSGraphCacheCallback);

} // namespace mps
} // namespace native
} // namespace at
53 changes: 19 additions & 34 deletions aten/src/ATen/native/mps/operations/Copy.mm
Original file line number Diff line number Diff line change
Expand Up @@ -123,45 +123,30 @@ Tensor as_strided_tensorimpl_mps(const Tensor& self, IntArrayRef size,
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
IntArrayRef size_;
IntArrayRef stride_;
int64_t storage_offset_;
};

MPSGraphCache* cache_ = MPSGraphCache::getInstance();

@autoreleasepool {
string lookup_key = mps::getStridedKey(self, size, stride, storage_offset);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(lookup_key));

if(!cachedGraph) {
string insert_key = mps::getStridedKey(self,size, stride, storage_offset);
CachedGraph* insertCachedGraph = static_cast<CachedGraph *>(cache_->LookUp(insert_key));
if (!insertCachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(insert_key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);

// Self is the input tensor we are creating view of
MPSGraphTensor* inputTensor = [mpsGraph placeholderWithShape : getMPSShape(self)
dataType : getMPSDataType(self.scalar_type())
name : nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = chainViewOperation(mpsGraph, size,
stride,
storage_offset,
inputTensor,
self);
newCachedGraph->size_ = size;
newCachedGraph->stride_ = stride;
newCachedGraph->storage_offset_ = storage_offset;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
string key = mps::getStridedKey(self, size, stride, storage_offset);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if (!cachedGraph) {
cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);

// Self is the input tensor we are creating view of
MPSGraphTensor* inputTensor = [mpsGraph placeholderWithShape : getMPSShape(self)
dataType : getMPSDataType(self.scalar_type())
name : nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = chainViewOperation(mpsGraph, size, stride,
storage_offset, inputTensor, self);
}
return newCachedGraph;
}, self.storage().data());
}
}
}
Expand Down