From 59fdf110765087bd8c39c2a20739bc54e901b029 Mon Sep 17 00:00:00 2001 From: masahi Date: Mon, 12 Apr 2021 04:08:24 +0900 Subject: [PATCH] Revert "[Vulkan] Support uniform buffer object for passing many scalar arguments (#7717)" (#7821) This reverts commit 5bc1cec4c4acf0a54889227c1d19a6b65b6803c2. --- src/runtime/vulkan/vulkan.cc | 268 ++++++++++------------------- src/runtime/vulkan/vulkan_common.h | 3 - src/target/spirv/codegen_spirv.cc | 23 +-- src/target/spirv/ir_builder.cc | 31 +--- src/target/spirv/ir_builder.h | 32 +--- 5 files changed, 102 insertions(+), 255 deletions(-) diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index c8a0858ec1bc..5cd4812f41c4 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -91,11 +91,6 @@ struct VulkanBuffer { VkDeviceMemory memory{VK_NULL_HANDLE}; }; -struct UniformBuffer { - VulkanBuffer* vk_buf; - void* host_buf; -}; - struct VulkanPipeline { VulkanContext* vctx_{nullptr}; VkShaderModule shader{VK_NULL_HANDLE}; @@ -105,105 +100,10 @@ struct VulkanPipeline { VkPipelineLayout pipeline_layout{VK_NULL_HANDLE}; VkPipeline pipeline{VK_NULL_HANDLE}; VkDescriptorUpdateTemplateKHR descriptor_update_template{VK_NULL_HANDLE}; - UniformBuffer ubo; }; typedef dmlc::ThreadLocalStore VulkanThreadStore; -uint32_t FindMemoryType(VkDevice logical_device, VkPhysicalDevice phy_device, VkBuffer buffer, - VkMemoryPropertyFlags req_prop) { - VkMemoryRequirements mem_reqs; - vkGetBufferMemoryRequirements(logical_device, buffer, &mem_reqs); - uint32_t type_bits = mem_reqs.memoryTypeBits; - VkPhysicalDeviceMemoryProperties phy_mem_prop; - vkGetPhysicalDeviceMemoryProperties(phy_device, &phy_mem_prop); - for (uint32_t i = 0; i < phy_mem_prop.memoryTypeCount; i++) { - if ((type_bits & 1) == 1 && - (phy_mem_prop.memoryTypes[i].propertyFlags & req_prop) == req_prop) { - return i; - } - type_bits >>= 1; - } - LOG(FATAL) << "Requested memory type not found"; - return 0; -} - -VulkanBuffer* CreateBuffer(const VulkanContext& vctx, size_t nbytes, VkBufferUsageFlags usage) { - VkBufferCreateInfo info; - info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; - info.pNext = nullptr; - info.flags = 0; - info.size = nbytes; - info.queueFamilyIndexCount = 1; - info.pQueueFamilyIndices = &(vctx.queue_family_index); - info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; - info.usage = usage; - // create buffer - VkBuffer buffer; - VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); - - uint32_t mem_type_index = vctx.compute_mtype_index; - - if (usage & VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT) { - // Find a memory type that supports UBO - auto prop = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT; - mem_type_index = FindMemoryType(vctx.device, vctx.phy_device, buffer, prop); - } - - // bind to memory - bool dedicated_allocation = false; - VkMemoryRequirements2KHR req2; - - if (vctx.get_buffer_memory_requirements_2_functions) { - VkBufferMemoryRequirementsInfo2KHR req_info2; - req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR; - req_info2.pNext = 0; - req_info2.buffer = buffer; - - req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR; - req2.pNext = 0; - - VkMemoryDedicatedRequirementsKHR dedicated_req; - dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; - dedicated_req.pNext = 0; - req2.pNext = &dedicated_req; - - vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( - vctx.device, &req_info2, &req2); - dedicated_allocation = - dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation; - } - - VkDeviceMemory memory; - if (!dedicated_allocation) { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = nbytes; - minfo.memoryTypeIndex = mem_type_index; - VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); - } else { - VkMemoryAllocateInfo minfo; - minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; - minfo.pNext = nullptr; - minfo.allocationSize = req2.memoryRequirements.size; - minfo.memoryTypeIndex = mem_type_index; - - VkMemoryDedicatedAllocateInfoKHR mdinfo; - mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR; - mdinfo.pNext = 0; - mdinfo.image = 0; - mdinfo.buffer = buffer; - minfo.pNext = &mdinfo; - VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); - } - VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0)); - VulkanBuffer* pbuf = new VulkanBuffer(); - pbuf->memory = memory; - pbuf->buffer = buffer; - return pbuf; -} - class VulkanDeviceAPI final : public DeviceAPI { public: VulkanDeviceAPI(); @@ -224,9 +124,70 @@ class VulkanDeviceAPI final : public DeviceAPI { nbytes = 1; } const auto& vctx = context(dev.device_id); - auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | + VkBufferCreateInfo info; + info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; + info.pNext = nullptr; + info.flags = 0; + info.size = nbytes; + info.queueFamilyIndexCount = 1; + info.pQueueFamilyIndices = &(vctx.queue_family_index); + info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + info.usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - return CreateBuffer(vctx, nbytes, usage); + // create buffer + VkBuffer buffer; + VULKAN_CALL(vkCreateBuffer(vctx.device, &info, nullptr, &buffer)); + // bind to memory + VkBufferMemoryRequirementsInfo2KHR req_info2; + req_info2.sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_REQUIREMENTS_INFO_2_KHR; + req_info2.pNext = 0; + req_info2.buffer = buffer; + + VkMemoryRequirements2KHR req2; + req2.sType = VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR; + req2.pNext = 0; + + VkMemoryDedicatedRequirementsKHR dedicated_req; + dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; + dedicated_req.pNext = 0; + req2.pNext = &dedicated_req; + + bool dedicated_allocation = false; + if (vctx.get_buffer_memory_requirements_2_functions) { + vctx.get_buffer_memory_requirements_2_functions->vkGetBufferMemoryRequirements2KHR( + vctx.device, &req_info2, &req2); + dedicated_allocation = + dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation; + } + + VkDeviceMemory memory; + if (!dedicated_allocation) { + VkMemoryAllocateInfo minfo; + minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + minfo.pNext = nullptr; + minfo.allocationSize = nbytes; + minfo.memoryTypeIndex = vctx.compute_mtype_index; + VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); + } else { + VkMemoryAllocateInfo minfo; + minfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO; + minfo.pNext = nullptr; + minfo.allocationSize = req2.memoryRequirements.size; + minfo.memoryTypeIndex = vctx.compute_mtype_index; + + VkMemoryDedicatedAllocateInfoKHR mdinfo; + mdinfo.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR; + mdinfo.pNext = 0; + mdinfo.image = 0; + mdinfo.buffer = buffer; + minfo.pNext = &mdinfo; + VULKAN_CALL(vkAllocateMemory(vctx.device, &minfo, nullptr, &memory)); + } + VULKAN_CALL(vkBindBufferMemory(vctx.device, buffer, memory, 0)); + VulkanBuffer* pbuf = new VulkanBuffer(); + pbuf->memory = memory; + pbuf->buffer = buffer; + return pbuf; } void FreeDataSpace(Device dev, void* ptr) final { @@ -786,7 +747,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { public: explicit VulkanModuleNode(std::unordered_map smap, std::unordered_map fmap, std::string source) - : smap_(smap), fmap_(fmap), source_(source), max_push_constants_(GetMaxPushConstantsSize()) {} + : smap_(smap), fmap_(fmap), source_(source) {} const char* type_key() const final { return "vulkan"; } @@ -820,13 +781,6 @@ class VulkanModuleNode final : public runtime::ModuleNode { vkDestroyDescriptorPool(vctx.device, pe->descriptor_pool, nullptr); vkDestroyDescriptorSetLayout(vctx.device, pe->descriptor_set_layout, nullptr); vkDestroyShaderModule(vctx.device, pe->shader, nullptr); - // UBO - if (pe->ubo.vk_buf) { - vkUnmapMemory(vctx.device, pe->ubo.vk_buf->memory); - vkDestroyBuffer(vctx.device, pe->ubo.vk_buf->buffer, nullptr); - vkFreeMemory(vctx.device, pe->ubo.vk_buf->memory, nullptr); - delete pe->ubo.vk_buf; - } } } } @@ -858,35 +812,30 @@ class VulkanModuleNode final : public runtime::ModuleNode { std::vector arg_template; uint32_t num_pod = 0, num_buffer = 0; - auto push_arg_info = [&arg_binding, &arg_template](uint32_t binding, - VkDescriptorType desc_type) { - { - VkDescriptorSetLayoutBinding bd; - bd.binding = binding; - bd.descriptorType = desc_type; - bd.descriptorCount = 1; - bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; - bd.pImmutableSamplers = nullptr; - arg_binding.push_back(bd); - } - { - VkDescriptorUpdateTemplateEntryKHR tpl; - tpl.dstBinding = binding; - tpl.dstArrayElement = 0; - tpl.descriptorCount = 1; - tpl.descriptorType = desc_type; - tpl.offset = binding * sizeof(VkDescriptorBufferInfo); - tpl.stride = sizeof(VkDescriptorBufferInfo); - arg_template.push_back(tpl); - } - }; - { auto fit = fmap_.find(func_name); ICHECK(fit != fmap_.end()); for (DLDataType arg_type : fit->second.arg_types) { if (arg_type.code == kTVMOpaqueHandle) { - push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER); + { + VkDescriptorSetLayoutBinding bd; + bd.binding = num_buffer; + bd.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + bd.descriptorCount = 1; + bd.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; + bd.pImmutableSamplers = nullptr; + arg_binding.push_back(bd); + } + { + VkDescriptorUpdateTemplateEntryKHR tpl; + tpl.dstBinding = num_buffer; + tpl.dstArrayElement = 0; + tpl.descriptorCount = 1; + tpl.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + tpl.offset = num_buffer * sizeof(VkDescriptorBufferInfo); + tpl.stride = sizeof(VkDescriptorBufferInfo); + arg_template.push_back(tpl); + } ++num_buffer; } else { ++num_pod; @@ -894,11 +843,6 @@ class VulkanModuleNode final : public runtime::ModuleNode { } } - size_t nbytes_scalars = num_pod * sizeof(ArgUnion64); - if (nbytes_scalars > max_push_constants_) { - push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER); - } - { VkDescriptorSetLayoutCreateInfo descrip_cinfo; descrip_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO; @@ -950,7 +894,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { playout_cinfo.setLayoutCount = 1; playout_cinfo.pSetLayouts = &(pe->descriptor_set_layout); - if (0 < nbytes_scalars && nbytes_scalars <= max_push_constants_) { + if (num_pack_args != 0) { playout_cinfo.pushConstantRangeCount = 1; playout_cinfo.pPushConstantRanges = &crange; ICHECK_LE(crange.size, vctx.phy_device_prop.limits.maxPushConstantsSize); @@ -979,13 +923,6 @@ class VulkanModuleNode final : public runtime::ModuleNode { VULKAN_CALL(vkCreateComputePipelines(vctx.device, VK_NULL_HANDLE, 1, &pipeline_cinfo, nullptr, &(pe->pipeline))); - if (nbytes_scalars > max_push_constants_) { - // Allocate, bind and map UBO - UniformBuffer& ubo = pe->ubo; - ubo.vk_buf = CreateBuffer(vctx, nbytes_scalars, VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT); - vkMapMemory(vctx.device, ubo.vk_buf->memory, 0, nbytes_scalars, 0, &(ubo.host_buf)); - } - if (vctx.UseImmediate()) { VkDescriptorUpdateTemplateCreateInfoKHR descrip_template_cinfo; descrip_template_cinfo.sType = VK_STRUCTURE_TYPE_DESCRIPTOR_UPDATE_TEMPLATE_CREATE_INFO_KHR; @@ -1029,8 +966,6 @@ class VulkanModuleNode final : public runtime::ModuleNode { return source_; } - uint32_t MaxPushConstantsSize() const { return max_push_constants_; } - private: // function information table. std::unordered_map smap_; @@ -1040,8 +975,6 @@ class VulkanModuleNode final : public runtime::ModuleNode { std::string fmt_{"vulkan"}; // The source std::string source_; - // The maximum size of push constants in bytes - const uint32_t max_push_constants_; // Guards accesses to `ecache_` std::mutex mutex_; @@ -1143,17 +1076,6 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, binfo.range = VK_WHOLE_SIZE; descriptor_buffers[i] = binfo; } - const size_t nbytes_scalars = num_pack_args_ * sizeof(ArgUnion64); - bool use_ubo = num_pack_args_ != 0 && nbytes_scalars > m_->MaxPushConstantsSize(); - if (use_ubo) { - CHECK(pipeline->ubo.host_buf) << "The UBO host buffer is not allocated"; - memcpy(pipeline->ubo.host_buf, pack_args, nbytes_scalars); - VkDescriptorBufferInfo binfo; - binfo.buffer = pipeline->ubo.vk_buf->buffer; - binfo.offset = 0; - binfo.range = VK_WHOLE_SIZE; - descriptor_buffers.push_back(binfo); - } if (vctx.UseImmediate()) { // Can safely capture by reference as this lambda is immediately executed on the calling thread. VulkanThreadEntry::ThreadLocal()->Stream(device_id)->Launch([&](VulkanStreamState* state) { @@ -1162,7 +1084,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, vctx.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR( state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0, descriptor_buffers.data()); - if (num_pack_args_ > 0 && num_pack_args_ <= m_->MaxPushConstantsSize()) { + if (num_pack_args_ != 0) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64), pack_args); @@ -1183,7 +1105,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, // Otherwise, the more expensive deferred path. std::vector pack_args_storage(pack_args, pack_args + num_pack_args_); - const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers, use_ubo]() { + const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() { std::vector write_descriptor_sets; write_descriptor_sets.resize(descriptor_buffers.size()); for (size_t i = 0; i < write_descriptor_sets.size(); i++) { @@ -1193,26 +1115,20 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, write_descriptor_sets[i].dstBinding = i; write_descriptor_sets[i].dstArrayElement = 0; write_descriptor_sets[i].descriptorCount = 1; + write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; write_descriptor_sets[i].pImageInfo = 0; write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]); write_descriptor_sets[i].pTexelBufferView = 0; - - if (use_ubo && i == write_descriptor_sets.size() - 1) { - // The last binding is for UBO - write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; - } else { - write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; - } } vkUpdateDescriptorSets(vctx.device, write_descriptor_sets.size(), write_descriptor_sets.data(), 0, 0); }; - const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage](VulkanStreamState* state) { + const auto& deferred_kernel = [pipeline, wl, pack_args_storage](VulkanStreamState* state) { vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); vkCmdBindDescriptorSets(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline_layout, 0, 1, &(pipeline->descriptor_set), 0, nullptr); - if (num_pack_args_ > 0 && num_pack_args_ <= m_->MaxPushConstantsSize()) { + if (pack_args_storage.size() != 0) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, pack_args_storage.size() * sizeof(ArgUnion64), pack_args_storage.data()); @@ -1267,12 +1183,6 @@ Module VulkanModuleLoadBinary(void* strm) { return VulkanModuleCreate(smap, fmap, ""); } -uint32_t GetMaxPushConstantsSize() { - int device_id = VulkanThreadEntry::ThreadLocal()->device.device_id; - const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); - return vctx.phy_device_prop.limits.maxPushConstantsSize; -} - TVM_REGISTER_GLOBAL("runtime.module.loadfile_vulkan").set_body_typed(VulkanModuleLoadFile); TVM_REGISTER_GLOBAL("runtime.module.loadbinary_vulkan").set_body_typed(VulkanModuleLoadBinary); diff --git a/src/runtime/vulkan/vulkan_common.h b/src/runtime/vulkan/vulkan_common.h index e94a9fe7fa90..3083ba6f9ce4 100644 --- a/src/runtime/vulkan/vulkan_common.h +++ b/src/runtime/vulkan/vulkan_common.h @@ -142,9 +142,6 @@ struct VulkanContext { bool UseImmediate() const { return descriptor_template_khr_functions.get() != nullptr; } }; -/*! \brief returns maximum push constant sizes in bytes for the target platform */ -uint32_t GetMaxPushConstantsSize(); - } // namespace vulkan } // namespace runtime } // namespace tvm diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 4d55f4c49a5f..24608ebc93f4 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -30,9 +30,6 @@ #include -#include "../../runtime/pack_args.h" -#include "../../runtime/vulkan/vulkan_common.h" - namespace tvm { namespace codegen { @@ -69,26 +66,16 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: spirv::Value func_ptr = builder_->NewFunction(); builder_->StartFunction(func_ptr); + // All the POD arguments are passed in through PushConstant if (pod_args.size() != 0) { std::vector value_types; for (size_t i = 0; i < pod_args.size(); ++i) { value_types.push_back(builder_->GetSType(pod_args[i].dtype())); } - const auto max_push_constants = runtime::vulkan::GetMaxPushConstantsSize(); - if (pod_args.size() * sizeof(runtime::ArgUnion64) <= max_push_constants) { - spirv::Value ptr = builder_->DeclarePushConstant(value_types); - for (size_t i = 0; i < pod_args.size(); ++i) { - spirv::Value value = - builder_->GetPushConstant(ptr, value_types[i], static_cast(i)); - var_map_[pod_args[i].get()] = value; - } - } else { - // If we need to pass more arguments than push constants could handle, we use UBO. - spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, num_buffer); - for (size_t i = 0; i < pod_args.size(); ++i) { - spirv::Value value = builder_->GetUniform(ptr, value_types[i], static_cast(i)); - var_map_[pod_args[i].get()] = value; - } + spirv::Value ptr = builder_->DeclarePushConstant(value_types); + for (size_t i = 0; i < pod_args.size(); ++i) { + spirv::Value value = builder_->GetPushConstant(ptr, value_types[i], static_cast(i)); + var_map_[pod_args[i].get()] = value; } } this->VisitStmt(f->body); diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index cd48c93530ec..5a1457387ae5 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -205,8 +205,8 @@ Value IRBuilder::BufferArgument(const SType& value_type, uint32_t descriptor_set return val; } -Value IRBuilder::DeclareStorageVariable(const std::vector& value_types, - spv::StorageClass storage_class, ValueKind kind) { +Value IRBuilder::DeclarePushConstant(const std::vector& value_types) { + ICHECK_EQ(push_const_.id, 0); SType struct_type; struct_type.id = id_counter_++; struct_type.type = DataType::Handle(); @@ -226,26 +226,22 @@ Value IRBuilder::DeclareStorageVariable(const std::vector& value_types, ICHECK_EQ(nbits % 8, 0); uint32_t bytes = (nbits / 8); if (t.bits() == 32) { - // In our Vulkan runtime, each scalar argument always occupies 64 bit. + // In our Vulkan runtime, each push constant always occupies 64 bit. offset += bytes * 2; } else { ICHECK_EQ(t.bits(), 64); offset += bytes; } } + // Decorate push constants as UBO this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock); - SType ptr_type = GetPointerType(struct_type, storage_class); - Value val = NewValue(ptr_type, kind); - ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, storage_class).Commit(&global_); + SType ptr_type = GetPointerType(struct_type, spv::StorageClassPushConstant); + Value val = NewValue(ptr_type, kPushConstantPtr); + ib_.Begin(spv::OpVariable).AddSeq(ptr_type, val, spv::StorageClassPushConstant).Commit(&global_); return val; } -Value IRBuilder::DeclarePushConstant(const std::vector& value_types) { - ICHECK_EQ(push_const_.id, 0); - return DeclareStorageVariable(value_types, spv::StorageClassPushConstant, kPushConstantPtr); -} - Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index) { SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassPushConstant); Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const, @@ -253,19 +249,6 @@ Value IRBuilder::GetPushConstant(Value ptr_push_const, const SType& v_type, uint return this->MakeValue(spv::OpLoad, v_type, ptr); } -Value IRBuilder::DeclareUniformBuffer(const std::vector& value_types, uint32_t binding) { - Value val = DeclareStorageVariable(value_types, spv::StorageClassUniform, kUniformPtr); - this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); - return val; -} - -Value IRBuilder::GetUniform(Value ptr_push_const, const SType& v_type, uint32_t index) { - SType ptr_vtype = this->GetPointerType(v_type, spv::StorageClassUniform); - Value ptr = this->MakeValue(spv::OpAccessChain, ptr_vtype, ptr_push_const, - IntImm(t_int32_, static_cast(index))); - return this->MakeValue(spv::OpLoad, v_type, ptr); -} - Value IRBuilder::NewFunction() { return NewValue(t_void_func_, kFunction); } void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) { diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index 05a2bc631743..8a08048e1955 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -60,8 +60,7 @@ enum ValueKind { kStructArrayPtr, kPushConstantPtr, kFunction, - kExtInst, - kUniformPtr + kExtInst }; /*! \brief Represent the SPIRV Value */ @@ -474,7 +473,6 @@ class IRBuilder { * \param The argument type. */ Value BufferArgument(const SType& value_type, uint32_t descriptor_set, uint32_t binding); - /*! * \brief Declare POD arguments through push constants. * @@ -490,23 +488,6 @@ class IRBuilder { * \return the value of push constant */ Value GetPushConstant(Value ptr_push_const, const SType& v_type, uint32_t index); - - /*! - * \brief Declare POD arguments through uniform buffer. - * - * \note Only call this function once! - * \param value_types The values in the uniform buffer - * \param binding The binding locaiton in descriptor set - * \return reference to self. - */ - Value DeclareUniformBuffer(const std::vector& value_types, uint32_t binding); - /*! - * \brief Get i-th uniform constant - * \param v_type The value type - * \param index The uniform index - * \return the value of uniform constant - */ - Value GetUniform(Value ptr_ubo, const SType& v_type, uint32_t index); /*! * \brief Declare a new function * \return The created function ID. @@ -574,17 +555,6 @@ class IRBuilder { val.flag = flag; return val; } - - /*! - * \brief The common function to declare push constants or uniform buffer - * \param value_types The values in the push constants or uniform buffer - * \param storage_class An enum defined by SPIR-V indicating push constant or uniform - * \param kind An enum indicating push constant or uniform - * \return The created new label - */ - Value DeclareStorageVariable(const std::vector& value_types, - spv::StorageClass storage_class, ValueKind kind); - // get constant given value encoded in uint64_t Value GetConst_(const SType& dtype, const uint64_t* pvalue); // declare type