Skip to content

Commit e4ffdc3

Browse files
razarmehrkulinseth
authored andcommitted
Refactor blit copies and reuse scalar tensor's MPS storage in BinaryOps (#52)
* Refactor blit copies and use commitAndContinue in commitAndWait() * Remove commitAndContinue from commitAndWait() * [MPS] Add commitAndContinue to non blocking blits. (#51) * Merge commitAndContinue changes from mps_master * Don't use getMPSGraphTensorFromScalar() if tensor is already on MPS device in Binary Ops This improves performance by preventing copy from GPU to CPU and back to GPU again Co-authored-by: Kulin Seth <kulin_seth@apple.com>
1 parent 09a9bff commit e4ffdc3

File tree

4 files changed

+61
-90
lines changed

4 files changed

+61
-90
lines changed

aten/src/ATen/mps/MPSStream.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ typedef void* MTLDevice_t;
3434
namespace at {
3535
namespace mps {
3636

37-
#define USE_MPSCOMMANDBUFFER 1
38-
3937
//-----------------------------------------------------------------
4038
// MPSStream
4139
//-----------------------------------------------------------------
@@ -44,6 +42,13 @@ class TORCH_API MPSStream
4442
{
4543
public:
4644
enum Unchecked { UNCHECKED };
45+
46+
enum class SyncType {
47+
NONE, // no commit to command buffer
48+
COMMIT, // commit and flush the command buffer
49+
COMMIT_AND_WAIT, // flush and wait for command buffer execution to finish
50+
COMMIT_AND_CONTINUE,// commit and continue with a new underlying command buffer
51+
};
4752
/// Construct a MPSStream from a Stream. This construction is checked,
4853
/// and will raise an error if the Stream is not, in fact, a MPS stream.
4954
explicit MPSStream(Stream stream);
@@ -57,7 +62,10 @@ class TORCH_API MPSStream
5762
void commitAndWait();
5863
void commitAndContinue();
5964
void synchronize();
60-
65+
void copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
66+
size_t length, size_t srcOffset, size_t dstOffset, SyncType syncType = SyncType::NONE);
67+
void copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
68+
size_t length, size_t srcOffset, size_t dstOffset, bool non_blocking);
6169
void flush();
6270
void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results);
6371

@@ -74,7 +82,7 @@ class TORCH_API MPSStream
7482
private:
7583
Stream _stream;
7684
MTLCommandQueue_t _commandQueue = nil;
77-
MTLCommandBuffer_t _commandBuffer = nil;
85+
MPSCommandBuffer* _commandBuffer = nil;
7886
MPSGraphExecutionDescriptor *_executionDescriptor = nil;
7987
void _flush(bool commitAndWait) const;
8088

aten/src/ATen/mps/MPSStream.mm

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
namespace at {
66
namespace mps {
77

8+
#define USE_MPSCOMMANDBUFFER 1
9+
810
//-----------------------------------------------------------------
911
// MPSStream
1012
//-----------------------------------------------------------------
@@ -46,9 +48,13 @@
4648
}
4749

4850
void MPSStream::commit(bool doFlush) {
51+
#if USE_MPSCOMMANDBUFFER
52+
[commandBuffer() commitAndContinue];
53+
#else
4954
if (doFlush) {
5055
flush();
5156
}
57+
#endif
5258
}
5359

5460
void MPSStream::commitAndWait() {
@@ -81,6 +87,41 @@
8187
[_commandBuffer release];
8288
}
8389

90+
void MPSStream::copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
91+
size_t length, size_t srcOffset, size_t dstOffset, SyncType syncType) {
92+
dispatch_sync(_serialQueue, ^() {
93+
@autoreleasepool {
94+
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer() blitCommandEncoder];
95+
96+
[blitEncoder copyFromBuffer:srcBuffer
97+
sourceOffset:(NSUInteger)srcOffset
98+
toBuffer:dstBuffer
99+
destinationOffset:(NSUInteger)dstOffset
100+
size:(NSUInteger)length];
101+
[blitEncoder endEncoding];
102+
switch(syncType) {
103+
case SyncType::NONE:
104+
// typically in GPU to GPU copies we won't commit explicitly
105+
break;
106+
case SyncType::COMMIT:
107+
commit(true);
108+
break;
109+
case SyncType::COMMIT_AND_WAIT:
110+
commitAndWait();
111+
break;
112+
case SyncType::COMMIT_AND_CONTINUE:
113+
commitAndContinue();
114+
break;
115+
}
116+
}
117+
});
118+
}
119+
120+
void MPSStream::copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer, size_t length,
121+
size_t srcOffset, size_t dstOffset, bool non_blocking) {
122+
copy(srcBuffer, dstBuffer, length, srcOffset, dstOffset,
123+
!non_blocking ? SyncType::COMMIT_AND_WAIT : SyncType::COMMIT);
124+
}
84125

