diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index 7d2ef0c9367b..47a5999fdce9 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -163,6 +163,7 @@ class MetalWorkspace final : public DeviceAPI { void SetStream(Device dev, TVMStreamHandle stream) final; void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; + void ReinitializeStreams(); // get the global workspace static MetalWorkspace* Global(); diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 43d8ccdbf6c7..0ef07b189a6b 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -131,6 +130,23 @@ int GetWarpSize(id dev) { } } +void MetalWorkspace::ReinitializeStreams() { + std::vector& threadStreams = MetalThreadEntry::ThreadLocal()->stream; + ICHECK_EQ(default_streams_.size(), threadStreams.size()); + for (size_t i = 0; i < default_streams_.size(); ++i) { + if (threadStreams[i] != nullptr && default_streams_[i] != threadStreams[i]) + delete threadStreams[i]; + delete default_streams_[i]; + } + default_streams_.resize(devices.size()); + threadStreams.resize(devices.size()); + for (size_t i = 0; i < devices.size(); ++i) { + Stream* stream = new Stream(devices[i]); + default_streams_[i] = stream; + threadStreams[i] = stream; + } +} + void MetalWorkspace::Init() { if (initialized_) return; std::lock_guard lock(this->mutex); @@ -141,21 +157,16 @@ int GetWarpSize(id dev) { // on iPhone id d = MTLCreateSystemDefaultDevice(); devices.push_back(d); - Stream* stream = new Stream(d); - MetalThreadEntry::ThreadLocal()->stream.push_back(stream); - default_streams_.push_back(stream); #else NSArray >* devs = MTLCopyAllDevices(); for (size_t i = 0; i < devs.count; ++i) { id d = [devs objectAtIndex:i]; devices.push_back(d); - Stream* stream = new Stream(d); - MetalThreadEntry::ThreadLocal()->stream.push_back(stream); - default_streams_.push_back(stream); LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String]; warp_size.push_back(GetWarpSize(d)); } #endif + ReinitializeStreams(); } void MetalWorkspace::SetDevice(Device dev) { @@ -193,11 +204,10 @@ int GetWarpSize(id dev) { }; } -Stream* GetStream(TVMStreamHandle stream, int device_id) { - if (stream != nullptr) - return static_cast(stream); - else - return MetalThreadEntry::ThreadLocal()->stream[device_id]; +Stream* CastStreamOrGetCurrent(TVMStreamHandle stream, int device_id) { + if (stream != nullptr) return static_cast(stream); + ICHECK(MetalThreadEntry::ThreadLocal()->stream[device_id] != nullptr); + return MetalThreadEntry::ThreadLocal()->stream[device_id]; } void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to, @@ -207,7 +217,7 @@ int GetWarpSize(id dev) { this->Init(); Device dev = dev_from; if (dev_from.device_type == kDLCPU) dev = dev_to; - Stream* s = GetStream(stream, dev.device_id); + Stream* s = CastStreamOrGetCurrent(stream, dev.device_id); if (s->HasErrorHappened()) { LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream"; } @@ -269,19 +279,23 @@ int GetWarpSize(id dev) { } TVMStreamHandle MetalWorkspace::CreateStream(Device dev) { + ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id; Stream* stream = new Stream(devices[dev.device_id]); return static_cast(stream); } void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) { ICHECK(stream != nullptr); + ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id; Stream* s = static_cast(stream); + if (MetalThreadEntry::ThreadLocal()->stream[dev.device_id] == s) + MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = nullptr; delete s; } void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) { AUTORELEASEPOOL { - Stream* s = GetStream(stream, dev.device_id); + Stream* s = CastStreamOrGetCurrent(stream, dev.device_id); // commit an empty command buffer and wait until it completes. id cb = s->GetCommandBuffer(); [cb commit]; @@ -293,6 +307,8 @@ int GetWarpSize(id dev) { } void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) { + ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << dev.device_id; + ICHECK(stream != nullptr); MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = static_cast(stream); } @@ -337,6 +353,10 @@ int GetWarpSize(id dev) { *rv = static_cast(ptr); }); +TVM_REGISTER_GLOBAL("metal.ResetGlobalState").set_body_typed([]() { + MetalWorkspace::Global()->ReinitializeStreams(); +}); + } // namespace metal } // namespace runtime } // namespace tvm