Skip to content

Commit

Permalink
[RUNTIME][VULKAN] Seg fault in WorkspacePool's destructor (apache#5632)…
Browse files Browse the repository at this point in the history
… (apache#5636)

* [RUNTIME][VULKAN] Seg fault in WorkspacePool's destructor (apache#5632)
* fixed this issue by changing WorkspacePool's destruction order

* make line < 100 charactors long
  • Loading branch information
samwyi authored and Trevor Morris committed Jun 9, 2020
1 parent 93446e6 commit ca1ade8
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class VulkanThreadEntry {
// the instance and device get destroyed.
// The destruction need to be manually called
// to ensure the destruction order.

pool.reset();
streams_.clear();
for (const auto& kv : staging_buffers_) {
if (!kv.second) {
Expand All @@ -75,7 +77,7 @@ class VulkanThreadEntry {
}

TVMContext ctx;
WorkspacePool pool;
std::unique_ptr<WorkspacePool> pool;
VulkanStream* Stream(size_t device_id);
VulkanStagingBuffer* StagingBuffer(int device_id, size_t size);

Expand Down Expand Up @@ -331,11 +333,11 @@ class VulkanDeviceAPI final : public DeviceAPI {
}

void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final {
return VulkanThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
return VulkanThreadEntry::ThreadLocal()->pool->AllocWorkspace(ctx, size);
}

void FreeWorkspace(TVMContext ctx, void* data) final {
VulkanThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data);
VulkanThreadEntry::ThreadLocal()->pool->FreeWorkspace(ctx, data);
}

static const std::shared_ptr<VulkanDeviceAPI>& Global() {
Expand Down Expand Up @@ -999,7 +1001,8 @@ VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size
}

VulkanThreadEntry::VulkanThreadEntry()
: pool(static_cast<DLDeviceType>(kDLVulkan), VulkanDeviceAPI::Global()) {
: pool(std::make_unique<WorkspacePool>(static_cast<DLDeviceType>(kDLVulkan),
VulkanDeviceAPI::Global())) {
ctx.device_id = 0;
ctx.device_type = static_cast<DLDeviceType>(kDLVulkan);
}
Expand Down

0 comments on commit ca1ade8

Please sign in to comment.