Skip to content

Commit cc15b85

Browse files
authored
Preserve command buffer in runMPSGraph() to optimize performance (#45)
* Preserve command buffer in runMPSGraph() to optimize performance * Add macro to toggle between commandBuffer and commandQueue methods in runMPSGraph
1 parent f11b254 commit cc15b85

File tree

3 files changed

+31
-25
lines changed

3 files changed

+31
-25
lines changed

aten/src/ATen/mps/MPSStream.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ class TORCH_API MPSStream
5656
void synchronize();
5757

5858
void flush();
59+
void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results);
5960

6061
/// Get the MPS device index that this stream is associated with.
6162
c10::DeviceIndex device_index() const { return _stream.device_index(); }
@@ -71,6 +72,7 @@ class TORCH_API MPSStream
7172
Stream _stream;
7273
MTLCommandQueue_t _commandQueue = nil;
7374
MTLCommandBuffer_t _commandBuffer = nil;
75+
MPSGraphExecutionDescriptor *_executionDescriptor = nil;
7476
void _flush(bool commitAndWait) const;
7577

7678
dispatch_queue_t _serialQueue = nullptr;

aten/src/ATen/mps/MPSStream.mm

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,16 @@
1313
_commandQueue = [MPSDevice::getInstance()->device() newCommandQueue];
1414
TORCH_CHECK(_stream.device_type() == DeviceType::MPS);
1515
_serialQueue = dispatch_queue_create("metal gpu stream", NULL);
16+
_executionDescriptor = [MPSGraphExecutionDescriptor new];
17+
_executionDescriptor.completionHandler = ^(NSDictionary<MPSGraphTensor *,
18+
MPSGraphTensorData *> * resultsDictionary,
19+
NSError * _Nullable error) { };
1620
}
1721

1822
MPSStream::~MPSStream() {
19-
[_commandQueue autorelease];
23+
[_commandQueue release];
2024
_commandQueue = nil;
25+
[_executionDescriptor release];
2126

2227
assert(_commandBuffer == nil);
2328
}
@@ -71,6 +76,27 @@
7176
[_commandBuffer release];
7277
}
7378

79+
#define USE_MPSCOMMANDBUFFER 1
80+
81+
void MPSStream::executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results) {
82+
dispatch_sync(_serialQueue, ^() {
83+
#if USE_MPSCOMMANDBUFFER
84+
[mpsGraph encodeToCommandBuffer:commandBuffer()
85+
feeds:feeds
86+
targetOperations:nil
87+
resultsDictionary:results
88+
executionDescriptor:_executionDescriptor];
89+
#else
90+
commit(true);
91+
[mpsGraph runAsyncWithMTLCommandQueue:_commandQueue
92+
feeds:feeds
93+
targetOperations:nil
94+
resultsDictionary:results
95+
executionDescriptor:_executionDescriptor];
96+
#endif
97+
});
98+
}
99+
74100
//-----------------------------------------------------------------
75101
// MPSStreamImpl
76102
//-----------------------------------------------------------------

aten/src/ATen/native/mps/OperationUtils.mm

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -60,30 +60,8 @@
6060
return gen;
6161
}
6262

63-
void runMPSGraph(
64-
MPSStream* mpsStream,
65-
MPSGraph* mpsGraph,
66-
NSDictionary* feeds,
67-
NSDictionary* results) {
68-
dispatch_sync(mpsStream->queue(), ^() {
69-
@autoreleasepool {
70-
mpsStream->commit(true);
71-
id<MTLCommandQueue> commandQueue = mpsStream->commandQueue();
72-
MPSGraphExecutionDescriptor *executionDescriptor = [[MPSGraphExecutionDescriptor new] autorelease];
73-
74-
executionDescriptor.completionHandler = ^(NSDictionary<MPSGraphTensor *,
75-
MPSGraphTensorData *> * resultsDictionary,
76-
NSError * _Nullable error) {
77-
};
78-
79-
[mpsGraph runAsyncWithMTLCommandQueue:commandQueue
80-
feeds:feeds
81-
targetOperations:nil
82-
resultsDictionary:results
83-
executionDescriptor:executionDescriptor];
84-
85-
}
86-
});
63+
void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results) {
64+
mpsStream->executeMPSGraph(mpsGraph, feeds, results);
8765
}
8866

8967
MPSDataType getMPSDataType(ScalarType scalar_type) {

0 commit comments

Comments
 (0)