diff --git a/src/runtime/thread_map.h b/src/runtime/thread_map.h new file mode 100644 index 000000000000..c3fc7e31e9bd --- /dev/null +++ b/src/runtime/thread_map.h @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_THREAD_MAP_H_ +#define TVM_RUNTIME_THREAD_MAP_H_ + +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { + +/*! \brief Container to hold one value per thread + * + * Similar to thread_local, but intended for use as a non-static or + * non-block variable, such as class member variables. All member + * functions are thread-safe to call. If only the current thread's + * value is accessed, no additional synchronization is required. If + * another thread's stored values are accessed, external + * synchronization may be required. + * + * Calls that only require access to already-existing values will not + * block each other. Calls that require constructing a new value will + * block any other calls. + * + * \tparam T The object type to be held. For instantiation of + * ThreadMap and for calls to ThreadMap::Get, only a forward + * declaration is required. For calls to ThreadMap::GetOrMake, a + * full class definition is required. + */ +template +class ThreadMap { + public: + ThreadMap() {} + + /*! \brief Return the current thread's stored object, if it exists. + * + * \return If it exists, a pointer to the stored object. Otherwise, + * returns nullptr. + */ + const T* Get() const { return this->Get(std::this_thread::get_id()); } + + /*! \brief Return the stored object for a given thread, if it exists. + * + * \param id The thread whose object should be returned. + * + * \return If it exists, a pointer to the stored object. Otherwise, + * returns nullptr. + */ + const T* Get(std::thread::id id) const { + std::shared_lock lock(mutex_); + auto res = values_.find(id); + if (res == values_.end()) { + return nullptr; + } else { + return res->second.get(); + } + } + + /*! \brief Return the current thread's stored object, if it exists. + * + * \return If it exists, a pointer to the stored object. Otherwise, + * returns nullptr. + */ + T* Get() { return const_cast(const_cast*>(this)->Get()); } + + /*! \brief Return the stored object for a given thread, if it exists. + * + * \param id The thread whose object should be returned. + * + * \return If it exists, a pointer to the stored object. Otherwise, + * returns nullptr. + */ + T* Get(std::thread::id id) { + return const_cast(const_cast*>(this)->Get(id)); + } + + /*! \brief Return the current thread's stored object, making it if + * necessary. + * + * Since this method can modify the stored map, there is no + * non-const version available. + * + * \tparam Params Types of the stored object's constructor arguments + * + * \return A reference to the stored object + */ + template + T& GetOrMake(Params&&... params) { + return GetOrMake(std::this_thread::get_id(), std::forward(params)...); + } + + /*! \brief Return the stored object for a given thread, making it if + * necessary + * + * Since this method can modify the stored map, there is no + * non-const version available. + * + * \tparam Params Types of the stored object's constructor arguments + * + * \param id The thread whose object should be returned. + * + * \param params Arguments to the stored object's constructor. Only + * used if the specified thread does not currently exist in the map. + * + * \return A reference to the stored object + */ + template + T& GetOrMake(std::thread::id id, Params&&... params) { + // Try to get stored value first, which would only require shared + // access. + if (T* output = Get(id)) { + return *output; + } + + // Not in map, need exclusive lock to write + std::unique_lock lock(mutex_); + + // Check again, in case another thread got the unique lock first + // and already constructed the object. + auto res = values_.find(id); + if (res != values_.end()) { + return *res->second; + } + + // No value exists, make one and return it. + std::unique_ptr& new_val = values_[id] = + std::make_unique(std::forward(params)...); + return *new_val; + } + + /*! \brief Clears all values held by the ThreadMap + * + * Calling Clear() invalidates any pointers/references previously + * returned by Get/GetOrMake. + * + */ + void Clear() { + std::unique_lock lock(mutex_); + values_.clear(); + } + + private: + //! \brief Mutex to protect values_ + mutable std::shared_timed_mutex mutex_; + + //! \brief Map containing stored values + std::unordered_map> values_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_THREAD_MAP_H_ diff --git a/src/runtime/vulkan/vulkan_buffer.cc b/src/runtime/vulkan/vulkan_buffer.cc index 7059e7c617f4..ef8215c01738 100644 --- a/src/runtime/vulkan/vulkan_buffer.cc +++ b/src/runtime/vulkan/vulkan_buffer.cc @@ -19,27 +19,125 @@ #include "vulkan_buffer.h" +#include + #include "vulkan_device_api.h" -#include "vulkan_thread_entry.h" namespace tvm { namespace runtime { namespace vulkan { -void DeleteHostVisibleBuffer(VulkanHostVisibleBuffer* buf) { - if (buf && buf->vk_buf) { - if (buf->host_addr != nullptr) { - vkUnmapMemory(buf->device, buf->vk_buf->memory); - } - if (buf->vk_buf->memory != VK_NULL_HANDLE) { - vkFreeMemory(buf->device, buf->vk_buf->memory, nullptr); - } - if (buf->vk_buf->buffer != VK_NULL_HANDLE) { - vkDestroyBuffer(buf->device, buf->vk_buf->buffer, nullptr); +VkBufferCreateInfo MakeBufferCreateInfo(size_t nbytes, VkBufferUsageFlags usage) { + VkBufferCreateInfo info = {VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO}; + info.size = nbytes; + // Since sharingMode is not VK_SHARING_MODE_CONCURRENT, no need to + // specify the queue families. + info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + info.usage = usage; + return info; +} + +VulkanBuffer::VulkanBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, + uint32_t mem_type_index) + : device_(device) { + // Create a buffer + VkBufferCreateInfo buffer_info = MakeBufferCreateInfo(nbytes, usage); + VULKAN_CALL(vkCreateBuffer(device, &buffer_info, nullptr, &buffer)); + + // Allocate memory + VkMemoryAllocateInfo mem_info = {VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO}; + mem_info.allocationSize = buffer_info.size; + mem_info.memoryTypeIndex = mem_type_index; + + VkMemoryDedicatedAllocateInfoKHR dedicated_info = { + VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR}; + + bool use_dedicated_allocation = UseDedicatedAllocation(device, buffer, &mem_info.allocationSize); + if (use_dedicated_allocation) { + dedicated_info.buffer = buffer; + mem_info.pNext = &dedicated_info; + } + + VULKAN_CALL(vkAllocateMemory(device, &mem_info, nullptr, &memory)); + + // Bind the buffer to the allocated memory + VULKAN_CALL(vkBindBufferMemory(device, buffer, memory, 0)); +} + +VulkanBuffer::~VulkanBuffer() { + if (buffer) { + vkDestroyBuffer(device_, buffer, nullptr); + } + if (memory) { + vkFreeMemory(device_, memory, nullptr); + } +} + +VulkanBuffer::VulkanBuffer(VulkanBuffer&& other) + : device_(other.device_), buffer(other.buffer), memory(other.memory) { + other.device_ = VK_NULL_HANDLE; + other.buffer = VK_NULL_HANDLE; + other.memory = VK_NULL_HANDLE; +} + +VulkanBuffer& VulkanBuffer::operator=(VulkanBuffer&& other) { + std::swap(device_, other.device_); + std::swap(buffer, other.buffer); + std::swap(memory, other.memory); + return *this; +} + +bool VulkanBuffer::UseDedicatedAllocation(const VulkanDevice& device, VkBuffer buffer, + VkDeviceSize* nbytes) { + if (device.get_buffer_memory_requirements_2_functions) { + // Which buffer to request information about + VkBufferMemoryRequirementsInfo2KHR req_info2 = { + VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR}; + req_info2.buffer = buffer; + + // What information to request + VkMemoryDedicatedRequirementsKHR dedicated_req; + dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; + dedicated_req.pNext = 0; + + VkMemoryRequirements2KHR req2 = {VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR}; + req2.pNext = &dedicated_req; + + device.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( + device, &req_info2, &req2); + if (dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation) { + *nbytes = req2.memoryRequirements.size; + return true; } - buf->host_addr = nullptr; - delete buf->vk_buf; } + + return false; +} + +VulkanHostVisibleBuffer::VulkanHostVisibleBuffer(const VulkanDevice& device, size_t nbytes, + VkBufferUsageFlags usage, uint32_t mem_type_index) + : vk_buf(device, nbytes, usage, mem_type_index), size(nbytes) { + VULKAN_CALL(vkMapMemory(device, vk_buf.memory, 0, size, 0, &host_addr)); +} + +VulkanHostVisibleBuffer::~VulkanHostVisibleBuffer() { + if (host_addr) { + vkUnmapMemory(vk_buf.device_, vk_buf.memory); + } +} + +VulkanHostVisibleBuffer::VulkanHostVisibleBuffer(VulkanHostVisibleBuffer&& other) + : vk_buf(std::move(other.vk_buf)), host_addr(other.host_addr), size(other.size) { + other.host_addr = nullptr; + other.size = 0; +} + +VulkanHostVisibleBuffer& VulkanHostVisibleBuffer::operator=(VulkanHostVisibleBuffer&& other) { + std::swap(vk_buf, other.vk_buf); + std::swap(host_addr, other.host_addr); + std::swap(size, other.size); + + return *this; } } // namespace vulkan diff --git a/src/runtime/vulkan/vulkan_buffer.h b/src/runtime/vulkan/vulkan_buffer.h index 77406ec2b2f8..a3e37431e434 100644 --- a/src/runtime/vulkan/vulkan_buffer.h +++ b/src/runtime/vulkan/vulkan_buffer.h @@ -29,20 +29,120 @@ namespace tvm { namespace runtime { namespace vulkan { -struct VulkanBuffer { +class VulkanDevice; + +class VulkanBuffer { + public: + /* \brief Allocate memory on the device + * + * \param device Which device should have the memory allocation. + * The VulkanDevice given should outlive the VulkanBuffer. + * + * \param nbytes Size of the buffer in bytes + * + * \param usage The usage flags for the buffer (e.g. transfer + * source, transfer dest, storage buffer, etc.) + * + * \param mem_type_index The memory type to index. This should be + * an index to a compatible memory located in + * VkPhysicalDeviceMemoryProperties. + */ + VulkanBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, + uint32_t mem_type_index); + + //! \brief Destructor, deallocates the memory and buffer. + ~VulkanBuffer(); + + // Forbid copy assignment/constructor + VulkanBuffer(const VulkanBuffer&) = delete; + VulkanBuffer& operator=(const VulkanBuffer&) = delete; + + // Allow move assignment/constructor + VulkanBuffer(VulkanBuffer&&); + VulkanBuffer& operator=(VulkanBuffer&&); + + private: + /*! \brief Whether this buffer should be allocated using dedicated + * allocation + * + * In typical usage, there will be one VkDeviceMemory that has a + * large number of VkBuffers pointing to it. Currently, the TVM + * Vulkan runtime has a single VkBuffer for each VkDeviceMemory. In + * this case, there can be performance benefits by explicitly + * marking this as a dedicated allocation. The function returns + * true if the device supports the dedicated allocation extension, + * and the buffer either requires or has better performance with a + * dedicated allocation. + * + * \param[out] nbytes If using dedicated allocation, the number of + * bytes required for the allocation. If not using dedicated + * allocation, this value is unchanged. + * + * \returns Whether the allocation should use the dedicated + * allocation extension. + */ + static bool UseDedicatedAllocation(const VulkanDevice& device, VkBuffer buffer, + VkDeviceSize* nbytes); + + // TODO(elunderberg): Move copy functionality into the buffer class + // so these don't need to be public. + public: + /*! \brief Pointer to the device that owns this buffer. + * + * Assumes that the VulkanBuffer will be destructed before the + * VulkanDevice, and this will never be a dangling reference. + * Stores a VkDevice and not a VulkanDevice, because the + * VulkanDevice may be moved to a different location while the + * VulkanBuffer is alive. + */ + VkDevice device_{VK_NULL_HANDLE}; + + //! \brief Handle to the logical buffer on the device VkBuffer buffer{VK_NULL_HANDLE}; + + //! \brief Handle to the physical device memory VkDeviceMemory memory{VK_NULL_HANDLE}; + + friend class VulkanHostVisibleBuffer; }; /*! \brief A struct to represent Vulkan buffers backed by host visible memory */ -struct VulkanHostVisibleBuffer { - // A device where the buffer is allocated - VkDevice device{nullptr}; - // Vulkan buffer and memory - VulkanBuffer* vk_buf{nullptr}; - // The corresponding pointer to the host memory +class VulkanHostVisibleBuffer { + public: + /* \brief Allocate memory on the device, visible to the host + * + * \param device Which GPU device should have the memory allocation. + * The VulkanDevice specified should outlive the VulkanBuffer. + * + * \param nbytes Size of the buffer in bytes + * + * \param usage The usage flags for the buffer (e.g. transfer + * source, transfer dest, storage buffer, etc.) + * + * \param mem_type_index The memory type to index. This should be + * an index to a compatible memory located in + * VkPhysicalDeviceMemoryProperties. + */ + VulkanHostVisibleBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, + uint32_t mem_type_index); + + //! \brief Unmap memory and deallocate. + ~VulkanHostVisibleBuffer(); + + // Forbid copy assignment/constructor + VulkanHostVisibleBuffer(const VulkanHostVisibleBuffer&) = delete; + VulkanHostVisibleBuffer& operator=(const VulkanHostVisibleBuffer&) = delete; + + // Allow move assignment/constructor + VulkanHostVisibleBuffer(VulkanHostVisibleBuffer&&); + VulkanHostVisibleBuffer& operator=(VulkanHostVisibleBuffer&&); + + private: + // TODO(elunderberg): Move copy functionality into the buffer class + // so these don't need to be public. + public: + VulkanBuffer vk_buf; void* host_addr{nullptr}; - // The size of the buffer in bytes size_t size{0}; }; @@ -54,8 +154,6 @@ VulkanHostVisibleBuffer* GetOrAllocate( std::unordered_map>* buffers_ptr, bool sync_before_realloc = false); -void DeleteHostVisibleBuffer(VulkanHostVisibleBuffer* buf); - } // namespace vulkan } // namespace runtime } // namespace tvm diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/runtime/vulkan/vulkan_device.cc index e92b566e0aab..5e4be8209550 100644 --- a/src/runtime/vulkan/vulkan_device.cc +++ b/src/runtime/vulkan/vulkan_device.cc @@ -28,7 +28,6 @@ #include "vulkan_device.h" #include "vulkan_device_api.h" #include "vulkan_instance.h" -#include "vulkan_thread_entry.h" namespace tvm { namespace runtime { @@ -310,6 +309,14 @@ VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_ } VulkanDevice::~VulkanDevice() { + // Need to clear anything that uses this device calling + // vkDestroyDevice. Might be a sign that the VkDevice should be + // held by member variable rather than beind owned directly by + // VulkanDevice. + stream_per_thread.Clear(); + staging_buffer_per_thread.Clear(); + uniform_buffer_per_thread.Clear(); + if (device_) { vkDestroyDevice(device_, nullptr); } @@ -491,6 +498,49 @@ void VulkanDevice::CreateVkDevice(const VulkanInstance& instance) { VULKAN_CALL(vkCreateDevice(physical_device_, &device_create_info, nullptr, &device_)); } +VulkanStream& VulkanDevice::ThreadLocalStream() { + return const_cast(const_cast(this)->ThreadLocalStream()); +} + +const VulkanStream& VulkanDevice::ThreadLocalStream() const { + return stream_per_thread.GetOrMake(this); +} + +VulkanStagingBuffer& VulkanDevice::ThreadLocalStagingBuffer(size_t min_size) { + auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + VulkanStagingBuffer& result = + staging_buffer_per_thread.GetOrMake(*this, min_size, usage, staging_mtype_index); + + if (result.size < min_size) { + result = VulkanStagingBuffer(*this, min_size, usage, staging_mtype_index); + } + + return result; +} + +void VulkanDevice::AllocateThreadLocalUniformBuffer(size_t min_size) { + auto usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; + auto buffer_info = MakeBufferCreateInfo(min_size, usage); + auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; + auto mem_type_index = FindMemoryType(*this, buffer_info, prop); + + VulkanUniformBuffer& result = + uniform_buffer_per_thread.GetOrMake(*this, min_size, usage, mem_type_index); + + if (result.size < min_size) { + result = VulkanUniformBuffer(*this, min_size, usage, mem_type_index); + } +} + +VulkanStagingBuffer& VulkanDevice::ThreadLocalUniformBuffer(size_t min_size) { + VulkanStagingBuffer* buffer = uniform_buffer_per_thread.Get(); + ICHECK(buffer) << "Vulkan uniform buffer requested, but not previously allocated."; + ICHECK_GE(buffer->size, min_size) + << "Vulkan uniform buffer of size " << min_size << " requested, but only " << buffer->size + << " was previously allocated."; + return *buffer; +} + uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info, VkMemoryPropertyFlags req_prop) { VkBuffer buffer; @@ -512,115 +562,26 @@ uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info, return 0; } -VkBufferCreateInfo MakeBufferCreateInfo(const VulkanDevice& device, size_t nbytes, - VkBufferUsageFlags usage) { - VkBufferCreateInfo info; - info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - info.pNext = nullptr; - info.flags = 0; - info.size = nbytes; - info.queueFamilyIndexCount = 1; - info.pQueueFamilyIndices = &(device.queue_family_index); - info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - info.usage = usage; - return info; -} - -VulkanBuffer* CreateBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, - uint32_t mem_type_index) { - auto info = MakeBufferCreateInfo(device, nbytes, usage); - // create buffer - VkBuffer buffer; - VULKAN_CALL(vkCreateBuffer(device, &info, nullptr, &buffer)); - - // bind to memory - bool dedicated_allocation = false; - VkMemoryRequirements2KHR req2; - - if (device.get_buffer_memory_requirements_2_functions) { - VkBufferMemoryRequirementsInfo2KHR req_info2; - req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR; - req_info2.pNext = 0; - req_info2.buffer = buffer; - - req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR; - req2.pNext = 0; - - VkMemoryDedicatedRequirementsKHR dedicated_req; - dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; - dedicated_req.pNext = 0; - req2.pNext = &dedicated_req; - - device.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( - device, &req_info2, &req2); - dedicated_allocation = - dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation; - } - - VkDeviceMemory memory; - if (!dedicated_allocation) { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = info.size; - minfo.memoryTypeIndex = mem_type_index; - VULKAN_CALL(vkAllocateMemory(device, &minfo, nullptr, &memory)); - } else { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = req2.memoryRequirements.size; - minfo.memoryTypeIndex = mem_type_index; - - VkMemoryDedicatedAllocateInfoKHR mdinfo; - mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR; - mdinfo.pNext = 0; - mdinfo.image = 0; - mdinfo.buffer = buffer; - minfo.pNext = &mdinfo; - VULKAN_CALL(vkAllocateMemory(device, &minfo, nullptr, &memory)); - } - VULKAN_CALL(vkBindBufferMemory(device, buffer, memory, 0)); - VulkanBuffer* pbuf = new VulkanBuffer(); - pbuf->memory = memory; - pbuf->buffer = buffer; - return pbuf; -} - VulkanHostVisibleBuffer* GetOrAllocate( int device_id, size_t size, VkBufferUsageFlags usage, uint32_t mem_type_index, std::unordered_map>* buffers_ptr, bool sync_before_realloc) { + auto& device = VulkanDeviceAPI::Global()->device(device_id); + auto& buffers = *buffers_ptr; - if (!buffers[device_id]) { - buffers[device_id] = std::make_unique(); - } - auto& buf = *(buffers[device_id]); - if (buf.device != nullptr && buf.size < size) { - // free previous buffer - if (sync_before_realloc) { - // For the deferred execution mode, we need to make sure that old tasks that use - // the older, smaller buffer get finished - // Synchronization on staging buffers is done after host to device memory copy - // For UBO, we sync here before we reallocate a larger buffer, to minimize synchronization - // points - VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Synchronize(); - } - DeleteHostVisibleBuffer(&buf); + bool needs_alloc = !buffers[device_id] || (buffers[device_id]->size < size); + bool is_realloc = buffers[device_id] && (buffers[device_id]->size < size); + if (is_realloc && sync_before_realloc) { + device.ThreadLocalStream().Synchronize(); } - const auto& vulkan_device = VulkanDeviceAPI::Global()->device(device_id); - - if (buf.device == nullptr) { - buf.device = vulkan_device; - } - if (buf.host_addr == nullptr) { - buf.vk_buf = CreateBuffer(vulkan_device, size, usage, mem_type_index); - VULKAN_CALL(vkMapMemory(vulkan_device, buf.vk_buf->memory, 0, size, 0, &(buf.host_addr))); - buf.size = size; + if (needs_alloc) { + auto new_buffer = + std::make_unique(device, size, usage, mem_type_index); + buffers[device_id] = std::move(new_buffer); } - return &buf; + return buffers[device_id].get(); } } // namespace vulkan diff --git a/src/runtime/vulkan/vulkan_device.h b/src/runtime/vulkan/vulkan_device.h index b55eb8a3d9e0..045628bc9092 100644 --- a/src/runtime/vulkan/vulkan_device.h +++ b/src/runtime/vulkan/vulkan_device.h @@ -21,14 +21,19 @@ #define TVM_RUNTIME_VULKAN_VULKAN_DEVICE_H_ #include -#include #include +#include +#include #include +#include +#include #include +#include "../thread_map.h" #include "vulkan/vulkan_core.h" #include "vulkan_buffer.h" +#include "vulkan_stream.h" namespace tvm { namespace runtime { @@ -156,6 +161,43 @@ class VulkanDevice { */ bool HasExtension(const char* query) const; + //! \brief Return the VulkanStream for the current CPU thread + VulkanStream& ThreadLocalStream(); + + //! \brief Return the VulkanStream for the current CPU thread + const VulkanStream& ThreadLocalStream() const; + + /*! \brief Return the staging buffer for the current CPU thread + * + * This function may re-allocate the staging buffer depending on the + * size of the previously allocated buffer. + * + * \param min_size The size in bytes of the staging buffer to be + * returned. The buffer may be larger than requested, depending on + * previous use. + */ + VulkanStagingBuffer& ThreadLocalStagingBuffer(size_t min_size); + + /*! \brief Allocate the uniform buffer for the current CPU thread + * + * \param min_size The minimum size in bytes of the uniformn buffer + * to be allocated. If a larger uniform buffer has already been + * allocated, no allocation is performed. + */ + void AllocateThreadLocalUniformBuffer(size_t min_size); + + /*! \brief Return the uniform buffer for the current CPU thread + * + * Assumes that AllocateThreadLocalUniformBuffer has previously been + * called, with a min_size greater than or equal to the min_size of + * the current call. If this is not the case, will throw an + * exception. + * + * \param min_size The minimum size in bytes of the uniform buffer to be + * returned. + */ + VulkanUniformBuffer& ThreadLocalUniformBuffer(size_t min_size); + // Cached device properties, queried through Vulkan API. VulkanDeviceProperties device_properties{}; @@ -183,8 +225,24 @@ class VulkanDevice { */ void do_swap(VulkanDevice&& other); + /*! \brief Returns a queue family capable of running Vulkan compute + * operations + */ uint32_t SelectComputeQueueFamily() const; + + /*! \brief Returns the extensions to be enabled. + * + * All char* in the returned vector point to static memory + * allocations, and do not require cleanup. + */ std::vector SelectEnabledExtensions() const; + + /*! \brief Initialize the VkDevice + * + * Called during VulkanDevice construction. Assumes that + * queue_family_index, device_properties, and enabled_extensions + * have been set. + */ void CreateVkDevice(const VulkanInstance& instance); //! \brief Handle to the Vulkan API physical device @@ -207,19 +265,30 @@ class VulkanDevice { /*! \brief Handle to Vulkan API VkQueue. * * Work can be executed by submitted to this queue using - * VulkanDevice::SubmitQueue. + * VulkanDevice::QueueSubmit. */ VkQueue queue{nullptr}; + + /*! \brief The VulkanStream for each CPU thread. + * + * To mimic the semantics of cudaSetDevice and cuLaunchKernel, each + * CPU thread must have a separate stream of execution. The + * ThreadMap is declared mutable so that the streams can be lazily + * generated. + */ + mutable ThreadMap stream_per_thread; + + //! \brief The VulkanStagingBuffer for each CPU thread. + ThreadMap staging_buffer_per_thread; + + //! \brief The VulkanUniformBuffer for each CPU thread. + ThreadMap uniform_buffer_per_thread; }; uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info, VkMemoryPropertyFlags req_prop); -VkBufferCreateInfo MakeBufferCreateInfo(const VulkanDevice& device, size_t nbytes, - VkBufferUsageFlags usage); - -VulkanBuffer* CreateBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, - uint32_t mem_type_index); +VkBufferCreateInfo MakeBufferCreateInfo(size_t nbytes, VkBufferUsageFlags usage); } // namespace vulkan } // namespace runtime diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index bc25f25e7e12..1fede98f7211 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -26,7 +26,6 @@ #include #include "vulkan_common.h" -#include "vulkan_thread_entry.h" namespace tvm { namespace runtime { @@ -55,7 +54,20 @@ VulkanDeviceAPI::VulkanDeviceAPI() { VulkanDeviceAPI::~VulkanDeviceAPI() {} -void VulkanDeviceAPI::SetDevice(Device dev) { VulkanThreadEntry::ThreadLocal()->device = dev; } +void VulkanDeviceAPI::SetDevice(Device dev) { + ICHECK_EQ(dev.device_type, kDLVulkan) + << "Active vulkan device cannot be set to non-vulkan device" << dev; + + ICHECK_LE(dev.device_id, static_cast(devices_.size())) + << "Attempted to set active vulkan device to device_id==" << dev.device_id << ", but only " + << devices_.size() << " devices present"; + + active_device_id_per_thread.GetOrMake(0) = dev.device_id; +} + +int VulkanDeviceAPI::GetActiveDeviceID() { return active_device_id_per_thread.GetOrMake(0); } + +VulkanDevice& VulkanDeviceAPI::GetActiveDevice() { return device(GetActiveDeviceID()); } void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { size_t index = static_cast(dev.device_id); @@ -225,7 +237,7 @@ void* VulkanDeviceAPI::AllocDataSpace(Device dev, size_t nbytes, size_t alignmen const auto& device = this->device(dev.device_id); auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - return CreateBuffer(device, nbytes, usage, device.compute_mtype_index); + return new VulkanBuffer(device, nbytes, usage, device.compute_mtype_index); } void VulkanDeviceAPI::FreeDataSpace(Device dev, void* ptr) { @@ -233,19 +245,20 @@ void VulkanDeviceAPI::FreeDataSpace(Device dev, void* ptr) { // finish all the vulkan commands that reference the buffer. StreamSync(dev, nullptr); - const auto& device = this->device(dev.device_id); auto* pbuf = static_cast(ptr); - vkDestroyBuffer(device, pbuf->buffer, nullptr); - vkFreeMemory(device, pbuf->memory, nullptr); delete pbuf; } void* VulkanDeviceAPI::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { - return VulkanThreadEntry::ThreadLocal()->pool->AllocWorkspace(dev, size); + auto& pool = pool_per_thread.GetOrMake(kDLVulkan, this); + return pool.AllocWorkspace(dev, size); } void VulkanDeviceAPI::FreeWorkspace(Device dev, void* data) { - VulkanThreadEntry::ThreadLocal()->pool->FreeWorkspace(dev, data); + auto* pool = pool_per_thread.Get(); + ICHECK(pool) << "Attempted to free a vulkan workspace on a CPU-thread " + << "that has never allocated a workspace"; + pool->FreeWorkspace(dev, data); } TVMStreamHandle VulkanDeviceAPI::CreateStream(Device dev) { return nullptr; } @@ -263,7 +276,7 @@ void VulkanDeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, void VulkanDeviceAPI::StreamSync(Device dev, TVMStreamHandle stream) { ICHECK_EQ(stream, static_cast(nullptr)); - VulkanThreadEntry::ThreadLocal()->Stream(dev.device_id)->Synchronize(); + device(dev.device_id).ThreadLocalStream().Synchronize(); } void VulkanDeviceAPI::SetStream(Device dev, TVMStreamHandle stream) { @@ -282,96 +295,94 @@ void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* int from_dev_type = static_cast(dev_from.device_type); int to_dev_type = static_cast(dev_to.device_type); if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) { - VulkanThreadEntry::ThreadLocal() - ->Stream(dev_from.device_id) - ->Launch([=](VulkanStreamState* state) { - // 1: copy - const auto* from_buf = static_cast(from); - auto* to_buf = static_cast(to); - VkBufferCopy copy_info; - copy_info.srcOffset = from_offset; - copy_info.dstOffset = to_offset; - copy_info.size = size; - vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, to_buf->buffer, 1, ©_info); - // 2: barrier(transfer-> compute|transfer) - ICHECK_EQ(dev_from.device_id, dev_to.device_id) << "Vulkan disallow cross device copy."; - VkMemoryBarrier barrier_info; - barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - barrier_info.pNext = nullptr; - barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT; - barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | - VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT); - vkCmdPipelineBarrier( - state->cmd_buffer_, VK_PIPELINE_STAGE_TRANSFER_BIT, - VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, 1, - &barrier_info, 0, nullptr, 0, nullptr); - }); + ICHECK_EQ(dev_from.device_id, dev_to.device_id) + << "The Vulkan runtime does not support deviceA to deviceB copies. " + << "This should be changed to a deviceA to CPU copy, followed by a CPU to deviceB copy"; + + device(dev_from.device_id).ThreadLocalStream().Launch([=](VulkanStreamState* state) { + // 1: copy + const auto* from_buf = static_cast(from); + auto* to_buf = static_cast(to); + VkBufferCopy copy_info; + copy_info.srcOffset = from_offset; + copy_info.dstOffset = to_offset; + copy_info.size = size; + vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, to_buf->buffer, 1, ©_info); + // 2: barrier(transfer-> compute|transfer) + VkMemoryBarrier barrier_info; + barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + barrier_info.pNext = nullptr; + barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT; + barrier_info.dstAccessMask = (VK_ACCESS_TRANSFER_READ_BIT | VK_ACCESS_TRANSFER_WRITE_BIT | + VK_ACCESS_SHADER_READ_BIT | VK_ACCESS_SHADER_WRITE_BIT); + vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_TRANSFER_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT | VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, + 1, &barrier_info, 0, nullptr, 0, nullptr); + }); } else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) { const auto* from_buf = static_cast(from); - const auto& device = this->device(dev_from.device_id); - auto* temp = VulkanThreadEntry::ThreadLocal()->StagingBuffer(dev_from.device_id, size); - VulkanThreadEntry::ThreadLocal() - ->Stream(dev_from.device_id) - ->Launch([&](VulkanStreamState* state) { - VkBufferCopy copy_info; - copy_info.srcOffset = from_offset; - copy_info.dstOffset = 0; - copy_info.size = size; - vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, temp->vk_buf->buffer, 1, - ©_info); - }); - VulkanThreadEntry::ThreadLocal()->Stream(dev_from.device_id)->Synchronize(); + auto& device = this->device(dev_from.device_id); + auto& stream = device.ThreadLocalStream(); + auto& staging_buffer = device.ThreadLocalStagingBuffer(size); + stream.Launch([&](VulkanStreamState* state) { + VkBufferCopy copy_info; + copy_info.srcOffset = from_offset; + copy_info.dstOffset = 0; + copy_info.size = size; + vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, staging_buffer.vk_buf.buffer, 1, + ©_info); + }); + stream.Synchronize(); if (!device.coherent_staging) { VkMappedMemoryRange mrange; mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; mrange.pNext = nullptr; - mrange.memory = temp->vk_buf->memory; + mrange.memory = staging_buffer.vk_buf.memory; mrange.offset = 0; mrange.size = VK_WHOLE_SIZE; // size; VULKAN_CALL(vkInvalidateMappedMemoryRanges(device, 1, &mrange)); } - memcpy(static_cast(to) + to_offset, static_cast(temp->host_addr), size); + memcpy(static_cast(to) + to_offset, static_cast(staging_buffer.host_addr), size); } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) { - const auto& device = this->device(dev_to.device_id); + auto& device = this->device(dev_to.device_id); + auto& stream = device.ThreadLocalStream(); const auto* to_buf = static_cast(to); - VulkanStagingBuffer* temp = - VulkanThreadEntry::ThreadLocal()->StagingBuffer(dev_to.device_id, size); - memcpy(temp->host_addr, static_cast(from) + from_offset, size); + auto& staging_buffer = device.ThreadLocalStagingBuffer(size); + memcpy(staging_buffer.host_addr, static_cast(from) + from_offset, size); // host side flush if access is not coherent. // so writes from CPU is visible to GPU if (!device.coherent_staging) { VkMappedMemoryRange mrange; mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; mrange.pNext = nullptr; - mrange.memory = temp->vk_buf->memory; + mrange.memory = staging_buffer.vk_buf.memory; mrange.offset = 0; mrange.size = VK_WHOLE_SIZE; // size; VULKAN_CALL(vkFlushMappedMemoryRanges(device, 1, &mrange)); } - VulkanThreadEntry::ThreadLocal() - ->Stream(dev_to.device_id) - ->Launch([&](VulkanStreamState* state) { - // 0: barrier(host->transfer) - VkMemoryBarrier barrier_info; - barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; - barrier_info.pNext = nullptr; - barrier_info.srcAccessMask = 0; - barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT; - vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_HOST_BIT, - VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &barrier_info, 0, nullptr, 0, - nullptr); - // 1: copy - VkBufferCopy copy_info; - copy_info.srcOffset = 0; - copy_info.dstOffset = to_offset; - copy_info.size = size; - vkCmdCopyBuffer(state->cmd_buffer_, temp->vk_buf->buffer, to_buf->buffer, 1, ©_info); - }); + stream.Launch([&](VulkanStreamState* state) { + // 0: barrier(host->transfer) + VkMemoryBarrier barrier_info; + barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; + barrier_info.pNext = nullptr; + barrier_info.srcAccessMask = 0; + barrier_info.dstAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT; + vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_HOST_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &barrier_info, 0, nullptr, 0, + nullptr); + // 1: copy + VkBufferCopy copy_info; + copy_info.srcOffset = 0; + copy_info.dstOffset = to_offset; + copy_info.size = size; + vkCmdCopyBuffer(state->cmd_buffer_, staging_buffer.vk_buf.buffer, to_buf->buffer, 1, + ©_info); + }); // TODO(tulloch): should we instead make the staging buffer a property of the // Stream? This would allow us to elide synchronizations here. - VulkanThreadEntry::ThreadLocal()->Stream(dev_to.device_id)->Synchronize(); + stream.Synchronize(); } else { LOG(FATAL) << "Expect copy from/to Vulkan or between Vulkan" << ", from=" << from_dev_type << ", to=" << to_dev_type; @@ -384,6 +395,10 @@ const VulkanDevice& VulkanDeviceAPI::device(size_t device_id) const { return devices_[device_id]; } +VulkanDevice& VulkanDeviceAPI::device(size_t device_id) { + return const_cast(const_cast(this)->device(device_id)); +} + TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) { DeviceAPI* ptr = VulkanDeviceAPI::Global(); *rv = static_cast(ptr); diff --git a/src/runtime/vulkan/vulkan_device_api.h b/src/runtime/vulkan/vulkan_device_api.h index cf5652a3d9c4..b8be3eb43c79 100644 --- a/src/runtime/vulkan/vulkan_device_api.h +++ b/src/runtime/vulkan/vulkan_device_api.h @@ -21,14 +21,16 @@ #define TVM_RUNTIME_VULKAN_VULKAN_DEVICE_API_H_ #include +#include #include #include +#include "../thread_map.h" +#include "../workspace_pool.h" #include "vulkan/vulkan_core.h" #include "vulkan_device.h" #include "vulkan_instance.h" -#include "vulkan_thread_entry.h" namespace tvm { namespace runtime { @@ -69,6 +71,22 @@ class VulkanDeviceAPI final : public DeviceAPI { // End of required methods for the DeviceAPI interface public: + /*! \brief Return the currently active VulkanDevice + * + * The active device can be set using VulkanDeviceAPI::SetDevice. + * Each CPU thread has its own active device, mimicking the + * semantics of cudaSetDevice. + */ + VulkanDevice& GetActiveDevice(); + + /*! \brief Return the currently active VulkanDevice + * + * The active device can be set using VulkanDeviceAPI::SetDevice. + * Each CPU thread has its own active device, mimicking the + * semantics of cudaSetDevice. + */ + int GetActiveDeviceID(); + /*! \brief Return the VulkanDevice associated with a specific device_id * * These are constructed during VulkanDeviceAPI initialization, so @@ -76,6 +94,13 @@ class VulkanDeviceAPI final : public DeviceAPI { */ const VulkanDevice& device(size_t device_id) const; + /*! \brief Return the VulkanDevice associated with a specific device_id + * + * These are constructed during VulkanDeviceAPI initialization, so + * this function returns immediately. + */ + VulkanDevice& device(size_t device_id); + /*! \brief Returns a property to be stored in a target. * * Returns the results of feature/property queries done during the @@ -86,9 +111,33 @@ class VulkanDeviceAPI final : public DeviceAPI { private: std::vector GetComputeQueueFamilies(VkPhysicalDevice phy_dev); + /*! \brief The Vulkan API instance owned by the VulkanDeviceAPI + * + * Holds and manages VkInstance. + */ VulkanInstance instance_; - // The physical devices, have 1 to 1 mapping to devices + + /*! \brief Handles to the Vulkan devices + * + * The physical devices. These are constructed after the instance_, + * and must be destructed before the instance_. + */ std::vector devices_; + + /*! \brief One pool of device memory for each CPU thread. + * + * These allocate memory based on the devices stored in devices_. + * The memory pools must be destructed before devices_. + */ + ThreadMap pool_per_thread; + + /*! \brief The index of the active device for each CPU thread. + * + * To mimic the semantics of cudaSetDevice, each CPU thread can set + * the device on which functions should run. If unset, the active + * device defaults to device_id == 0. + */ + ThreadMap active_device_id_per_thread; }; } // namespace vulkan diff --git a/src/runtime/vulkan/vulkan_stream.cc b/src/runtime/vulkan/vulkan_stream.cc index 9784ee78503d..3eff112a6eea 100644 --- a/src/runtime/vulkan/vulkan_stream.cc +++ b/src/runtime/vulkan/vulkan_stream.cc @@ -19,6 +19,8 @@ #include "vulkan_stream.h" +#include "vulkan_device.h" + namespace tvm { namespace runtime { namespace vulkan { diff --git a/src/runtime/vulkan/vulkan_stream.h b/src/runtime/vulkan/vulkan_stream.h index ff02be4c5c35..fb4e447c15e1 100644 --- a/src/runtime/vulkan/vulkan_stream.h +++ b/src/runtime/vulkan/vulkan_stream.h @@ -26,12 +26,13 @@ #include #include "vulkan_common.h" -#include "vulkan_device.h" namespace tvm { namespace runtime { namespace vulkan { +class VulkanDevice; + class VulkanStreamState { public: VkCommandBuffer cmd_buffer_; diff --git a/src/runtime/vulkan/vulkan_thread_entry.cc b/src/runtime/vulkan/vulkan_thread_entry.cc deleted file mode 100644 index 1e2815f31146..000000000000 --- a/src/runtime/vulkan/vulkan_thread_entry.cc +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "vulkan_thread_entry.h" - -#include "vulkan_buffer.h" -#include "vulkan_device_api.h" - -namespace tvm { -namespace runtime { -namespace vulkan { - -VulkanThreadEntry::~VulkanThreadEntry() { - // Because the thread entry refers to Device API - // The command buffer always will be destroyed before - // the instance and device get destroyed. - // The destruction need to be manually called - // to ensure the destruction order. - - pool.reset(); - streams_.clear(); - for (const auto& kv : staging_buffers_) { - DeleteHostVisibleBuffer(kv.second.get()); - } -} - -VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() { return VulkanThreadStore::Get(); } - -void VulkanThreadEntry::AllocateUniformBuffer(int device_id, size_t size) { - const auto& device = VulkanDeviceAPI::Global()->device(device_id); - auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; - auto info = MakeBufferCreateInfo(device, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT); - auto mem_type_index = FindMemoryType(device, info, prop); - GetOrAllocate(device_id, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, mem_type_index, - &uniform_buffers_, true); -} - -VulkanUniformBuffer* VulkanThreadEntry::GetUniformBuffer(int device_id, size_t size) { - auto& buf = uniform_buffers_[device_id]; - ICHECK(buf); - ICHECK_GE(buf->size, size); - return buf.get(); -} - -VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size) { - const auto& device = VulkanDeviceAPI::Global()->device(device_id); - auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - return GetOrAllocate(device_id, size, usage, device.staging_mtype_index, &staging_buffers_); -} - -VulkanThreadEntry::VulkanThreadEntry() - : pool(std::make_unique(static_cast(kDLVulkan), - VulkanDeviceAPI::Global())) { - device.device_id = 0; - device.device_type = static_cast(kDLVulkan); -} - -VulkanStream* VulkanThreadEntry::Stream(size_t device_id) { - if (!streams_[device_id]) { - streams_[device_id] = std::unique_ptr( - new VulkanStream(&VulkanDeviceAPI::Global()->device(device_id))); - } - return streams_[device_id].get(); -} - -} // namespace vulkan -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_thread_entry.h b/src/runtime/vulkan/vulkan_thread_entry.h deleted file mode 100644 index cea5494823fd..000000000000 --- a/src/runtime/vulkan/vulkan_thread_entry.h +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_RUNTIME_VULKAN_VULKAN_THREAD_ENTRY_H_ -#define TVM_RUNTIME_VULKAN_VULKAN_THREAD_ENTRY_H_ - -#include - -#include -#include - -#include "../workspace_pool.h" -#include "vulkan_buffer.h" -#include "vulkan_stream.h" - -namespace tvm { -namespace runtime { -namespace vulkan { - -/*! \brief Contains all per-CPU-thread resources. - */ -class VulkanThreadEntry { - public: - VulkanThreadEntry(); - static VulkanThreadEntry* ThreadLocal(); - - ~VulkanThreadEntry(); - - Device device; - std::unique_ptr pool; - VulkanStream* Stream(size_t device_id); - VulkanStagingBuffer* StagingBuffer(int device_id, size_t size); - void AllocateUniformBuffer(int device_id, size_t size); - VulkanUniformBuffer* GetUniformBuffer(int device_id, size_t size); - - private: - //! Map from device to the VulkanStream for it - std::unordered_map> streams_; - //! Map from device to the StagingBuffer for it - std::unordered_map> staging_buffers_; - //! Map from device to the UniformBuffer associated with it - std::unordered_map> uniform_buffers_; -}; - -typedef dmlc::ThreadLocalStore VulkanThreadStore; - -} // namespace vulkan -} // namespace runtime -} // namespace tvm - -#endif // TVM_RUNTIME_VULKAN_VULKAN_THREAD_ENTRY_H_ diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index 86c3ffe23f7d..103b2aa7692c 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -25,7 +25,6 @@ #include "../file_utils.h" #include "vulkan_device_api.h" -#include "vulkan_thread_entry.h" namespace tvm { namespace runtime { @@ -45,9 +44,8 @@ void VulkanWrappedFunc::Init(VulkanModuleNode* m, ObjectPtr sptr, void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const { - int device_id = VulkanThreadEntry::ThreadLocal()->device.device_id; - ICHECK_LT(device_id, kVulkanMaxNumDevice); - const auto& device = VulkanDeviceAPI::Global()->device(device_id); + int device_id = VulkanDeviceAPI::Global()->GetActiveDeviceID(); + auto& device = VulkanDeviceAPI::Global()->device(device_id); if (!scache_[device_id]) { scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_); } @@ -65,17 +63,16 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, } const size_t nbytes_scalars = num_pack_args_ * sizeof(ArgUnion64); if (pipeline->use_ubo) { - auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); - CHECK(ubo->host_addr) << "The UBO host buffer is not allocated"; + auto& ubo = device.ThreadLocalUniformBuffer(nbytes_scalars); VkDescriptorBufferInfo binfo; - binfo.buffer = ubo->vk_buf->buffer; + binfo.buffer = ubo.vk_buf.buffer; binfo.offset = 0; binfo.range = VK_WHOLE_SIZE; descriptor_buffers.push_back(binfo); } if (device.UseImmediate()) { // Can safely capture by reference as this lambda is immediately executed on the calling thread. - VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) { + device.ThreadLocalStream().Launch([&](VulkanStreamState* state) { vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); ICHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE); device.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR( @@ -83,8 +80,8 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, descriptor_buffers.data()); if (pipeline->use_ubo) { - auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); - memcpy(ubo->host_addr, pack_args, nbytes_scalars); + auto& ubo = device.ThreadLocalUniformBuffer(nbytes_scalars); + memcpy(ubo.host_addr, pack_args, nbytes_scalars); } else if (num_pack_args_ > 0) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64), @@ -133,14 +130,16 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, }; const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage, nbytes_scalars, device_id](VulkanStreamState* state) { + auto& device = VulkanDeviceAPI::Global()->device(device_id); + vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0, nullptr); if (pipeline->use_ubo) { - auto ubo = VulkanThreadEntry::ThreadLocal()->GetUniformBuffer(device_id, nbytes_scalars); - memcpy(ubo->host_addr, pack_args_storage.data(), nbytes_scalars); + auto& ubo = device.ThreadLocalUniformBuffer(nbytes_scalars); + memcpy(ubo.host_addr, pack_args_storage.data(), nbytes_scalars); } else if (num_pack_args_ > 0) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, pack_args_storage.size() * sizeof(ArgUnion64), @@ -164,8 +163,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, for (size_t i = 0; i < descriptor_buffers.size(); ++i) { deferred_token.buffers_[i] = descriptor_buffers[i].buffer; } - VulkanThreadEntry::ThreadLocal()->Stream(device_id)->LaunchDeferred( - deferred_initializer, deferred_kernel, deferred_token); + device.ThreadLocalStream().LaunchDeferred(deferred_initializer, deferred_kernel, deferred_token); } VulkanModuleNode::~VulkanModuleNode() { @@ -206,7 +204,7 @@ PackedFunc VulkanModuleNode::GetFunction(const std::string& name, std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, const std::string& func_name, size_t num_pack_args) { - const auto& device = VulkanDeviceAPI::Global()->device(device_id); + auto& device = VulkanDeviceAPI::Global()->device(device_id); std::lock_guard lock(mutex_); const auto& cp = ecache_[device_id][func_name]; if (cp) { @@ -286,7 +284,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, if (pe->use_ubo) { // Use UBO instead of push constants push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER); - VulkanThreadEntry::ThreadLocal()->AllocateUniformBuffer(device_id, nbytes_scalars); + device.AllocateThreadLocalUniformBuffer(nbytes_scalars); } {