From a5e7cac2ad011e8260cdea6dd26611d40451d69e Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Fri, 11 Jun 2021 15:35:14 +0200 Subject: [PATCH 1/2] [Metal] Fix bad stream after interrupted tuning session After interrupted tuning session, we may face the problem that the stream object was released, but we didn't create a new one. In this case it wasn't possible to run a new Metal task on the device without restarting rpc application. Created a global function `metal.ResetGlobalState` which should be called in RPC application when the connection was closed. In this function, we reinitialize the streams of Metal devices. And it guarantees to us that the new RPC session will work with the correct streams. --- src/runtime/metal/metal_common.h | 1 + src/runtime/metal/metal_device_api.mm | 31 +++++++++++++++++++++------ 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index 7d2ef0c9367b..47a5999fdce9 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -163,6 +163,7 @@ class MetalWorkspace final : public DeviceAPI { void SetStream(Device dev, TVMStreamHandle stream) final; void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; + void ReinitializeStreams(); // get the global workspace static MetalWorkspace* Global(); diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 43d8ccdbf6c7..842eac005b43 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -131,6 +131,23 @@ int GetWarpSize(id dev) { } } +void MetalWorkspace::ReinitializeStreams() { + std::vector& threadStreams = MetalThreadEntry::ThreadLocal()->stream; + ICHECK_EQ(default_streams_.size(), threadStreams.size()); + for (size_t i = 0; i < default_streams_.size(); ++i) { + if (threadStreams[i] != nullptr && default_streams_[i] != threadStreams[i]) + delete threadStreams[i]; + delete default_streams_[i]; + } + default_streams_.resize(devices.size()); + threadStreams.resize(devices.size()); + for (size_t i = 0; i < devices.size(); ++i) { + Stream* stream = new Stream(devices[i]); + default_streams_[i] = stream; + threadStreams[i] = stream; + } +} + void MetalWorkspace::Init() { if (initialized_) return; std::lock_guard lock(this->mutex); @@ -141,21 +158,16 @@ int GetWarpSize(id dev) { // on iPhone id d = MTLCreateSystemDefaultDevice(); devices.push_back(d); - Stream* stream = new Stream(d); - MetalThreadEntry::ThreadLocal()->stream.push_back(stream); - default_streams_.push_back(stream); #else NSArray >* devs = MTLCopyAllDevices(); for (size_t i = 0; i < devs.count; ++i) { id d = [devs objectAtIndex:i]; devices.push_back(d); - Stream* stream = new Stream(d); - MetalThreadEntry::ThreadLocal()->stream.push_back(stream); - default_streams_.push_back(stream); LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String]; warp_size.push_back(GetWarpSize(d)); } #endif + ReinitializeStreams(); } void MetalWorkspace::SetDevice(Device dev) { @@ -275,7 +287,10 @@ int GetWarpSize(id dev) { void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) { ICHECK(stream != nullptr); + ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id; Stream* s = static_cast(stream); + if (MetalThreadEntry::ThreadLocal()->stream[dev.device_id] == s) + MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = nullptr; delete s; } @@ -337,6 +352,10 @@ int GetWarpSize(id dev) { *rv = static_cast(ptr); }); +TVM_REGISTER_GLOBAL("metal.ResetGlobalState").set_body_typed([]() { + MetalWorkspace::Global()->ReinitializeStreams(); +}); + } // namespace metal } // namespace runtime } // namespace tvm From 0fa2aeac08b7820e538d9390ecc8796e66f7b32a Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Fri, 11 Jun 2021 15:39:44 +0200 Subject: [PATCH 2/2] Refactor metal_device_api - Rename function GetStream -> CastStreamOrGetCurrent - Add several checks on device id - When we use `SetStream` with nullptr, then the default stream will be associated with the device. --- src/runtime/metal/metal_device_api.mm | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 842eac005b43..0ef07b189a6b 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -205,11 +204,10 @@ int GetWarpSize(id dev) { }; } -Stream* GetStream(TVMStreamHandle stream, int device_id) { - if (stream != nullptr) - return static_cast(stream); - else - return MetalThreadEntry::ThreadLocal()->stream[device_id]; +Stream* CastStreamOrGetCurrent(TVMStreamHandle stream, int device_id) { + if (stream != nullptr) return static_cast(stream); + ICHECK(MetalThreadEntry::ThreadLocal()->stream[device_id] != nullptr); + return MetalThreadEntry::ThreadLocal()->stream[device_id]; } void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, @@ -219,7 +217,7 @@ int GetWarpSize(id dev) { this->Init(); Device dev = dev_from; if (dev_from.device_type == kDLCPU) dev = dev_to; - Stream* s = GetStream(stream, dev.device_id); + Stream* s = CastStreamOrGetCurrent(stream, dev.device_id); if (s->HasErrorHappened()) { LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream"; } @@ -281,6 +279,7 @@ int GetWarpSize(id dev) { } TVMStreamHandle MetalWorkspace::CreateStream(Device dev) { + ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id; Stream* stream = new Stream(devices[dev.device_id]); return static_cast(stream); } @@ -296,7 +295,7 @@ int GetWarpSize(id dev) { void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) { AUTORELEASEPOOL { - Stream* s = GetStream(stream, dev.device_id); + Stream* s = CastStreamOrGetCurrent(stream, dev.device_id); // commit an empty command buffer and wait until it completes. id cb = s->GetCommandBuffer(); [cb commit]; @@ -308,6 +307,8 @@ int GetWarpSize(id dev) { } void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) { + ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id; + ICHECK(stream != nullptr); MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = static_cast(stream); }