Skip to content

[SYCL] Make ur::getAdapter return raw adapter pointer instead of shared_ptr #19102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: sycl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions sycl/source/detail/adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,6 @@ class Adapter {
UrFuncPtrMapT UrFuncPtrs;
}; // class Adapter

using AdapterPtr = std::shared_ptr<Adapter>;

} // namespace detail
} // namespace _V1
} // namespace sycl
3 changes: 1 addition & 2 deletions sycl/source/detail/context_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,8 +366,7 @@ void GetCapabilitiesIntersectionSet(const std::vector<sycl::device> &Devices,
// convenient to be able to reference them without extra `detail::`.
inline auto get_ur_handles(sycl::detail::context_impl &Ctx) {
ur_context_handle_t urCtx = Ctx.getHandleRef();
const sycl::detail::Adapter *Adapter = Ctx.getAdapter().get();
return std::tuple{urCtx, Adapter};
return std::tuple{urCtx, Ctx.getAdapter()};
}
inline auto get_ur_handles(const sycl::context &syclContext) {
return get_ur_handles(*sycl::detail::getSyclObjImpl(syclContext));
Expand Down
9 changes: 7 additions & 2 deletions sycl/source/detail/global_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ std::mutex &GlobalHandler::getFilterMutex() {
return FilterMutex;
}

std::vector<AdapterPtr> &GlobalHandler::getAdapters() {
static std::vector<AdapterPtr> &adapters = getOrCreate(MAdapters);
std::vector<Adapter *> &GlobalHandler::getAdapters() {
static std::vector<Adapter *> &adapters = getOrCreate(MAdapters);
enableOnCrashStackPrinting();
return adapters;
}
Expand Down Expand Up @@ -314,6 +314,7 @@ void GlobalHandler::unloadAdapters() {
if (MAdapters.Inst) {
for (const auto &Adapter : getAdapters()) {
Adapter->release();
delete Adapter;
}
}

Expand Down Expand Up @@ -387,6 +388,10 @@ void shutdown_late() {
Handler->MScheduler.Inst.reset(nullptr);
Handler->MProgramManager.Inst.reset(nullptr);

// Cache stores handles to the adapter, so clear it before
// releasing adapters.
Handler->MKernelNameBasedCaches.Inst.reset(nullptr);

// Clear the adapters and reset the instance if it was there.
Handler->unloadAdapters();
if (Handler->MAdapters.Inst)
Expand Down
5 changes: 2 additions & 3 deletions sycl/source/detail/global_handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class ThreadPool;
struct KernelNameBasedCacheT;

using ContextImplPtr = std::shared_ptr<context_impl>;
using AdapterPtr = std::shared_ptr<Adapter>;

/// Wrapper class for global data structures with non-trivial destructors.
///
Expand Down Expand Up @@ -71,7 +70,7 @@ class GlobalHandler {
std::mutex &getPlatformToDefaultContextCacheMutex();
std::mutex &getPlatformMapMutex();
std::mutex &getFilterMutex();
std::vector<AdapterPtr> &getAdapters();
std::vector<Adapter *> &getAdapters();
ods_target_list &getOneapiDeviceSelectorTargets(const std::string &InitValue);
XPTIRegistry &getXPTIRegistry();
ThreadPool &getHostTaskThreadPool();
Expand Down Expand Up @@ -126,7 +125,7 @@ class GlobalHandler {
InstWithLock<std::mutex> MPlatformToDefaultContextCacheMutex;
InstWithLock<std::mutex> MPlatformMapMutex;
InstWithLock<std::mutex> MFilterMutex;
InstWithLock<std::vector<AdapterPtr>> MAdapters;
InstWithLock<std::vector<Adapter *>> MAdapters;
InstWithLock<ods_target_list> MOneapiDeviceSelectorTargets;
InstWithLock<XPTIRegistry> MXPTIRegistry;
// Thread pool for host task and event callbacks execution
Expand Down
32 changes: 12 additions & 20 deletions sycl/source/detail/kernel_program_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,25 +112,21 @@ class KernelProgramCache {
};

struct ProgramBuildResult : public BuildResult<ur_program_handle_t> {
std::weak_ptr<Adapter> AdapterWeakPtr;
ProgramBuildResult(const AdapterPtr &Adapter) : AdapterWeakPtr(Adapter) {
AdapterPtr adapter;
ProgramBuildResult(const AdapterPtr &_adapter) : adapter(_adapter) {
Val = nullptr;
}
ProgramBuildResult(const AdapterPtr &Adapter, BuildState InitialState)
: AdapterWeakPtr(Adapter) {
ProgramBuildResult(const AdapterPtr &_adapter, BuildState InitialState)
: adapter(_adapter) {
Val = nullptr;
this->State.store(InitialState);
}
~ProgramBuildResult() {
try {
if (Val) {
AdapterPtr AdapterSharedPtr = AdapterWeakPtr.lock();
if (AdapterSharedPtr) {
ur_result_t Err =
AdapterSharedPtr->call_nocheck<UrApiKind::urProgramRelease>(
Val);
__SYCL_CHECK_UR_CODE_NO_EXC(Err, AdapterSharedPtr->getBackend());
}
ur_result_t Err =
adapter->call_nocheck<UrApiKind::urProgramRelease>(Val);
__SYCL_CHECK_UR_CODE_NO_EXC(Err, adapter->getBackend());
}
} catch (std::exception &e) {
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~ProgramBuildResult",
Expand Down Expand Up @@ -202,20 +198,16 @@ class KernelProgramCache {
using KernelArgMaskPairT =
std::pair<ur_kernel_handle_t, const KernelArgMask *>;
struct KernelBuildResult : public BuildResult<KernelArgMaskPairT> {
std::weak_ptr<Adapter> AdapterWeakPtr;
KernelBuildResult(const AdapterPtr &Adapter) : AdapterWeakPtr(Adapter) {
AdapterPtr adapter;
KernelBuildResult(const AdapterPtr &_adapter) : adapter(_adapter) {
Val.first = nullptr;
}
~KernelBuildResult() {
try {
if (Val.first) {
AdapterPtr AdapterSharedPtr = AdapterWeakPtr.lock();
if (AdapterSharedPtr) {
ur_result_t Err =
AdapterSharedPtr->call_nocheck<UrApiKind::urKernelRelease>(
Val.first);
__SYCL_CHECK_UR_CODE_NO_EXC(Err, AdapterSharedPtr->getBackend());
}
ur_result_t Err =
adapter->call_nocheck<UrApiKind::urKernelRelease>(Val.first);
__SYCL_CHECK_UR_CODE_NO_EXC(Err, adapter->getBackend());
}
} catch (std::exception &e) {
__SYCL_REPORT_EXCEPTION_TO_STREAM("exception in ~KernelBuildResult", e);
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/platform_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ std::vector<platform> platform_impl::get_platforms() {

// See which platform we want to be served by which adapter.
// There should be just one adapter serving each backend.
std::vector<AdapterPtr> &Adapters = sycl::detail::ur::initializeUr();
std::vector<AdapterPtr> &Adapters = ur::initializeUr();
std::vector<std::pair<platform, AdapterPtr>> PlatformsWithAdapter;

// Then check backend-specific adapters
Expand Down Expand Up @@ -487,7 +487,7 @@ platform_impl::get_devices(info::device_type DeviceType) const {
// analysis. Doing adjustment by simple copy of last device num from
// previous platform.
// Needs non const adapter reference.
std::vector<AdapterPtr> &Adapters = sycl::detail::ur::initializeUr();
std::vector<AdapterPtr> &Adapters = ur::initializeUr();
auto It = std::find_if(Adapters.begin(), Adapters.end(),
[&Platform = MPlatform](AdapterPtr &Adapter) {
return Adapter->containsUrPlatform(Platform);
Expand Down
3 changes: 1 addition & 2 deletions sycl/source/detail/platform_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ class platform_impl : public std::enable_shared_from_this<platform_impl> {
//
// Platforms can only be created under `GlobalHandler`'s ownership via
// `platform_impl::getOrMakePlatformImpl` method.
explicit platform_impl(ur_platform_handle_t APlatform,
const std::shared_ptr<Adapter> &AAdapter)
explicit platform_impl(ur_platform_handle_t APlatform, Adapter *AAdapter)
: MPlatform(APlatform), MAdapter(AAdapter) {
// Find out backend of the platform
ur_backend_t UrBackend = UR_BACKEND_UNKNOWN;
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,7 @@ FastKernelCacheValPtr ProgramManager::getOrCreateKernel(
// nullptr for the mutex.
auto [Kernel, ArgMask] = BuildF();
return std::make_shared<FastKernelCacheVal>(
Kernel, nullptr, ArgMask, Program, *ContextImpl.getAdapter().get());
Kernel, nullptr, ArgMask, Program, *ContextImpl.getAdapter());
}

auto BuildResult = Cache.getOrBuild<errc::invalid>(GetCachedBuildF, BuildF);
Expand All @@ -1195,7 +1195,7 @@ FastKernelCacheValPtr ProgramManager::getOrCreateKernel(
const KernelArgMaskPairT &KernelArgMaskPair = BuildResult->Val;
auto ret_val = std::make_shared<FastKernelCacheVal>(
KernelArgMaskPair.first, &(BuildResult->MBuildResultMutex),
KernelArgMaskPair.second, Program, *ContextImpl.getAdapter().get());
KernelArgMaskPair.second, Program, *ContextImpl.getAdapter());
// If caching is enabled, one copy of the kernel handle will be
// stored in FastKernelCacheVal, and one is in
// KernelProgramCache::MKernelsPerProgramCache. To cover
Expand Down
2 changes: 1 addition & 1 deletion sycl/source/detail/sycl_mem_obj_t.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace detail {
class context_impl;
class event_impl;
class Adapter;
using AdapterPtr = std::shared_ptr<Adapter>;
using AdapterPtr = Adapter *;

using ContextImplPtr = std::shared_ptr<context_impl>;
using EventImplPtr = std::shared_ptr<event_impl>;
Expand Down
28 changes: 14 additions & 14 deletions sycl/source/detail/ur.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ bool trace(TraceLevel Level) {
return (TraceLevelMask & Level) == Level;
}

static void initializeAdapters(std::vector<AdapterPtr> &Adapters,
static void initializeAdapters(std::vector<Adapter *> &Adapters,
ur_loader_config_handle_t LoaderConfig);

bool XPTIInitDone = false;
Expand All @@ -117,7 +117,7 @@ std::vector<AdapterPtr> &initializeUr(ur_loader_config_handle_t LoaderConfig) {
return GlobalHandler::instance().getAdapters();
}

static void initializeAdapters(std::vector<AdapterPtr> &Adapters,
static void initializeAdapters(std::vector<Adapter *> &Adapters,
ur_loader_config_handle_t LoaderConfig) {
#define CHECK_UR_SUCCESS(Call) \
{ \
Expand Down Expand Up @@ -238,7 +238,7 @@ static void initializeAdapters(std::vector<AdapterPtr> &Adapters,
sizeof(adapterBackend), &adapterBackend,
nullptr));
auto syclBackend = UrToSyclBackend(adapterBackend);
Adapters.emplace_back(std::make_shared<Adapter>(UrAdapter, syclBackend));
Adapters.emplace_back(new Adapter(UrAdapter, syclBackend));

const char *env_value = std::getenv("UR_LOG_CALLBACK");
if (env_value == nullptr || std::string(env_value) != "disabled") {
Expand Down Expand Up @@ -284,25 +284,25 @@ static void initializeAdapters(std::vector<AdapterPtr> &Adapters,
}

// Get the adapter serving given backend.
template <backend BE> const AdapterPtr &getAdapter() {
static AdapterPtr *Adapter = nullptr;
if (Adapter)
return *Adapter;
template <backend BE> AdapterPtr &getAdapter() {
static AdapterPtr adapterPtr = nullptr;
if (adapterPtr)
return adapterPtr;

std::vector<AdapterPtr> &Adapters = ur::initializeUr();
std::vector<AdapterPtr> Adapters = ur::initializeUr();
for (auto &P : Adapters)
if (P->hasBackend(BE)) {
Adapter = &P;
return *Adapter;
adapterPtr = P;
return adapterPtr;
}

throw exception(errc::runtime, "ur::getAdapter couldn't find adapter");
}

template const AdapterPtr &getAdapter<backend::opencl>();
template const AdapterPtr &getAdapter<backend::ext_oneapi_level_zero>();
template const AdapterPtr &getAdapter<backend::ext_oneapi_cuda>();
template const AdapterPtr &getAdapter<backend::ext_oneapi_hip>();
template AdapterPtr &getAdapter<backend::opencl>();
template AdapterPtr &getAdapter<backend::ext_oneapi_level_zero>();
template AdapterPtr &getAdapter<backend::ext_oneapi_cuda>();
template AdapterPtr &getAdapter<backend::ext_oneapi_hip>();

// Reads an integer value from ELF data.
template <typename ResT>
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/ur.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ inline namespace _V1 {
enum class backend : char;
namespace detail {
class Adapter;
using AdapterPtr = std::shared_ptr<Adapter>;
using AdapterPtr = Adapter *;

namespace ur {
void *getURLoaderLibrary();
Expand All @@ -35,7 +35,7 @@ std::vector<AdapterPtr> &
initializeUr(ur_loader_config_handle_t LoaderConfig = nullptr);

// Get the adapter serving given backend.
template <backend BE> const AdapterPtr &getAdapter();
template <backend BE> AdapterPtr &getAdapter();
} // namespace ur

// Convert from UR backend to SYCL backend enum
Expand Down