From 02529b0c4b8f9a2c36f0a67398876f63df7a69a9 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Tue, 18 Aug 2020 16:08:03 -0700 Subject: [PATCH] [RUNTIME][REFACTOR] Use new to avoid exit-time de-allocation order problem in DeviceAPI (#6292) --- src/runtime/c_runtime_api.cc | 4 ++-- src/runtime/cpu_device_api.cc | 8 +++++--- src/runtime/cuda/cuda_device_api.cc | 11 ++++++----- src/runtime/hexagon/hexagon_device_api.cc | 8 +++++--- src/runtime/metal/metal_common.h | 2 +- src/runtime/metal/metal_device_api.mm | 8 +++++--- src/runtime/metal/metal_module.mm | 4 ++-- src/runtime/micro/micro_device_api.cc | 6 +++--- src/runtime/opencl/aocl/aocl_common.h | 2 +- src/runtime/opencl/aocl/aocl_device_api.cc | 6 +++--- src/runtime/opencl/aocl/aocl_module.cc | 6 ++---- src/runtime/opencl/opencl_common.h | 9 ++++----- src/runtime/opencl/opencl_device_api.cc | 6 +++--- src/runtime/opencl/opencl_module.cc | 4 ++-- src/runtime/opencl/sdaccel/sdaccel_common.h | 2 +- src/runtime/opencl/sdaccel/sdaccel_device_api.cc | 6 +++--- src/runtime/opencl/sdaccel/sdaccel_module.cc | 4 ++-- src/runtime/rocm/rocm_device_api.cc | 6 +++--- src/runtime/vulkan/vulkan.cc | 6 +++--- src/runtime/workspace_pool.cc | 6 +++--- src/runtime/workspace_pool.h | 6 +++--- vta/runtime/device_api.cc | 6 +++--- web/emcc/webgpu_runtime.cc | 6 +++--- 23 files changed, 68 insertions(+), 64 deletions(-) diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 7320e3d72a1c..0794348412cb 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -106,8 +106,8 @@ class DeviceAPIManager { ~DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); } // Global static variable. static DeviceAPIManager* Global() { - static DeviceAPIManager inst; - return &inst; + static DeviceAPIManager* inst = new DeviceAPIManager(); + return inst; } // Get or initialize API. DeviceAPI* GetAPI(int type, bool allow_missing) { diff --git a/src/runtime/cpu_device_api.cc b/src/runtime/cpu_device_api.cc index c70a4f29ccbe..5474b758ca9c 100644 --- a/src/runtime/cpu_device_api.cc +++ b/src/runtime/cpu_device_api.cc @@ -80,8 +80,10 @@ class CPUDeviceAPI final : public DeviceAPI { void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final; void FreeWorkspace(TVMContext ctx, void* data) final; - static const std::shared_ptr& Global() { - static std::shared_ptr inst = std::make_shared(); + static CPUDeviceAPI* Global() { + // NOTE: explicitly use new to avoid exit-time destruction of global state + // Global state will be recycled by OS as the process exits. + static auto* inst = new CPUDeviceAPI(); return inst; } }; @@ -99,7 +101,7 @@ void CPUDeviceAPI::FreeWorkspace(TVMContext ctx, void* data) { } TVM_REGISTER_GLOBAL("device_api.cpu").set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = CPUDeviceAPI::Global().get(); + DeviceAPI* ptr = CPUDeviceAPI::Global(); *rv = static_cast(ptr); }); } // namespace runtime diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 33c48c70674c..b69ecf26808e 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -207,9 +207,10 @@ class CUDADeviceAPI final : public DeviceAPI { CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data); } - static const std::shared_ptr& Global() { - static std::shared_ptr inst = std::make_shared(); - if (inst.use_count() == 0) inst = std::make_shared(); + static CUDADeviceAPI* Global() { + // NOTE: explicitly use new to avoid exit-time destruction of global state + // Global state will be recycled by OS as the process exits. + static auto* inst = new CUDADeviceAPI(); return inst; } @@ -231,12 +232,12 @@ CUDAThreadEntry::CUDAThreadEntry() : pool(kDLGPU, CUDADeviceAPI::Global()) {} CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); } TVM_REGISTER_GLOBAL("device_api.gpu").set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = CUDADeviceAPI::Global().get(); + DeviceAPI* ptr = CUDADeviceAPI::Global(); *rv = static_cast(ptr); }); TVM_REGISTER_GLOBAL("device_api.cpu_pinned").set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = CUDADeviceAPI::Global().get(); + DeviceAPI* ptr = CUDADeviceAPI::Global(); *rv = static_cast(ptr); }); diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index fd6f32374005..a89015707f99 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -42,8 +42,10 @@ class HexagonDeviceAPI : public DeviceAPI { void* AllocWorkspace(TVMContext ctx, size_t nbytes, DLDataType type_hint = {}) final; void FreeWorkspace(TVMContext ctx, void* ptr) final; - static const std::shared_ptr& Global() { - static std::shared_ptr inst = std::make_shared(); + static HexagonDeviceAPI* Global() { + // NOTE: explicitly use new to avoid destruction of global state + // Global state will be recycled by OS as the process exits. + static HexagonDeviceAPI* inst = new HexagonDeviceAPI(); return inst; } }; @@ -121,7 +123,7 @@ inline void HexagonDeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) { } TVM_REGISTER_GLOBAL("device_api.hexagon").set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = HexagonDeviceAPI::Global().get(); + DeviceAPI* ptr = HexagonDeviceAPI::Global(); *rv = ptr; }); } // namespace runtime diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h index ca369d46e5ba..634ee305153b 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/runtime/metal/metal_common.h @@ -91,7 +91,7 @@ class MetalWorkspace final : public DeviceAPI { void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final; void FreeWorkspace(TVMContext ctx, void* data) final; // get the global workspace - static const std::shared_ptr& Global(); + static MetalWorkspace* Global(); }; /*! \brief Thread local workspace */ diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index f2a2930810e5..fddeadf86f62 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -28,8 +28,10 @@ namespace runtime { namespace metal { -const std::shared_ptr& MetalWorkspace::Global() { - static std::shared_ptr inst = std::make_shared(); +MetalWorkspace* MetalWorkspace::Global() { + // NOTE: explicitly use new to avoid exit-time destruction of global state + // Global state will be recycled by OS as the process exits. + static MetalWorkspace* inst = new MetalWorkspace(); return inst; } @@ -273,7 +275,7 @@ int GetWarpSize(id dev) { MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); } TVM_REGISTER_GLOBAL("device_api.metal").set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = MetalWorkspace::Global().get(); + DeviceAPI* ptr = MetalWorkspace::Global(); *rv = static_cast(ptr); }); diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 9bdebf3d06c1..8d10ff210d8d 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -73,7 +73,7 @@ void SaveToBinary(dmlc::Stream* stream) final { } // get a from primary context in device_id id GetPipelineState(size_t device_id, const std::string& func_name) { - metal::MetalWorkspace* w = metal::MetalWorkspace::Global().get(); + metal::MetalWorkspace* w = metal::MetalWorkspace::Global(); CHECK_LT(device_id, w->devices.size()); // start lock scope. std::lock_guard lock(mutex_); @@ -168,7 +168,7 @@ void SaveToBinary(dmlc::Stream* stream) final { void Init(MetalModuleNode* m, ObjectPtr sptr, const std::string& func_name, size_t num_buffer_args, size_t num_pack_args, const std::vector& thread_axis_tags) { - w_ = metal::MetalWorkspace::Global().get(); + w_ = metal::MetalWorkspace::Global(); m_ = m; sptr_ = sptr; func_name_ = func_name; diff --git a/src/runtime/micro/micro_device_api.cc b/src/runtime/micro/micro_device_api.cc index 68480786ac87..3812ec072cd8 100644 --- a/src/runtime/micro/micro_device_api.cc +++ b/src/runtime/micro/micro_device_api.cc @@ -140,8 +140,8 @@ class MicroDeviceAPI final : public DeviceAPI { * \brief obtain a global singleton of MicroDeviceAPI * \return global shared pointer to MicroDeviceAPI */ - static const std::shared_ptr& Global() { - static std::shared_ptr inst = std::make_shared(); + static MicroDeviceAPI* Global() { + static MicroDeviceAPI* inst = new MicroDeviceAPI(); return inst; } @@ -155,7 +155,7 @@ class MicroDeviceAPI final : public DeviceAPI { // register device that can be obtained from Python frontend TVM_REGISTER_GLOBAL("device_api.micro_dev").set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = MicroDeviceAPI::Global().get(); + DeviceAPI* ptr = MicroDeviceAPI::Global(); *rv = static_cast(ptr); }); } // namespace runtime diff --git a/src/runtime/opencl/aocl/aocl_common.h b/src/runtime/opencl/aocl/aocl_common.h index 1b98d4b2d221..ae1a4a8cc31f 100644 --- a/src/runtime/opencl/aocl/aocl_common.h +++ b/src/runtime/opencl/aocl/aocl_common.h @@ -42,7 +42,7 @@ class AOCLWorkspace final : public OpenCLWorkspace { bool IsOpenCLDevice(TVMContext ctx) final; OpenCLThreadEntry* GetThreadEntry() final; // get the global workspace - static const std::shared_ptr& Global(); + static OpenCLWorkspace* Global(); }; /*! \brief Thread local workspace for AOCL */ diff --git a/src/runtime/opencl/aocl/aocl_device_api.cc b/src/runtime/opencl/aocl/aocl_device_api.cc index 07057ff29716..5432507087ca 100644 --- a/src/runtime/opencl/aocl/aocl_device_api.cc +++ b/src/runtime/opencl/aocl/aocl_device_api.cc @@ -31,8 +31,8 @@ namespace cl { OpenCLThreadEntry* AOCLWorkspace::GetThreadEntry() { return AOCLThreadEntry::ThreadLocal(); } -const std::shared_ptr& AOCLWorkspace::Global() { - static std::shared_ptr inst = std::make_shared(); +OpenCLWorkspace* AOCLWorkspace::Global() { + static OpenCLWorkspace* inst = new AOCLWorkspace(); return inst; } @@ -49,7 +49,7 @@ typedef dmlc::ThreadLocalStore AOCLThreadStore; AOCLThreadEntry* AOCLThreadEntry::ThreadLocal() { return AOCLThreadStore::Get(); } TVM_REGISTER_GLOBAL("device_api.aocl").set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = AOCLWorkspace::Global().get(); + DeviceAPI* ptr = AOCLWorkspace::Global(); *rv = static_cast(ptr); }); diff --git a/src/runtime/opencl/aocl/aocl_module.cc b/src/runtime/opencl/aocl/aocl_module.cc index 747188cf7b2d..cb8653356169 100644 --- a/src/runtime/opencl/aocl/aocl_module.cc +++ b/src/runtime/opencl/aocl/aocl_module.cc @@ -39,12 +39,10 @@ class AOCLModuleNode : public OpenCLModuleNode { explicit AOCLModuleNode(std::string data, std::string fmt, std::unordered_map fmap, std::string source) : OpenCLModuleNode(data, fmt, fmap, source) {} - const std::shared_ptr& GetGlobalWorkspace() final; + cl::OpenCLWorkspace* GetGlobalWorkspace() final; }; -const std::shared_ptr& AOCLModuleNode::GetGlobalWorkspace() { - return cl::AOCLWorkspace::Global(); -} +cl::OpenCLWorkspace* AOCLModuleNode::GetGlobalWorkspace() { return cl::AOCLWorkspace::Global(); } Module AOCLModuleCreate(std::string data, std::string fmt, std::unordered_map fmap, std::string source) { diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index a892bff75342..aab0c27cb39b 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -245,7 +245,7 @@ class OpenCLWorkspace : public DeviceAPI { virtual OpenCLThreadEntry* GetThreadEntry(); // get the global workspace - static const std::shared_ptr& Global(); + static OpenCLWorkspace* Global(); }; /*! \brief Thread local workspace */ @@ -265,8 +265,7 @@ class OpenCLThreadEntry { /*! \brief workspace pool */ WorkspacePool pool; // constructor - OpenCLThreadEntry(DLDeviceType device_type, std::shared_ptr device) - : pool(device_type, device) { + OpenCLThreadEntry(DLDeviceType device_type, DeviceAPI* device) : pool(device_type, device) { context.device_id = 0; context.device_type = device_type; } @@ -298,7 +297,7 @@ class OpenCLModuleNode : public ModuleNode { /*! * \brief Get the global workspace */ - virtual const std::shared_ptr& GetGlobalWorkspace(); + virtual cl::OpenCLWorkspace* GetGlobalWorkspace(); const char* type_key() const final { return workspace_->type_key.c_str(); } @@ -315,7 +314,7 @@ class OpenCLModuleNode : public ModuleNode { private: // The workspace, need to keep reference to use it in destructor. // In case of static destruction order problem. - std::shared_ptr workspace_; + cl::OpenCLWorkspace* workspace_; // the binary data std::string data_; // The format diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 5753c1d0f76b..83944cd4a83e 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -31,8 +31,8 @@ namespace cl { OpenCLThreadEntry* OpenCLWorkspace::GetThreadEntry() { return OpenCLThreadEntry::ThreadLocal(); } -const std::shared_ptr& OpenCLWorkspace::Global() { - static std::shared_ptr inst = std::make_shared(); +OpenCLWorkspace* OpenCLWorkspace::Global() { + static OpenCLWorkspace* inst = new OpenCLWorkspace(); return inst; } @@ -276,7 +276,7 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic } TVM_REGISTER_GLOBAL("device_api.opencl").set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = OpenCLWorkspace::Global().get(); + DeviceAPI* ptr = OpenCLWorkspace::Global(); *rv = static_cast(ptr); }); diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 95d0481c31d5..590a446efe64 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -40,7 +40,7 @@ class OpenCLWrappedFunc { void Init(OpenCLModuleNode* m, ObjectPtr sptr, OpenCLModuleNode::KTRefEntry entry, std::string func_name, std::vector arg_size, const std::vector& thread_axis_tags) { - w_ = m->GetGlobalWorkspace().get(); + w_ = m->GetGlobalWorkspace(); m_ = m; sptr_ = sptr; entry_ = entry; @@ -110,7 +110,7 @@ OpenCLModuleNode::~OpenCLModuleNode() { } } -const std::shared_ptr& OpenCLModuleNode::GetGlobalWorkspace() { +cl::OpenCLWorkspace* OpenCLModuleNode::GetGlobalWorkspace() { return cl::OpenCLWorkspace::Global(); } diff --git a/src/runtime/opencl/sdaccel/sdaccel_common.h b/src/runtime/opencl/sdaccel/sdaccel_common.h index 803cbe67b9a7..feeab0bc89ce 100644 --- a/src/runtime/opencl/sdaccel/sdaccel_common.h +++ b/src/runtime/opencl/sdaccel/sdaccel_common.h @@ -42,7 +42,7 @@ class SDAccelWorkspace final : public OpenCLWorkspace { bool IsOpenCLDevice(TVMContext ctx) final; OpenCLThreadEntry* GetThreadEntry() final; // get the global workspace - static const std::shared_ptr& Global(); + static OpenCLWorkspace* Global(); }; /*! \brief Thread local workspace for SDAccel*/ diff --git a/src/runtime/opencl/sdaccel/sdaccel_device_api.cc b/src/runtime/opencl/sdaccel/sdaccel_device_api.cc index 6bac0c916aad..ebe387b1ddb3 100644 --- a/src/runtime/opencl/sdaccel/sdaccel_device_api.cc +++ b/src/runtime/opencl/sdaccel/sdaccel_device_api.cc @@ -31,8 +31,8 @@ namespace cl { OpenCLThreadEntry* SDAccelWorkspace::GetThreadEntry() { return SDAccelThreadEntry::ThreadLocal(); } -const std::shared_ptr& SDAccelWorkspace::Global() { - static std::shared_ptr inst = std::make_shared(); +OpenCLWorkspace* SDAccelWorkspace::Global() { + static OpenCLWorkspace* inst = new SDAccelWorkspace(); return inst; } @@ -47,7 +47,7 @@ typedef dmlc::ThreadLocalStore SDAccelThreadStore; SDAccelThreadEntry* SDAccelThreadEntry::ThreadLocal() { return SDAccelThreadStore::Get(); } TVM_REGISTER_GLOBAL("device_api.sdaccel").set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = SDAccelWorkspace::Global().get(); + DeviceAPI* ptr = SDAccelWorkspace::Global(); *rv = static_cast(ptr); }); diff --git a/src/runtime/opencl/sdaccel/sdaccel_module.cc b/src/runtime/opencl/sdaccel/sdaccel_module.cc index b4edca32a998..36dabd1e0292 100644 --- a/src/runtime/opencl/sdaccel/sdaccel_module.cc +++ b/src/runtime/opencl/sdaccel/sdaccel_module.cc @@ -39,10 +39,10 @@ class SDAccelModuleNode : public OpenCLModuleNode { explicit SDAccelModuleNode(std::string data, std::string fmt, std::unordered_map fmap, std::string source) : OpenCLModuleNode(data, fmt, fmap, source) {} - const std::shared_ptr& GetGlobalWorkspace() final; + cl::OpenCLWorkspace* GetGlobalWorkspace() final; }; -const std::shared_ptr& SDAccelModuleNode::GetGlobalWorkspace() { +cl::OpenCLWorkspace* SDAccelModuleNode::GetGlobalWorkspace() { return cl::SDAccelWorkspace::Global(); } diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index e1a14c7dcf1c..7f5bc99380a4 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -174,8 +174,8 @@ class ROCMDeviceAPI final : public DeviceAPI { ROCMThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data); } - static const std::shared_ptr& Global() { - static std::shared_ptr inst = std::make_shared(); + static ROCMDeviceAPI* Global() { + static ROCMDeviceAPI* inst = new ROCMDeviceAPI(); return inst; } @@ -197,7 +197,7 @@ ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {} ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); } TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = ROCMDeviceAPI::Global().get(); + DeviceAPI* ptr = ROCMDeviceAPI::Global(); *rv = static_cast(ptr); }); } // namespace runtime diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index 9e730b7fd8b1..568672591497 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -340,8 +340,8 @@ class VulkanDeviceAPI final : public DeviceAPI { VulkanThreadEntry::ThreadLocal()->pool->FreeWorkspace(ctx, data); } - static const std::shared_ptr& Global() { - static std::shared_ptr inst = std::make_shared(); + static VulkanDeviceAPI* Global() { + static VulkanDeviceAPI* inst = new VulkanDeviceAPI(); return inst; } @@ -1159,7 +1159,7 @@ TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModul TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary); TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = VulkanDeviceAPI::Global().get(); + DeviceAPI* ptr = VulkanDeviceAPI::Global(); *rv = static_cast(ptr); }); diff --git a/src/runtime/workspace_pool.cc b/src/runtime/workspace_pool.cc index 8ee905e4ea84..49a4c961159d 100644 --- a/src/runtime/workspace_pool.cc +++ b/src/runtime/workspace_pool.cc @@ -134,7 +134,7 @@ class WorkspacePool::Pool { std::vector allocated_; }; -WorkspacePool::WorkspacePool(DLDeviceType device_type, std::shared_ptr device) +WorkspacePool::WorkspacePool(DLDeviceType device_type, DeviceAPI* device) : device_type_(device_type), device_(device) {} WorkspacePool::~WorkspacePool() { @@ -143,7 +143,7 @@ WorkspacePool::~WorkspacePool() { TVMContext ctx; ctx.device_type = device_type_; ctx.device_id = static_cast(i); - array_[i]->Release(ctx, device_.get()); + array_[i]->Release(ctx, device_); delete array_[i]; } } @@ -156,7 +156,7 @@ void* WorkspacePool::AllocWorkspace(TVMContext ctx, size_t size) { if (array_[ctx.device_id] == nullptr) { array_[ctx.device_id] = new Pool(); } - return array_[ctx.device_id]->Alloc(ctx, device_.get(), size); + return array_[ctx.device_id]->Alloc(ctx, device_, size); } void WorkspacePool::FreeWorkspace(TVMContext ctx, void* ptr) { diff --git a/src/runtime/workspace_pool.h b/src/runtime/workspace_pool.h index 288da7d10483..887afc5cbb57 100644 --- a/src/runtime/workspace_pool.h +++ b/src/runtime/workspace_pool.h @@ -47,9 +47,9 @@ class TVM_DLL WorkspacePool { /*! * \brief Create pool with specific device type and device. * \param device_type The device type. - * \param device The device API. + * \param device_api The device API. */ - WorkspacePool(DLDeviceType device_type, std::shared_ptr device); + WorkspacePool(DLDeviceType device_type, DeviceAPI* device_api); /*! \brief destructor */ ~WorkspacePool(); /*! @@ -73,7 +73,7 @@ class TVM_DLL WorkspacePool { /*! \brief device type this pool support */ DLDeviceType device_type_; /*! \brief The device API */ - std::shared_ptr device_; + DeviceAPI* device_; }; } // namespace runtime diff --git a/vta/runtime/device_api.cc b/vta/runtime/device_api.cc index 298403ca840d..0fea7ba5e364 100644 --- a/vta/runtime/device_api.cc +++ b/vta/runtime/device_api.cc @@ -66,8 +66,8 @@ class VTADeviceAPI final : public DeviceAPI { void FreeWorkspace(TVMContext ctx, void* data) final; - static const std::shared_ptr& Global() { - static std::shared_ptr inst = std::make_shared(); + static VTADeviceAPI* Global() { + static VTADeviceAPI* inst = new VTADeviceAPI(); return inst; } }; @@ -88,7 +88,7 @@ void VTADeviceAPI::FreeWorkspace(TVMContext ctx, void* data) { static TVM_ATTRIBUTE_UNUSED auto& __register_dev__ = ::tvm::runtime::Registry::Register("device_api.ext_dev", true) .set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = VTADeviceAPI::Global().get(); + DeviceAPI* ptr = VTADeviceAPI::Global(); *rv = static_cast(ptr); }); } // namespace runtime diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc index 7f0b0d9f72cb..54601e37d037 100644 --- a/web/emcc/webgpu_runtime.cc +++ b/web/emcc/webgpu_runtime.cc @@ -132,8 +132,8 @@ class WebGPUDeviceAPI : public DeviceAPI { WebGPUThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data); } - static const std::shared_ptr& Global() { - static std::shared_ptr inst = std::make_shared(); + static WebGPUDeviceAPI* Global() { + static WebGPUDeviceAPI* inst = new WebGPUDeviceAPI(); return inst; } @@ -222,7 +222,7 @@ Module WebGPUModuleLoadBinary(void* strm) { TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(WebGPUModuleLoadBinary); TVM_REGISTER_GLOBAL("device_api.webgpu").set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = WebGPUDeviceAPI::Global().get(); + DeviceAPI* ptr = WebGPUDeviceAPI::Global(); *rv = static_cast(ptr); });