Skip to content

Commit

Permalink
[RUNTIME][REFACTOR] Use new to avoid exit-time de-allocation order pr…
Browse files Browse the repository at this point in the history
…oblem in DeviceAPI
  • Loading branch information
tqchen committed Aug 18, 2020
1 parent 4644991 commit b68f759
Show file tree
Hide file tree
Showing 19 changed files with 58 additions and 50 deletions.
4 changes: 2 additions & 2 deletions src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ class DeviceAPIManager {
DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); }
// Global static variable.
static DeviceAPIManager* Global() {
static DeviceAPIManager inst;
return &inst;
static DeviceAPIManager* inst = new DeviceAPIManager();
return inst;
}
// Get or initialize API.
DeviceAPI* GetAPI(int type, bool allow_missing) {
Expand Down
8 changes: 5 additions & 3 deletions src/runtime/cpu_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ class CPUDeviceAPI final : public DeviceAPI {
void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final;

static const std::shared_ptr<CPUDeviceAPI>& Global() {
static std::shared_ptr<CPUDeviceAPI> inst = std::make_shared<CPUDeviceAPI>();
static CPUDeviceAPI* Global() {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
static auto* inst = new CPUDeviceAPI();
return inst;
}
};
Expand All @@ -99,7 +101,7 @@ void CPUDeviceAPI::FreeWorkspace(TVMContext ctx, void* data) {
}

TVM_REGISTER_GLOBAL("device_api.cpu").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = CPUDeviceAPI::Global().get();
DeviceAPI* ptr = CPUDeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
Expand Down
10 changes: 6 additions & 4 deletions src/runtime/cuda/cuda_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,10 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}

static const std::shared_ptr<CUDADeviceAPI>& Global() {
static std::shared_ptr<CUDADeviceAPI> inst = std::make_shared<CUDADeviceAPI>();
static CUDADeviceAPI* Global() {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
static auto* inst = new CUDADeviceAPI();
return inst;
}

Expand All @@ -230,12 +232,12 @@ CUDAThreadEntry::CUDAThreadEntry() : pool(kDLGPU, CUDADeviceAPI::Global()) {}
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); }

TVM_REGISTER_GLOBAL("device_api.gpu").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global().get();
DeviceAPI* ptr = CUDADeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});

TVM_REGISTER_GLOBAL("device_api.cpu_pinned").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global().get();
DeviceAPI* ptr = CUDADeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
8 changes: 5 additions & 3 deletions src/runtime/hexagon/hexagon_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ class HexagonDeviceAPI : public DeviceAPI {
void* AllocWorkspace(TVMContext ctx, size_t nbytes, DLDataType type_hint = {}) final;
void FreeWorkspace(TVMContext ctx, void* ptr) final;

static const std::shared_ptr<HexagonDeviceAPI>& Global() {
static std::shared_ptr<HexagonDeviceAPI> inst = std::make_shared<HexagonDeviceAPI>();
static HexagonDeviceAPI* Global() {
// NOTE: explicitly use new to avoid destruction of global state
// Global state will be recycled by OS as the process exits.
static HexagonDeviceAPI* inst = new HexagonDeviceAPI();
return inst;
}
};
Expand Down Expand Up @@ -121,7 +123,7 @@ inline void HexagonDeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) {
}

TVM_REGISTER_GLOBAL("device_api.hexagon").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = HexagonDeviceAPI::Global().get();
DeviceAPI* ptr = HexagonDeviceAPI::Global();
*rv = ptr;
});
} // namespace runtime
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/metal/metal_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class MetalWorkspace final : public DeviceAPI {
void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
// get the global workspace
static const std::shared_ptr<MetalWorkspace>& Global();
static MetalWorkspace* Global();
};

