Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aten/src/ATen/mps/MPSStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ typedef void* MTLDevice_t;
namespace at {
namespace mps {

#define USE_MPSCOMMANDBUFFER 1

//-----------------------------------------------------------------
// MPSStream
//-----------------------------------------------------------------
Expand All @@ -53,6 +55,7 @@ class TORCH_API MPSStream
MTLCommandBuffer_t commandBuffer();
void commit(bool flush);
void commitAndWait();
void commitAndContinue();
void synchronize();

void flush();
Expand Down
6 changes: 5 additions & 1 deletion aten/src/ATen/mps/MPSStream.mm
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@
_commandBuffer = nil;
}

void MPSStream::commitAndContinue() {
assert(_commandBuffer);
[_commandBuffer commitAndContinue];
}

void MPSStream::flush() {
if (_commandBuffer) {
[_commandBuffer commit];
Expand All @@ -76,7 +81,6 @@
[_commandBuffer release];
}

#define USE_MPSCOMMANDBUFFER 1

void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results) {
dispatch_sync(_serialQueue, ^() {
Expand Down
14 changes: 13 additions & 1 deletion aten/src/ATen/native/mps/operations/Copy.mm
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,11 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
[blitEncoder endEncoding];

if (non_blocking) {
#if USE_MPSCOMMANDBUFFER
stream->commitAndContinue();
#else
stream->commit(true);
#endif
} else {
stream->commitAndWait();
}
Expand Down Expand Up @@ -215,7 +219,11 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
size:(NSUInteger)size];
[blitEncoder endEncoding];
if (non_blocking) {
#if USE_MPSCOMMANDBUFFER
stream->commitAndContinue();
#else
stream->commit(true);
#endif
} else {
stream->commitAndWait();
}
Expand Down Expand Up @@ -300,7 +308,11 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
size:src_size];
[blitEncoder endEncoding];
// GPU to GPU copy needs flushing only, and no synchronization with CPU is necessary
stream->commit(true);
#if USE_MPSCOMMANDBUFFER
stream->commitAndContinue();
#else
stream->commit(true);
#endif
}
});
} else {
Expand Down