diff --git a/src/runtime/vulkan/vulkan_stream.h b/src/runtime/vulkan/vulkan_stream.h index 512f44fa8ad9..1a24d2873a60 100644 --- a/src/runtime/vulkan/vulkan_stream.h +++ b/src/runtime/vulkan/vulkan_stream.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "vulkan_common.h" @@ -98,8 +99,10 @@ class VulkanStream { // It is invalid to schedule this instance on the current stream if we already // have a matching descriptor set and a non-matching buffer set. - if (std::any_of(deferred_tokens_.begin(), deferred_tokens_.end(), + if (std::any_of(deferred_tokens_[deferred_token.descriptor_set_].begin(), + deferred_tokens_[deferred_token.descriptor_set_].end(), [&](const VulkanStreamToken& token) { + DCHECK(token.descriptor_set_ == deferred_token.descriptor_set_); return token.descriptor_set_ == deferred_token.descriptor_set_ && token.buffers_ != deferred_token.buffers_; })) { @@ -107,9 +110,10 @@ class VulkanStream { } // It is unnecessary to invoke our initializer if we have a matching token. - if (!std::any_of(deferred_tokens_.begin(), deferred_tokens_.end(), + if (!std::any_of(deferred_tokens_[deferred_token.descriptor_set_].begin(), + deferred_tokens_[deferred_token.descriptor_set_].end(), [&](const VulkanStreamToken& token) { - // If we have a matching descriptor set + DCHECK(token.descriptor_set_ == deferred_token.descriptor_set_); return token.descriptor_set_ == deferred_token.descriptor_set_ && token.buffers_ == deferred_token.buffers_; })) { @@ -117,7 +121,7 @@ class VulkanStream { } deferred_kernels_.push_back(deferred_kernel); - deferred_tokens_.push_back(deferred_token); + deferred_tokens_[deferred_token.descriptor_set_].push_back(deferred_token); } // Synchronize the current stream `state_` with respect to the host. @@ -172,7 +176,9 @@ class VulkanStream { private: const VulkanContext* vctx_; std::unique_ptr state_; - std::vector deferred_tokens_; + // An index of deferred tokens, allowing us to efficiently detect duplicated + // deferred_initializer blocks. + std::unordered_map> deferred_tokens_; std::vector> deferred_kernels_; VkCommandPool cmd_pool_; };