/*! \brief Thread local workspace */
Expand Down
8 changes: 5 additions & 3 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
namespace runtime {
namespace metal {

const std::shared_ptr<MetalWorkspace>& MetalWorkspace::Global() {
static std::shared_ptr<MetalWorkspace> inst = std::make_shared<MetalWorkspace>();
const MetalWorkspace* MetalWorkspace::Global() {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
static MetalWorkspace* inst = new MetalWorkspace();
return inst;
}

Expand Down Expand Up @@ -273,7 +275,7 @@ int GetWarpSize(id<MTLDevice> dev) {
MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); }

TVM_REGISTER_GLOBAL("device_api.metal").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = MetalWorkspace::Global().get();
DeviceAPI* ptr = MetalWorkspace::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
6 changes: 3 additions & 3 deletions src/runtime/micro/micro_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ class MicroDeviceAPI final : public DeviceAPI {
* \brief obtain a global singleton of MicroDeviceAPI
* \return global shared pointer to MicroDeviceAPI
*/
static const std::shared_ptr<MicroDeviceAPI>& Global() {
static std::shared_ptr<MicroDeviceAPI> inst = std::make_shared<MicroDeviceAPI>();
static MicroDeviceAPI* Global() {
static MicroDeviceAPI* inst = new MicroDeviceAPI();
return inst;
}

Expand All @@ -155,7 +155,7 @@ class MicroDeviceAPI final : public DeviceAPI {

// register device that can be obtained from Python frontend
TVM_REGISTER_GLOBAL("device_api.micro_dev").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = MicroDeviceAPI::Global().get();
DeviceAPI* ptr = MicroDeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/opencl/aocl/aocl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class AOCLWorkspace final : public OpenCLWorkspace {
bool IsOpenCLDevice(TVMContext ctx) final;
OpenCLThreadEntry* GetThreadEntry() final;
// get the global workspace
static const std::shared_ptr<OpenCLWorkspace>& Global();
static OpenCLWorkspace* Global();
};

/*! \brief Thread local workspace for AOCL */
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/opencl/aocl/aocl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ namespace cl {

OpenCLThreadEntry* AOCLWorkspace::GetThreadEntry() { return AOCLThreadEntry::ThreadLocal(); }

const std::shared_ptr<OpenCLWorkspace>& AOCLWorkspace::Global() {
static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<AOCLWorkspace>();
OpenCLWorkspace* AOCLWorkspace::Global() {
static OpenCLWorkspace* inst = new AOCLWorkspace();
return inst;
}

Expand All @@ -49,7 +49,7 @@ typedef dmlc::ThreadLocalStore<AOCLThreadEntry> AOCLThreadStore;
AOCLThreadEntry* AOCLThreadEntry::ThreadLocal() { return AOCLThreadStore::Get(); }

TVM_REGISTER_GLOBAL("device_api.aocl").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = AOCLWorkspace::Global().get();
DeviceAPI* ptr = AOCLWorkspace::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
4 changes: 2 additions & 2 deletions src/runtime/opencl/opencl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class OpenCLWorkspace : public DeviceAPI {
virtual OpenCLThreadEntry* GetThreadEntry();

// get the global workspace
static const std::shared_ptr<OpenCLWorkspace>& Global();
static OpenCLWorkspace* Global();
};

/*! \brief Thread local workspace */
Expand All @@ -265,7 +265,7 @@ class OpenCLThreadEntry {
/*! \brief workspace pool */
WorkspacePool pool;
// constructor
OpenCLThreadEntry(DLDeviceType device_type, std::shared_ptr<DeviceAPI> device)
OpenCLThreadEntry(DLDeviceType device_type, DeviceAPI* device)
: pool(device_type, device) {
context.device_id = 0;
context.device_type = device_type;
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/opencl/opencl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ namespace cl {

OpenCLThreadEntry* OpenCLWorkspace::GetThreadEntry() { return OpenCLThreadEntry::ThreadLocal(); }

const std::shared_ptr<OpenCLWorkspace>& OpenCLWorkspace::Global() {
static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<OpenCLWorkspace>();
OpenCLWorkspace* OpenCLWorkspace::Global() {
static OpenCLWorkspace* inst = new OpenCLWorkspace();
return inst;
}

Expand Down Expand Up @@ -276,7 +276,7 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic
}

TVM_REGISTER_GLOBAL("device_api.opencl").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = OpenCLWorkspace::Global().get();
DeviceAPI* ptr = OpenCLWorkspace::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
2 changes: 1 addition & 1 deletion src/runtime/opencl/sdaccel/sdaccel_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class SDAccelWorkspace final : public OpenCLWorkspace {
bool IsOpenCLDevice(TVMContext ctx) final;
OpenCLThreadEntry* GetThreadEntry() final;
// get the global workspace
static const std::shared_ptr<OpenCLWorkspace>& Global();
static OpenCLWorkspace* Global();
};

/*! \brief Thread local workspace for SDAccel*/
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/opencl/sdaccel/sdaccel_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ namespace cl {

OpenCLThreadEntry* SDAccelWorkspace::GetThreadEntry() { return SDAccelThreadEntry::ThreadLocal(); }

const std::shared_ptr<OpenCLWorkspace>& SDAccelWorkspace::Global() {
static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<SDAccelWorkspace>();
OpenCLWorkspace* SDAccelWorkspace::Global() {
static OpenCLWorkspace* inst = new SDAccelWorkspace();
return inst;
}

Expand All @@ -47,7 +47,7 @@ typedef dmlc::ThreadLocalStore<SDAccelThreadEntry> SDAccelThreadStore;
SDAccelThreadEntry* SDAccelThreadEntry::ThreadLocal() { return SDAccelThreadStore::Get(); }

TVM_REGISTER_GLOBAL("device_api.sdaccel").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = SDAccelWorkspace::Global().get();
DeviceAPI* ptr = SDAccelWorkspace::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
6 changes: 3 additions & 3 deletions src/runtime/rocm/rocm_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ class ROCMDeviceAPI final : public DeviceAPI {
ROCMThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}

static const std::shared_ptr<ROCMDeviceAPI>& Global() {
static std::shared_ptr<ROCMDeviceAPI> inst = std::make_shared<ROCMDeviceAPI>();
static ROCMDeviceAPI* Global() {
static ROCMDeviceAPI* inst = new ROCMDeviceAPI();
return inst;
}

Expand All @@ -197,7 +197,7 @@ ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {}
ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); }

TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
DeviceAPI* ptr = ROCMDeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,8 @@ class VulkanDeviceAPI final : public DeviceAPI {
VulkanThreadEntry::ThreadLocal()->pool->FreeWorkspace(ctx, data);
}

static const std::shared_ptr<VulkanDeviceAPI>& Global() {
static std::shared_ptr<VulkanDeviceAPI> inst = std::make_shared<VulkanDeviceAPI>();
static VulkanDeviceAPI* Global() {
static VulkanDeviceAPI* inst = new VulkanDeviceAPI();
return inst;
}

Expand Down Expand Up @@ -1159,7 +1159,7 @@ TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModul
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary);

TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = VulkanDeviceAPI::Global().get();
DeviceAPI* ptr = VulkanDeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
6 changes: 3 additions & 3 deletions src/runtime/workspace_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class WorkspacePool::Pool {
std::vector<Entry> allocated_;
};

WorkspacePool::WorkspacePool(DLDeviceType device_type, std::shared_ptr<DeviceAPI> device)
WorkspacePool::WorkspacePool(DLDeviceType device_type, DeviceAPI* device)
: device_type_(device_type), device_(device) {}

WorkspacePool::~WorkspacePool() {
Expand All @@ -143,7 +143,7 @@ WorkspacePool::~WorkspacePool() {
TVMContext ctx;
ctx.device_type = device_type_;
ctx.device_id = static_cast<int>(i);
array_[i]->Release(ctx, device_.get());
array_[i]->Release(ctx, device_);
delete array_[i];
}
}
Expand All @@ -156,7 +156,7 @@ void* WorkspacePool::AllocWorkspace(TVMContext ctx, size_t size) {
if (array_[ctx.device_id] == nullptr) {
array_[ctx.device_id] = new Pool();
}
return array_[ctx.device_id]->Alloc(ctx, device_.get(), size);
return array_[ctx.device_id]->Alloc(ctx, device_, size);
}

void WorkspacePool::FreeWorkspace(TVMContext ctx, void* ptr) {
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/workspace_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ class TVM_DLL WorkspacePool {
/*!
* \brief Create pool with specific device type and device.
* \param device_type The device type.
* \param device The device API.
* \param device_api The device API.
*/
WorkspacePool(DLDeviceType device_type, std::shared_ptr<DeviceAPI> device);
WorkspacePool(DLDeviceType device_type, DeviceAPI* device_api);
/*! \brief destructor */
~WorkspacePool();
/*!
Expand All @@ -73,7 +73,7 @@ class TVM_DLL WorkspacePool {
/*! \brief device type this pool support */
DLDeviceType device_type_;
/*! \brief The device API */
std::shared_ptr<DeviceAPI> device_;
DeviceAPI* device_;
};

} // namespace runtime
Expand Down
6 changes: 3 additions & 3 deletions vta/runtime/device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class VTADeviceAPI final : public DeviceAPI {

void FreeWorkspace(TVMContext ctx, void* data) final;

static const std::shared_ptr<VTADeviceAPI>& Global() {
static std::shared_ptr<VTADeviceAPI> inst = std::make_shared<VTADeviceAPI>();
static VTADeviceAPI* Global() {
static VTADeviceAPI* inst = new VTADeviceAPI();
return inst;
}
};
Expand All @@ -88,7 +88,7 @@ void VTADeviceAPI::FreeWorkspace(TVMContext ctx, void* data) {
static TVM_ATTRIBUTE_UNUSED auto& __register_dev__ =
::tvm::runtime::Registry::Register("device_api.ext_dev", true)
.set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = VTADeviceAPI::Global().get();
DeviceAPI* ptr = VTADeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
Expand Down
6 changes: 3 additions & 3 deletions web/emcc/webgpu_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ class WebGPUDeviceAPI : public DeviceAPI {
WebGPUThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}

static const std::shared_ptr<WebGPUDeviceAPI>& Global() {
static std::shared_ptr<WebGPUDeviceAPI> inst = std::make_shared<WebGPUDeviceAPI>();
static WebGPUDeviceAPI* Global() {
static WebGPUDeviceAPI* inst = new WebGPUDeviceAPI();
return inst;
}

Expand Down Expand Up @@ -222,7 +222,7 @@ Module WebGPUModuleLoadBinary(void* strm) {
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(WebGPUModuleLoadBinary);

TVM_REGISTER_GLOBAL("device_api.webgpu").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = WebGPUDeviceAPI::Global().get();
DeviceAPI* ptr = WebGPUDeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down

0 comments on commit b68f759

Please sign in to comment.