Skip to content

[SYCL] Use shared_ptr instead of manual changing UR counters #18465

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: sycl
Choose a base branch
from
26 changes: 7 additions & 19 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1463,7 +1463,7 @@ bool exec_graph_impl::needsScheduledUpdate(

void exec_graph_impl::populateURKernelUpdateStructs(
const std::shared_ptr<node_impl> &Node,
std::pair<ur_program_handle_t, ur_kernel_handle_t> &BundleObjs,
KernelProgramCache::KernelFastCacheValPtr &BundleObjs,
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t> &MemobjDescs,
std::vector<ur_kernel_arg_mem_obj_properties_t> &MemobjProps,
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> &PtrDescs,
Expand Down Expand Up @@ -1499,11 +1499,10 @@ void exec_graph_impl::populateURKernelUpdateStructs(
UrKernel = SyclKernelImpl->getHandleRef();
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();
} else {
ur_program_handle_t UrProgram = nullptr;
std::tie(UrKernel, std::ignore, EliminatedArgMask, UrProgram) =
sycl::detail::ProgramManager::getInstance().getOrCreateKernel(
ContextImpl, DeviceImpl, ExecCG.MKernelName);
BundleObjs = std::make_pair(UrProgram, UrKernel);
BundleObjs = sycl::detail::ProgramManager::getInstance().getOrCreateKernel(
ContextImpl, DeviceImpl, ExecCG.MKernelName);
UrKernel = BundleObjs->MKernelHandle;
EliminatedArgMask = BundleObjs->MKernelArgMask;
}

// Remove eliminated args
Expand Down Expand Up @@ -1698,8 +1697,8 @@ void exec_graph_impl::updateURImpl(
std::vector<sycl::detail::NDRDescT> NDRDescList(NumUpdatableNodes);
std::vector<ur_exp_command_buffer_update_kernel_launch_desc_t> UpdateDescList(
NumUpdatableNodes);
std::vector<std::pair<ur_program_handle_t, ur_kernel_handle_t>>
KernelBundleObjList(NumUpdatableNodes);
std::vector<KernelProgramCache::KernelFastCacheValPtr> KernelBundleObjList(
NumUpdatableNodes);

size_t StructListIndex = 0;
for (auto &Node : Nodes) {
Expand All @@ -1724,17 +1723,6 @@ void exec_graph_impl::updateURImpl(
const sycl::detail::AdapterPtr &Adapter = ContextImpl->getAdapter();
Adapter->call<sycl::detail::UrApiKind::urCommandBufferUpdateKernelLaunchExp>(
CommandBuffer, UpdateDescList.size(), UpdateDescList.data());

for (auto &BundleObjs : KernelBundleObjList) {
// We retained these objects by inside populateUpdateStruct() by calling
// getOrCreateKernel()
if (auto &UrKernel = BundleObjs.second; nullptr != UrKernel) {
Adapter->call<sycl::detail::UrApiKind::urKernelRelease>(UrKernel);
}
if (auto &UrProgram = BundleObjs.first; nullptr != UrProgram) {
Adapter->call<sycl::detail::UrApiKind::urProgramRelease>(UrProgram);
}
}
}

modifiable_command_graph::modifiable_command_graph(
Expand Down
3 changes: 2 additions & 1 deletion sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <detail/graph_memory_pool.hpp>
#include <detail/host_task.hpp>
#include <detail/kernel_impl.hpp>
#include <detail/kernel_program_cache.hpp>
#include <detail/sycl_mem_obj_t.hpp>

#include <cstring>
Expand Down Expand Up @@ -1484,7 +1485,7 @@ class exec_graph_impl {
/// @param[out] UpdateDesc Base struct in the pointer chain.
void populateURKernelUpdateStructs(
const std::shared_ptr<node_impl> &Node,
std::pair<ur_program_handle_t, ur_kernel_handle_t> &BundleObjs,
KernelProgramCache::KernelFastCacheValPtr &BundleObjs,
std::vector<ur_exp_command_buffer_update_memobj_arg_desc_t> &MemobjDescs,
std::vector<ur_kernel_arg_mem_obj_properties_t> &MemobjProps,
std::vector<ur_exp_command_buffer_update_pointer_arg_desc_t> &PtrDescs,
Expand Down
53 changes: 37 additions & 16 deletions sycl/source/detail/kernel_program_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,21 +233,42 @@ class KernelProgramCache {
KernelNameStrT /* Kernel Name */
>;

using KernelFastCacheValT =
std::tuple<ur_kernel_handle_t, /* UR kernel handle pointer. */
std::mutex *, /* Mutex guarding this kernel. */
const KernelArgMask *, /* Eliminated kernel argument mask. */
ur_program_handle_t /* UR program handle corresponding to this
kernel. */
>;
struct KernelFastCacheVal {
ur_kernel_handle_t MKernelHandle; /* UR kernel handle pointer. */
std::mutex *MMutex; /* Mutex guarding this kernel. */
const KernelArgMask *MKernelArgMask; /* Eliminated kernel argument mask. */
ur_program_handle_t MProgramHandle; /* UR program handle corresponding to
this kernel. */
std::weak_ptr<Adapter> MAdapterWeakPtr; /* Weak pointer to the adapter. */

KernelFastCacheVal(ur_kernel_handle_t KernelHandle, std::mutex *Mutex,
const KernelArgMask *KernelArgMask,
ur_program_handle_t ProgramHandle,
const AdapterPtr &Adapter)
: MKernelHandle(KernelHandle), MMutex(Mutex),
MKernelArgMask(KernelArgMask), MProgramHandle(ProgramHandle),
MAdapterWeakPtr(Adapter) {}

~KernelFastCacheVal() {
if (AdapterPtr Adapter = MAdapterWeakPtr.lock()) {
if (MKernelHandle)
Adapter->call<sycl::detail::UrApiKind::urKernelRelease>(
MKernelHandle);
if (MProgramHandle)
Adapter->call<sycl::detail::UrApiKind::urProgramRelease>(
MProgramHandle);
}
}
};
using KernelFastCacheValPtr = std::shared_ptr<KernelFastCacheVal>;

// This container is used as a fast path for retrieving cached kernels.
// unordered_flat_map is used here to reduce lookup overhead.
// The slow path is used only once for each newly created kernel, so the
// higher overhead of insertion that comes with unordered_flat_map is more
// of an issue there. For that reason, those use regular unordered maps.
using KernelFastCacheT =
::boost::unordered_flat_map<KernelFastCacheKeyT, KernelFastCacheValT>;
::boost::unordered_flat_map<KernelFastCacheKeyT, KernelFastCacheValPtr>;

// DS to hold data and functions related to Program cache eviction.
struct EvictionList {
Expand Down Expand Up @@ -427,34 +448,34 @@ class KernelProgramCache {
return std::make_pair(It->second, DidInsert);
}

template <typename KeyT>
KernelFastCacheValT tryToGetKernelFast(KeyT &&CacheKey) {
KernelFastCacheValPtr
tryToGetKernelFast(const KernelProgramCache::KernelFastCacheKeyT &CacheKey) {
KernelFastCacheReadLockT Lock(MKernelFastCacheMutex);
auto It = MKernelFastCache.find(CacheKey);
if (It != MKernelFastCache.end()) {
traceKernel("Kernel fetched.", CacheKey.second, true);
return It->second;
}
return std::make_tuple(nullptr, nullptr, nullptr, nullptr);
return KernelFastCacheValPtr();
}

template <typename KeyT, typename ValT>
void saveKernel(KeyT &&CacheKey, ValT &&CacheVal) {
ur_program_handle_t Program = std::get<3>(CacheVal);
void saveKernel(const KernelProgramCache::KernelFastCacheKeyT &CacheKey,
const KernelProgramCache::KernelFastCacheValPtr &CacheVal) {
if (SYCLConfig<SYCL_IN_MEM_CACHE_EVICTION_THRESHOLD>::
isProgramCacheEvictionEnabled()) {

// Save kernel in fast cache only if the corresponding program is also
// in the cache.
auto LockedCache = acquireCachedPrograms();
auto &ProgCache = LockedCache.get();
if (ProgCache.ProgramSizeMap.find(Program) ==
if (ProgCache.ProgramSizeMap.find(CacheVal->MProgramHandle) ==
ProgCache.ProgramSizeMap.end())
return;
}
// Save reference between the program and the fast cache key.
KernelFastCacheWriteLockT Lock(MKernelFastCacheMutex);
MProgramToKernelFastCacheKeyMap[Program].emplace_back(CacheKey);
MProgramToKernelFastCacheKeyMap[CacheVal->MProgramHandle].emplace_back(
CacheKey);

// if no insertion took place, thus some other thread has already inserted
// smth in the cache
Expand Down
36 changes: 14 additions & 22 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1107,12 +1107,9 @@ ur_program_handle_t ProgramManager::getBuiltURProgram(
}
// When caching is enabled, the returned UrProgram and UrKernel will
// already have their ref count incremented.
std::tuple<ur_kernel_handle_t, std::mutex *, const KernelArgMask *,
ur_program_handle_t>
ProgramManager::getOrCreateKernel(const ContextImplPtr &ContextImpl,
device_impl &DeviceImpl,
KernelNameStrRefT KernelName,
const NDRDescT &NDRDesc) {
KernelProgramCache::KernelFastCacheValPtr ProgramManager::getOrCreateKernel(
const ContextImplPtr &ContextImpl, device_impl &DeviceImpl,
KernelNameStrRefT KernelName, const NDRDescT &NDRDesc) {
if constexpr (DbgProgMgr > 0) {
std::cerr << ">>> ProgramManager::getOrCreateKernel(" << ContextImpl.get()
<< ", " << &DeviceImpl << ", " << KernelName << ")\n";
Expand All @@ -1126,15 +1123,7 @@ ProgramManager::getOrCreateKernel(const ContextImplPtr &ContextImpl,
auto key = std::make_pair(UrDevice, KernelName);
if (SYCLConfig<SYCL_CACHE_IN_MEM>::get()) {
auto ret_tuple = Cache.tryToGetKernelFast(key);
constexpr size_t Kernel = 0; // see KernelFastCacheValT tuple
constexpr size_t Program = 3; // see KernelFastCacheValT tuple
if (std::get<Kernel>(ret_tuple)) {
// Pulling a copy of a kernel and program from the cache,
// so we need to retain those resources.
ContextImpl->getAdapter()->call<UrApiKind::urKernelRetain>(
std::get<Kernel>(ret_tuple));
ContextImpl->getAdapter()->call<UrApiKind::urProgramRetain>(
std::get<Program>(ret_tuple));
if (ret_tuple) {
return ret_tuple;
}
}
Expand Down Expand Up @@ -1174,22 +1163,25 @@ ProgramManager::getOrCreateKernel(const ContextImplPtr &ContextImpl,
// threads when caching is disabled, so we can return
// nullptr for the mutex.
auto [Kernel, ArgMask] = BuildF();
return make_tuple(Kernel, nullptr, ArgMask, Program);
return std::make_shared<KernelProgramCache::KernelFastCacheVal>(
Kernel, nullptr, ArgMask, Program, ContextImpl->getAdapter());
}

auto BuildResult = Cache.getOrBuild<errc::invalid>(GetCachedBuildF, BuildF);
// getOrBuild is not supposed to return nullptr
assert(BuildResult != nullptr && "Invalid build result");
const KernelArgMaskPairT &KernelArgMaskPair = BuildResult->Val;
auto ret_val = std::make_tuple(KernelArgMaskPair.first,
&(BuildResult->MBuildResultMutex),
KernelArgMaskPair.second, Program);
auto ret_val = std::make_shared<KernelProgramCache::KernelFastCacheVal>(
KernelArgMaskPair.first, &(BuildResult->MBuildResultMutex),
KernelArgMaskPair.second, Program, ContextImpl->getAdapter());

// If caching is enabled, one copy of the kernel handle will be
// stored in the cache, and one handle is returned to the
// caller. In that case, we need to increase the ref count of the
// kernel.
// stored in KernelProgramCache::KernelFastCacheT, and one is in
// KernelProgramCache::MKernelsPerProgramCache. To cover this,
// we need to increase the ref count of the kernel.
ContextImpl->getAdapter()->call<UrApiKind::urKernelRetain>(
KernelArgMaskPair.first);

Cache.saveKernel(key, ret_val);
return ret_val;
}
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/program_manager/program_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <detail/device_global_map_entry.hpp>
#include <detail/host_pipe_map_entry.hpp>
#include <detail/kernel_arg_mask.hpp>
#include <detail/kernel_program_cache.hpp>
#include <detail/spec_constant_impl.hpp>
#include <sycl/detail/cg_types.hpp>
#include <sycl/detail/common.hpp>
Expand Down Expand Up @@ -197,8 +198,7 @@ class ProgramManager {
const DevImgPlainWithDeps *DevImgWithDeps = nullptr,
const SerializedObj &SpecConsts = {});

std::tuple<ur_kernel_handle_t, std::mutex *, const KernelArgMask *,
ur_program_handle_t>
KernelProgramCache::KernelFastCacheValPtr
getOrCreateKernel(const ContextImplPtr &ContextImpl, device_impl &DeviceImpl,
KernelNameStrRefT KernelName, const NDRDescT &NDRDesc = {});

Expand Down
32 changes: 18 additions & 14 deletions sycl/source/detail/scheduler/commands.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1989,8 +1989,6 @@ void instrumentationAddExtraKernelMetadata(
auto FilterArgs = [&Args](detail::ArgDesc &Arg, int NextTrueIndex) {
Args.push_back({Arg.MType, Arg.MPtr, Arg.MSize, NextTrueIndex});
};
ur_kernel_handle_t Kernel = nullptr;
std::mutex *KernelMutex = nullptr;
const KernelArgMask *EliminatedArgMask = nullptr;

if (nullptr != SyclKernel) {
Expand All @@ -2005,10 +2003,10 @@ void instrumentationAddExtraKernelMetadata(
// NOTE: Queue can be null when kernel is directly enqueued to a command
// buffer
// by graph API, when a modifiable graph is finalized.
ur_program_handle_t Program = nullptr;
std::tie(Kernel, KernelMutex, EliminatedArgMask, Program) =
KernelProgramCache::KernelFastCacheValPtr KernelCacheVal =
detail::ProgramManager::getInstance().getOrCreateKernel(
Queue->getContextImplPtr(), Queue->getDeviceImpl(), KernelName);
EliminatedArgMask = KernelCacheVal->MKernelArgMask;
}

applyFuncOnFilteredArgs(EliminatedArgMask, CGArgs, FilterArgs);
Expand Down Expand Up @@ -2554,9 +2552,17 @@ getCGKernelInfo(const CGExecKernel &CommandGroup, ContextImplPtr ContextImpl,
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();
} else {
ur_program_handle_t UrProgram = nullptr;
std::tie(UrKernel, std::ignore, EliminatedArgMask, UrProgram) =
KernelProgramCache::KernelFastCacheValPtr KernelCacheVal =
sycl::detail::ProgramManager::getInstance().getOrCreateKernel(
ContextImpl, DeviceImpl, CommandGroup.MKernelName);
UrKernel = KernelCacheVal->MKernelHandle;
EliminatedArgMask = KernelCacheVal->MKernelArgMask;
UrProgram = KernelCacheVal->MProgramHandle;
// UrProgram/UrKernel are used after KernelCacheVal is destroyed, so caller
// must call ur*Release
ContextImpl->getAdapter()->call<UrApiKind::urProgramRetain>(UrProgram);
ContextImpl->getAdapter()->call<UrApiKind::urKernelRetain>(UrKernel);

UrKernelsToRelease.push_back(UrKernel);
UrProgramsToRelease.push_back(UrProgram);
}
Expand Down Expand Up @@ -2697,6 +2703,7 @@ void enqueueImpKernel(

std::shared_ptr<kernel_impl> SyclKernelImpl;
std::shared_ptr<device_image_impl> DeviceImageImpl;
KernelProgramCache::KernelFastCacheValPtr KernelCacheVal;

if (nullptr != MSyclKernel) {
assert(MSyclKernel->get_info<info::kernel::context>() ==
Expand Down Expand Up @@ -2724,9 +2731,12 @@ void enqueueImpKernel(
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();
KernelMutex = SyclKernelImpl->getCacheMutex();
} else {
std::tie(Kernel, KernelMutex, EliminatedArgMask, Program) =
detail::ProgramManager::getInstance().getOrCreateKernel(
ContextImpl, DeviceImpl, KernelName, NDRDesc);
KernelCacheVal = detail::ProgramManager::getInstance().getOrCreateKernel(
ContextImpl, DeviceImpl, KernelName, NDRDesc);
Kernel = KernelCacheVal->MKernelHandle;
KernelMutex = KernelCacheVal->MMutex;
Program = KernelCacheVal->MProgramHandle;
EliminatedArgMask = KernelCacheVal->MKernelArgMask;
}

// We may need more events for the launch, so we make another reference.
Expand Down Expand Up @@ -2771,12 +2781,6 @@ void enqueueImpKernel(
KernelIsCooperative, KernelUsesClusterLaunch, WorkGroupMemorySize,
BinImage, KernelName, KernelFuncPtr, KernelNumArgs,
KernelParamDescGetter, KernelHasSpecialCaptures);

const AdapterPtr &Adapter = Queue->getAdapter();
if (!SyclKernelImpl && !MSyclKernel) {
Adapter->call<UrApiKind::urKernelRelease>(Kernel);
Adapter->call<UrApiKind::urProgramRelease>(Program);
}
}
if (UR_RESULT_SUCCESS != Error) {
// If we have got non-success error code, let's analyze it to emit nice
Expand Down
4 changes: 2 additions & 2 deletions sycl/test-e2e/KernelAndProgram/disable-caching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ int main() {
// CHECK-CACHE: <--- urKernelRetain
// CHECK-CACHE-NOT: <--- urKernelCreate
// CHECK-CACHE: <--- urEnqueueKernelLaunch
// CHECK-CACHE: <--- urKernelRelease
// CHECK-CACHE: <--- urProgramRelease
// CHECK-CACHE: <--- urEventWait
q.single_task([] {}).wait();

Expand Down Expand Up @@ -98,6 +96,8 @@ int main() {
// windows should handle the memory cleanup.

// (Program cache releases)
// CHECK-CACHE: <--- urKernelRelease
// CHECK-CACHE: <--- urProgramRelease
// CHECK-RELEASE: <--- urKernelRelease
// CHECK-RELEASE: <--- urKernelRelease
// CHECK-RELEASE: <--- urKernelRelease
Expand Down
Loading