diff --git a/src/runtime/vulkan/vulkan_common.cc b/src/runtime/vulkan/vulkan_common.cc new file mode 100644 index 000000000000..30df8b86ecd5 --- /dev/null +++ b/src/runtime/vulkan/vulkan_common.cc @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_common.h" + +#include + +namespace tvm { +namespace runtime { +namespace vulkan { + +std::vector FindEnabledExtensions( + const std::vector& ext_prop, + const std::vector& required_extensions, + const std::vector& optional_extensions) { + std::set available_extensions; + for (const auto& prop : ext_prop) { + if (prop.specVersion > 0) { + available_extensions.insert(prop.extensionName); + } + } + + std::vector enabled_extensions; + for (const auto& ext : required_extensions) { + ICHECK(available_extensions.count(ext)) + << "Required vulkan extension \"" << ext << "\" not supported by driver"; + enabled_extensions.push_back(ext); + } + + for (const auto& ext : optional_extensions) { + if (available_extensions.count(ext)) { + enabled_extensions.push_back(ext); + } + } + + return enabled_extensions; +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index 8fce5dbd192a..a03801cf511f 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -106,6 +106,10 @@ inline const char* VKGetErrorString(VkResult error) { VULKAN_CHECK_ERROR(__e); \ } +std::vector FindEnabledExtensions(const std::vector& ext_prop, + const std::vector& required_extensions, + const std::vector& optional_extensions); + } // namespace vulkan } // namespace runtime } // namespace tvm diff --git a/src/runtime/vulkan/vulkan_context.cc b/src/runtime/vulkan/vulkan_context.cc deleted file mode 100644 index 7e59c9da47b5..000000000000 --- a/src/runtime/vulkan/vulkan_context.cc +++ /dev/null @@ -1,354 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "vulkan_context.h" - -#include -#include - -#include "vulkan_common.h" -#include "vulkan_device_api.h" -#include "vulkan_thread_entry.h" - -namespace tvm { -namespace runtime { -namespace vulkan { - -VulkanDeviceProperties::VulkanDeviceProperties(VkInstance instance, VkPhysicalDevice phy_dev, - const std::vector instance_extensions, - const std::vector device_extensions) { - auto has_instance_extension = [&](const char* query) { - return std::any_of(instance_extensions.begin(), instance_extensions.end(), - [&](const char* extension) { return std::strcmp(query, extension) == 0; }); - }; - - auto has_device_extension = [&](const char* query) { - return std::any_of(device_extensions.begin(), device_extensions.end(), - [&](const char* extension) { return std::strcmp(query, extension) == 0; }); - }; - - /////////////////////////////////////////////////////////////// - // Query properties from Vulkan API // - /////////////////////////////////////////////////////////////// - - // Declare output locations for properties - VkPhysicalDeviceProperties2 properties = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2}; - VkPhysicalDeviceDriverProperties driver = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES}; - VkPhysicalDeviceSubgroupProperties subgroup = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES}; - - // Need to do initial query in order to check the apiVersion. - vkGetPhysicalDeviceProperties(phy_dev, &properties.properties); - - // Set up linked list for property query - { - void** pp_next = &properties.pNext; - if (has_device_extension("VK_KHR_driver_properties")) { - *pp_next = &driver; - pp_next = &driver.pNext; - } - if (properties.properties.apiVersion >= VK_API_VERSION_1_1) { - *pp_next = &subgroup; - pp_next = &subgroup.pNext; - } - } - - // Declare output locations for features - VkPhysicalDeviceFeatures2 features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; - VkPhysicalDevice8BitStorageFeatures storage_8bit = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; - VkPhysicalDevice16BitStorageFeatures storage_16bit = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; - VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; - - // Set up linked list for feature query - { - void** pp_next = &features.pNext; - if (has_device_extension("VK_KHR_8bit_storage")) { - *pp_next = &storage_8bit; - pp_next = &storage_8bit.pNext; - } - if (has_device_extension("VK_KHR_16bit_storage")) { - *pp_next = &storage_16bit; - pp_next = &storage_16bit.pNext; - } - if (has_device_extension("VK_KHR_shader_float16_int8")) { - *pp_next = &float16_int8; - pp_next = &float16_int8.pNext; - } - } - - if (has_instance_extension("VK_KHR_get_physical_device_properties2")) { - // Preferred method, call to get all properties that can be queried. - auto vkGetPhysicalDeviceProperties2KHR = (PFN_vkGetPhysicalDeviceProperties2KHR)ICHECK_NOTNULL( - vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceProperties2KHR")); - vkGetPhysicalDeviceProperties2KHR(phy_dev, &properties); - - auto vkGetPhysicalDeviceFeatures2KHR = (PFN_vkGetPhysicalDeviceFeatures2KHR)ICHECK_NOTNULL( - vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceFeatures2KHR")); - vkGetPhysicalDeviceFeatures2KHR(phy_dev, &features); - } else { - // Fallback, get as many features as we can from the Vulkan1.0 - // API. Corresponding vkGetPhysicalDeviceProperties was already done earlier. - vkGetPhysicalDeviceFeatures(phy_dev, &features.features); - } - - /////////////////////////////////////////////////////////////// - // Fill member variables from Vulkan structures // - /////////////////////////////////////////////////////////////// - - supports_float16 = float16_int8.shaderFloat16; - supports_float32 = true; - supports_float64 = features.features.shaderFloat64; - supports_int8 = float16_int8.shaderInt8; - supports_int16 = features.features.shaderInt16; - supports_int32 = true; - supports_int64 = features.features.shaderInt64; - supports_8bit_buffer = storage_8bit.storageBuffer8BitAccess; - supports_16bit_buffer = storage_16bit.storageBuffer16BitAccess; - supports_storage_buffer_storage_class = - has_device_extension("VK_KHR_storage_buffer_storage_class"); - - // Support is available based on these extensions, but allow it to - // be disabled based on an environment variable. - supports_push_descriptor = has_device_extension("VK_KHR_push_descriptor") && - has_device_extension("VK_KHR_descriptor_update_template"); - { - const char* disable = std::getenv("TVM_VULKAN_DISABLE_PUSH_DESCRIPTOR"); - if (disable && *disable) { - supports_push_descriptor = false; - } - } - - // Support is available based on these extensions, but allow it to - // be disabled based on an environment variable. - supports_dedicated_allocation = has_device_extension("VK_KHR_get_memory_requirements2") && - has_device_extension("VK_KHR_dedicated_allocation"); - { - const char* disable = std::getenv("TVM_VULKAN_DISABLE_DEDICATED_ALLOCATION"); - if (disable && *disable) { - supports_dedicated_allocation = false; - } - } - - // The check of VK_SHADER_STAGE_COMPUTE_BIT isn't technically - // needed, since it will be set so long at least one queue has - // VK_QUEUE_COMPUTE_BIT. Including it to avoid potential future - // confusion.. - supported_subgroup_operations = - (subgroup.supportedStages & VK_SHADER_STAGE_COMPUTE_BIT) ? subgroup.supportedOperations : 0; - - max_num_threads = properties.properties.limits.maxComputeWorkGroupInvocations; - - // Even if we can't query it, warp size must be at least 1. - thread_warp_size = std::max(subgroup.subgroupSize, 1U); - - max_block_size_x = properties.properties.limits.maxComputeWorkGroupSize[0]; - max_block_size_y = properties.properties.limits.maxComputeWorkGroupSize[1]; - max_block_size_z = properties.properties.limits.maxComputeWorkGroupSize[2]; - max_push_constants_size = properties.properties.limits.maxPushConstantsSize; - max_uniform_buffer_range = properties.properties.limits.maxUniformBufferRange; - max_storage_buffer_range = properties.properties.limits.maxStorageBufferRange; - max_per_stage_descriptor_storage_buffer = - properties.properties.limits.maxPerStageDescriptorStorageBuffers; - max_shared_memory_per_block = properties.properties.limits.maxComputeSharedMemorySize; - device_name = properties.properties.deviceName; - driver_version = properties.properties.driverVersion; - - // By default, use the maximum API version that the driver allows, - // so that any supported features can be used by TVM shaders. - // However, if we can query the conformance version, then limit to - // only using the api version that passes the vulkan conformance - // tests. - vulkan_api_version = properties.properties.apiVersion; - if (has_device_extension("VK_KHR_driver_properties")) { - auto api_major = VK_VERSION_MAJOR(vulkan_api_version); - auto api_minor = VK_VERSION_MINOR(vulkan_api_version); - if ((api_major > driver.conformanceVersion.major) || - ((api_major == driver.conformanceVersion.major) && - (api_minor > driver.conformanceVersion.minor))) { - vulkan_api_version = - VK_MAKE_VERSION(driver.conformanceVersion.major, driver.conformanceVersion.minor, 0); - } - } - - // From "Versions and Formats" section of Vulkan spec. - max_spirv_version = 0x10000; - if (vulkan_api_version >= VK_API_VERSION_1_2) { - max_spirv_version = 0x10500; - } else if (has_device_extension("VK_KHR_spirv_1_4")) { - max_spirv_version = 0x10400; - } else if (vulkan_api_version >= VK_API_VERSION_1_1) { - max_spirv_version = 0x10300; - } -} - -VulkanDescriptorTemplateKHRFunctions::VulkanDescriptorTemplateKHRFunctions(VkDevice device) { - vkCreateDescriptorUpdateTemplateKHR = (PFN_vkCreateDescriptorUpdateTemplateKHR)ICHECK_NOTNULL( - vkGetDeviceProcAddr(device, "vkCreateDescriptorUpdateTemplateKHR")); - vkDestroyDescriptorUpdateTemplateKHR = (PFN_vkDestroyDescriptorUpdateTemplateKHR)ICHECK_NOTNULL( - vkGetDeviceProcAddr(device, "vkDestroyDescriptorUpdateTemplateKHR")); - vkUpdateDescriptorSetWithTemplateKHR = (PFN_vkUpdateDescriptorSetWithTemplateKHR)ICHECK_NOTNULL( - vkGetDeviceProcAddr(device, "vkUpdateDescriptorSetWithTemplateKHR")); - vkCmdPushDescriptorSetWithTemplateKHR = (PFN_vkCmdPushDescriptorSetWithTemplateKHR)ICHECK_NOTNULL( - vkGetDeviceProcAddr(device, "vkCmdPushDescriptorSetWithTemplateKHR")); -} - -VulkanGetBufferMemoryRequirements2Functions::VulkanGetBufferMemoryRequirements2Functions( - VkDevice device) { - vkGetBufferMemoryRequirements2KHR = (PFN_vkGetBufferMemoryRequirements2KHR)ICHECK_NOTNULL( - vkGetDeviceProcAddr(device, "vkGetBufferMemoryRequirements2KHR")); -} - -uint32_t FindMemoryType(const VulkanContext& vctx, VkBufferCreateInfo info, - VkMemoryPropertyFlags req_prop) { - VkBuffer buffer; - VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); - - VkMemoryRequirements mem_reqs; - vkGetBufferMemoryRequirements(vctx.device, buffer, &mem_reqs); - uint32_t type_bits = mem_reqs.memoryTypeBits; - VkPhysicalDeviceMemoryProperties phy_mem_prop; - vkGetPhysicalDeviceMemoryProperties(vctx.phy_device, &phy_mem_prop); - for (uint32_t i = 0; i < phy_mem_prop.memoryTypeCount; i++) { - if ((type_bits & 1) == 1 && - (phy_mem_prop.memoryTypes[i].propertyFlags & req_prop) == req_prop) { - return i; - } - type_bits >>= 1; - } - LOG(FATAL) << "Requested memory type not found"; - return 0; -} - -VkBufferCreateInfo MakeBufferCreateInfo(const VulkanContext& vctx, size_t nbytes, - VkBufferUsageFlags usage) { - VkBufferCreateInfo info; - info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - info.pNext = nullptr; - info.flags = 0; - info.size = nbytes; - info.queueFamilyIndexCount = 1; - info.pQueueFamilyIndices = &(vctx.queue_family_index); - info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - info.usage = usage; - return info; -} - -VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsageFlags usage, - uint32_t mem_type_index) { - auto info = MakeBufferCreateInfo(vctx, nbytes, usage); - // create buffer - VkBuffer buffer; - VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); - - // bind to memory - bool dedicated_allocation = false; - VkMemoryRequirements2KHR req2; - - if (vctx.get_buffer_memory_requirements_2_functions) { - VkBufferMemoryRequirementsInfo2KHR req_info2; - req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR; - req_info2.pNext = 0; - req_info2.buffer = buffer; - - req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR; - req2.pNext = 0; - - VkMemoryDedicatedRequirementsKHR dedicated_req; - dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; - dedicated_req.pNext = 0; - req2.pNext = &dedicated_req; - - vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( - vctx.device, &req_info2, &req2); - dedicated_allocation = - dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation; - } - - VkDeviceMemory memory; - if (!dedicated_allocation) { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = info.size; - minfo.memoryTypeIndex = mem_type_index; - VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); - } else { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = req2.memoryRequirements.size; - minfo.memoryTypeIndex = mem_type_index; - - VkMemoryDedicatedAllocateInfoKHR mdinfo; - mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR; - mdinfo.pNext = 0; - mdinfo.image = 0; - mdinfo.buffer = buffer; - minfo.pNext = &mdinfo; - VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); - } - VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0)); - VulkanBuffer* pbuf = new VulkanBuffer(); - pbuf->memory = memory; - pbuf->buffer = buffer; - return pbuf; -} - -VulkanHostVisibleBuffer* GetOrAllocate( - int device_id, size_t size, VkBufferUsageFlags usage, uint32_t mem_type_index, - std::unordered_map>* buffers_ptr, - bool sync_before_realloc) { - auto& buffers = *buffers_ptr; - if (!buffers[device_id]) { - buffers[device_id] = std::make_unique(); - } - - auto& buf = *(buffers[device_id]); - if (buf.device != nullptr && buf.size < size) { - // free previous buffer - if (sync_before_realloc) { - // For the deferred execution mode, we need to make sure that old tasks that use - // the older, smaller buffer get finished - // Synchronization on staging buffers is done after host to device memory copy - // For UBO, we sync here before we reallocate a larger buffer, to minimize synchronization - // points - VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Synchronize(); - } - DeleteHostVisibleBuffer(&buf); - } - - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); - - if (buf.device == nullptr) { - buf.device = vctx.device; - } - if (buf.host_addr == nullptr) { - buf.vk_buf = CreateBuffer(vctx, size, usage, mem_type_index); - VULKAN_CALL(vkMapMemory(vctx.device, buf.vk_buf->memory, 0, size, 0, &(buf.host_addr))); - buf.size = size; - } - return &buf; -} - -} // namespace vulkan -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/runtime/vulkan/vulkan_device.cc new file mode 100644 index 000000000000..e92b566e0aab --- /dev/null +++ b/src/runtime/vulkan/vulkan_device.cc @@ -0,0 +1,628 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_device.h" + +#include +#include +#include +#include + +#include "vulkan_common.h" +#include "vulkan_device.h" +#include "vulkan_device_api.h" +#include "vulkan_instance.h" +#include "vulkan_thread_entry.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, + const VulkanDevice& device) { + /////////////////////////////////////////////////////////////// + // Query properties from Vulkan API // + /////////////////////////////////////////////////////////////// + + // Declare output locations for properties + VkPhysicalDeviceProperties2 properties = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PROPERTIES_2}; + VkPhysicalDeviceDriverProperties driver = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_DRIVER_PROPERTIES}; + VkPhysicalDeviceSubgroupProperties subgroup = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_PROPERTIES}; + + // Need to do initial query in order to check the apiVersion. + vkGetPhysicalDeviceProperties(device, &properties.properties); + + // Set up linked list for property query + { + void** pp_next = &properties.pNext; + if (device.HasExtension("VK_KHR_driver_properties")) { + *pp_next = &driver; + pp_next = &driver.pNext; + } + if (properties.properties.apiVersion >= VK_API_VERSION_1_1) { + *pp_next = &subgroup; + pp_next = &subgroup.pNext; + } + } + + // Declare output locations for features + VkPhysicalDeviceFeatures2 features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; + VkPhysicalDevice8BitStorageFeatures storage_8bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; + VkPhysicalDevice16BitStorageFeatures storage_16bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; + VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; + + // Set up linked list for feature query + { + void** pp_next = &features.pNext; + if (device.HasExtension("VK_KHR_8bit_storage")) { + *pp_next = &storage_8bit; + pp_next = &storage_8bit.pNext; + } + if (device.HasExtension("VK_KHR_16bit_storage")) { + *pp_next = &storage_16bit; + pp_next = &storage_16bit.pNext; + } + if (device.HasExtension("VK_KHR_shader_float16_int8")) { + *pp_next = &float16_int8; + pp_next = &float16_int8.pNext; + } + } + + if (instance.HasExtension("VK_KHR_get_physical_device_properties2")) { + // Preferred method, call to get all properties that can be queried. + auto vkGetPhysicalDeviceProperties2KHR = (PFN_vkGetPhysicalDeviceProperties2KHR)ICHECK_NOTNULL( + vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceProperties2KHR")); + vkGetPhysicalDeviceProperties2KHR(device, &properties); + + auto vkGetPhysicalDeviceFeatures2KHR = (PFN_vkGetPhysicalDeviceFeatures2KHR)ICHECK_NOTNULL( + vkGetInstanceProcAddr(instance, "vkGetPhysicalDeviceFeatures2KHR")); + vkGetPhysicalDeviceFeatures2KHR(device, &features); + } else { + // Fallback, get as many features as we can from the Vulkan1.0 + // API. Corresponding vkGetPhysicalDeviceProperties was already done earlier. + vkGetPhysicalDeviceFeatures(device, &features.features); + } + + /////////////////////////////////////////////////////////////// + // Fill member variables from Vulkan structures // + /////////////////////////////////////////////////////////////// + + supports_float16 = float16_int8.shaderFloat16; + supports_float32 = true; + supports_float64 = features.features.shaderFloat64; + supports_int8 = float16_int8.shaderInt8; + supports_int16 = features.features.shaderInt16; + supports_int32 = true; + supports_int64 = features.features.shaderInt64; + supports_8bit_buffer = storage_8bit.storageBuffer8BitAccess; + supports_16bit_buffer = storage_16bit.storageBuffer16BitAccess; + supports_storage_buffer_storage_class = + device.HasExtension("VK_KHR_storage_buffer_storage_class"); + + // Support is available based on these extensions, but allow it to + // be disabled based on an environment variable. + supports_push_descriptor = device.HasExtension("VK_KHR_push_descriptor") && + device.HasExtension("VK_KHR_descriptor_update_template"); + { + const char* disable = std::getenv("TVM_VULKAN_DISABLE_PUSH_DESCRIPTOR"); + if (disable && *disable) { + supports_push_descriptor = false; + } + } + + // Support is available based on these extensions, but allow it to + // be disabled based on an environment variable. + supports_dedicated_allocation = device.HasExtension("VK_KHR_get_memory_requirements2") && + device.HasExtension("VK_KHR_dedicated_allocation"); + { + const char* disable = std::getenv("TVM_VULKAN_DISABLE_DEDICATED_ALLOCATION"); + if (disable && *disable) { + supports_dedicated_allocation = false; + } + } + + // The check of VK_SHADER_STAGE_COMPUTE_BIT isn't technically + // needed, since it will be set so long at least one queue has + // VK_QUEUE_COMPUTE_BIT. Including it to avoid potential future + // confusion.. + supported_subgroup_operations = + (subgroup.supportedStages & VK_SHADER_STAGE_COMPUTE_BIT) ? subgroup.supportedOperations : 0; + + max_num_threads = properties.properties.limits.maxComputeWorkGroupInvocations; + + // Even if we can't query it, warp size must be at least 1. + thread_warp_size = std::max(subgroup.subgroupSize, 1U); + + max_block_size_x = properties.properties.limits.maxComputeWorkGroupSize[0]; + max_block_size_y = properties.properties.limits.maxComputeWorkGroupSize[1]; + max_block_size_z = properties.properties.limits.maxComputeWorkGroupSize[2]; + max_push_constants_size = properties.properties.limits.maxPushConstantsSize; + max_uniform_buffer_range = properties.properties.limits.maxUniformBufferRange; + max_storage_buffer_range = properties.properties.limits.maxStorageBufferRange; + max_per_stage_descriptor_storage_buffer = + properties.properties.limits.maxPerStageDescriptorStorageBuffers; + max_shared_memory_per_block = properties.properties.limits.maxComputeSharedMemorySize; + device_name = properties.properties.deviceName; + driver_version = properties.properties.driverVersion; + + // By default, use the maximum API version that the driver allows, + // so that any supported features can be used by TVM shaders. + // However, if we can query the conformance version, then limit to + // only using the api version that passes the vulkan conformance + // tests. + vulkan_api_version = properties.properties.apiVersion; + if (device.HasExtension("VK_KHR_driver_properties")) { + auto api_major = VK_VERSION_MAJOR(vulkan_api_version); + auto api_minor = VK_VERSION_MINOR(vulkan_api_version); + if ((api_major > driver.conformanceVersion.major) || + ((api_major == driver.conformanceVersion.major) && + (api_minor > driver.conformanceVersion.minor))) { + vulkan_api_version = + VK_MAKE_VERSION(driver.conformanceVersion.major, driver.conformanceVersion.minor, 0); + } + } + + // From "Versions and Formats" section of Vulkan spec. + max_spirv_version = 0x10000; + if (vulkan_api_version >= VK_API_VERSION_1_2) { + max_spirv_version = 0x10500; + } else if (device.HasExtension("VK_KHR_spirv_1_4")) { + max_spirv_version = 0x10400; + } else if (vulkan_api_version >= VK_API_VERSION_1_1) { + max_spirv_version = 0x10300; + } +} + +VulkanDescriptorTemplateKHRFunctions::VulkanDescriptorTemplateKHRFunctions(VkDevice device) { + vkCreateDescriptorUpdateTemplateKHR = (PFN_vkCreateDescriptorUpdateTemplateKHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkCreateDescriptorUpdateTemplateKHR")); + vkDestroyDescriptorUpdateTemplateKHR = (PFN_vkDestroyDescriptorUpdateTemplateKHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkDestroyDescriptorUpdateTemplateKHR")); + vkUpdateDescriptorSetWithTemplateKHR = (PFN_vkUpdateDescriptorSetWithTemplateKHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkUpdateDescriptorSetWithTemplateKHR")); + vkCmdPushDescriptorSetWithTemplateKHR = (PFN_vkCmdPushDescriptorSetWithTemplateKHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkCmdPushDescriptorSetWithTemplateKHR")); +} + +VulkanGetBufferMemoryRequirements2Functions::VulkanGetBufferMemoryRequirements2Functions( + VkDevice device) { + vkGetBufferMemoryRequirements2KHR = (PFN_vkGetBufferMemoryRequirements2KHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkGetBufferMemoryRequirements2KHR")); +} + +VulkanDevice::VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_device) + : physical_device_(phy_device) { + queue_family_index = SelectComputeQueueFamily(); + if (queue_family_index == uint32_t(-1)) { + // The GPU doesn't support compute, cannot use + return; + } + + enabled_extensions = SelectEnabledExtensions(); + device_properties = VulkanDeviceProperties(instance, *this); + CreateVkDevice(instance); + + // Currently, any exceptions called after this point will prevent + // vkDestroyDevice from being called in the destructor. If this + // becomes an issue, can split out the VulkanDevice into two + // classes, one of which strictly holds the VkDevice, and one which + // holds the ancillary handles that TVM needs. + + vkGetDeviceQueue(device_, queue_family_index, 0, &queue); + + // Find suitable memory type for staging and compute + // Find suitable compute index. + VkBuffer buffer; + VkMemoryRequirements req_staging, req_compute; + VkBufferCreateInfo info; + info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + info.pNext = nullptr; + info.flags = 0; + info.size = 1024; + info.queueFamilyIndexCount = 1; + info.pQueueFamilyIndices = &queue_family_index; + info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + + // get staging requirement + info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; + VULKAN_CALL(vkCreateBuffer(device_, &info, nullptr, &buffer)); + vkGetBufferMemoryRequirements(device_, buffer, &req_staging); + vkDestroyBuffer(device_, buffer, nullptr); + // get compute requirement + info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | + VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; + VULKAN_CALL(vkCreateBuffer(device_, &info, nullptr, &buffer)); + vkGetBufferMemoryRequirements(device_, buffer, &req_compute); + vkDestroyBuffer(device_, buffer, nullptr); + + // Query phyiscal device property + // find a memory that is host visible, no need to be consistent + int win_rank = -1; + VkPhysicalDeviceMemoryProperties prop; + vkGetPhysicalDeviceMemoryProperties(physical_device_, &prop); + + for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { + VkMemoryType ty = prop.memoryTypes[k]; + size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; + // host visible + if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue; + // match copy requirment + if (!(req_staging.memoryTypeBits & (1 << k))) continue; + if (heap_size < 1024) continue; + int rank = 0; + rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT; + if (rank > win_rank) { + win_rank = rank; + staging_mtype_index = k; + coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; + } + } + ICHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device."; + + win_rank = -1; + for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { + VkMemoryType ty = prop.memoryTypes[k]; + size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; + // host visible + if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue; + // match copy requirment + if (!(req_staging.memoryTypeBits & (1 << k))) continue; + if (heap_size < 1024) continue; + int rank = 0; + // prefer not host visible + rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT); + if (rank > win_rank) { + win_rank = rank; + compute_mtype_index = k; + } + } + ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device."; + + if (device_properties.supports_push_descriptor) { + descriptor_template_khr_functions = + std::make_unique(device_); + } + + if (device_properties.supports_dedicated_allocation) { + get_buffer_memory_requirements_2_functions = + std::make_unique(device_); + } +} + +VulkanDevice::~VulkanDevice() { + if (device_) { + vkDestroyDevice(device_, nullptr); + } +} + +VulkanDevice::VulkanDevice(VulkanDevice&& other) { do_swap(std::move(other)); } + +VulkanDevice& VulkanDevice::operator=(VulkanDevice&& other) { + do_swap(std::move(other)); + return *this; +} + +void VulkanDevice::do_swap(VulkanDevice&& other) { + if (this == &other) { + return; + } + + std::lock(queue_mutex, other.queue_mutex); + std::lock_guard lock_self(queue_mutex, std::adopt_lock); + std::lock_guard lock_other(other.queue_mutex, std::adopt_lock); + + std::swap(device_properties, other.device_properties); + std::swap(staging_mtype_index, other.staging_mtype_index); + std::swap(coherent_staging, other.coherent_staging); + std::swap(descriptor_template_khr_functions, other.descriptor_template_khr_functions); + std::swap(get_buffer_memory_requirements_2_functions, + other.get_buffer_memory_requirements_2_functions); + std::swap(compute_mtype_index, other.compute_mtype_index); + std::swap(queue, other.queue); + std::swap(queue_family_index, other.queue_family_index); + std::swap(physical_device_, other.physical_device_); + std::swap(enabled_extensions, other.enabled_extensions); + std::swap(device_, other.device_); +} + +bool VulkanDevice::SupportsCompute() const { return queue_family_index != uint32_t(-1); } + +void VulkanDevice::QueueSubmit(VkSubmitInfo submit_info, VkFence fence) const { + // Multiple streams (on different threads) use the same VulkanDevice + // instance, so we need to externally synchronize accesses. + std::lock_guard lock(queue_mutex); + VULKAN_CALL(vkQueueSubmit(queue, 1, &submit_info, fence)); +} + +uint32_t VulkanDevice::SelectComputeQueueFamily() const { + // Get a queue family that supports compute. We currently only use + // one queue from one family. + uint32_t queue_prop_count = 0; + vkGetPhysicalDeviceQueueFamilyProperties(physical_device_, &queue_prop_count, nullptr); + std::vector queue_props(queue_prop_count); + vkGetPhysicalDeviceQueueFamilyProperties(physical_device_, &queue_prop_count, + dmlc::BeginPtr(queue_props)); + + std::vector result; + // Prefer compute-only queues. On certain devices supporting this (e.g. Mesa RADV), using + // compute-only queues gives better responsiveness for other graphics workload (e.g. desktop). + for (uint32_t i = 0; i != queue_prop_count; ++i) { + if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && + (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) == 0) { + return i; + } + } + // Now, push the compute queues that we skipped above into the list. + for (uint32_t i = 0; i != queue_prop_count; ++i) { + if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && + (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) != 0) { + return i; + } + } + + // No queues support compute capability, this GPU cannot be used. + return -1; +} + +std::vector VulkanDevice::SelectEnabledExtensions() const { + std::vector required_extensions{}; + std::vector optional_extensions{ + "VK_KHR_driver_properties", + "VK_KHR_storage_buffer_storage_class", + "VK_KHR_8bit_storage", + "VK_KHR_16bit_storage", + "VK_KHR_shader_float16_int8", + "VK_KHR_push_descriptor", + "VK_KHR_descriptor_update_template", + "VK_KHR_get_memory_requirements2", + "VK_KHR_dedicated_allocation", + "VK_KHR_spirv_1_4", + }; + + uint32_t device_extension_prop_count; + VULKAN_CALL(vkEnumerateDeviceExtensionProperties(physical_device_, nullptr, + &device_extension_prop_count, nullptr)); + std::vector device_extension_prop(device_extension_prop_count); + VULKAN_CALL(vkEnumerateDeviceExtensionProperties( + physical_device_, nullptr, &device_extension_prop_count, device_extension_prop.data())); + + return FindEnabledExtensions(device_extension_prop, required_extensions, optional_extensions); +} + +bool VulkanDevice::HasExtension(const char* query) const { + return std::any_of(enabled_extensions.begin(), enabled_extensions.end(), + [&](const char* extension) { return std::strcmp(query, extension) == 0; }); +} + +void VulkanDevice::CreateVkDevice(const VulkanInstance& instance) { + // Enable all features we may use that a device supports. + VkPhysicalDeviceFeatures2 enabled_features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; + VkPhysicalDevice8BitStorageFeatures storage_8bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; + VkPhysicalDevice16BitStorageFeatures storage_16bit = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; + VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { + VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; + + void** pp_next = &enabled_features.pNext; + bool needs_float16_int8 = false; + + if (device_properties.supports_float16) { + float16_int8.shaderFloat16 = true; + needs_float16_int8 = true; + } + if (device_properties.supports_float64) { + enabled_features.features.shaderFloat64 = true; + } + if (device_properties.supports_int8) { + float16_int8.shaderInt8 = true; + needs_float16_int8 = true; + } + if (device_properties.supports_int16) { + enabled_features.features.shaderInt16 = true; + } + if (device_properties.supports_int64) { + enabled_features.features.shaderInt64 = true; + } + if (device_properties.supports_8bit_buffer) { + storage_8bit.storageBuffer8BitAccess = true; + *pp_next = &storage_8bit; + pp_next = &storage_8bit.pNext; + } + if (device_properties.supports_16bit_buffer) { + storage_16bit.storageBuffer16BitAccess = true; + *pp_next = &storage_16bit; + pp_next = &storage_16bit.pNext; + } + + if (needs_float16_int8) { + *pp_next = &float16_int8; + pp_next = &float16_int8.pNext; + } + + float priority = 1.0f; + + struct VkDeviceQueueCreateInfo queue_create_info; + queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; + queue_create_info.pNext = nullptr; + queue_create_info.flags = 0; + queue_create_info.queueFamilyIndex = queue_family_index; + queue_create_info.queueCount = 1; + queue_create_info.pQueuePriorities = &priority; + + VkDeviceCreateInfo device_create_info; + device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; + device_create_info.pNext = nullptr; + device_create_info.flags = 0; + device_create_info.queueCreateInfoCount = 1; + device_create_info.pQueueCreateInfos = &queue_create_info; + device_create_info.enabledLayerCount = 0; + device_create_info.ppEnabledLayerNames = nullptr; + device_create_info.enabledExtensionCount = enabled_extensions.size(); + device_create_info.ppEnabledExtensionNames = enabled_extensions.data(); + + if (instance.HasExtension("VK_KHR_get_physical_device_properties2")) { + device_create_info.pEnabledFeatures = nullptr; + device_create_info.pNext = &enabled_features; + } else { + device_create_info.pNext = nullptr; + device_create_info.pEnabledFeatures = &enabled_features.features; + } + VULKAN_CALL(vkCreateDevice(physical_device_, &device_create_info, nullptr, &device_)); +} + +uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info, + VkMemoryPropertyFlags req_prop) { + VkBuffer buffer; + VULKAN_CALL(vkCreateBuffer(device, &info, nullptr, &buffer)); + + VkMemoryRequirements mem_reqs; + vkGetBufferMemoryRequirements(device, buffer, &mem_reqs); + uint32_t type_bits = mem_reqs.memoryTypeBits; + VkPhysicalDeviceMemoryProperties phy_mem_prop; + vkGetPhysicalDeviceMemoryProperties(device, &phy_mem_prop); + for (uint32_t i = 0; i < phy_mem_prop.memoryTypeCount; i++) { + if ((type_bits & 1) == 1 && + (phy_mem_prop.memoryTypes[i].propertyFlags & req_prop) == req_prop) { + return i; + } + type_bits >>= 1; + } + LOG(FATAL) << "Requested memory type not found"; + return 0; +} + +VkBufferCreateInfo MakeBufferCreateInfo(const VulkanDevice& device, size_t nbytes, + VkBufferUsageFlags usage) { + VkBufferCreateInfo info; + info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + info.pNext = nullptr; + info.flags = 0; + info.size = nbytes; + info.queueFamilyIndexCount = 1; + info.pQueueFamilyIndices = &(device.queue_family_index); + info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + info.usage = usage; + return info; +} + +VulkanBuffer* CreateBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, + uint32_t mem_type_index) { + auto info = MakeBufferCreateInfo(device, nbytes, usage); + // create buffer + VkBuffer buffer; + VULKAN_CALL(vkCreateBuffer(device, &info, nullptr, &buffer)); + + // bind to memory + bool dedicated_allocation = false; + VkMemoryRequirements2KHR req2; + + if (device.get_buffer_memory_requirements_2_functions) { + VkBufferMemoryRequirementsInfo2KHR req_info2; + req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR; + req_info2.pNext = 0; + req_info2.buffer = buffer; + + req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR; + req2.pNext = 0; + + VkMemoryDedicatedRequirementsKHR dedicated_req; + dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; + dedicated_req.pNext = 0; + req2.pNext = &dedicated_req; + + device.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( + device, &req_info2, &req2); + dedicated_allocation = + dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation; + } + + VkDeviceMemory memory; + if (!dedicated_allocation) { + VkMemoryAllocateInfo minfo; + minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + minfo.pNext = nullptr; + minfo.allocationSize = info.size; + minfo.memoryTypeIndex = mem_type_index; + VULKAN_CALL(vkAllocateMemory(device, &minfo, nullptr, &memory)); + } else { + VkMemoryAllocateInfo minfo; + minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + minfo.pNext = nullptr; + minfo.allocationSize = req2.memoryRequirements.size; + minfo.memoryTypeIndex = mem_type_index; + + VkMemoryDedicatedAllocateInfoKHR mdinfo; + mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR; + mdinfo.pNext = 0; + mdinfo.image = 0; + mdinfo.buffer = buffer; + minfo.pNext = &mdinfo; + VULKAN_CALL(vkAllocateMemory(device, &minfo, nullptr, &memory)); + } + VULKAN_CALL(vkBindBufferMemory(device, buffer, memory, 0)); + VulkanBuffer* pbuf = new VulkanBuffer(); + pbuf->memory = memory; + pbuf->buffer = buffer; + return pbuf; +} + +VulkanHostVisibleBuffer* GetOrAllocate( + int device_id, size_t size, VkBufferUsageFlags usage, uint32_t mem_type_index, + std::unordered_map>* buffers_ptr, + bool sync_before_realloc) { + auto& buffers = *buffers_ptr; + if (!buffers[device_id]) { + buffers[device_id] = std::make_unique(); + } + + auto& buf = *(buffers[device_id]); + if (buf.device != nullptr && buf.size < size) { + // free previous buffer + if (sync_before_realloc) { + // For the deferred execution mode, we need to make sure that old tasks that use + // the older, smaller buffer get finished + // Synchronization on staging buffers is done after host to device memory copy + // For UBO, we sync here before we reallocate a larger buffer, to minimize synchronization + // points + VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Synchronize(); + } + DeleteHostVisibleBuffer(&buf); + } + + const auto& vulkan_device = VulkanDeviceAPI::Global()->device(device_id); + + if (buf.device == nullptr) { + buf.device = vulkan_device; + } + if (buf.host_addr == nullptr) { + buf.vk_buf = CreateBuffer(vulkan_device, size, usage, mem_type_index); + VULKAN_CALL(vkMapMemory(vulkan_device, buf.vk_buf->memory, 0, size, 0, &(buf.host_addr))); + buf.size = size; + } + return &buf; +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_context.h b/src/runtime/vulkan/vulkan_device.h similarity index 51% rename from src/runtime/vulkan/vulkan_context.h rename to src/runtime/vulkan/vulkan_device.h index 158a53043c7b..b55eb8a3d9e0 100644 --- a/src/runtime/vulkan/vulkan_context.h +++ b/src/runtime/vulkan/vulkan_device.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_VULKAN_VULKAN_CONTEXT_H_ -#define TVM_RUNTIME_VULKAN_VULKAN_CONTEXT_H_ +#ifndef TVM_RUNTIME_VULKAN_VULKAN_DEVICE_H_ +#define TVM_RUNTIME_VULKAN_VULKAN_DEVICE_H_ #include #include @@ -34,6 +34,9 @@ namespace tvm { namespace runtime { namespace vulkan { +class VulkanInstance; +class VulkanDevice; + struct VulkanDescriptorTemplateKHRFunctions { explicit VulkanDescriptorTemplateKHRFunctions(VkDevice device); @@ -59,9 +62,7 @@ struct VulkanGetBufferMemoryRequirements2Functions { */ struct VulkanDeviceProperties { VulkanDeviceProperties() {} - VulkanDeviceProperties(VkInstance instance, VkPhysicalDevice phy_device, - const std::vector instance_extensions, - const std::vector device_extensions); + VulkanDeviceProperties(const VulkanInstance& instance, const VulkanDevice& device); bool supports_float16{false}; bool supports_float32{true}; @@ -92,15 +93,72 @@ struct VulkanDeviceProperties { uint32_t max_spirv_version{0x10000}; }; -struct VulkanContext { - // physical device - VkPhysicalDevice phy_device{nullptr}; +/*! \brief Handle to the Vulkan API's VkDevice + * + * Handles all setup and teardown of the class. The owner of the + * VulkanDevice object is responsible for ensuring that it remains + * alive as long as any object that accesses that device is used. + */ +class VulkanDevice { + public: + VulkanDevice(const VulkanInstance& instance, VkPhysicalDevice phy_dev); + ~VulkanDevice(); + + // Allow move constructor/assignment + VulkanDevice(VulkanDevice&&); + VulkanDevice& operator=(VulkanDevice&&); + + // Disable copy constructor/assignment + VulkanDevice(const VulkanDevice&) = delete; + VulkanDevice& operator=(const VulkanDevice&) = delete; + + /*! \brief Expose the internal VkDevice + * + * Allows the managed class to be passed to Vulkan APIs as if it + * were the VkDevice handler itself. + */ + operator VkDevice() const { return device_; } + + /*! \brief Expose the internal VkPhysicalDevice + * + * Allows the managed class to be passed to Vulkan APIs as if it + * were the VkPhysicalDevice handler itself. + */ + operator VkPhysicalDevice() const { return physical_device_; } + + /*! \brief Returns whether this device supports Vulkan compute operations. + * + * If the device does not support Vulkan compute operations, it + * should not be used any further. + */ + bool SupportsCompute() const; + + /*! \brief Calls vkQueueSubmit to run work on the GPU + * + * Currently only supports submitting a single VkSubmitInfo at a + * time. Handles mutexing internally, safe to call from multiple + * CPU threads. + * + * \param submit_info The job submission information to be passed to + * vkQueueSubmit. + * + * \param fence Optional fence to be passed to vkQueueSubmit, + * signals once the command buffers submitted have completed. + */ + void QueueSubmit(VkSubmitInfo submit_info, VkFence fence) const; + + /*! \brief Checks if the device has an extension enabled + * + * Returns true if the device was initialized with the extension + * given. + * + * \param query The name of the extension to check. + */ + bool HasExtension(const char* query) const; // Cached device properties, queried through Vulkan API. - VulkanDeviceProperties device_properties; + VulkanDeviceProperties device_properties{}; - // Phyiscal device property - VkPhysicalDeviceProperties phy_device_prop; // Memory type index for staging. uint32_t staging_mtype_index{0}; // whether staging is coherent @@ -111,31 +169,60 @@ struct VulkanContext { get_buffer_memory_requirements_2_functions{nullptr}; // Memory type index for compute uint32_t compute_mtype_index{0}; - // The logical device - VkDevice device{nullptr}; - // command queue - std::unique_ptr queue_mutex; - VkQueue queue{nullptr}; // queue family_index; - uint32_t queue_family_index{0}; - // Queue family index. - VkQueueFamilyProperties queue_prop; + uint32_t queue_family_index{uint32_t(-1)}; bool UseImmediate() const { return descriptor_template_khr_functions != nullptr; } + + private: + /*! \brief Helper function for move assignment/construction + * + * Named "do_swap" instead of "swap" because otherwise cpplint.py + * thinks that it needs the header include. + */ + void do_swap(VulkanDevice&& other); + + uint32_t SelectComputeQueueFamily() const; + std::vector SelectEnabledExtensions() const; + void CreateVkDevice(const VulkanInstance& instance); + + //! \brief Handle to the Vulkan API physical device + VkPhysicalDevice physical_device_{nullptr}; + + /*! \brief Extensions enabled for this device + * + * Based on supported extensions queried from physical_device_ prior + * to creating device_. Contains only statically allocated string + * literals, no cleanup required. + */ + std::vector enabled_extensions; + + //! \brief Handle to the Vulkan API logical device + VkDevice device_{nullptr}; + + //! \brief Mutex to protect access to queue + mutable std::mutex queue_mutex; + + /*! \brief Handle to Vulkan API VkQueue. + * + * Work can be executed by submitted to this queue using + * VulkanDevice::SubmitQueue. + */ + VkQueue queue{nullptr}; }; -uint32_t FindMemoryType(const VulkanContext& vctx, VkBufferCreateInfo info, +uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info, VkMemoryPropertyFlags req_prop); -VkBufferCreateInfo MakeBufferCreateInfo(const VulkanContext& vctx, size_t nbytes, +VkBufferCreateInfo MakeBufferCreateInfo(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage); -VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsageFlags usage, +VulkanBuffer* CreateBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, uint32_t mem_type_index); } // namespace vulkan } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_VULKAN_VULKAN_CONTEXT_H_ +#endif // TVM_RUNTIME_VULKAN_VULKAN_DEVICE_H_ diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index 7cea2489cb1b..bc25f25e7e12 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -25,6 +25,7 @@ #include #include +#include "vulkan_common.h" #include "vulkan_thread_entry.h" namespace tvm { @@ -42,318 +43,28 @@ VulkanDeviceAPI* VulkanDeviceAPI::Global() { } VulkanDeviceAPI::VulkanDeviceAPI() { - const auto layers = []() -> std::vector { - uint32_t inst_layer_prop_count; - VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, nullptr)); - std::vector inst_layer_prop(inst_layer_prop_count); - VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, inst_layer_prop.data())); - std::vector l; - - const char* enable = std::getenv("TVM_VULKAN_ENABLE_VALIDATION_LAYERS"); - bool validation_enabled = enable && *enable; - if (validation_enabled) { - for (const auto& lp : inst_layer_prop) { - if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_standard_validation") == 0) { - l.push_back("VK_LAYER_LUNARG_standard_validation"); - } - if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_parameter_validation") == 0) { - l.push_back("VK_LAYER_LUNARG_parameter_validation"); - } - if (std::strcmp(lp.layerName, "VK_LAYER_KHRONOS_validation") == 0) { - l.push_back("VK_LAYER_KHRONOS_validation"); - } - } - } - return l; - }(); - - const auto instance_extensions = [this]() { - std::vector required_extensions{}; - std::vector optional_extensions{"VK_KHR_get_physical_device_properties2"}; - - uint32_t inst_extension_prop_count; - VULKAN_CALL( - vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, nullptr)); - std::vector inst_extension_prop(inst_extension_prop_count); - VULKAN_CALL(vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, - inst_extension_prop.data())); - - return FindEnabledExtensions(inst_extension_prop, required_extensions, optional_extensions); - }(); - - auto has_instance_extension = [&instance_extensions](const char* query) { - return std::any_of(instance_extensions.begin(), instance_extensions.end(), - [&](const char* extension) { return std::strcmp(query, extension) == 0; }); - }; - - const auto instance_api_version = []() { - uint32_t api_version = VK_MAKE_VERSION(1, 0, 0); - - // Result from vkGetInstanceProcAddr is NULL if driver only - // supports vulkan 1.0. - auto vkEnumerateInstanceVersion = - (PFN_vkEnumerateInstanceVersion)vkGetInstanceProcAddr(NULL, "vkEnumerateInstanceVersion"); - if (vkEnumerateInstanceVersion) { - vkEnumerateInstanceVersion(&api_version); - } - - return api_version; - }(); - - { - VkApplicationInfo app_info; - app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; - app_info.pNext = nullptr; - app_info.pApplicationName = "TVM"; - app_info.applicationVersion = 0; - app_info.pEngineName = ""; - app_info.engineVersion = 0; - app_info.apiVersion = instance_api_version; - - VkInstanceCreateInfo inst_info; - inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; - inst_info.pNext = nullptr; - inst_info.flags = 0; - inst_info.pApplicationInfo = &app_info; - inst_info.enabledLayerCount = layers.size(); - inst_info.ppEnabledLayerNames = layers.data(); - inst_info.enabledExtensionCount = instance_extensions.size(); - inst_info.ppEnabledExtensionNames = instance_extensions.data(); - - VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &instance_)); - } - - uint32_t phy_dev_count = 0; - VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, nullptr)); - std::vector all_phy_devs(phy_dev_count); - VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &phy_dev_count, dmlc::BeginPtr(all_phy_devs))); - for (VkPhysicalDevice phy_dev : all_phy_devs) { - // Get a list of queue families supporting compute, in order of preference. We currently only - // make use of the most preferred one family. - std::vector queue_family_indexes = GetComputeQueueFamilies(phy_dev); - if (queue_family_indexes.empty()) continue; - uint32_t queue_family_index = queue_family_indexes[0]; - float priority = 1.0f; - - struct VkDeviceQueueCreateInfo queue_create_info; - queue_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO; - queue_create_info.pNext = nullptr; - queue_create_info.flags = 0; - queue_create_info.queueFamilyIndex = queue_family_index; - queue_create_info.queueCount = 1; - queue_create_info.pQueuePriorities = &priority; - - VulkanContext ctx; - // setup context - ctx.phy_device = phy_dev; - vkGetPhysicalDeviceProperties(ctx.phy_device, &(ctx.phy_device_prop)); - - const auto device_extensions = [&]() { - std::vector required_extensions{}; - std::vector optional_extensions{ - "VK_KHR_driver_properties", - "VK_KHR_storage_buffer_storage_class", - "VK_KHR_8bit_storage", - "VK_KHR_16bit_storage", - "VK_KHR_shader_float16_int8", - "VK_KHR_push_descriptor", - "VK_KHR_descriptor_update_template", - "VK_KHR_get_memory_requirements2", - "VK_KHR_dedicated_allocation", - "VK_KHR_spirv_1_4", - }; - - uint32_t device_extension_prop_count; - VULKAN_CALL(vkEnumerateDeviceExtensionProperties(ctx.phy_device, nullptr, - &device_extension_prop_count, nullptr)); - std::vector device_extension_prop(device_extension_prop_count); - VULKAN_CALL(vkEnumerateDeviceExtensionProperties( - ctx.phy_device, nullptr, &device_extension_prop_count, device_extension_prop.data())); - - return FindEnabledExtensions(device_extension_prop, required_extensions, optional_extensions); - }(); - - ctx.device_properties = - VulkanDeviceProperties(instance_, phy_dev, instance_extensions, device_extensions); - - { - // Enable all features we may use that a device supports. - VkPhysicalDeviceFeatures2 enabled_features = {VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2}; - VkPhysicalDevice8BitStorageFeatures storage_8bit = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_8BIT_STORAGE_FEATURES}; - VkPhysicalDevice16BitStorageFeatures storage_16bit = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_16BIT_STORAGE_FEATURES}; - VkPhysicalDeviceShaderFloat16Int8Features float16_int8 = { - VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_FLOAT16_INT8_FEATURES}; - - void** pp_next = &enabled_features.pNext; - bool needs_float16_int8 = false; - - if (ctx.device_properties.supports_float16) { - float16_int8.shaderFloat16 = true; - needs_float16_int8 = true; - } - if (ctx.device_properties.supports_float64) { - enabled_features.features.shaderFloat64 = true; - } - if (ctx.device_properties.supports_int8) { - float16_int8.shaderInt8 = true; - needs_float16_int8 = true; - } - if (ctx.device_properties.supports_int16) { - enabled_features.features.shaderInt16 = true; - } - if (ctx.device_properties.supports_int64) { - enabled_features.features.shaderInt64 = true; - } - if (ctx.device_properties.supports_8bit_buffer) { - storage_8bit.storageBuffer8BitAccess = true; - *pp_next = &storage_8bit; - pp_next = &storage_8bit.pNext; - } - if (ctx.device_properties.supports_16bit_buffer) { - storage_16bit.storageBuffer16BitAccess = true; - *pp_next = &storage_16bit; - pp_next = &storage_16bit.pNext; - } - - if (needs_float16_int8) { - *pp_next = &float16_int8; - pp_next = &float16_int8.pNext; - } - - VkDeviceCreateInfo device_create_info; - device_create_info.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO; - device_create_info.pNext = nullptr; - device_create_info.flags = 0; - device_create_info.queueCreateInfoCount = 1; - device_create_info.pQueueCreateInfos = &queue_create_info; - device_create_info.enabledLayerCount = 0; - device_create_info.ppEnabledLayerNames = nullptr; - device_create_info.enabledExtensionCount = device_extensions.size(); - device_create_info.ppEnabledExtensionNames = device_extensions.data(); - - if (has_instance_extension("VK_KHR_get_physical_device_properties2")) { - device_create_info.pEnabledFeatures = nullptr; - device_create_info.pNext = &enabled_features; - } else { - device_create_info.pNext = nullptr; - device_create_info.pEnabledFeatures = &enabled_features.features; - } - VULKAN_CALL(vkCreateDevice(phy_dev, &device_create_info, nullptr, &(ctx.device))); - } - - ctx.queue_mutex.reset(new std::mutex()); - vkGetDeviceQueue(ctx.device, queue_family_index, 0, &(ctx.queue)); - ctx.queue_family_index = queue_family_index; - // Find suitable memory type for staging and compute - // Find suitable compute index. - VkBuffer buffer; - VkMemoryRequirements req_staging, req_compute; - VkBufferCreateInfo info; - info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - info.pNext = nullptr; - info.flags = 0; - info.size = 1024; - info.queueFamilyIndexCount = 1; - info.pQueueFamilyIndices = &(ctx.queue_family_index); - info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - - // get staging requirement - info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer)); - vkGetBufferMemoryRequirements(ctx.device, buffer, &req_staging); - vkDestroyBuffer(ctx.device, buffer, nullptr); - // get compute requirement - info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | - VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - VULKAN_CALL(vkCreateBuffer(ctx.device, &info, nullptr, &buffer)); - vkGetBufferMemoryRequirements(ctx.device, buffer, &req_compute); - vkDestroyBuffer(ctx.device, buffer, nullptr); - - // Query phyiscal device property - // find a memory that is host visible, no need to be consistent - int win_rank = -1; - VkPhysicalDeviceMemoryProperties prop; - vkGetPhysicalDeviceMemoryProperties(ctx.phy_device, &prop); - - for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { - VkMemoryType ty = prop.memoryTypes[k]; - size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; - // host visible - if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT)) continue; - // match copy requirment - if (!(req_staging.memoryTypeBits & (1 << k))) continue; - if (heap_size < 1024) continue; - int rank = 0; - rank += ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_CACHED_BIT; - if (rank > win_rank) { - win_rank = rank; - ctx.staging_mtype_index = k; - ctx.coherent_staging = ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; - } - } - ICHECK_GE(win_rank, 0) << "Cannot find suitable staging memory on device."; - - win_rank = -1; - for (uint32_t k = 0; k < prop.memoryTypeCount; ++k) { - VkMemoryType ty = prop.memoryTypes[k]; - size_t heap_size = prop.memoryHeaps[ty.heapIndex].size; - // host visible - if (!(ty.propertyFlags & VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT)) continue; - // match copy requirment - if (!(req_staging.memoryTypeBits & (1 << k))) continue; - if (heap_size < 1024) continue; - int rank = 0; - // prefer not host visible - rank += !(ty.propertyFlags & VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT); - if (rank > win_rank) { - win_rank = rank; - ctx.compute_mtype_index = k; - } - } - ICHECK_GE(win_rank, 0) << "Cannot find suitable local memory on device."; - - if (ctx.device_properties.supports_push_descriptor) { - ctx.descriptor_template_khr_functions = - std::make_unique(ctx.device); - } + std::vector vulkan_physical_devices = instance_.GetPhysicalDevices(); + for (VkPhysicalDevice phy_dev : vulkan_physical_devices) { + VulkanDevice device(instance_, phy_dev); - if (ctx.device_properties.supports_dedicated_allocation) { - ctx.get_buffer_memory_requirements_2_functions = - std::make_unique(ctx.device); + if (device.SupportsCompute()) { + devices_.push_back(std::move(device)); } - - context_.push_back(std::move(ctx)); - } - - LOG(INFO) << "Initialize Vulkan with " << context_.size() << " devices.."; - for (size_t i = 0; i < context_.size(); ++i) { - LOG(INFO) << "vulkan(" << i << ")=\'" << context_[i].phy_device_prop.deviceName - << "\' phy_dev_id=" << context_[i].phy_device - << " use_immediate=" << context_[i].UseImmediate(); } } -VulkanDeviceAPI::~VulkanDeviceAPI() { - for (auto& vctx : context_) { - vkDestroyDevice(vctx.device, nullptr); - } - if (instance_) { - vkDestroyInstance(instance_, nullptr); - } -} +VulkanDeviceAPI::~VulkanDeviceAPI() {} void VulkanDeviceAPI::SetDevice(Device dev) { VulkanThreadEntry::ThreadLocal()->device = dev; } void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) { size_t index = static_cast(dev.device_id); if (kind == kExist) { - *rv = static_cast(index < context_.size()); + *rv = static_cast(index < devices_.size()); return; } - const auto& prop = context(index).device_properties; + const auto& prop = device(index).device_properties; switch (kind) { case kMaxThreadsPerBlock: { @@ -420,7 +131,7 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property, TVMRetValue* rv) { size_t index = static_cast(dev.device_id); - const auto& prop = context(index).device_properties; + const auto& prop = device(index).device_properties; if (property == "supports_float16") { *rv = prop.supports_float16; @@ -511,10 +222,10 @@ void* VulkanDeviceAPI::AllocDataSpace(Device dev, size_t nbytes, size_t alignmen // Vulkan seems to have issues if we return nullptr on zero size alloc nbytes = 1; } - const auto& vctx = context(dev.device_id); + const auto& device = this->device(dev.device_id); auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - return CreateBuffer(vctx, nbytes, usage, vctx.compute_mtype_index); + return CreateBuffer(device, nbytes, usage, device.compute_mtype_index); } void VulkanDeviceAPI::FreeDataSpace(Device dev, void* ptr) { @@ -522,10 +233,10 @@ void VulkanDeviceAPI::FreeDataSpace(Device dev, void* ptr) { // finish all the vulkan commands that reference the buffer. StreamSync(dev, nullptr); - const auto& vctx = context(dev.device_id); + const auto& device = this->device(dev.device_id); auto* pbuf = static_cast(ptr); - vkDestroyBuffer(vctx.device, pbuf->buffer, nullptr); - vkFreeMemory(vctx.device, pbuf->memory, nullptr); + vkDestroyBuffer(device, pbuf->buffer, nullptr); + vkFreeMemory(device, pbuf->memory, nullptr); delete pbuf; } @@ -598,7 +309,7 @@ void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* } else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) { const auto* from_buf = static_cast(from); - const auto& vctx = context(dev_from.device_id); + const auto& device = this->device(dev_from.device_id); auto* temp = VulkanThreadEntry::ThreadLocal()->StagingBuffer(dev_from.device_id, size); VulkanThreadEntry::ThreadLocal() ->Stream(dev_from.device_id) @@ -611,32 +322,32 @@ void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* ©_info); }); VulkanThreadEntry::ThreadLocal()->Stream(dev_from.device_id)->Synchronize(); - if (!vctx.coherent_staging) { + if (!device.coherent_staging) { VkMappedMemoryRange mrange; mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; mrange.pNext = nullptr; mrange.memory = temp->vk_buf->memory; mrange.offset = 0; mrange.size = VK_WHOLE_SIZE; // size; - VULKAN_CALL(vkInvalidateMappedMemoryRanges(vctx.device, 1, &mrange)); + VULKAN_CALL(vkInvalidateMappedMemoryRanges(device, 1, &mrange)); } memcpy(static_cast(to) + to_offset, static_cast(temp->host_addr), size); } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) { - const auto& vctx = context(dev_to.device_id); + const auto& device = this->device(dev_to.device_id); const auto* to_buf = static_cast(to); VulkanStagingBuffer* temp = VulkanThreadEntry::ThreadLocal()->StagingBuffer(dev_to.device_id, size); memcpy(temp->host_addr, static_cast(from) + from_offset, size); // host side flush if access is not coherent. // so writes from CPU is visible to GPU - if (!vctx.coherent_staging) { + if (!device.coherent_staging) { VkMappedMemoryRange mrange; mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; mrange.pNext = nullptr; mrange.memory = temp->vk_buf->memory; mrange.offset = 0; mrange.size = VK_WHOLE_SIZE; // size; - VULKAN_CALL(vkFlushMappedMemoryRanges(vctx.device, 1, &mrange)); + VULKAN_CALL(vkFlushMappedMemoryRanges(device, 1, &mrange)); } VulkanThreadEntry::ThreadLocal() @@ -667,62 +378,10 @@ void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* } } -std::vector VulkanDeviceAPI::FindEnabledExtensions( - const std::vector& ext_prop, - const std::vector& required_extensions, - const std::vector& optional_extensions) { - std::set available_extensions; - for (const auto& prop : ext_prop) { - if (prop.specVersion > 0) { - available_extensions.insert(prop.extensionName); - } - } - - std::vector enabled_extensions; - for (const auto& ext : required_extensions) { - ICHECK(available_extensions.count(ext)) - << "Required vulkan extension \"" << ext << "\" not supported by driver"; - enabled_extensions.push_back(ext); - } - - for (const auto& ext : optional_extensions) { - if (available_extensions.count(ext)) { - enabled_extensions.push_back(ext); - } - } - - return enabled_extensions; -} - -const VulkanContext& VulkanDeviceAPI::context(size_t device_id) const { - ICHECK_LT(device_id, context_.size()) << "Requested Vulkan device_id=" << device_id - << ", but only " << context_.size() << " devices present"; - return context_[device_id]; -} - -std::vector VulkanDeviceAPI::GetComputeQueueFamilies(VkPhysicalDevice phy_dev) { - uint32_t queue_prop_count = 0; - vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, nullptr); - std::vector queue_props(queue_prop_count); - vkGetPhysicalDeviceQueueFamilyProperties(phy_dev, &queue_prop_count, dmlc::BeginPtr(queue_props)); - - std::vector result; - // Prefer compute-only queues. On certain devices supporting this (e.g. Mesa RADV), using - // compute-only queues gives better responsiveness for other graphics workload (e.g. desktop). - for (uint32_t i = 0; i != queue_prop_count; ++i) { - if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && - (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) == 0) { - result.push_back(i); - } - } - // Now, push the compute queues that we skipped above into the list. - for (uint32_t i = 0; i != queue_prop_count; ++i) { - if ((VK_QUEUE_COMPUTE_BIT & queue_props[i].queueFlags) != 0 && - (VK_QUEUE_GRAPHICS_BIT & queue_props[i].queueFlags) != 0) { - result.push_back(i); - } - } - return result; +const VulkanDevice& VulkanDeviceAPI::device(size_t device_id) const { + ICHECK_LT(device_id, devices_.size()) << "Requested Vulkan device_id=" << device_id + << ", but only " << devices_.size() << " devices present"; + return devices_[device_id]; } TVM_REGISTER_GLOBAL("device_api.vulkan").set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/src/runtime/vulkan/vulkan_device_api.h b/src/runtime/vulkan/vulkan_device_api.h index 71c73afb0d61..cf5652a3d9c4 100644 --- a/src/runtime/vulkan/vulkan_device_api.h +++ b/src/runtime/vulkan/vulkan_device_api.h @@ -26,7 +26,8 @@ #include #include "vulkan/vulkan_core.h" -#include "vulkan_context.h" +#include "vulkan_device.h" +#include "vulkan_instance.h" #include "vulkan_thread_entry.h" namespace tvm { @@ -68,12 +69,12 @@ class VulkanDeviceAPI final : public DeviceAPI { // End of required methods for the DeviceAPI interface public: - /*! \brief Return the context associated with a specific device. + /*! \brief Return the VulkanDevice associated with a specific device_id * * These are constructed during VulkanDeviceAPI initialization, so * this function returns immediately. */ - const VulkanContext& context(size_t device_id) const; + const VulkanDevice& device(size_t device_id) const; /*! \brief Returns a property to be stored in a target. * @@ -85,14 +86,9 @@ class VulkanDeviceAPI final : public DeviceAPI { private: std::vector GetComputeQueueFamilies(VkPhysicalDevice phy_dev); - std::vector FindEnabledExtensions( - const std::vector& ext_prop, - const std::vector& required_extensions, - const std::vector& optional_extensions); - - VkInstance instance_{nullptr}; + VulkanInstance instance_; // The physical devices, have 1 to 1 mapping to devices - std::vector context_; + std::vector devices_; }; } // namespace vulkan diff --git a/src/runtime/vulkan/vulkan_instance.cc b/src/runtime/vulkan/vulkan_instance.cc new file mode 100644 index 000000000000..351319e0e898 --- /dev/null +++ b/src/runtime/vulkan/vulkan_instance.cc @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_instance.h" + +#include +#include + +#include "vulkan_common.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +VulkanInstance::VulkanInstance() { + const auto layers = []() { + std::vector layers; + + const char* validation_enabled_env = std::getenv("TVM_VULKAN_ENABLE_VALIDATION_LAYERS"); + bool validation_enabled = validation_enabled_env && *validation_enabled_env; + + if (validation_enabled) { + uint32_t inst_layer_prop_count; + VULKAN_CALL(vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, nullptr)); + std::vector inst_layer_prop(inst_layer_prop_count); + VULKAN_CALL( + vkEnumerateInstanceLayerProperties(&inst_layer_prop_count, inst_layer_prop.data())); + + for (const auto& lp : inst_layer_prop) { + if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_standard_validation") == 0) { + layers.push_back("VK_LAYER_LUNARG_standard_validation"); + } + if (std::strcmp(lp.layerName, "VK_LAYER_LUNARG_parameter_validation") == 0) { + layers.push_back("VK_LAYER_LUNARG_parameter_validation"); + } + if (std::strcmp(lp.layerName, "VK_LAYER_KHRONOS_validation") == 0) { + layers.push_back("VK_LAYER_KHRONOS_validation"); + } + } + } + return layers; + }(); + + { + std::vector required_extensions{}; + std::vector optional_extensions{"VK_KHR_get_physical_device_properties2"}; + + uint32_t inst_extension_prop_count; + VULKAN_CALL( + vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, nullptr)); + std::vector inst_extension_prop(inst_extension_prop_count); + VULKAN_CALL(vkEnumerateInstanceExtensionProperties(nullptr, &inst_extension_prop_count, + inst_extension_prop.data())); + + enabled_extensions_ = + FindEnabledExtensions(inst_extension_prop, required_extensions, optional_extensions); + } + + uint32_t api_version = VK_MAKE_VERSION(1, 0, 0); + { + // Result from vkGetInstanceProcAddr is NULL if driver only + // supports vulkan 1.0. + auto vkEnumerateInstanceVersion = + (PFN_vkEnumerateInstanceVersion)vkGetInstanceProcAddr(NULL, "vkEnumerateInstanceVersion"); + if (vkEnumerateInstanceVersion) { + vkEnumerateInstanceVersion(&api_version); + } + } + + { + VkApplicationInfo app_info; + app_info.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO; + app_info.pNext = nullptr; + app_info.pApplicationName = "TVM"; + app_info.applicationVersion = 0; + app_info.pEngineName = ""; + app_info.engineVersion = 0; + app_info.apiVersion = api_version; + + VkInstanceCreateInfo inst_info; + inst_info.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO; + inst_info.pNext = nullptr; + inst_info.flags = 0; + inst_info.pApplicationInfo = &app_info; + inst_info.enabledLayerCount = layers.size(); + inst_info.ppEnabledLayerNames = layers.data(); + inst_info.enabledExtensionCount = enabled_extensions_.size(); + inst_info.ppEnabledExtensionNames = enabled_extensions_.data(); + + VULKAN_CALL(vkCreateInstance(&inst_info, nullptr, &instance_)); + } +} + +VulkanInstance::~VulkanInstance() { + if (instance_) { + vkDestroyInstance(instance_, nullptr); + } +} + +VulkanInstance::VulkanInstance(VulkanInstance&& other) { do_swap(std::move(other)); } + +VulkanInstance& VulkanInstance::operator=(VulkanInstance&& other) { + do_swap(std::move(other)); + return *this; +} + +void VulkanInstance::do_swap(VulkanInstance&& other) { + if (this == &other) { + return; + } + + std::swap(enabled_extensions_, other.enabled_extensions_); + std::swap(instance_, other.instance_); +} + +bool VulkanInstance::HasExtension(const char* query) const { + return std::any_of(enabled_extensions_.begin(), enabled_extensions_.end(), + [&](const char* extension) { return std::strcmp(query, extension) == 0; }); +} + +std::vector VulkanInstance::GetPhysicalDevices() const { + uint32_t device_count = 0; + VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &device_count, nullptr)); + std::vector devices(device_count); + VULKAN_CALL(vkEnumeratePhysicalDevices(instance_, &device_count, devices.data())); + return devices; +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_instance.h b/src/runtime/vulkan/vulkan_instance.h new file mode 100644 index 000000000000..06016d8f0aea --- /dev/null +++ b/src/runtime/vulkan/vulkan_instance.h @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_VULKAN_VULKAN_INSTANCE_H_ +#define TVM_RUNTIME_VULKAN_VULKAN_INSTANCE_H_ + +#include + +#include "vulkan/vulkan_core.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +class VulkanInstance { + public: + VulkanInstance(); + ~VulkanInstance(); + + // Allow move assignment/construction + VulkanInstance(VulkanInstance&&); + VulkanInstance& operator=(VulkanInstance&&); + + // Forbid copy assignment/construction + VulkanInstance(const VulkanInstance&) = delete; + VulkanInstance& operator=(const VulkanInstance&) = delete; + + /*! \brief Expose the internal VkInstance + * + * Allows the managed class to be passed to Vulkan APIs as if it + * were the VkInstance handler itself. + */ + operator VkInstance() const { return instance_; } + + /*! \brief Checks if the device has an extension enabled + * + * Returns true if the device was initialized with the extension + * given. + * + * \param query The name of the extension to check. + */ + bool HasExtension(const char* query) const; + + /*! \brief Return all accessible physical devices + * + * Wrapper around vkEnumeratePhysicalDevices. + */ + std::vector GetPhysicalDevices() const; + + private: + /*! \brief Helper function for move assignment/construction + * + * Named "do_swap" instead of "swap" because otherwise cpplint.py + * thinks that it needs the header include. + */ + void do_swap(VulkanInstance&& other); + + /*! \brief Extensions enabled for this instance + * + * Based on supported extensions queried through + * vkEnumerateInstanceExtensionProperties, prior to creating + * instance_. Contains only statically allocated string literals, + * no cleanup required. + */ + std::vector enabled_extensions_; + + //! \brief The Vulkan API instance handle + VkInstance instance_{nullptr}; +}; + +} // namespace vulkan +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_VULKAN_VULKAN_INSTANCE_H_ diff --git a/src/runtime/vulkan/vulkan_stream.cc b/src/runtime/vulkan/vulkan_stream.cc index fee390ad7e45..9784ee78503d 100644 --- a/src/runtime/vulkan/vulkan_stream.cc +++ b/src/runtime/vulkan/vulkan_stream.cc @@ -23,15 +23,15 @@ namespace tvm { namespace runtime { namespace vulkan { -VulkanStream::VulkanStream(const VulkanContext* vctx) - : vctx_(vctx), state_(new VulkanStreamState()) { +VulkanStream::VulkanStream(const VulkanDevice* device) + : device_(device), state_(new VulkanStreamState()) { // create command pool VkCommandPoolCreateInfo cmd_pool_cinfo; cmd_pool_cinfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO; cmd_pool_cinfo.pNext = nullptr; cmd_pool_cinfo.flags = VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT; - cmd_pool_cinfo.queueFamilyIndex = vctx_->queue_family_index; - VULKAN_CALL(vkCreateCommandPool(vctx_->device, &cmd_pool_cinfo, nullptr, &cmd_pool_)); + cmd_pool_cinfo.queueFamilyIndex = device_->queue_family_index; + VULKAN_CALL(vkCreateCommandPool(*device_, &cmd_pool_cinfo, nullptr, &cmd_pool_)); VkCommandBufferAllocateInfo buffer_alloc_info; buffer_alloc_info.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO; @@ -39,13 +39,13 @@ VulkanStream::VulkanStream(const VulkanContext* vctx) buffer_alloc_info.commandPool = cmd_pool_; buffer_alloc_info.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY; buffer_alloc_info.commandBufferCount = 1; - VULKAN_CALL(vkAllocateCommandBuffers(vctx_->device, &buffer_alloc_info, &(state_->cmd_buffer_))); + VULKAN_CALL(vkAllocateCommandBuffers(*device_, &buffer_alloc_info, &(state_->cmd_buffer_))); VkFenceCreateInfo fence_cinfo; fence_cinfo.sType = VK_STRUCTURE_TYPE_FENCE_CREATE_INFO; fence_cinfo.pNext = nullptr; fence_cinfo.flags = 0; // VK_FENCE_CREATE_SIGNALED_BIT; - VULKAN_CALL(vkCreateFence(vctx_->device, &fence_cinfo, nullptr, &(state_->fence_))); + VULKAN_CALL(vkCreateFence(*device_, &fence_cinfo, nullptr, &(state_->fence_))); VkCommandBufferBeginInfo cb_begin; cb_begin.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO; @@ -56,12 +56,12 @@ VulkanStream::VulkanStream(const VulkanContext* vctx) } VulkanStream::~VulkanStream() { - vkDestroyFence(vctx_->device, state_->fence_, nullptr); - vkDestroyCommandPool(vctx_->device, cmd_pool_, nullptr); + vkDestroyFence(*device_, state_->fence_, nullptr); + vkDestroyCommandPool(*device_, cmd_pool_, nullptr); } void VulkanStream::Launch(const std::function& kernel) { - if (vctx_->UseImmediate()) { + if (device_->UseImmediate()) { kernel(state_.get()); } else { deferred_kernels_.push_back(kernel); @@ -71,7 +71,7 @@ void VulkanStream::Launch(const std::function& kernel) void VulkanStream::LaunchDeferred(const std::function& deferred_initializer, const std::function& deferred_kernel, const VulkanStreamToken& deferred_token) { - ICHECK(!vctx_->UseImmediate()); + ICHECK(!device_->UseImmediate()); // If the new kernel uses the same descriptor set as one of the // kernels already in the command buffer, we need to synchronize @@ -107,7 +107,7 @@ void VulkanStream::LaunchDeferred(const std::function& deferred_initiali } void VulkanStream::Synchronize() { - if (!vctx_->UseImmediate()) { + if (!device_->UseImmediate()) { for (const auto& deferred_kernel : deferred_kernels_) { deferred_kernel(state_.get()); } @@ -130,20 +130,16 @@ void VulkanStream::Synchronize() { cb_submit.signalSemaphoreCount = 0; cb_submit.pSignalSemaphores = nullptr; - { - // Multiple streams (on different threads) use the same VulkanContext - // instance, so we need to externally synchronize accesses. - std::lock_guard g(*(vctx_->queue_mutex)); - VULKAN_CALL(vkQueueSubmit(vctx_->queue, 1, &cb_submit, state_->fence_)); - } + device_->QueueSubmit(cb_submit, state_->fence_); + uint64_t timeout = 1UL << 30UL; VkResult res; do { - res = vkWaitForFences(vctx_->device, 1, &(state_->fence_), 0, timeout); + res = vkWaitForFences(*device_, 1, &(state_->fence_), 0, timeout); } while (res == VK_TIMEOUT); VULKAN_CHECK_ERROR(res); VULKAN_CALL(vkResetCommandBuffer(state_->cmd_buffer_, 0)); - VULKAN_CALL(vkResetFences(vctx_->device, 1, &(state_->fence_))); + VULKAN_CALL(vkResetFences(*device_, 1, &(state_->fence_))); // Re-initialize the command buffer VkCommandBufferBeginInfo cb_begin; diff --git a/src/runtime/vulkan/vulkan_stream.h b/src/runtime/vulkan/vulkan_stream.h index f328262a8b10..ff02be4c5c35 100644 --- a/src/runtime/vulkan/vulkan_stream.h +++ b/src/runtime/vulkan/vulkan_stream.h @@ -26,7 +26,7 @@ #include #include "vulkan_common.h" -#include "vulkan_context.h" +#include "vulkan_device.h" namespace tvm { namespace runtime { @@ -62,13 +62,13 @@ struct VulkanStreamToken { */ class VulkanStream { public: - explicit VulkanStream(const VulkanContext* vctx); + explicit VulkanStream(const VulkanDevice* device); ~VulkanStream(); /*! \brief Push the kernel onto the stream's command buffer. * - * If context.UseImmediate() is true, the kernel is executed + * If device.UseImmediate() is true, the kernel is executed * immediately to update the command buffer. Otherwise, it is added * to the list of deferred updates to be pushed onto the command * buffer. @@ -80,7 +80,7 @@ class VulkanStream { /*! \brief Push the kernel onto the stream's command buffer. * - * Can only be called if context.UseImmediate() is false. The + * Can only be called if device.UseImmediate() is false. The * kernel is delayed, and isn't pushed to the command buffer until * all kernels are collected. * @@ -102,7 +102,7 @@ class VulkanStream { void Synchronize(); private: - const VulkanContext* vctx_; + const VulkanDevice* device_; std::unique_ptr state_; // An index of deferred tokens, allowing us to efficiently detect duplicated // deferred_initializer blocks. diff --git a/src/runtime/vulkan/vulkan_thread_entry.cc b/src/runtime/vulkan/vulkan_thread_entry.cc index e7e01b9c2d06..1e2815f31146 100644 --- a/src/runtime/vulkan/vulkan_thread_entry.cc +++ b/src/runtime/vulkan/vulkan_thread_entry.cc @@ -43,10 +43,10 @@ VulkanThreadEntry::~VulkanThreadEntry() { VulkanThreadEntry* VulkanThreadEntry::ThreadLocal() { return VulkanThreadStore::Get(); } void VulkanThreadEntry::AllocateUniformBuffer(int device_id, size_t size) { - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + const auto& device = VulkanDeviceAPI::Global()->device(device_id); auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; - auto info = MakeBufferCreateInfo(vctx, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT); - auto mem_type_index = FindMemoryType(vctx, info, prop); + auto info = MakeBufferCreateInfo(device, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT); + auto mem_type_index = FindMemoryType(device, info, prop); GetOrAllocate(device_id, size, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, mem_type_index, &uniform_buffers_, true); } @@ -59,9 +59,9 @@ VulkanUniformBuffer* VulkanThreadEntry::GetUniformBuffer(int device_id, size_t s } VulkanStagingBuffer* VulkanThreadEntry::StagingBuffer(int device_id, size_t size) { - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + const auto& device = VulkanDeviceAPI::Global()->device(device_id); auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT; - return GetOrAllocate(device_id, size, usage, vctx.staging_mtype_index, &staging_buffers_); + return GetOrAllocate(device_id, size, usage, device.staging_mtype_index, &staging_buffers_); } VulkanThreadEntry::VulkanThreadEntry() @@ -74,7 +74,7 @@ VulkanThreadEntry::VulkanThreadEntry() VulkanStream* VulkanThreadEntry::Stream(size_t device_id) { if (!streams_[device_id]) { streams_[device_id] = std::unique_ptr( - new VulkanStream(&VulkanDeviceAPI::Global()->context(device_id))); + new VulkanStream(&VulkanDeviceAPI::Global()->device(device_id))); } return streams_[device_id].get(); } diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index 2ee46b7db80c..86c3ffe23f7d 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -47,7 +47,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const { int device_id = VulkanThreadEntry::ThreadLocal()->device.device_id; ICHECK_LT(device_id, kVulkanMaxNumDevice); - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + const auto& device = VulkanDeviceAPI::Global()->device(device_id); if (!scache_[device_id]) { scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_); } @@ -73,12 +73,12 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, binfo.range = VK_WHOLE_SIZE; descriptor_buffers.push_back(binfo); } - if (vctx.UseImmediate()) { + if (device.UseImmediate()) { // Can safely capture by reference as this lambda is immediately executed on the calling thread. VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) { vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); ICHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE); - vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR( + device.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR( state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0, descriptor_buffers.data()); @@ -107,7 +107,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, // Otherwise, the more expensive deferred path. std::vector pack_args_storage(pack_args, pack_args + num_pack_args_); - const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() { + const auto& deferred_initializer = [&device, pipeline, descriptor_buffers]() { std::vector write_descriptor_sets; write_descriptor_sets.resize(descriptor_buffers.size()); for (size_t i = 0; i < write_descriptor_sets.size(); i++) { @@ -128,8 +128,8 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; } } - vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(), - 0, 0); + vkUpdateDescriptorSets(device, write_descriptor_sets.size(), write_descriptor_sets.data(), 0, + 0); }; const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage, nbytes_scalars, device_id](VulkanStreamState* state) { @@ -174,17 +174,17 @@ VulkanModuleNode::~VulkanModuleNode() { for (auto& kv : ecache_[device_id]) { auto& pe = kv.second; ICHECK(pe); - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + const auto& device = VulkanDeviceAPI::Global()->device(device_id); if (pe->descriptor_update_template != VK_NULL_HANDLE) { - vctx.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR( - vctx.device, pe->descriptor_update_template, nullptr); + device.descriptor_template_khr_functions->vkDestroyDescriptorUpdateTemplateKHR( + device, pe->descriptor_update_template, nullptr); } - vkDestroyPipeline(vctx.device, pe->pipeline, nullptr); - vkDestroyPipelineLayout(vctx.device, pe->pipeline_layout, nullptr); - vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr); - vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr); - vkDestroyShaderModule(vctx.device, pe->shader, nullptr); + vkDestroyPipeline(device, pe->pipeline, nullptr); + vkDestroyPipelineLayout(device, pe->pipeline_layout, nullptr); + vkDestroyDescriptorPool(device, pe->descriptor_pool, nullptr); + vkDestroyDescriptorSetLayout(device, pe->descriptor_set_layout, nullptr); + vkDestroyShaderModule(device, pe->shader, nullptr); } } } @@ -206,7 +206,7 @@ PackedFunc VulkanModuleNode::GetFunction(const std::string& name, std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, const std::string& func_name, size_t num_pack_args) { - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); + const auto& device = VulkanDeviceAPI::Global()->device(device_id); std::lock_guard lock(mutex_); const auto& cp = ecache_[device_id][func_name]; if (cp) { @@ -226,7 +226,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, shader_cinfo.flags = 0; shader_cinfo.codeSize = data.size() * sizeof(uint32_t); shader_cinfo.pCode = data.data(); - VULKAN_CALL(vkCreateShaderModule(vctx.device, &shader_cinfo, nullptr, &(pe->shader))); + VULKAN_CALL(vkCreateShaderModule(device, &shader_cinfo, nullptr, &(pe->shader))); } std::vector arg_binding; std::vector arg_template; @@ -294,16 +294,16 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; descrip_cinfo.pNext = nullptr; descrip_cinfo.flags = 0; - if (vctx.UseImmediate()) { + if (device.UseImmediate()) { descrip_cinfo.flags |= VK_DESCRIPTOR_SET_LAYOUT_CREATE_PUSH_DESCRIPTOR_BIT_KHR; } descrip_cinfo.bindingCount = arg_binding.size(); descrip_cinfo.pBindings = arg_binding.data(); - VULKAN_CALL(vkCreateDescriptorSetLayout(vctx.device, &descrip_cinfo, nullptr, - &(pe->descriptor_set_layout))); + VULKAN_CALL( + vkCreateDescriptorSetLayout(device, &descrip_cinfo, nullptr, &(pe->descriptor_set_layout))); } - if (!vctx.UseImmediate()) { + if (!device.UseImmediate()) { VkDescriptorPoolCreateInfo descrip_pool_cinfo; descrip_pool_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO; descrip_pool_cinfo.pNext = nullptr; @@ -312,7 +312,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, descrip_pool_cinfo.poolSizeCount = descriptor_set_pool_sizes.size(); descrip_pool_cinfo.pPoolSizes = descriptor_set_pool_sizes.data(); VULKAN_CALL( - vkCreateDescriptorPool(vctx.device, &descrip_pool_cinfo, nullptr, &(pe->descriptor_pool))); + vkCreateDescriptorPool(device, &descrip_pool_cinfo, nullptr, &(pe->descriptor_pool))); VkDescriptorSetAllocateInfo alloc_info; alloc_info.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO; @@ -320,7 +320,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, alloc_info.descriptorPool = pe->descriptor_pool; alloc_info.descriptorSetCount = 1; alloc_info.pSetLayouts = &(pe->descriptor_set_layout); - VULKAN_CALL(vkAllocateDescriptorSets(vctx.device, &alloc_info, &(pe->descriptor_set))); + VULKAN_CALL(vkAllocateDescriptorSets(device, &alloc_info, &(pe->descriptor_set))); } VkPushConstantRange crange; @@ -338,13 +338,19 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, if (0 < nbytes_scalars && !pe->use_ubo) { playout_cinfo.pushConstantRangeCount = 1; playout_cinfo.pPushConstantRanges = &crange; - ICHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize); + ICHECK_LE(crange.size, device.device_properties.max_push_constants_size) + << "The Vulkan shader uses " << crange.size + << " bytes of push constants, but the device only supports " + << device.device_properties.max_push_constants_size << "bytes. " + << "Please rebuild the shader using a smaller limit on push constants size " + << "by passing -max_push_constants_size=N in the Target string, " + << "or pass -from_device=0 to query all device parameters."; } else { playout_cinfo.pushConstantRangeCount = 0; playout_cinfo.pPushConstantRanges = nullptr; } - VULKAN_CALL(vkCreatePipelineLayout(vctx.device, &playout_cinfo, nullptr, &(pe->pipeline_layout))); + VULKAN_CALL(vkCreatePipelineLayout(device, &playout_cinfo, nullptr, &(pe->pipeline_layout))); VkComputePipelineCreateInfo pipeline_cinfo; pipeline_cinfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; @@ -360,10 +366,10 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, pipeline_cinfo.layout = pe->pipeline_layout; pipeline_cinfo.basePipelineHandle = VK_NULL_HANDLE; pipeline_cinfo.basePipelineIndex = 0; - VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, + VULKAN_CALL(vkCreateComputePipelines(device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, &(pe->pipeline))); - if (vctx.UseImmediate()) { + if (device.UseImmediate()) { VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo; descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR; descrip_template_cinfo.pNext = 0; @@ -375,8 +381,8 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, descrip_template_cinfo.pipelineBindPoint = VK_PIPELINE_BIND_POINT_COMPUTE; descrip_template_cinfo.pipelineLayout = pe->pipeline_layout; descrip_template_cinfo.set = 0; - VULKAN_CALL(vctx.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR( - vctx.device, &descrip_template_cinfo, 0, &(pe->descriptor_update_template))); + VULKAN_CALL(device.descriptor_template_khr_functions->vkCreateDescriptorUpdateTemplateKHR( + device, &descrip_template_cinfo, 0, &(pe->descriptor_update_template))); } ecache_[device_id][func_name] = pe; return pe; diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/runtime/vulkan/vulkan_wrapped_func.h index be5f385316ea..a174f22eba59 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/runtime/vulkan/vulkan_wrapped_func.h @@ -32,7 +32,7 @@ #include "../thread_storage_scope.h" #include "vulkan/vulkan_core.h" #include "vulkan_common.h" -#include "vulkan_context.h" +#include "vulkan_device.h" #include "vulkan_shader.h" namespace tvm { @@ -40,7 +40,7 @@ namespace runtime { namespace vulkan { struct VulkanPipeline { - VulkanContext* vctx_{nullptr}; + VulkanDevice* device{nullptr}; VkShaderModule shader{VK_NULL_HANDLE}; VkDescriptorSetLayout descriptor_set_layout{VK_NULL_HANDLE}; VkDescriptorPool descriptor_pool{VK_NULL_HANDLE};