@@ -270,6 +270,9 @@ void swap(PipelineLayout& lhs, PipelineLayout& rhs) noexcept {
270270// ComputePipeline
271271//
272272
273+ ComputePipeline::ComputePipeline (VkDevice device, VkPipeline handle)
274+ : device_{device}, handle_{handle} {}
275+
273276ComputePipeline::ComputePipeline (
274277 VkDevice device,
275278 const ComputePipeline::Descriptor& descriptor,
@@ -444,19 +447,94 @@ ComputePipelineCache::~ComputePipelineCache() {
444447 pipeline_cache_ = VK_NULL_HANDLE;
445448}
446449
450+ bool ComputePipelineCache::contains (const ComputePipelineCache::Key& key) {
451+ std::lock_guard<std::mutex> lock (cache_mutex_);
452+
453+ auto it = cache_.find (key);
454+ return it != cache_.cend ();
455+ }
456+
457+ void ComputePipelineCache::create_pipelines (
458+ const std::unordered_set<Key, Hasher>& descriptors) {
459+ std::lock_guard<std::mutex> lock (cache_mutex_);
460+
461+ const auto num_pipelines = descriptors.size ();
462+ std::vector<VkPipeline> pipelines (num_pipelines);
463+
464+ std::vector<std::vector<VkSpecializationMapEntry>> map_entries;
465+ map_entries.reserve (num_pipelines);
466+
467+ std::vector<VkSpecializationInfo> specialization_infos;
468+ specialization_infos.reserve (num_pipelines);
469+
470+ std::vector<VkPipelineShaderStageCreateInfo> shader_stage_create_infos;
471+ shader_stage_create_infos.reserve (num_pipelines);
472+
473+ std::vector<VkComputePipelineCreateInfo> create_infos;
474+ create_infos.reserve (num_pipelines);
475+
476+ for (auto & key : descriptors) {
477+ map_entries.push_back (key.specialization_constants .generate_map_entries ());
478+
479+ specialization_infos.push_back (VkSpecializationInfo{
480+ key.specialization_constants .size (), // mapEntryCount
481+ map_entries.back ().data (), // pMapEntries
482+ key.specialization_constants .data_nbytes (), // dataSize
483+ key.specialization_constants .data (), // pData
484+ });
485+
486+ shader_stage_create_infos.push_back (VkPipelineShaderStageCreateInfo{
487+ VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, // sType
488+ nullptr , // pNext
489+ 0u , // flags
490+ VK_SHADER_STAGE_COMPUTE_BIT, // stage
491+ key.shader_module , // module
492+ " main" , // pName
493+ &specialization_infos.back (), // pSpecializationInfo
494+ });
495+
496+ create_infos.push_back (VkComputePipelineCreateInfo{
497+ VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO, // sType
498+ nullptr , // pNext
499+ 0u , // flags
500+ shader_stage_create_infos.back (), // stage
501+ key.pipeline_layout , // layout
502+ VK_NULL_HANDLE, // basePipelineHandle
503+ 0u , // basePipelineIndex
504+ });
505+ }
506+
507+ VK_CHECK (vkCreateComputePipelines (
508+ device_,
509+ pipeline_cache_,
510+ create_infos.size (),
511+ create_infos.data (),
512+ nullptr ,
513+ pipelines.data ()));
514+
515+ uint32_t i = 0 ;
516+ for (auto & key : descriptors) {
517+ auto it = cache_.find (key);
518+ if (it != cache_.cend ()) {
519+ continue ;
520+ }
521+ cache_.insert ({key, ComputePipelineCache::Value (device_, pipelines[i])});
522+ ++i;
523+ }
524+ }
525+
447526VkPipeline ComputePipelineCache::retrieve (
448527 const ComputePipelineCache::Key& key) {
449528 std::lock_guard<std::mutex> lock (cache_mutex_);
450529
451530 auto it = cache_.find (key);
452- if (cache_.cend () == it ) {
531+ if (it == cache_.cend ()) {
453532 it = cache_
454533 .insert (
455534 {key,
456535 ComputePipelineCache::Value (device_, key, pipeline_cache_)})
457536 .first ;
458537 }
459-
460538 return it->second .handle ();
461539}
462540
0 commit comments