Skip to content

Commit

Permalink
[RUNTIME][REFACTOR] Use new to avoid exit-time de-allocation order pr…
Browse files Browse the repository at this point in the history
…oblem in DeviceAPI (apache#6292)
  • Loading branch information
tqchen authored and Trevor Morris committed Aug 26, 2020
1 parent 42a961d commit addde71
Show file tree
Hide file tree
Showing 23 changed files with 68 additions and 63 deletions.
4 changes: 2 additions & 2 deletions src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ class DeviceAPIManager {
DeviceAPIManager() { std::fill(api_.begin(), api_.end(), nullptr); }
// Global static variable.
static DeviceAPIManager* Global() {
static DeviceAPIManager inst;
return &inst;
static DeviceAPIManager* inst = new DeviceAPIManager();
return inst;
}
// Get or initialize API.
DeviceAPI* GetAPI(int type, bool allow_missing) {
Expand Down
8 changes: 5 additions & 3 deletions src/runtime/cpu_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ class CPUDeviceAPI final : public DeviceAPI {
void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final;

static const std::shared_ptr<CPUDeviceAPI>& Global() {
static std::shared_ptr<CPUDeviceAPI> inst = std::make_shared<CPUDeviceAPI>();
static CPUDeviceAPI* Global() {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
static auto* inst = new CPUDeviceAPI();
return inst;
}
};
Expand All @@ -99,7 +101,7 @@ void CPUDeviceAPI::FreeWorkspace(TVMContext ctx, void* data) {
}

TVM_REGISTER_GLOBAL("device_api.cpu").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = CPUDeviceAPI::Global().get();
DeviceAPI* ptr = CPUDeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
Expand Down
10 changes: 6 additions & 4 deletions src/runtime/cuda/cuda_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,10 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDAThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}

static const std::shared_ptr<CUDADeviceAPI>& Global() {
static std::shared_ptr<CUDADeviceAPI> inst = std::make_shared<CUDADeviceAPI>();
static CUDADeviceAPI* Global() {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
static auto* inst = new CUDADeviceAPI();
return inst;
}

Expand All @@ -230,12 +232,12 @@ CUDAThreadEntry::CUDAThreadEntry() : pool(kDLGPU, CUDADeviceAPI::Global()) {}
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { return CUDAThreadStore::Get(); }

TVM_REGISTER_GLOBAL("device_api.gpu").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global().get();
DeviceAPI* ptr = CUDADeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});

TVM_REGISTER_GLOBAL("device_api.cpu_pinned").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = CUDADeviceAPI::Global().get();
DeviceAPI* ptr = CUDADeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
8 changes: 5 additions & 3 deletions src/runtime/hexagon/hexagon_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ class HexagonDeviceAPI : public DeviceAPI {
void* AllocWorkspace(TVMContext ctx, size_t nbytes, DLDataType type_hint = {}) final;
void FreeWorkspace(TVMContext ctx, void* ptr) final;

static const std::shared_ptr<HexagonDeviceAPI>& Global() {
static std::shared_ptr<HexagonDeviceAPI> inst = std::make_shared<HexagonDeviceAPI>();
static HexagonDeviceAPI* Global() {
// NOTE: explicitly use new to avoid destruction of global state
// Global state will be recycled by OS as the process exits.
static HexagonDeviceAPI* inst = new HexagonDeviceAPI();
return inst;
}
};
Expand Down Expand Up @@ -121,7 +123,7 @@ inline void HexagonDeviceAPI::FreeWorkspace(TVMContext ctx, void* ptr) {
}

TVM_REGISTER_GLOBAL("device_api.hexagon").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = HexagonDeviceAPI::Global().get();
DeviceAPI* ptr = HexagonDeviceAPI::Global();
*rv = ptr;
});
} // namespace runtime
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/metal/metal_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class MetalWorkspace final : public DeviceAPI {
void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final;
// get the global workspace
static const std::shared_ptr<MetalWorkspace>& Global();
static MetalWorkspace* Global();
};

