Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Vulkan][Refactor] Move ownership of per-CPU-thread objects to VulkanDeviceAPI #8196

Merged
merged 6 commits into from
Jun 11, 2021
175 changes: 175 additions & 0 deletions src/runtime/thread_map.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef TVM_RUNTIME_THREAD_MAP_H_
#define TVM_RUNTIME_THREAD_MAP_H_

#include <functional>
#include <memory>
#include <mutex>
#include <shared_mutex>
#include <thread>
#include <unordered_map>
#include <utility>

namespace tvm {
namespace runtime {

/*! \brief Container to hold one value per thread
*
* Similar to thread_local, but intended for use as a non-static or
* non-block variable, such as class member variables. All member
* functions are thread-safe to call. If only the current thread's
* value is accessed, no additional synchronization is required. If
* another thread's stored values are accessed, external
* synchronization may be required.
*
* Calls that only require access to already-existing values will not
* block each other. Calls that require constructing a new value will
* block any other calls.
*
* \tparam T The object type to be held. For instantiation of
* ThreadMap<T> and for calls to ThreadMap<T>::Get, only a forward
* declaration is required. For calls to ThreadMap<T>::GetOrMake, a
* full class definition is required.
*/
template <typename T>
class ThreadMap {
public:
ThreadMap() {}

/*! \brief Return the current thread's stored object, if it exists.
*
* \return If it exists, a pointer to the stored object. Otherwise,
* returns nullptr.
*/
const T* Get() const { return this->Get(std::this_thread::get_id()); }

/*! \brief Return the stored object for a given thread, if it exists.
*
* \param id The thread whose object should be returned.
*
* \return If it exists, a pointer to the stored object. Otherwise,
* returns nullptr.
*/
const T* Get(std::thread::id id) const {
std::shared_lock<std::shared_timed_mutex> lock(mutex_);
auto res = values_.find(id);
if (res == values_.end()) {
return nullptr;
} else {
return res->second.get();
}
}

/*! \brief Return the current thread's stored object, if it exists.
*
* \return If it exists, a pointer to the stored object. Otherwise,
* returns nullptr.
*/
T* Get() { return const_cast<T*>(const_cast<const ThreadMap<T>*>(this)->Get()); }

/*! \brief Return the stored object for a given thread, if it exists.
*
* \param id The thread whose object should be returned.
*
* \return If it exists, a pointer to the stored object. Otherwise,
* returns nullptr.
*/
T* Get(std::thread::id id) {
return const_cast<T*>(const_cast<const ThreadMap<T>*>(this)->Get(id));
}

/*! \brief Return the current thread's stored object, making it if
* necessary.
*
* Since this method can modify the stored map, there is no
* non-const version available.
*
* \tparam Params Types of the stored object's constructor arguments
*
* \return A reference to the stored object
*/
template <typename... Params>
T& GetOrMake(Params&&... params) {
return GetOrMake(std::this_thread::get_id(), std::forward<Params>(params)...);
}

/*! \brief Return the stored object for a given thread, making it if
* necessary
*
* Since this method can modify the stored map, there is no
* non-const version available.
*
* \tparam Params Types of the stored object's constructor arguments
*
* \param id The thread whose object should be returned.
*
* \param params Arguments to the stored object's constructor. Only
* used if the specified thread does not currently exist in the map.
*
* \return A reference to the stored object
*/
template <typename... Params>
T& GetOrMake(std::thread::id id, Params&&... params) {
// Try to get stored value first, which would only require shared
// access.
if (T* output = Get(id)) {
return *output;
}

// Not in map, need exclusive lock to write
std::unique_lock<std::shared_timed_mutex> lock(mutex_);

// Check again, in case another thread got the unique lock first
// and already constructed the object.
auto res = values_.find(id);
if (res != values_.end()) {
return *res->second;
}

// No value exists, make one and return it.
std::unique_ptr<T>& new_val = values_[id] =
std::make_unique<T>(std::forward<Params>(params)...);
return *new_val;
}

/*! \brief Clears all values held by the ThreadMap
*
* Calling Clear() invalidates any pointers/references previously
* returned by Get/GetOrMake.
*
*/
void Clear() {
std::unique_lock<std::shared_timed_mutex> lock(mutex_);
values_.clear();
}

private:
//! \brief Mutex to protect values_
mutable std::shared_timed_mutex mutex_;

//! \brief Map containing stored values
std::unordered_map<std::thread::id, std::unique_ptr<T>> values_;
};

} // namespace runtime
} // namespace tvm

#endif // TVM_RUNTIME_THREAD_MAP_H_
124 changes: 111 additions & 13 deletions src/runtime/vulkan/vulkan_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,125 @@

#include "vulkan_buffer.h"

#include <utility>

#include "vulkan_device_api.h"
#include "vulkan_thread_entry.h"

