diff --git a/sycl/source/detail/graph_impl.cpp b/sycl/source/detail/graph_impl.cpp index 1a328507808b8..534f70783d020 100644 --- a/sycl/source/detail/graph_impl.cpp +++ b/sycl/source/detail/graph_impl.cpp @@ -1480,8 +1480,7 @@ bool exec_graph_impl::needsScheduledUpdate( } void exec_graph_impl::populateURKernelUpdateStructs( - const std::shared_ptr &Node, - std::pair &BundleObjs, + const std::shared_ptr &Node, FastKernelCacheValPtr &BundleObjs, std::vector &MemobjDescs, std::vector &MemobjProps, std::vector &PtrDescs, @@ -1517,12 +1516,11 @@ void exec_graph_impl::populateURKernelUpdateStructs( UrKernel = SyclKernelImpl->getHandleRef(); EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); } else { - ur_program_handle_t UrProgram = nullptr; - std::tie(UrKernel, std::ignore, EliminatedArgMask, UrProgram) = - sycl::detail::ProgramManager::getInstance().getOrCreateKernel( - ContextImpl, DeviceImpl, ExecCG.MKernelName, - ExecCG.MKernelNameBasedCachePtr); - BundleObjs = std::make_pair(UrProgram, UrKernel); + BundleObjs = sycl::detail::ProgramManager::getInstance().getOrCreateKernel( + ContextImpl, DeviceImpl, ExecCG.MKernelName, + ExecCG.MKernelNameBasedCachePtr); + UrKernel = BundleObjs->MKernelHandle; + EliminatedArgMask = BundleObjs->MKernelArgMask; } // Remove eliminated args @@ -1717,8 +1715,7 @@ void exec_graph_impl::updateURImpl( std::vector NDRDescList(NumUpdatableNodes); std::vector UpdateDescList( NumUpdatableNodes); - std::vector> - KernelBundleObjList(NumUpdatableNodes); + std::vector KernelBundleObjList(NumUpdatableNodes); size_t StructListIndex = 0; for (auto &Node : Nodes) { @@ -1743,17 +1740,6 @@ void exec_graph_impl::updateURImpl( const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter(); Adapter->call( CommandBuffer, UpdateDescList.size(), UpdateDescList.data()); - - for (auto &BundleObjs : KernelBundleObjList) { - // We retained these objects by inside populateUpdateStruct() by calling - // getOrCreateKernel() - if (auto &UrKernel = BundleObjs.second; nullptr != UrKernel) { - Adapter->call(UrKernel); - } - if (auto &UrProgram = BundleObjs.first; nullptr != UrProgram) { - Adapter->call(UrProgram); - } - } } modifiable_command_graph::modifiable_command_graph( diff --git a/sycl/source/detail/graph_impl.hpp b/sycl/source/detail/graph_impl.hpp index 4001eed719286..b803daa97c6b0 100644 --- a/sycl/source/detail/graph_impl.hpp +++ b/sycl/source/detail/graph_impl.hpp @@ -1521,8 +1521,7 @@ class exec_graph_impl { /// @param[out] NDRDesc ND-Range to update. /// @param[out] UpdateDesc Base struct in the pointer chain. void populateURKernelUpdateStructs( - const std::shared_ptr &Node, - std::pair &BundleObjs, + const std::shared_ptr &Node, FastKernelCacheValPtr &BundleObjs, std::vector &MemobjDescs, std::vector &MemobjProps, std::vector &PtrDescs, diff --git a/sycl/source/detail/kernel_name_based_cache_t.hpp b/sycl/source/detail/kernel_name_based_cache_t.hpp index f1ecd3ec4cd9d..63f56651b99b2 100644 --- a/sycl/source/detail/kernel_name_based_cache_t.hpp +++ b/sycl/source/detail/kernel_name_based_cache_t.hpp @@ -19,11 +19,47 @@ namespace sycl { inline namespace _V1 { namespace detail { using FastKernelCacheKeyT = std::pair; -using FastKernelCacheValT = - std::tuple; + +struct FastKernelCacheVal { + ur_kernel_handle_t MKernelHandle; /* UR kernel handle pointer. */ + std::mutex *MMutex; /* Mutex guarding this kernel. When + caching is disabled, the pointer is + nullptr. */ + const KernelArgMask *MKernelArgMask; /* Eliminated kernel argument mask. */ + ur_program_handle_t MProgramHandle; /* UR program handle corresponding to + this kernel. */ + const Adapter &MAdapterPtr; /* We can keep reference to the adapter + because during 2-stage shutdown the kernel + cache is destroyed deliberately before the + adapter. */ + + FastKernelCacheVal(ur_kernel_handle_t KernelHandle, std::mutex *Mutex, + const KernelArgMask *KernelArgMask, + ur_program_handle_t ProgramHandle, + const Adapter &AdapterPtr) + : MKernelHandle(KernelHandle), MMutex(Mutex), + MKernelArgMask(KernelArgMask), MProgramHandle(ProgramHandle), + MAdapterPtr(AdapterPtr) {} + + ~FastKernelCacheVal() { + if (MKernelHandle) + MAdapterPtr.call(MKernelHandle); + if (MProgramHandle) + MAdapterPtr.call( + MProgramHandle); + MKernelHandle = nullptr; + MMutex = nullptr; + MKernelArgMask = nullptr; + MProgramHandle = nullptr; + } + + FastKernelCacheVal(const FastKernelCacheVal &) = delete; + FastKernelCacheVal &operator=(const FastKernelCacheVal &) = delete; +}; +using FastKernelCacheValPtr = std::shared_ptr; + using FastKernelSubcacheMapT = - ::boost::unordered_flat_map; + ::boost::unordered_flat_map; using FastKernelSubcacheMutexT = SpinLock; using FastKernelSubcacheReadLockT = std::lock_guard; diff --git a/sycl/source/detail/kernel_program_cache.hpp b/sycl/source/detail/kernel_program_cache.hpp index 7317b9fd9b309..2585710e4781e 100644 --- a/sycl/source/detail/kernel_program_cache.hpp +++ b/sycl/source/detail/kernel_program_cache.hpp @@ -468,7 +468,7 @@ class KernelProgramCache { return std::make_pair(It->second, DidInsert); } - FastKernelCacheValT + FastKernelCacheValPtr tryToGetKernelFast(KernelNameStrRefT KernelName, ur_device_handle_t Device, FastKernelSubcacheT *KernelSubcacheHint) { FastKernelCacheWriteLockT Lock(MFastKernelCacheMutex); @@ -486,27 +486,27 @@ class KernelProgramCache { traceKernel("Kernel fetched.", KernelName, true); return It->second; } - return std::make_tuple(nullptr, nullptr, nullptr, nullptr); + return FastKernelCacheValPtr(); } void saveKernel(KernelNameStrRefT KernelName, ur_device_handle_t Device, - FastKernelCacheValT CacheVal, + const FastKernelCacheValPtr &CacheVal, FastKernelSubcacheT *KernelSubcacheHint) { - ur_program_handle_t Program = std::get<3>(CacheVal); if (SYCLConfig:: isProgramCacheEvictionEnabled()) { // Save kernel in fast cache only if the corresponding program is also // in the cache. auto LockedCache = acquireCachedPrograms(); auto &ProgCache = LockedCache.get(); - if (ProgCache.ProgramSizeMap.find(Program) == + if (ProgCache.ProgramSizeMap.find(CacheVal->MProgramHandle) == ProgCache.ProgramSizeMap.end()) return; } // Save reference between the program and the fast cache key. FastKernelCacheWriteLockT Lock(MFastKernelCacheMutex); - MProgramToFastKernelCacheKeyMap[Program].emplace_back(KernelName, Device); + MProgramToFastKernelCacheKeyMap[CacheVal->MProgramHandle].emplace_back( + KernelName, Device); // if no insertion took place, then some other thread has already inserted // smth in the cache @@ -518,7 +518,7 @@ class KernelProgramCache { FastKernelSubcacheWriteLockT SubcacheLock{KernelSubcacheHint->Mutex}; ur_context_handle_t Context = getURContext(); KernelSubcacheHint->Map.emplace(FastKernelCacheKeyT(Device, Context), - std::move(CacheVal)); + CacheVal); } // Expects locked program cache diff --git a/sycl/source/detail/program_manager/program_manager.cpp b/sycl/source/detail/program_manager/program_manager.cpp index 5836a215b3216..ff18c879f8678 100644 --- a/sycl/source/detail/program_manager/program_manager.cpp +++ b/sycl/source/detail/program_manager/program_manager.cpp @@ -13,7 +13,6 @@ #include #include #include -#include #include #include #include @@ -1108,11 +1107,8 @@ ur_program_handle_t ProgramManager::getBuiltURProgram( Adapter->call(ResProgram); return ResProgram; } -// When caching is enabled, the returned UrProgram and UrKernel will -// already have their ref count incremented. -std::tuple -ProgramManager::getOrCreateKernel( + +FastKernelCacheValPtr ProgramManager::getOrCreateKernel( const ContextImplPtr &ContextImpl, device_impl &DeviceImpl, KernelNameStrRefT KernelName, KernelNameBasedCacheT *KernelNameBasedCachePtr, const NDRDescT &NDRDesc) { @@ -1129,18 +1125,11 @@ ProgramManager::getOrCreateKernel( KernelNameBasedCachePtr ? &KernelNameBasedCachePtr->FastKernelSubcache : nullptr; if (SYCLConfig::get()) { - auto ret_tuple = + auto KernelCacheValPtr = Cache.tryToGetKernelFast(KernelName, UrDevice, CacheHintPtr); - constexpr size_t Kernel = 0; // see FastKernelCacheValT tuple - constexpr size_t Program = 3; // see FastKernelCacheValT tuple - if (std::get(ret_tuple)) { - // Pulling a copy of a kernel and program from the cache, - // so we need to retain those resources. - ContextImpl->getAdapter()->call( - std::get(ret_tuple)); - ContextImpl->getAdapter()->call( - std::get(ret_tuple)); - return ret_tuple; + if (auto KernelCacheValPtr = + Cache.tryToGetKernelFast(KernelName, UrDevice, CacheHintPtr)) { + return KernelCacheValPtr; } } @@ -1179,20 +1168,21 @@ ProgramManager::getOrCreateKernel( // threads when caching is disabled, so we can return // nullptr for the mutex. auto [Kernel, ArgMask] = BuildF(); - return make_tuple(Kernel, nullptr, ArgMask, Program); + return std::make_shared( + Kernel, nullptr, ArgMask, Program, *ContextImpl->getAdapter().get()); } auto BuildResult = Cache.getOrBuild(GetCachedBuildF, BuildF); // getOrBuild is not supposed to return nullptr assert(BuildResult != nullptr && "Invalid build result"); const KernelArgMaskPairT &KernelArgMaskPair = BuildResult->Val; - auto ret_val = std::make_tuple(KernelArgMaskPair.first, - &(BuildResult->MBuildResultMutex), - KernelArgMaskPair.second, Program); + auto ret_val = std::make_shared( + KernelArgMaskPair.first, &(BuildResult->MBuildResultMutex), + KernelArgMaskPair.second, Program, *ContextImpl->getAdapter().get()); // If caching is enabled, one copy of the kernel handle will be - // stored in the cache, and one handle is returned to the - // caller. In that case, we need to increase the ref count of the - // kernel. + // stored in FastKernelCacheVal, and one is in + // KernelProgramCache::MKernelsPerProgramCache. To cover + // MKernelsPerProgramCache, we need to increase the ref count of the kernel. ContextImpl->getAdapter()->call( KernelArgMaskPair.first); Cache.saveKernel(KernelName, UrDevice, ret_val, CacheHintPtr); diff --git a/sycl/source/detail/program_manager/program_manager.hpp b/sycl/source/detail/program_manager/program_manager.hpp index 27c4610421ca4..49855adad6e55 100644 --- a/sycl/source/detail/program_manager/program_manager.hpp +++ b/sycl/source/detail/program_manager/program_manager.hpp @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -197,8 +198,7 @@ class ProgramManager { const DevImgPlainWithDeps *DevImgWithDeps = nullptr, const SerializedObj &SpecConsts = {}); - std::tuple + FastKernelCacheValPtr getOrCreateKernel(const ContextImplPtr &ContextImpl, device_impl &DeviceImpl, KernelNameStrRefT KernelName, KernelNameBasedCacheT *KernelNameBasedCachePtr, diff --git a/sycl/source/detail/scheduler/commands.cpp b/sycl/source/detail/scheduler/commands.cpp index 724dcc2956734..6b02557d3266e 100644 --- a/sycl/source/detail/scheduler/commands.cpp +++ b/sycl/source/detail/scheduler/commands.cpp @@ -1991,8 +1991,6 @@ void instrumentationAddExtraKernelMetadata( auto FilterArgs = [&Args](detail::ArgDesc &Arg, int NextTrueIndex) { Args.push_back({Arg.MType, Arg.MPtr, Arg.MSize, NextTrueIndex}); }; - ur_kernel_handle_t Kernel = nullptr; - std::mutex *KernelMutex = nullptr; const KernelArgMask *EliminatedArgMask = nullptr; if (nullptr != SyclKernel) { @@ -2007,11 +2005,11 @@ void instrumentationAddExtraKernelMetadata( // NOTE: Queue can be null when kernel is directly enqueued to a command // buffer // by graph API, when a modifiable graph is finalized. - ur_program_handle_t Program = nullptr; - std::tie(Kernel, KernelMutex, EliminatedArgMask, Program) = + FastKernelCacheValPtr FastKernelCacheVal = detail::ProgramManager::getInstance().getOrCreateKernel( Queue->getContextImplPtr(), Queue->getDeviceImpl(), KernelName, KernelNameBasedCachePtr); + EliminatedArgMask = FastKernelCacheVal->MKernelArgMask; } applyFuncOnFilteredArgs(EliminatedArgMask, CGArgs, FilterArgs); @@ -2544,8 +2542,7 @@ static std::tuple, const KernelArgMask *> getCGKernelInfo(const CGExecKernel &CommandGroup, ContextImplPtr ContextImpl, device_impl &DeviceImpl, - std::vector &UrKernelsToRelease, - std::vector &UrProgramsToRelease) { + std::vector &KernelCacheValsToRelease) { ur_kernel_handle_t UrKernel = nullptr; std::shared_ptr DeviceImageImpl = nullptr; @@ -2564,13 +2561,14 @@ getCGKernelInfo(const CGExecKernel &CommandGroup, ContextImplPtr ContextImpl, DeviceImageImpl = SyclKernelImpl->getDeviceImage(); EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); } else { - ur_program_handle_t UrProgram = nullptr; - std::tie(UrKernel, std::ignore, EliminatedArgMask, UrProgram) = + FastKernelCacheValPtr FastKernelCacheVal = sycl::detail::ProgramManager::getInstance().getOrCreateKernel( ContextImpl, DeviceImpl, CommandGroup.MKernelName, CommandGroup.MKernelNameBasedCachePtr); - UrKernelsToRelease.push_back(UrKernel); - UrProgramsToRelease.push_back(UrProgram); + UrKernel = FastKernelCacheVal->MKernelHandle; + EliminatedArgMask = FastKernelCacheVal->MKernelArgMask; + // To keep UrKernel valid, we return FastKernelCacheValPtr. + KernelCacheValsToRelease.push_back(std::move(FastKernelCacheVal)); } return std::make_tuple(UrKernel, DeviceImageImpl, EliminatedArgMask); } @@ -2583,20 +2581,18 @@ ur_result_t enqueueImpCommandBufferKernel( ur_exp_command_buffer_sync_point_t *OutSyncPoint, ur_exp_command_buffer_command_handle_t *OutCommand, const std::function &getMemAllocationFunc) { - // List of ur objects to be released after UR call. We don't do anything - // with the ur_program_handle_t objects, but need to update their reference - // count. - std::vector UrKernelsToRelease; - std::vector UrProgramsToRelease; + // List of fast cache elements to be released after UR call. We don't do + // anything with them, but they must exist to keep ur_kernel_handle_t-s + // valid. + std::vector FastKernelCacheValsToRelease; ur_kernel_handle_t UrKernel = nullptr; std::shared_ptr DeviceImageImpl = nullptr; const KernelArgMask *EliminatedArgMask = nullptr; auto ContextImpl = sycl::detail::getSyclObjImpl(Ctx); - std::tie(UrKernel, DeviceImageImpl, EliminatedArgMask) = - getCGKernelInfo(CommandGroup, ContextImpl, DeviceImpl, UrKernelsToRelease, - UrProgramsToRelease); + std::tie(UrKernel, DeviceImageImpl, EliminatedArgMask) = getCGKernelInfo( + CommandGroup, ContextImpl, DeviceImpl, FastKernelCacheValsToRelease); // Build up the list of UR kernel handles that the UR command could be // updated to use. @@ -2610,7 +2606,7 @@ ur_result_t enqueueImpCommandBufferKernel( ur_kernel_handle_t AltUrKernel = nullptr; std::tie(AltUrKernel, std::ignore, std::ignore) = getCGKernelInfo(*AltCGKernel.get(), ContextImpl, DeviceImpl, - UrKernelsToRelease, UrProgramsToRelease); + FastKernelCacheValsToRelease); AltUrKernels.push_back(AltUrKernel); } @@ -2671,13 +2667,6 @@ ur_result_t enqueueImpCommandBufferKernel( nullptr, OutSyncPoint, nullptr, CommandBufferDesc.isUpdatable ? OutCommand : nullptr); - for (auto &Kernel : UrKernelsToRelease) { - Adapter->call(Kernel); - } - for (auto &Program : UrProgramsToRelease) { - Adapter->call(Program); - } - if (Res != UR_RESULT_SUCCESS) { detail::enqueue_kernel_launch::handleErrorOrWarning(Res, DeviceImpl, UrKernel, NDRDesc); @@ -2709,6 +2698,7 @@ void enqueueImpKernel( std::shared_ptr SyclKernelImpl; std::shared_ptr DeviceImageImpl; + FastKernelCacheValPtr KernelCacheVal; if (nullptr != MSyclKernel) { assert(MSyclKernel->get_info() == @@ -2736,10 +2726,12 @@ void enqueueImpKernel( EliminatedArgMask = SyclKernelImpl->getKernelArgMask(); KernelMutex = SyclKernelImpl->getCacheMutex(); } else { - std::tie(Kernel, KernelMutex, EliminatedArgMask, Program) = - detail::ProgramManager::getInstance().getOrCreateKernel( - ContextImpl, DeviceImpl, KernelName, KernelNameBasedCachePtr, - NDRDesc); + KernelCacheVal = detail::ProgramManager::getInstance().getOrCreateKernel( + ContextImpl, DeviceImpl, KernelName, KernelNameBasedCachePtr, NDRDesc); + Kernel = KernelCacheVal->MKernelHandle; + KernelMutex = KernelCacheVal->MMutex; + Program = KernelCacheVal->MProgramHandle; + EliminatedArgMask = KernelCacheVal->MKernelArgMask; } // We may need more events for the launch, so we make another reference. @@ -2784,12 +2776,6 @@ void enqueueImpKernel( KernelIsCooperative, KernelUsesClusterLaunch, WorkGroupMemorySize, BinImage, KernelName, KernelFuncPtr, KernelNumArgs, KernelParamDescGetter, KernelHasSpecialCaptures); - - const AdapterPtr &Adapter = Queue->getAdapter(); - if (!SyclKernelImpl && !MSyclKernel) { - Adapter->call(Kernel); - Adapter->call(Program); - } } if (UR_RESULT_SUCCESS != Error) { // If we have got non-success error code, let's analyze it to emit nice diff --git a/sycl/test-e2e/KernelAndProgram/disable-caching.cpp b/sycl/test-e2e/KernelAndProgram/disable-caching.cpp index 8a0b15b12311f..2da9ee4df047c 100644 --- a/sycl/test-e2e/KernelAndProgram/disable-caching.cpp +++ b/sycl/test-e2e/KernelAndProgram/disable-caching.cpp @@ -35,8 +35,6 @@ int main() { // CHECK-CACHE: <--- urKernelRetain // CHECK-CACHE-NOT: <--- urKernelCreate // CHECK-CACHE: <--- urEnqueueKernelLaunch - // CHECK-CACHE: <--- urKernelRelease - // CHECK-CACHE: <--- urProgramRelease // CHECK-CACHE: <--- urEventWait q.single_task([] {}).wait(); @@ -99,6 +97,8 @@ int main() { // (Program cache releases) // CHECK-RELEASE: <--- urKernelRelease +// CHECK-RELEASE: <--- urProgramRelease +// CHECK-RELEASE: <--- urKernelRelease // CHECK-RELEASE: <--- urKernelRelease // CHECK-RELEASE: <--- urKernelRelease // CHECK-RELEASE: <--- urProgramRelease diff --git a/sycl/test-e2e/XPTI/basic_event_collection_linux.cpp b/sycl/test-e2e/XPTI/basic_event_collection_linux.cpp index 4dfe5928bd5ee..b8dceca5367a9 100644 --- a/sycl/test-e2e/XPTI/basic_event_collection_linux.cpp +++ b/sycl/test-e2e/XPTI/basic_event_collection_linux.cpp @@ -17,6 +17,7 @@ // CHECK-NEXT: UR Call Begin : urPlatformGetInfo // CHECK-NEXT: UR Call Begin : urPlatformGetInfo // CHECK-NEXT: UR Call Begin : urKernelSetExecInfo +// CHECK-NEXT: UR Call Begin : urKernelRetain // CHECK: UR Call Begin : urKernelSetArgPointer // CHECK-NEXT: UR Call Begin : urKernelGetGroupInfo // CHECK-NEXT: UR Call Begin : urEnqueueKernelLaunch @@ -24,6 +25,7 @@ // CHECK-NEXT: UR Call Begin : urPlatformGetInfo // CHECK-NEXT: UR Call Begin : urPlatformGetInfo // CHECK-NEXT: UR Call Begin : urKernelSetExecInfo +// CHECK-NEXT: UR Call Begin : urKernelRetain // CHECK: Node create // CHECK-DAG: queue_id : {{.*}} // CHECK-DAG: sym_line_no : {{.*}} @@ -43,8 +45,6 @@ // CHECK: UR Call Begin : urKernelSetArgPointer // CHECK-NEXT: UR Call Begin : urKernelGetGroupInfo // CHECK-NEXT: UR Call Begin : urEnqueueKernelLaunch -// CHECK-NEXT: UR Call Begin : urKernelRelease -// CHECK-NEXT: UR Call Begin : urProgramRelease // CHECK-NEXT: Signal // CHECK-DAG: queue_id : {{.*}} // CHECK-DAG: sym_line_no : {{.*}} @@ -93,3 +93,7 @@ // CHECK-NEXT: Wait end // CHECK-DAG: queue_id : {{.*}} // CHECK-DAG: sycl_device_type : {{.*}} +// CHECK: UR Call Begin : urKernelRelease +// CHECK: UR Call Begin : urKernelRelease +// CHECK: UR Call Begin : urKernelRelease +// CHECK: UR Call Begin : urKernelRelease