85126
void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results) {
86127
dispatch_sync(_serialQueue, ^() {

aten/src/ATen/native/mps/operations/BinaryOps.mm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,13 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha
9494
Placeholder selfPlaceholder;
9595
Placeholder otherPlaceholder;
9696

97-
if (is_self_scalar) {
97+
if (is_self_scalar && !self.is_mps()) {
9898
feeds[cachedGraph->primaryTensor] = getMPSGraphTensorFromScalar(mpsStream, self.item(), getMPSScalarType(self.scalar_type()));
9999
} else {
100100
selfPlaceholder = Placeholder(cachedGraph->primaryTensor, self);
101101
feeds[selfPlaceholder.getMPSGraphTensor()] = selfPlaceholder.getMPSGraphTensorData();
102102
}
103-
if (is_other_scalar) {
103+
if (is_other_scalar && !other.is_mps()) {
104104
feeds[cachedGraph->secondaryTensor] = getMPSGraphTensorFromScalar(mpsStream, other.item(), getMPSScalarType(other.scalar_type()));
105105
} else {
106106
otherPlaceholder = Placeholder(cachedGraph->secondaryTensor, other);

aten/src/ATen/native/mps/operations/Copy.mm

Lines changed: 6 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -136,31 +136,8 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
136136
// 4 bytes alignment required on macos for blits.
137137
TORCH_CHECK(destOffset % 4 == 0, "Unaligned blit request");
138138

139-
dispatch_sync(stream->queue(), ^() {
140-
@autoreleasepool {
141-
id<MTLCommandBuffer> commandBuffer = stream->commandBuffer();
142-
id<MTLBlitCommandEncoder> blitEncoder =
143-
[commandBuffer blitCommandEncoder];
144-
145-
[blitEncoder copyFromBuffer:sourceBuffer
146-
sourceOffset:(NSUInteger)storage_byte_offset
147-
toBuffer:destBuffer
148-
destinationOffset:(NSUInteger)destOffset
149-
size:(NSUInteger)src_size];
150-
[blitEncoder endEncoding];
151-
152-
if (non_blocking) {
153-
#if USE_MPSCOMMANDBUFFER
154-
stream->commitAndContinue();
155-
#else
156-
stream->commit(true);
157-
#endif
158-
} else {
159-
stream->commitAndWait();
160-
}
161-
[destBuffer release];
162-
}
163-
});
139+
stream->copy_and_sync(sourceBuffer, destBuffer, src_size, storage_byte_offset, destOffset, non_blocking);
140+
[destBuffer release];
164141
}
165142
if (!dst.is_same(dst_)) {
166143
dst_.copy_(dst, non_blocking);
@@ -206,29 +183,7 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
206183
if (src_.is_view() || !src_.is_contiguous())
207184
sourceOffset += src_.storage_offset() * src_.itemsize();
208185

209-
dispatch_sync(stream->queue(), ^() {
210-
@autoreleasepool {
211-
id<MTLCommandBuffer> commandBuffer = stream->commandBuffer();
212-
id<MTLBlitCommandEncoder> blitEncoder =
213-
[commandBuffer blitCommandEncoder];
214-
215-
[blitEncoder copyFromBuffer:sourceBuffer
216-
sourceOffset:(NSUInteger)sourceOffset
217-
toBuffer:destBuffer
218-
destinationOffset:(NSUInteger)dst_byte_offset
219-
size:(NSUInteger)size];
220-
[blitEncoder endEncoding];
221-
if (non_blocking) {
222-
#if USE_MPSCOMMANDBUFFER
223-
stream->commitAndContinue();
224-
#else
225-
stream->commit(true);
226-
#endif
227-
} else {
228-
stream->commitAndWait();
229-
}
230-
}
231-
});
186+
stream->copy_and_sync(sourceBuffer, destBuffer, size, sourceOffset, dst_byte_offset, non_blocking);
232187
[sourceBuffer release];
233188
}
234189

@@ -237,23 +192,7 @@ void copy_cast_mps(at::Tensor& dst, const at::Tensor& src,
237192

238193
void copy_blit_mps(void* dst, const void* src, size_t size) {
239194
MPSStream* stream = getCurrentMPSStream();
240-
id<MTLBuffer> sourceBuffer = (id<MTLBuffer>)(src);
241-
id<MTLBuffer> destBuffer = (id<MTLBuffer>)(dst);
242-
dispatch_sync(stream->queue(), ^() {
243-
@autoreleasepool {
244-
id<MTLCommandBuffer> commandBuffer = stream->commandBuffer();
245-
id<MTLBlitCommandEncoder> blitEncoder =
246-
[commandBuffer blitCommandEncoder];
247-
248-
[blitEncoder copyFromBuffer:sourceBuffer
249-
sourceOffset:0
250-
toBuffer:destBuffer
251-
destinationOffset:0
252-
size:size];
253-
[blitEncoder endEncoding];
254-
stream->commitAndWait();
255-
}
256-
});
195+
stream->copy_and_sync((id<MTLBuffer>)(src), (id<MTLBuffer>)(dst), size, 0, 0, true);
257196
}
258197

259198
static at::Tensor& copy_kernel_mps(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
@@ -294,27 +233,10 @@ void copy_blit_mps(void* dst, const void* src, size_t size) {
294233
id<MTLBuffer> destBuffer = getMTLBufferStorage(dst_);
295234
id<MTLBuffer> sourceBuffer = getMTLBufferStorage(src);
296235
const size_t src_size = src.nbytes();
297-
298236
if (src.dtype() == dst_.dtype()) {
299237
MPSStream* stream = getCurrentMPSStream();
300-
dispatch_sync(stream->queue(), ^() {
301-
@autoreleasepool {
302-
id<MTLCommandBuffer> commandBuffer = stream->commandBuffer();
303-
id<MTLBlitCommandEncoder> blitEncoder = [commandBuffer blitCommandEncoder];
304-
[blitEncoder copyFromBuffer:sourceBuffer
305-
sourceOffset:src_byte_offset
306-
toBuffer:destBuffer
307-
destinationOffset:dst_byte_offset
308-
size:src_size];
309-
[blitEncoder endEncoding];
310-
// GPU to GPU copy needs flushing only, and no synchronization with CPU is necessary
311-
#if USE_MPSCOMMANDBUFFER
312-
stream->commitAndContinue();
313-
#else
314-
stream->commit(true);
315-
#endif
316-
}
317-
});
238+
// for GPU to GPU copies we only encode to stream's command buffer (no flushing)
239+
stream->copy(sourceBuffer, destBuffer, src_size, src_byte_offset, dst_byte_offset);
318240
} else {
319241
copy_cast_mps(dst_, src, destBuffer, sourceBuffer);
320242
}

0 commit comments

Comments
 (0)