namespace tvm {
namespace runtime {
namespace vulkan {

void DeleteHostVisibleBuffer(VulkanHostVisibleBuffer* buf) {
if (buf && buf->vk_buf) {
if (buf->host_addr != nullptr) {
vkUnmapMemory(buf->device, buf->vk_buf->memory);
}
if (buf->vk_buf->memory != VK_NULL_HANDLE) {
vkFreeMemory(buf->device, buf->vk_buf->memory, nullptr);
}
if (buf->vk_buf->buffer != VK_NULL_HANDLE) {
vkDestroyBuffer(buf->device, buf->vk_buf->buffer, nullptr);
VkBufferCreateInfo MakeBufferCreateInfo(size_t nbytes, VkBufferUsageFlags usage) {
VkBufferCreateInfo info = {VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO};
info.size = nbytes;
// Since sharingMode is not VK_SHARING_MODE_CONCURRENT, no need to
// specify the queue families.
info.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
info.usage = usage;
return info;
}

VulkanBuffer::VulkanBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage,
uint32_t mem_type_index)
: device_(device) {
// Create a buffer
VkBufferCreateInfo buffer_info = MakeBufferCreateInfo(nbytes, usage);
VULKAN_CALL(vkCreateBuffer(device, &buffer_info, nullptr, &buffer));

// Allocate memory
VkMemoryAllocateInfo mem_info = {VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO};
mem_info.allocationSize = buffer_info.size;
mem_info.memoryTypeIndex = mem_type_index;

VkMemoryDedicatedAllocateInfoKHR dedicated_info = {
VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR};

bool use_dedicated_allocation = UseDedicatedAllocation(device, buffer, &mem_info.allocationSize);
if (use_dedicated_allocation) {
dedicated_info.buffer = buffer;
mem_info.pNext = &dedicated_info;
}

VULKAN_CALL(vkAllocateMemory(device, &mem_info, nullptr, &memory));

// Bind the buffer to the allocated memory
VULKAN_CALL(vkBindBufferMemory(device, buffer, memory, 0));
}

VulkanBuffer::~VulkanBuffer() {
if (buffer) {
vkDestroyBuffer(device_, buffer, nullptr);
}
if (memory) {
vkFreeMemory(device_, memory, nullptr);
}
}

VulkanBuffer::VulkanBuffer(VulkanBuffer&& other)
: device_(other.device_), buffer(other.buffer), memory(other.memory) {
other.device_ = VK_NULL_HANDLE;
other.buffer = VK_NULL_HANDLE;
other.memory = VK_NULL_HANDLE;
}

VulkanBuffer& VulkanBuffer::operator=(VulkanBuffer&& other) {
std::swap(device_, other.device_);
std::swap(buffer, other.buffer);
std::swap(memory, other.memory);
return *this;
}

bool VulkanBuffer::UseDedicatedAllocation(const VulkanDevice& device, VkBuffer buffer,
VkDeviceSize* nbytes) {
if (device.get_buffer_memory_requirements_2_functions) {
// Which buffer to request information about
VkBufferMemoryRequirementsInfo2KHR req_info2 = {
VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR};
req_info2.buffer = buffer;

// What information to request
VkMemoryDedicatedRequirementsKHR dedicated_req;
dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR;
dedicated_req.pNext = 0;

VkMemoryRequirements2KHR req2 = {VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR};
req2.pNext = &dedicated_req;

device.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR(
device, &req_info2, &req2);
if (dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation) {
*nbytes = req2.memoryRequirements.size;
return true;
}
buf->host_addr = nullptr;
delete buf->vk_buf;
}

return false;
}

VulkanHostVisibleBuffer::VulkanHostVisibleBuffer(const VulkanDevice& device, size_t nbytes,
VkBufferUsageFlags usage, uint32_t mem_type_index)
: vk_buf(device, nbytes, usage, mem_type_index), size(nbytes) {
VULKAN_CALL(vkMapMemory(device, vk_buf.memory, 0, size, 0, &host_addr));
}

VulkanHostVisibleBuffer::~VulkanHostVisibleBuffer() {
if (host_addr) {
vkUnmapMemory(vk_buf.device_, vk_buf.memory);
}
}

VulkanHostVisibleBuffer::VulkanHostVisibleBuffer(VulkanHostVisibleBuffer&& other)
: vk_buf(std::move(other.vk_buf)), host_addr(other.host_addr), size(other.size) {
other.host_addr = nullptr;
other.size = 0;
}

VulkanHostVisibleBuffer& VulkanHostVisibleBuffer::operator=(VulkanHostVisibleBuffer&& other) {
std::swap(vk_buf, other.vk_buf);
std::swap(host_addr, other.host_addr);
std::swap(size, other.size);

return *this;
}

} // namespace vulkan
Expand Down
Loading