diff --git a/sycl/source/detail/queue_impl.cpp b/sycl/source/detail/queue_impl.cpp index ffce483031e4e..8c94b42266db0 100644 --- a/sycl/source/detail/queue_impl.cpp +++ b/sycl/source/detail/queue_impl.cpp @@ -455,7 +455,7 @@ event queue_impl::submitMemOpHelper(const std::vector &DepEvents, // If we have a command graph set we need to capture the op through the // handler rather than by-passing the scheduler. if (MGraph.expired() && Scheduler::areEventsSafeForSchedulerBypass( - ExpandedDepEvents, MContext)) { + ExpandedDepEvents, *MContext)) { auto isNoEventsMode = trySwitchingToNoEventsMode(); if (!CallerNeedsEvent && isNoEventsMode) { NestedCallsTracker tracker; diff --git a/sycl/source/detail/queue_impl.hpp b/sycl/source/detail/queue_impl.hpp index 04d36259c2de9..4451814927260 100644 --- a/sycl/source/detail/queue_impl.hpp +++ b/sycl/source/detail/queue_impl.hpp @@ -726,7 +726,7 @@ class queue_impl : public std::enable_shared_from_this { return false; if (MDefaultGraphDeps.LastEventPtr != nullptr && - !Scheduler::CheckEventReadiness(MContext, + !Scheduler::CheckEventReadiness(*MContext, MDefaultGraphDeps.LastEventPtr)) return false; diff --git a/sycl/source/detail/scheduler/graph_builder.cpp b/sycl/source/detail/scheduler/graph_builder.cpp index a0e3e25e07d1c..eaba6f0033455 100644 --- a/sycl/source/detail/scheduler/graph_builder.cpp +++ b/sycl/source/detail/scheduler/graph_builder.cpp @@ -225,7 +225,7 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(const QueueImplPtr &Queue, Dev, InteropCtxPtr, async_handler{}, property_list{}); MemObject->MRecord.reset( - new MemObjRecord{InteropCtxPtr, LeafLimit, AllocateDependency}); + new MemObjRecord{InteropCtxPtr.get(), LeafLimit, AllocateDependency}); std::vector ToEnqueue; getOrCreateAllocaForReq(MemObject->MRecord.get(), Req, InteropQueuePtr, ToEnqueue); @@ -233,8 +233,8 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(const QueueImplPtr &Queue, "shouldn't lead to any enqueuing (no linked " "alloca or exceeding the leaf limit)."); } else - MemObject->MRecord.reset(new MemObjRecord{queue_impl::getContext(Queue), - LeafLimit, AllocateDependency}); + MemObject->MRecord.reset(new MemObjRecord{ + queue_impl::getContext(Queue).get(), LeafLimit, AllocateDependency}); MMemObjs.push_back(MemObject); return MemObject->MRecord.get(); diff --git a/sycl/source/detail/scheduler/scheduler.cpp b/sycl/source/detail/scheduler/scheduler.cpp index f071132ef2356..b84500dc65a96 100644 --- a/sycl/source/detail/scheduler/scheduler.cpp +++ b/sycl/source/detail/scheduler/scheduler.cpp @@ -678,7 +678,7 @@ EventImplPtr Scheduler::addCommandGraphUpdate( return NewCmdEvent; } -bool Scheduler::CheckEventReadiness(const ContextImplPtr &Context, +bool Scheduler::CheckEventReadiness(context_impl &Context, const EventImplPtr &SyclEventImplPtr) { // Events that don't have an initialized context are throwaway events that // don't represent actual dependencies. Calling getContextImpl() would set @@ -691,7 +691,7 @@ bool Scheduler::CheckEventReadiness(const ContextImplPtr &Context, return SyclEventImplPtr->isCompleted(); } // Cross-context dependencies can't be passed to the backend directly. - if (SyclEventImplPtr->getContextImpl() != Context) + if (SyclEventImplPtr->getContextImpl().get() != &Context) return false; // A nullptr here means that the commmand does not produce a UR event or it @@ -700,7 +700,7 @@ bool Scheduler::CheckEventReadiness(const ContextImplPtr &Context, } bool Scheduler::areEventsSafeForSchedulerBypass( - const std::vector &DepEvents, const ContextImplPtr &Context) { + const std::vector &DepEvents, context_impl &Context) { return std::all_of( DepEvents.begin(), DepEvents.end(), [&Context](const sycl::event &Event) { @@ -710,7 +710,7 @@ bool Scheduler::areEventsSafeForSchedulerBypass( } bool Scheduler::areEventsSafeForSchedulerBypass( - const std::vector &DepEvents, const ContextImplPtr &Context) { + const std::vector &DepEvents, context_impl &Context) { return std::all_of(DepEvents.begin(), DepEvents.end(), [&Context](const EventImplPtr &SyclEventImplPtr) { diff --git a/sycl/source/detail/scheduler/scheduler.hpp b/sycl/source/detail/scheduler/scheduler.hpp index 69448ca817937..a43e7546e9a6b 100644 --- a/sycl/source/detail/scheduler/scheduler.hpp +++ b/sycl/source/detail/scheduler/scheduler.hpp @@ -9,6 +9,7 @@ #pragma once #include +#include #include #include #include @@ -198,10 +199,11 @@ using CommandPtr = std::unique_ptr; /// /// \ingroup sycl_graph struct MemObjRecord { - MemObjRecord(ContextImplPtr Ctx, std::size_t LeafLimit, + MemObjRecord(context_impl *Ctx, std::size_t LeafLimit, LeavesCollection::AllocateDependencyF AllocateDependency) : MReadLeaves{this, LeafLimit, AllocateDependency}, - MWriteLeaves{this, LeafLimit, AllocateDependency}, MCurContext{Ctx} {} + MWriteLeaves{this, LeafLimit, AllocateDependency}, + MCurContext{Ctx ? Ctx->shared_from_this() : nullptr} {} // Contains all allocation commands for the memory object. std::vector MAllocaCommands; @@ -212,7 +214,7 @@ struct MemObjRecord { LeavesCollection MWriteLeaves; // The context which has the latest state of the memory object. - ContextImplPtr MCurContext; + std::shared_ptr MCurContext; // The mode this object can be accessed from the host (host_accessor). // Valid only if the current usage is on host. @@ -477,15 +479,15 @@ class Scheduler { const QueueImplPtr &Queue, std::vector Requirements, std::vector &Events); - static bool CheckEventReadiness(const ContextImplPtr &Context, + static bool CheckEventReadiness(context_impl &Context, const EventImplPtr &SyclEventImplPtr); static bool areEventsSafeForSchedulerBypass(const std::vector &DepEvents, - const ContextImplPtr &Context); + context_impl &Context); static bool areEventsSafeForSchedulerBypass(const std::vector &DepEvents, - const ContextImplPtr &Context); + context_impl &Context); protected: using RWLockT = std::shared_timed_mutex; diff --git a/sycl/source/handler.cpp b/sycl/source/handler.cpp index 8f55572622a70..2e21286c09df6 100644 --- a/sycl/source/handler.cpp +++ b/sycl/source/handler.cpp @@ -411,7 +411,7 @@ event handler::finalize() { (Queue && !Graph && !impl->MSubgraphNode && !Queue->hasCommandGraph() && !impl->CGData.MRequirements.size() && !MStreamStorage.size() && detail::Scheduler::areEventsSafeForSchedulerBypass( - impl->CGData.MEvents, Queue->getContextImplPtr())); + impl->CGData.MEvents, Queue->getContextImpl())); // Extract arguments from the kernel lambda, if required. // Skipping this is currently limited to simple kernels on the fast path.