Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix high memory consumption in view ops (#81610) #71

Merged
merged 1 commit into from
Jul 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/mps/MPSStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class TORCH_API MPSStream
void copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
size_t length, size_t srcOffset, size_t dstOffset, bool non_blocking);
void flush();
void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results);
void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType = SyncType::NONE);

/// Get the MPS device index that this stream is associated with.
c10::DeviceIndex device_index() const { return _stream.device_index(); }
Expand Down
6 changes: 5 additions & 1 deletion aten/src/ATen/mps/MPSStream.mm
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
}

void MPSStream::synchronize(SyncType syncType) {
if (!_commandBuffer)
return;
switch(syncType) {
case SyncType::NONE:
// typically in GPU to GPU copies we won't commit explicitly
Expand Down Expand Up @@ -134,14 +136,16 @@
!non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT);
}

void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results) {
void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType) {
dispatch_sync(_serialQueue, ^() {
#if USE_MPSCOMMANDBUFFER
[mpsGraph encodeToCommandBuffer:commandBuffer()
feeds:feeds
targetOperations:nil
resultsDictionary:results
executionDescriptor:_executionDescriptor];
// mostly the syncType is NONE, but in some cases we may want to sync and wait (e.g., gatherViewTensor)
synchronize(syncType);
#else
commit(true);
[mpsGraph runAsyncWithMTLCommandQueue:_commandQueue
Expand Down
14 changes: 9 additions & 5 deletions aten/src/ATen/native/mps/operations/View.mm
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
}

// initializes the MTLBuffers for tesnsor data and runs the MPSGraph for the view op
static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src, Tensor& output, bool needsScatter)
static Tensor& runViewGraph(ViewCachedGraph* cachedGraph, const at::Tensor& src, Tensor& output,
bool needsScatter, bool requires_sync = false)
{
const id<MTLBuffer> sourceBuffer = getMTLBufferStorage(src);
const id<MTLBuffer> outputBuffer = getMTLBufferStorage(output);
Expand Down Expand Up @@ -71,7 +72,8 @@
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
cachedGraph->outputTensor : outputTensorData
};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
stream->executeMPSGraph(cachedGraph->graph(), feeds, results,
requires_sync ? SyncType::COMMIT_AND_WAIT : SyncType::NONE);
}
return output;
}
Expand Down Expand Up @@ -225,11 +227,13 @@ Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst)
return Tensor();
}

bool requires_sync = false;
Tensor output;
if (!dst.has_storage())
if (!dst.has_storage()) {
output = at::native::empty_mps(src.sizes(), src.scalar_type(), c10::nullopt, kMPS);

return runViewGraph(cachedGraph, src, dst.has_storage() ? dst : output, /*needsScatter*/ false);
requires_sync = true;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the Tensor we allocate (output) would get freed once it gets out of this function scope - does commitAndContinue retain the resource (the MTLBuffer of the output Tensor)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning from runMPSGraph(), doesn't mean we have committed anything. The commitAndContinue keeps accumulating ops in commandBuffer (and retains refs to buffers) and commits them based on some heuristic to increase GPU utilization.

}
return runViewGraph(cachedGraph, src, dst.has_storage() ? dst : output, /*needsScatter*/ false, requires_sync);
}

Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output)
Expand Down