/*! \brief Thread local workspace */
Expand Down
8 changes: 5 additions & 3 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
namespace runtime {
namespace metal {

const std::shared_ptr<MetalWorkspace>& MetalWorkspace::Global() {
static std::shared_ptr<MetalWorkspace> inst = std::make_shared<MetalWorkspace>();
MetalWorkspace* MetalWorkspace::Global() {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
static MetalWorkspace* inst = new MetalWorkspace();
return inst;
}

Expand Down Expand Up @@ -273,7 +275,7 @@ int GetWarpSize(id<MTLDevice> dev) {
MetalThreadEntry* MetalThreadEntry::ThreadLocal() { return MetalThreadStore::Get(); }

TVM_REGISTER_GLOBAL("device_api.metal").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = MetalWorkspace::Global().get();
DeviceAPI* ptr = MetalWorkspace::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
4 changes: 2 additions & 2 deletions src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void SaveToBinary(dmlc::Stream* stream) final {
}
// get a from primary context in device_id
id<MTLComputePipelineState> GetPipelineState(size_t device_id, const std::string& func_name) {
metal::MetalWorkspace* w = metal::MetalWorkspace::Global().get();
metal::MetalWorkspace* w = metal::MetalWorkspace::Global();
CHECK_LT(device_id, w->devices.size());
// start lock scope.
std::lock_guard<std::mutex> lock(mutex_);
Expand Down Expand Up @@ -168,7 +168,7 @@ void SaveToBinary(dmlc::Stream* stream) final {
void Init(MetalModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_name,
size_t num_buffer_args, size_t num_pack_args,
const std::vector<std::string>& thread_axis_tags) {
w_ = metal::MetalWorkspace::Global().get();
w_ = metal::MetalWorkspace::Global();
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/micro/micro_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ class MicroDeviceAPI final : public DeviceAPI {
* \brief obtain a global singleton of MicroDeviceAPI
* \return global shared pointer to MicroDeviceAPI
*/
static const std::shared_ptr<MicroDeviceAPI>& Global() {
static std::shared_ptr<MicroDeviceAPI> inst = std::make_shared<MicroDeviceAPI>();
static MicroDeviceAPI* Global() {
static MicroDeviceAPI* inst = new MicroDeviceAPI();
return inst;
}

Expand All @@ -155,7 +155,7 @@ class MicroDeviceAPI final : public DeviceAPI {

// register device that can be obtained from Python frontend
TVM_REGISTER_GLOBAL("device_api.micro_dev").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = MicroDeviceAPI::Global().get();
DeviceAPI* ptr = MicroDeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/opencl/aocl/aocl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class AOCLWorkspace final : public OpenCLWorkspace {
bool IsOpenCLDevice(TVMContext ctx) final;
OpenCLThreadEntry* GetThreadEntry() final;
// get the global workspace
static const std::shared_ptr<OpenCLWorkspace>& Global();
static OpenCLWorkspace* Global();
};

/*! \brief Thread local workspace for AOCL */
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/opencl/aocl/aocl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ namespace cl {

OpenCLThreadEntry* AOCLWorkspace::GetThreadEntry() { return AOCLThreadEntry::ThreadLocal(); }

const std::shared_ptr<OpenCLWorkspace>& AOCLWorkspace::Global() {
static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<AOCLWorkspace>();
OpenCLWorkspace* AOCLWorkspace::Global() {
static OpenCLWorkspace* inst = new AOCLWorkspace();
return inst;
}

Expand All @@ -49,7 +49,7 @@ typedef dmlc::ThreadLocalStore<AOCLThreadEntry> AOCLThreadStore;
AOCLThreadEntry* AOCLThreadEntry::ThreadLocal() { return AOCLThreadStore::Get(); }

TVM_REGISTER_GLOBAL("device_api.aocl").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = AOCLWorkspace::Global().get();
DeviceAPI* ptr = AOCLWorkspace::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
6 changes: 2 additions & 4 deletions src/runtime/opencl/aocl/aocl_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,10 @@ class AOCLModuleNode : public OpenCLModuleNode {
explicit AOCLModuleNode(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
: OpenCLModuleNode(data, fmt, fmap, source) {}
const std::shared_ptr<cl::OpenCLWorkspace>& GetGlobalWorkspace() final;
cl::OpenCLWorkspace* GetGlobalWorkspace() final;
};

const std::shared_ptr<cl::OpenCLWorkspace>& AOCLModuleNode::GetGlobalWorkspace() {
return cl::AOCLWorkspace::Global();
}
cl::OpenCLWorkspace* AOCLModuleNode::GetGlobalWorkspace() { return cl::AOCLWorkspace::Global(); }

Module AOCLModuleCreate(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
Expand Down
9 changes: 4 additions & 5 deletions src/runtime/opencl/opencl_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class OpenCLWorkspace : public DeviceAPI {
virtual OpenCLThreadEntry* GetThreadEntry();

// get the global workspace
static const std::shared_ptr<OpenCLWorkspace>& Global();
static OpenCLWorkspace* Global();
};

/*! \brief Thread local workspace */
Expand All @@ -265,8 +265,7 @@ class OpenCLThreadEntry {
/*! \brief workspace pool */
WorkspacePool pool;
// constructor
OpenCLThreadEntry(DLDeviceType device_type, std::shared_ptr<DeviceAPI> device)
: pool(device_type, device) {
OpenCLThreadEntry(DLDeviceType device_type, DeviceAPI* device) : pool(device_type, device) {
context.device_id = 0;
context.device_type = device_type;
}
Expand Down Expand Up @@ -298,7 +297,7 @@ class OpenCLModuleNode : public ModuleNode {
/*!
* \brief Get the global workspace
*/
virtual const std::shared_ptr<cl::OpenCLWorkspace>& GetGlobalWorkspace();
virtual cl::OpenCLWorkspace* GetGlobalWorkspace();

const char* type_key() const final { return workspace_->type_key.c_str(); }

Expand All @@ -315,7 +314,7 @@ class OpenCLModuleNode : public ModuleNode {
private:
// The workspace, need to keep reference to use it in destructor.
// In case of static destruction order problem.
std::shared_ptr<cl::OpenCLWorkspace> workspace_;
cl::OpenCLWorkspace* workspace_;
// the binary data
std::string data_;
// The format
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/opencl/opencl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ namespace cl {

OpenCLThreadEntry* OpenCLWorkspace::GetThreadEntry() { return OpenCLThreadEntry::ThreadLocal(); }

const std::shared_ptr<OpenCLWorkspace>& OpenCLWorkspace::Global() {
static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<OpenCLWorkspace>();
OpenCLWorkspace* OpenCLWorkspace::Global() {
static OpenCLWorkspace* inst = new OpenCLWorkspace();
return inst;
}

Expand Down Expand Up @@ -276,7 +276,7 @@ void OpenCLWorkspace::Init(const std::string& type_key, const std::string& devic
}

TVM_REGISTER_GLOBAL("device_api.opencl").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = OpenCLWorkspace::Global().get();
DeviceAPI* ptr = OpenCLWorkspace::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
4 changes: 2 additions & 2 deletions src/runtime/opencl/opencl_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class OpenCLWrappedFunc {
void Init(OpenCLModuleNode* m, ObjectPtr<Object> sptr, OpenCLModuleNode::KTRefEntry entry,
std::string func_name, std::vector<size_t> arg_size,
const std::vector<std::string>& thread_axis_tags) {
w_ = m->GetGlobalWorkspace().get();
w_ = m->GetGlobalWorkspace();
m_ = m;
sptr_ = sptr;
entry_ = entry;
Expand Down Expand Up @@ -110,7 +110,7 @@ OpenCLModuleNode::~OpenCLModuleNode() {
}
}

const std::shared_ptr<cl::OpenCLWorkspace>& OpenCLModuleNode::GetGlobalWorkspace() {
cl::OpenCLWorkspace* OpenCLModuleNode::GetGlobalWorkspace() {
return cl::OpenCLWorkspace::Global();
}

Expand Down
2 changes: 1 addition & 1 deletion src/runtime/opencl/sdaccel/sdaccel_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class SDAccelWorkspace final : public OpenCLWorkspace {
bool IsOpenCLDevice(TVMContext ctx) final;
OpenCLThreadEntry* GetThreadEntry() final;
// get the global workspace
static const std::shared_ptr<OpenCLWorkspace>& Global();
static OpenCLWorkspace* Global();
};

/*! \brief Thread local workspace for SDAccel*/
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/opencl/sdaccel/sdaccel_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ namespace cl {

OpenCLThreadEntry* SDAccelWorkspace::GetThreadEntry() { return SDAccelThreadEntry::ThreadLocal(); }

const std::shared_ptr<OpenCLWorkspace>& SDAccelWorkspace::Global() {
static std::shared_ptr<OpenCLWorkspace> inst = std::make_shared<SDAccelWorkspace>();
OpenCLWorkspace* SDAccelWorkspace::Global() {
static OpenCLWorkspace* inst = new SDAccelWorkspace();
return inst;
}

Expand All @@ -47,7 +47,7 @@ typedef dmlc::ThreadLocalStore<SDAccelThreadEntry> SDAccelThreadStore;
SDAccelThreadEntry* SDAccelThreadEntry::ThreadLocal() { return SDAccelThreadStore::Get(); }

TVM_REGISTER_GLOBAL("device_api.sdaccel").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = SDAccelWorkspace::Global().get();
DeviceAPI* ptr = SDAccelWorkspace::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
4 changes: 2 additions & 2 deletions src/runtime/opencl/sdaccel/sdaccel_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ class SDAccelModuleNode : public OpenCLModuleNode {
explicit SDAccelModuleNode(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source)
: OpenCLModuleNode(data, fmt, fmap, source) {}
const std::shared_ptr<cl::OpenCLWorkspace>& GetGlobalWorkspace() final;
cl::OpenCLWorkspace* GetGlobalWorkspace() final;
};

const std::shared_ptr<cl::OpenCLWorkspace>& SDAccelModuleNode::GetGlobalWorkspace() {
cl::OpenCLWorkspace* SDAccelModuleNode::GetGlobalWorkspace() {
return cl::SDAccelWorkspace::Global();
}

Expand Down
6 changes: 3 additions & 3 deletions src/runtime/rocm/rocm_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ class ROCMDeviceAPI final : public DeviceAPI {
ROCMThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
}

static const std::shared_ptr<ROCMDeviceAPI>& Global() {
static std::shared_ptr<ROCMDeviceAPI> inst = std::make_shared<ROCMDeviceAPI>();
static ROCMDeviceAPI* Global() {
static ROCMDeviceAPI* inst = new ROCMDeviceAPI();
return inst;
}

Expand All @@ -197,7 +197,7 @@ ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {}
ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); }

TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
DeviceAPI* ptr = ROCMDeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,8 @@ class VulkanDeviceAPI final : public DeviceAPI {
VulkanThreadEntry::ThreadLocal()->pool->FreeWorkspace(ctx, data);
}

static const std::shared_ptr<VulkanDeviceAPI>& Global() {
static std::shared_ptr<VulkanDeviceAPI> inst = std::make_shared<VulkanDeviceAPI>();
static VulkanDeviceAPI* Global() {
static VulkanDeviceAPI* inst = new VulkanDeviceAPI();
return inst;
}

Expand Down Expand Up @@ -1159,7 +1159,7 @@ TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModul
TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary);

TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = VulkanDeviceAPI::Global().get();
DeviceAPI* ptr = VulkanDeviceAPI::Global();
*rv = static_cast<void*>(ptr);
});

Expand Down
6 changes: 3 additions & 3 deletions src/runtime/workspace_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class WorkspacePool::Pool {
std::vector<Entry> allocated_;
};

WorkspacePool::WorkspacePool(DLDeviceType device_type, std::shared_ptr<DeviceAPI> device)
WorkspacePool::WorkspacePool(DLDeviceType device_type, DeviceAPI* device)
: device_type_(device_type), device_(device) {}

WorkspacePool::~WorkspacePool() {
Expand All @@ -143,7 +143,7 @@ WorkspacePool::~WorkspacePool() {
TVMContext ctx;
ctx.device_type = device_type_;
ctx.device_id = static_cast<int>(i);
array_[i]->Release(ctx, device_.get());
array_[i]->Release(ctx, device_);
delete array_[i];
}
}
Expand All @@ -156,7 +156,7 @@ void* WorkspacePool::AllocWorkspace(TVMContext ctx, size_t size) {
if (array_[ctx.device_id] == nullptr) {
array_[ctx.device_id] = new Pool();
}
return array_[ctx.device_id]->Alloc(ctx, device_.get(), size);
return array_[ctx.device_id]->Alloc(ctx, device_, size);
}

void WorkspacePool::FreeWorkspace(TVMContext ctx, void* ptr) {
Expand Down
Loading

0 comments on commit addde71

Please sign in to comment.