Skip to content

Commit

Permalink
Remove is_view() and the return value from the allocator's callback
Browse files Browse the repository at this point in the history
  • Loading branch information
razarmehr committed Jun 6, 2022
1 parent de6a397 commit 22857dd
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 41 deletions.
4 changes: 2 additions & 2 deletions aten/src/ATen/mps/MPSAllocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class IMpsAllocatorCallback {
RELEASED, // buffer memory released
};
virtual ~IMpsAllocatorCallback() = default;
virtual bool executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
};

// MPS allocator will execute every registered callback when a block of memory is freed.
Expand Down Expand Up @@ -226,7 +226,7 @@ class MPSHeapAllocatorImpl
void release_buffers(BufferPool& pool);
bool release_available_cached_buffers(const AllocParams& p);
bool release_cached_buffers();
bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event);
void trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event);

BufferPool& get_pool(size_t Size, bool useShared) {
return Size <= kMaxSmallAlloc ? (useShared ? m_small_pool_shared : m_small_pool_private) :
Expand Down
10 changes: 4 additions & 6 deletions aten/src/ATen/mps/MPSAllocator.mm
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,10 @@
return it->second;
}

bool MPSHeapAllocatorImpl::trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) {
bool result = false;
for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) {
result |= MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block->buffer, event);
}
return result;
void MPSHeapAllocatorImpl::trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) {
for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) {
MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block->buffer, event);
}
}

bool MPSHeapAllocatorImpl::isSharedBuffer(void* ptr)
Expand Down
31 changes: 13 additions & 18 deletions aten/src/ATen/native/mps/OperationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ void printTensorNDArray(const Tensor& t);
MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType);
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape);
MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor);
MPSGraphTensor* mpsGraphConstantPlaceHolder(MPSGraph *mpsGraph, const double value, MPSShape* mpsShape, MPSDataType dataType);

string get_mem_format_string(c10::MemoryFormat memory_format);

Expand Down Expand Up @@ -200,26 +199,22 @@ struct MPSGraphCache
return result;
}

bool FindAndRemoveViewEntry(void* ptr) {
bool removed_entry = false;

void FindAndRemoveViewEntry(void* ptr) {
// this may find multiple view entries with the same buffer pointers
auto views_range = views_list.equal_range(ptr);
if (views_range.first != views_range.second) {
for (auto view_it = views_range.first; view_it != views_range.second; ++view_it) {
MPSCacheKey hash = view_it->second;
// find the cache entry associated with the hash
auto cache_it = cache_.find(hash);
if (cache_it != cache_.end()) {
cache_.erase(cache_it);
delete cache_it->second.cachedGraph_;
removed_entry = true;
}
if (views_range.first == views_range.second)
return;
for (auto view_it = views_range.first; view_it != views_range.second; ++view_it) {
MPSCacheKey hash = view_it->second;
// find the cache entry associated with the hash
auto cache_it = cache_.find(hash);
if (cache_it != cache_.end()) {
cache_.erase(cache_it);
delete cache_it->second.cachedGraph_;
}
// this erase-by-key will remove all pairs in the list with the same key
views_list.erase(ptr);
}
return removed_entry;
// this erase-by-key will remove all pairs in the list with the same key
views_list.erase(ptr);
}

private:
Expand All @@ -231,7 +226,7 @@ struct MPSGraphCache
std::unordered_map<MPSCacheKey, CacheEntry> cache_;
// list of buffers associated with view entries in the cache
// note that multiple view cache entries could use the same buffer pointer
std::multimap<void*, MPSCacheKey> views_list;
std::unordered_multimap<void*, MPSCacheKey> views_list;
dispatch_queue_t serialQueue_ = nullptr;

};
Expand Down
16 changes: 2 additions & 14 deletions aten/src/ATen/native/mps/OperationUtils.mm
Original file line number Diff line number Diff line change
Expand Up @@ -449,17 +449,6 @@ void resize_tensor(Tensor* output) {
return mpsGraph;
}

MPSGraphTensor* mpsGraphConstantPlaceHolder(MPSGraph *mpsGraph, const double value, MPSShape* mpsShape, MPSDataType dataType) {
// Bool is not handled by constantWithScalar
MPSGraphTensor* constPlaceHolder = [mpsGraph constantWithScalar:value
shape:mpsShape
dataType:(dataType == MPSDataTypeBool ? MPSDataTypeFloat32 : dataType)];
if (dataType == MPSDataTypeBool)
return [mpsGraph castTensor:constPlaceHolder toType:MPSDataTypeBool name:@"ConstantBoolTensor"];

return constPlaceHolder;
}

MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType) {
return [mpsGraph placeholderWithShape:nil
dataType:dataType
Expand Down Expand Up @@ -501,11 +490,10 @@ string get_mem_format_string(c10::MemoryFormat memory_format) {
public:
MPSGraphCacheCallback() : graph_cache(MPSGraphCache::getInstance()) { }

bool executeMPSAllocatorCallback(void* ptr, EventType event) override {
void executeMPSAllocatorCallback(void* ptr, EventType event) override {
if (event == EventType::FREED || event == EventType::RELEASED) {
return graph_cache->FindAndRemoveViewEntry(ptr);
graph_cache->FindAndRemoveViewEntry(ptr);
}
return false;
}
private:
MPSGraphCache* graph_cache;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/mps/operations/Copy.mm
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,

auto storage_byte_offset = src_.storage_offset() * src_.itemsize();
id<MTLBuffer> sourceBuffer = __builtin_bit_cast(id<MTLBuffer>, src_.storage().data());
if (src_.is_view()) {
if (!src_.is_contiguous()) {
id<MTLBuffer> gatherTensor = gatherViewTensor(src_, sourceBuffer);
if (gatherTensor) {
sourceBuffer = gatherTensor;
Expand Down

0 comments on commit 22857dd

Please sign in to comment.