Skip to content

Commit

Permalink
Minor cleanup on the vulkan runtime.
Browse files Browse the repository at this point in the history
- Explicitly require int64 support at device creation time, since the
  TVM-generated shaders require it.

- Allocate an appropriate pool size for the buffer inputs, including
  both uniform and storage buffers.
  • Loading branch information
Lunderberg committed May 3, 2021
1 parent 497d518 commit e7c9bc2
Showing 1 changed file with 28 additions and 11 deletions.
39 changes: 28 additions & 11 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <vulkan/vulkan.h>
#include <vulkan/vulkan_core.h>

#include <algorithm>
#include <array>
#include <cstring>

Expand Down Expand Up @@ -621,6 +622,12 @@ VulkanDeviceAPI::VulkanDeviceAPI() {
}
return extensions;
}();

// All TVM-generated spirv shaders are marked as requiring int64
// support, so we need to request it from the device, too.
VkPhysicalDeviceFeatures enabled_features = {};
enabled_features.shaderInt64 = VK_TRUE;

VkDeviceCreateInfo device_create_info;
device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
device_create_info.pNext = nullptr;
Expand All @@ -631,7 +638,7 @@ VulkanDeviceAPI::VulkanDeviceAPI() {
device_create_info.ppEnabledLayerNames = nullptr;
device_create_info.enabledExtensionCount = extensions.size();
device_create_info.ppEnabledExtensionNames = extensions.data();
device_create_info.pEnabledFeatures = nullptr;
device_create_info.pEnabledFeatures = &enabled_features;
VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(ctx.device)));
ctx.queue_mutex.reset(new std::mutex());
vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue));
Expand Down Expand Up @@ -882,10 +889,25 @@ class VulkanModuleNode final : public runtime::ModuleNode {
}
std::vector<VkDescriptorSetLayoutBinding> arg_binding;
std::vector<VkDescriptorUpdateTemplateEntryKHR> arg_template;
std::vector<VkDescriptorPoolSize> descriptor_set_pool_sizes;
uint32_t num_pod = 0, num_buffer = 0;

auto push_arg_info = [&arg_binding, &arg_template](uint32_t binding,
VkDescriptorType desc_type) {
auto push_arg_info = [&arg_binding, &arg_template, &descriptor_set_pool_sizes](
uint32_t binding, VkDescriptorType desc_type) {
{
auto result =
std::find_if(descriptor_set_pool_sizes.begin(), descriptor_set_pool_sizes.end(),
[&](const auto& psize) { return psize.type == desc_type; });
if (result == descriptor_set_pool_sizes.end()) {
VkDescriptorPoolSize new_size;
new_size.type = desc_type;
new_size.descriptorCount = 1;
descriptor_set_pool_sizes.push_back(new_size);
} else {
result->descriptorCount++;
}
}

{
VkDescriptorSetLayoutBinding bd;
bd.binding = binding;
Expand Down Expand Up @@ -941,22 +963,17 @@ class VulkanModuleNode final : public runtime::ModuleNode {
&(pe->descriptor_set_layout)));
}

{
VkDescriptorPoolSize pool_size;
pool_size.type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
pool_size.descriptorCount = arg_binding.size();
if (!vctx.UseImmediate()) {
VkDescriptorPoolCreateInfo descrip_pool_cinfo;
descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
descrip_pool_cinfo.pNext = nullptr;
descrip_pool_cinfo.flags = VK_DESCRIPTOR_POOL_CREATE_FREE_DESCRIPTOR_SET_BIT;
descrip_pool_cinfo.maxSets = 1;
descrip_pool_cinfo.poolSizeCount = 1;
descrip_pool_cinfo.pPoolSizes = &pool_size;
descrip_pool_cinfo.poolSizeCount = descriptor_set_pool_sizes.size();
descrip_pool_cinfo.pPoolSizes = descriptor_set_pool_sizes.data();
VULKAN_CALL(vkCreateDescriptorPool(vctx.device, &descrip_pool_cinfo, nullptr,
&(pe->descriptor_pool)));
}

if (!vctx.UseImmediate()) {
VkDescriptorSetAllocateInfo alloc_info;
alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
alloc_info.pNext = nullptr;
Expand Down

0 comments on commit e7c9bc2

Please sign in to comment.