Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions sycl/source/detail/cg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <sycl/kernel.hpp> // for kernel_impl
#include <sycl/kernel_bundle.hpp> // for kernel_bundle_impl

#include <detail/device_kernel_info.hpp>

#include <assert.h> // for assert
#include <memory> // for shared_ptr, unique_ptr
#include <stddef.h> // for size_t
Expand Down Expand Up @@ -253,7 +255,6 @@ class CGExecKernel : public CG {
std::shared_ptr<detail::kernel_impl> MSyclKernel;
std::shared_ptr<detail::kernel_bundle_impl> MKernelBundle;
std::vector<ArgDesc> MArgs;
KernelNameStrT MKernelName;
DeviceKernelInfo &MDeviceKernelInfo;
std::vector<std::shared_ptr<detail::stream_impl>> MStreams;
std::vector<std::shared_ptr<const void>> MAuxiliaryResources;
Expand All @@ -269,7 +270,7 @@ class CGExecKernel : public CG {
std::shared_ptr<detail::kernel_impl> SyclKernel,
std::shared_ptr<detail::kernel_bundle_impl> KernelBundle,
CG::StorageInitHelper CGData, std::vector<ArgDesc> Args,
KernelNameStrT KernelName, DeviceKernelInfo &DeviceKernelInfo,
DeviceKernelInfo &DeviceKernelInfo,
std::vector<std::shared_ptr<detail::stream_impl>> Streams,
std::vector<std::shared_ptr<const void>> AuxiliaryResources,
CGType Type, ur_kernel_cache_config_t KernelCacheConfig,
Expand All @@ -278,8 +279,7 @@ class CGExecKernel : public CG {
: CG(Type, std::move(CGData), std::move(loc)), MNDRDesc(NDRDesc),
MHostKernel(std::move(HKernel)), MSyclKernel(std::move(SyclKernel)),
MKernelBundle(std::move(KernelBundle)), MArgs(std::move(Args)),
MKernelName(std::move(KernelName)), MDeviceKernelInfo(DeviceKernelInfo),
MStreams(std::move(Streams)),
MDeviceKernelInfo(DeviceKernelInfo), MStreams(std::move(Streams)),
MAuxiliaryResources(std::move(AuxiliaryResources)),
MAlternativeKernels{}, MKernelCacheConfig(std::move(KernelCacheConfig)),
MKernelIsCooperative(KernelIsCooperative),
Expand All @@ -291,7 +291,9 @@ class CGExecKernel : public CG {
CGExecKernel(const CGExecKernel &CGExec) = default;

const std::vector<ArgDesc> &getArguments() const { return MArgs; }
KernelNameStrRefT getKernelName() const { return MKernelName; }
std::string_view getKernelName() const {
return static_cast<std::string_view>(MDeviceKernelInfo.Name);
}
const std::vector<std::shared_ptr<detail::stream_impl>> &getStreams() const {
return MStreams;
}
Expand Down
19 changes: 9 additions & 10 deletions sycl/source/detail/graph/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -739,9 +739,8 @@ ur_exp_command_buffer_sync_point_t exec_graph_impl::enqueueNodeDirect(
CGExec->MLine, CGExec->MColumn);
std::tie(CmdTraceEvent, InstanceID) = emitKernelInstrumentationData(
sycl::detail::GSYCLStreamID, CGExec->MSyclKernel, CodeLoc,
CGExec->MIsTopCodeLoc, CGExec->MKernelName.data(),
CGExec->MDeviceKernelInfo, nullptr, CGExec->MNDRDesc,
CGExec->MKernelBundle.get(), CGExec->MArgs);
CGExec->MIsTopCodeLoc, CGExec->MDeviceKernelInfo, nullptr,
CGExec->MNDRDesc, CGExec->MKernelBundle.get(), CGExec->MArgs);
if (CmdTraceEvent)
sycl::detail::emitInstrumentationGeneral(sycl::detail::GSYCLStreamID,
InstanceID, CmdTraceEvent,
Expand Down Expand Up @@ -1401,14 +1400,14 @@ void exec_graph_impl::update(std::shared_ptr<graph_impl> GraphImpl) {
sycl::detail::CGExecKernel *TargetCGExec =
static_cast<sycl::detail::CGExecKernel *>(
MNodeStorage[i]->MCommandGroup.get());
KernelNameStrRefT TargetKernelName = TargetCGExec->getKernelName();
std::string_view TargetKernelName = TargetCGExec->getKernelName();

sycl::detail::CGExecKernel *SourceCGExec =
static_cast<sycl::detail::CGExecKernel *>(
GraphImpl->MNodeStorage[i]->MCommandGroup.get());
KernelNameStrRefT SourceKernelName = SourceCGExec->getKernelName();
std::string_view SourceKernelName = SourceCGExec->getKernelName();

if (TargetKernelName.compare(SourceKernelName) != 0) {
if (TargetKernelName != SourceKernelName) {
std::stringstream ErrorStream(
"Cannot update using a graph with mismatched kernel "
"types. Source node type ");
Expand Down Expand Up @@ -1568,14 +1567,14 @@ void exec_graph_impl::populateURKernelUpdateStructs(
UrKernel = Kernel->getHandleRef();
EliminatedArgMask = Kernel->getKernelArgMask();
} else if (auto SyclKernelImpl =
KernelBundleImplPtr
? KernelBundleImplPtr->tryGetKernel(ExecCG.MKernelName)
: std::shared_ptr<kernel_impl>{nullptr}) {
KernelBundleImplPtr ? KernelBundleImplPtr->tryGetKernel(
ExecCG.MDeviceKernelInfo.Name)
: std::shared_ptr<kernel_impl>{nullptr}) {
UrKernel = SyclKernelImpl->getHandleRef();
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();
} else {
BundleObjs = sycl::detail::ProgramManager::getInstance().getOrCreateKernel(
ContextImpl, DeviceImpl, ExecCG.MKernelName, ExecCG.MDeviceKernelInfo);
ContextImpl, DeviceImpl, ExecCG.MDeviceKernelInfo);
UrKernel = BundleObjs->MKernelHandle;
EliminatedArgMask = BundleObjs->MKernelArgMask;
}
Expand Down
7 changes: 5 additions & 2 deletions sycl/source/detail/graph/node_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,8 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get());
sycl::detail::CGExecKernel *ExecKernelB =
static_cast<sycl::detail::CGExecKernel *>(Node.MCommandGroup.get());
return ExecKernelA->MKernelName.compare(ExecKernelB->MKernelName) == 0;
return std::string_view{ExecKernelA->MDeviceKernelInfo.Name} ==
std::string_view{ExecKernelB->MDeviceKernelInfo.Name};
}
case sycl::detail::CGType::CopyUSM: {
sycl::detail::CGCopyUSM *CopyA =
Expand Down Expand Up @@ -543,7 +544,9 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
Stream << "CGExecKernel \\n";
sycl::detail::CGExecKernel *Kernel =
static_cast<sycl::detail::CGExecKernel *>(MCommandGroup.get());
Stream << "NAME = " << Kernel->MKernelName << "\\n";
Stream << "NAME = "
<< static_cast<std::string_view>(Kernel->MDeviceKernelInfo.Name)
<< "\\n";
if (Verbose) {
Stream << "ARGS = \\n";
for (size_t i = 0; i < Kernel->MArgs.size(); i++) {
Expand Down
5 changes: 5 additions & 0 deletions sycl/source/detail/handler_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ class handler_impl {
HandlerSubmissionState::EXPLICIT_KERNEL_BUNDLE_STATE;
}

KernelNameStrRefT getKernelName() const {
assert(MDeviceKernelInfoPtr);
return static_cast<KernelNameStrRefT>(MDeviceKernelInfoPtr->Name);
}

/// Registers mutually exclusive submission states.
HandlerSubmissionState MSubmissionState = HandlerSubmissionState::NO_STATE;

Expand Down
26 changes: 17 additions & 9 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1084,25 +1084,32 @@ ProgramManager::getBuiltURProgram(const BinImgWithDeps &ImgWithDeps,

FastKernelCacheValPtr ProgramManager::getOrCreateKernel(
context_impl &ContextImpl, device_impl &DeviceImpl,
KernelNameStrRefT KernelName, DeviceKernelInfo &DeviceKernelInfo,
const NDRDescT &NDRDesc) {
DeviceKernelInfo &DeviceKernelInfo, const NDRDescT &NDRDesc) {
if constexpr (DbgProgMgr > 0) {
std::cerr << ">>> ProgramManager::getOrCreateKernel(" << &ContextImpl
<< ", " << &DeviceImpl << ", " << KernelName << ")\n";
<< ", " << &DeviceImpl << ", "
<< static_cast<std::string_view>(DeviceKernelInfo.Name) << ")\n";
}

KernelProgramCache &Cache = ContextImpl.getKernelProgramCache();
ur_device_handle_t UrDevice = DeviceImpl.getHandleRef();
if (SYCLConfig<SYCL_CACHE_IN_MEM>::get()) {
if (auto KernelCacheValPtr = Cache.tryToGetKernelFast(
KernelName, UrDevice, DeviceKernelInfo.getKernelSubcache())) {
if (auto KernelCacheValPtr =
Cache.tryToGetKernelFast(DeviceKernelInfo.Name, UrDevice,
DeviceKernelInfo.getKernelSubcache())) {
return KernelCacheValPtr;
}
}

Managed<ur_program_handle_t> Program =
getBuiltURProgram(ContextImpl, DeviceImpl, KernelName, NDRDesc);
Managed<ur_program_handle_t> Program = getBuiltURProgram(
ContextImpl, DeviceImpl, DeviceKernelInfo.Name, NDRDesc);

#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
// Simplify this once `DeviceKernelInfo.Name`'s type is known.
// Using `decltype(auto)` insteado of just `auto` to get reference when
// possible.
#endif
decltype(auto) KernelName = KernelNameStrRefT{DeviceKernelInfo.Name};
auto BuildF = [this, &Program, &KernelName, &ContextImpl] {
adapter_impl &Adapter = ContextImpl.getAdapter();
Managed<ur_kernel_handle_t> Kernel{Adapter};
Expand All @@ -1125,7 +1132,8 @@ FastKernelCacheValPtr ProgramManager::getOrCreateKernel(
return std::make_pair(std::move(Kernel), ArgMask);
};

auto GetCachedBuildF = [&Cache, &KernelName, &Program]() {
auto GetCachedBuildF = [&Cache, &KernelName = DeviceKernelInfo.Name,
&Program]() {
return Cache.getOrInsertKernel(Program, KernelName);
};

Expand All @@ -1147,7 +1155,7 @@ FastKernelCacheValPtr ProgramManager::getOrCreateKernel(
auto ret_val = std::make_shared<FastKernelCacheVal>(
KernelArgMaskPair.first.retain(), &(BuildResult->MBuildResultMutex),
KernelArgMaskPair.second, std::move(Program), ContextImpl.getAdapter());
Cache.saveKernel(KernelName, UrDevice, ret_val,
Cache.saveKernel(DeviceKernelInfo.Name, UrDevice, ret_val,
DeviceKernelInfo.getKernelSubcache());
return ret_val;
}
Expand Down
1 change: 0 additions & 1 deletion sycl/source/detail/program_manager/program_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ class ProgramManager {

FastKernelCacheValPtr getOrCreateKernel(context_impl &ContextImpl,
device_impl &DeviceImpl,
KernelNameStrRefT KernelName,
DeviceKernelInfo &DeviceKernelInfo,
const NDRDescT &NDRDesc = {});

Expand Down
Loading
Loading