Skip to content

Commit daf7524

Browse files
authored
[MPS] Add commitAndContinue to non blocking blits. (#51)
1 parent e4c3074 commit daf7524

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

aten/src/ATen/mps/MPSStream.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ typedef void* MTLDevice_t;
3434
namespace at {
3535
namespace mps {
3636

37+
#define USE_MPSCOMMANDBUFFER 1
38+
3739
//-----------------------------------------------------------------
3840
// MPSStream
3941
//-----------------------------------------------------------------
@@ -53,6 +55,7 @@ class TORCH_API MPSStream
5355
MTLCommandBuffer_t commandBuffer();
5456
void commit(bool flush);
5557
void commitAndWait();
58+
void commitAndContinue();
5659
void synchronize();
5760

5861
void flush();

aten/src/ATen/mps/MPSStream.mm

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@
5959
_commandBuffer = nil;
6060
}
6161

62+
void MPSStream::commitAndContinue() {
63+
assert(_commandBuffer);
64+
[_commandBuffer commitAndContinue];
65+
}
66+
6267
void MPSStream::flush() {
6368
if (_commandBuffer) {
6469
[_commandBuffer commit];
@@ -76,7 +81,6 @@
7681
[_commandBuffer release];
7782
}
7883

79-
#define USE_MPSCOMMANDBUFFER 1
8084

8185
void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results) {
8286
dispatch_sync(_serialQueue, ^() {

aten/src/ATen/native/mps/operations/Copy.mm

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,11 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
150150
[blitEncoder endEncoding];
151151

152152
if (non_blocking) {
153+
#if USE_MPSCOMMANDBUFFER
154+
stream->commitAndContinue();
155+
#else
153156
stream->commit(true);
157+
#endif
154158
} else {
155159
stream->commitAndWait();
156160
}
@@ -215,7 +219,11 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
215219
size:(NSUInteger)size];
216220
[blitEncoder endEncoding];
217221
if (non_blocking) {
222+
#if USE_MPSCOMMANDBUFFER
223+
stream->commitAndContinue();
224+
#else
218225
stream->commit(true);
226+
#endif
219227
} else {
220228
stream->commitAndWait();
221229
}
@@ -300,7 +308,11 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
300308
size:src_size];
301309
[blitEncoder endEncoding];
302310
// GPU to GPU copy needs flushing only, and no synchronization with CPU is necessary
303-
stream->commit(true);
311+
#if USE_MPSCOMMANDBUFFER
312+
stream->commitAndContinue();
313+
#else
314+
stream->commit(true);
315+
#endif
304316
}
305317
});
306318
} else {

0 commit comments

Comments
 (0)