@@ -659,6 +659,14 @@ graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
659659 return NodeImpl;
660660}
661661
662+ void graph_impl::addQueue (sycl::detail::queue_impl &RecordingQueue) {
663+ MRecordingQueues.insert (RecordingQueue.weak_from_this ());
664+ }
665+
666+ void graph_impl::removeQueue (sycl::detail::queue_impl &RecordingQueue) {
667+ MRecordingQueues.erase (RecordingQueue.weak_from_this ());
668+ }
669+
662670bool graph_impl::clearQueues () {
663671 bool AnyQueuesCleared = false ;
664672 for (auto &Queue : MRecordingQueues) {
@@ -689,6 +697,24 @@ bool graph_impl::checkForCycles() {
689697 return CycleFound;
690698}
691699
700+ std::shared_ptr<node_impl>
701+ graph_impl::getLastInorderNode (sycl::detail::queue_impl *Queue) {
702+ if (!Queue) {
703+ assert (0 ==
704+ MInorderQueueMap.count (std::weak_ptr<sycl::detail::queue_impl>{}));
705+ return {};
706+ }
707+ if (0 == MInorderQueueMap.count (Queue->weak_from_this ())) {
708+ return {};
709+ }
710+ return MInorderQueueMap[Queue->weak_from_this ()];
711+ }
712+
713+ void graph_impl::setLastInorderNode (sycl::detail::queue_impl &Queue,
714+ std::shared_ptr<node_impl> Node) {
715+ MInorderQueueMap[Queue.weak_from_this ()] = Node;
716+ }
717+
692718void graph_impl::makeEdge (std::shared_ptr<node_impl> Src,
693719 std::shared_ptr<node_impl> Dest) {
694720 throwIfGraphRecordingQueue (" make_edge()" );
@@ -769,11 +795,10 @@ std::vector<sycl::detail::EventImplPtr> graph_impl::getExitNodesEvents(
769795 return Events;
770796}
771797
772- void graph_impl::beginRecording (
773- const std::shared_ptr<sycl::detail::queue_impl> &Queue) {
798+ void graph_impl::beginRecording (sycl::detail::queue_impl &Queue) {
774799 graph_impl::WriteLock Lock (MMutex);
775- if (!Queue-> hasCommandGraph ()) {
776- Queue-> setCommandGraph (shared_from_this ());
800+ if (!Queue. hasCommandGraph ()) {
801+ Queue. setCommandGraph (shared_from_this ());
777802 addQueue (Queue);
778803 }
779804}
@@ -1003,7 +1028,7 @@ exec_graph_impl::~exec_graph_impl() {
10031028}
10041029
10051030sycl::event
1006- exec_graph_impl::enqueue (const std::shared_ptr< sycl::detail::queue_impl> &Queue,
1031+ exec_graph_impl::enqueue (sycl::detail::queue_impl &Queue,
10071032 sycl::detail::CG::StorageInitHelper CGData) {
10081033 WriteLock Lock (MMutex);
10091034
@@ -1012,8 +1037,9 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
10121037 PartitionsExecutionEvents;
10131038
10141039 auto CreateNewEvent ([&]() {
1015- auto NewEvent = std::make_shared<sycl::detail::event_impl>(Queue);
1016- NewEvent->setContextImpl (Queue->getContextImplPtr ());
1040+ auto NewEvent =
1041+ std::make_shared<sycl::detail::event_impl>(Queue.shared_from_this ());
1042+ NewEvent->setContextImpl (Queue.getContextImplPtr ());
10171043 NewEvent->setStateIncomplete ();
10181044 return NewEvent;
10191045 });
@@ -1035,7 +1061,7 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
10351061 CGData.MEvents .push_back (PartitionsExecutionEvents[DepPartition]);
10361062 }
10371063
1038- auto CommandBuffer = CurrentPartition->MCommandBuffers [Queue-> get_device ()];
1064+ auto CommandBuffer = CurrentPartition->MCommandBuffers [Queue. get_device ()];
10391065
10401066 if (CommandBuffer) {
10411067 for (std::vector<sycl::detail::EventImplPtr>::iterator It =
@@ -1073,10 +1099,10 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
10731099 if (CGData.MRequirements .empty () && CGData.MEvents .empty ()) {
10741100 NewEvent->setSubmissionTime ();
10751101 ur_result_t Res =
1076- Queue-> getAdapter ()
1102+ Queue. getAdapter ()
10771103 ->call_nocheck <
10781104 sycl::detail::UrApiKind::urEnqueueCommandBufferExp>(
1079- Queue-> getHandleRef (), CommandBuffer, 0 , nullptr , &UREvent);
1105+ Queue. getHandleRef (), CommandBuffer, 0 , nullptr , &UREvent);
10801106 NewEvent->setHandle (UREvent);
10811107 if (Res == UR_RESULT_ERROR_INVALID_QUEUE_PROPERTIES) {
10821108 throw sycl::exception (
@@ -1096,7 +1122,8 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
10961122 CommandBuffer, nullptr , std::move (CGData));
10971123
10981124 NewEvent = sycl::detail::Scheduler::getInstance ().addCG (
1099- std::move (CommandGroup), Queue, /* EventNeeded=*/ true );
1125+ std::move (CommandGroup), Queue.shared_from_this (),
1126+ /* EventNeeded=*/ true );
11001127 }
11011128 NewEvent->setEventFromSubmittedExecCommandBuffer (true );
11021129 } else if ((CurrentPartition->MSchedule .size () > 0 ) &&
@@ -1112,10 +1139,11 @@ exec_graph_impl::enqueue(const std::shared_ptr<sycl::detail::queue_impl> &Queue,
11121139 // In case of graph, this queue may differ from the actual execution
11131140 // queue. We therefore overload this Queue before submitting the task.
11141141 static_cast <sycl::detail::CGHostTask &>(*NodeImpl->MCommandGroup .get ())
1115- .MQueue = Queue;
1142+ .MQueue = Queue. shared_from_this () ;
11161143
11171144 NewEvent = sycl::detail::Scheduler::getInstance ().addCG (
1118- NodeImpl->getCGCopy (), Queue, /* EventNeeded=*/ true );
1145+ NodeImpl->getCGCopy (), Queue.shared_from_this (),
1146+ /* EventNeeded=*/ true );
11191147 }
11201148 PartitionsExecutionEvents[CurrentPartition] = NewEvent;
11211149 }
@@ -1844,21 +1872,20 @@ void modifiable_command_graph::begin_recording(
18441872 // related to graph at all.
18451873 checkGraphPropertiesAndThrow (PropList);
18461874
1847- auto QueueImpl = sycl::detail::getSyclObjImpl (RecordingQueue);
1848- assert (QueueImpl);
1875+ queue_impl &QueueImpl = *sycl::detail::getSyclObjImpl (RecordingQueue);
18491876
1850- if (QueueImpl-> hasCommandGraph ()) {
1877+ if (QueueImpl. hasCommandGraph ()) {
18511878 throw sycl::exception (sycl::make_error_code (errc::invalid),
18521879 " begin_recording cannot be called for a queue which "
18531880 " is already in the recording state." );
18541881 }
18551882
1856- if (QueueImpl-> get_context () != impl->getContext ()) {
1883+ if (QueueImpl. get_context () != impl->getContext ()) {
18571884 throw sycl::exception (sycl::make_error_code (errc::invalid),
18581885 " begin_recording called for a queue whose context "
18591886 " differs from the graph context." );
18601887 }
1861- if (QueueImpl-> get_device () != impl->getDevice ()) {
1888+ if (QueueImpl. get_device () != impl->getDevice ()) {
18621889 throw sycl::exception (sycl::make_error_code (errc::invalid),
18631890 " begin_recording called for a queue whose device "
18641891 " differs from the graph device." );
@@ -1881,15 +1908,13 @@ void modifiable_command_graph::end_recording() {
18811908}
18821909
18831910void modifiable_command_graph::end_recording (queue &RecordingQueue) {
1884- auto QueueImpl = sycl::detail::getSyclObjImpl (RecordingQueue);
1885- if (!QueueImpl)
1886- return ;
1887- if (QueueImpl->getCommandGraph () == impl) {
1888- QueueImpl->setCommandGraph (nullptr );
1911+ queue_impl &QueueImpl = *sycl::detail::getSyclObjImpl (RecordingQueue);
1912+ if (QueueImpl.getCommandGraph () == impl) {
1913+ QueueImpl.setCommandGraph (nullptr );
18891914 graph_impl::WriteLock Lock (impl->MMutex );
18901915 impl->removeQueue (QueueImpl);
18911916 }
1892- if (QueueImpl-> hasCommandGraph ())
1917+ if (QueueImpl. hasCommandGraph ())
18931918 throw sycl::exception (sycl::make_error_code (errc::invalid),
18941919 " end_recording called for a queue which is recording "
18951920 " to a different graph." );
0 commit comments