Skip to content

Commit

Permalink
Refactor metal_device_api
Browse files Browse the repository at this point in the history
- Rename function GetStream -> CastStreamOrGetCurrent
- Add several checks on device id
- When we use `SetStream` with nullptr, then the default stream will be
  associated with the device.
  • Loading branch information
echuraev committed Jun 17, 2021
1 parent a5e7cac commit 0fa2aea
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
Expand Down Expand Up @@ -205,11 +204,10 @@ int GetWarpSize(id<MTLDevice> dev) {
};
}

Stream* GetStream(TVMStreamHandle stream, int device_id) {
if (stream != nullptr)
return static_cast<Stream*>(stream);
else
return MetalThreadEntry::ThreadLocal()->stream[device_id];
Stream* CastStreamOrGetCurrent(TVMStreamHandle stream, int device_id) {
if (stream != nullptr) return static_cast<Stream*>(stream);
ICHECK(MetalThreadEntry::ThreadLocal()->stream[device_id] != nullptr);
return MetalThreadEntry::ThreadLocal()->stream[device_id];
}

void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to,
Expand All @@ -219,7 +217,7 @@ int GetWarpSize(id<MTLDevice> dev) {
this->Init();
Device dev = dev_from;
if (dev_from.device_type == kDLCPU) dev = dev_to;
Stream* s = GetStream(stream, dev.device_id);
Stream* s = CastStreamOrGetCurrent(stream, dev.device_id);
if (s->HasErrorHappened()) {
LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream";
}
Expand Down Expand Up @@ -281,6 +279,7 @@ int GetWarpSize(id<MTLDevice> dev) {
}

TVMStreamHandle MetalWorkspace::CreateStream(Device dev) {
ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id;
Stream* stream = new Stream(devices[dev.device_id]);
return static_cast<TVMStreamHandle>(stream);
}
Expand All @@ -296,7 +295,7 @@ int GetWarpSize(id<MTLDevice> dev) {

void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
AUTORELEASEPOOL {
Stream* s = GetStream(stream, dev.device_id);
Stream* s = CastStreamOrGetCurrent(stream, dev.device_id);
// commit an empty command buffer and wait until it completes.
id<MTLCommandBuffer> cb = s->GetCommandBuffer();
[cb commit];
Expand All @@ -308,6 +307,8 @@ int GetWarpSize(id<MTLDevice> dev) {
}

void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) {
ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id;
ICHECK(stream != nullptr);
MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = static_cast<Stream*>(stream);
}

Expand Down

0 comments on commit 0fa2aea

Please sign in to comment.