@@ -48,6 +48,24 @@ static bool IsSuitableSubReq(const Requirement *Req) {
4848 return Req->MIsSubBuffer ;
4949}
5050
51+ // Checks if the required access mode is allowed under the current one
52+ static bool isAccessModeAllowed (access::mode Required, access::mode Current) {
53+ switch (Current) {
54+ case access::mode::read:
55+ return (Required == Current);
56+ case access::mode::write:
57+ assert (false && " Write only access is expected to be mapped as read_write" );
58+ return (Required == Current || Required == access::mode::discard_write);
59+ case access::mode::read_write:
60+ case access::mode::atomic:
61+ case access::mode::discard_write:
62+ case access::mode::discard_read_write:
63+ return true ;
64+ }
65+ assert (false );
66+ return false ;
67+ }
68+
5169Scheduler::GraphBuilder::GraphBuilder () {
5270 if (const char *EnvVarCStr = SYCLConfig<SYCL_PRINT_EXECUTION_GRAPH>::get ()) {
5371 std::string GraphPrintOpts (EnvVarCStr);
@@ -199,7 +217,8 @@ UpdateHostRequirementCommand *Scheduler::GraphBuilder::insertUpdateHostReqCmd(
199217// Takes linked alloca commands. Makes AllocaCmdDst command active using map
200218// or unmap operation.
201219static Command *insertMapUnmapForLinkedCmds (AllocaCommandBase *AllocaCmdSrc,
202- AllocaCommandBase *AllocaCmdDst) {
220+ AllocaCommandBase *AllocaCmdDst,
221+ access::mode MapMode) {
203222 assert (AllocaCmdSrc->MLinkedAllocaCmd == AllocaCmdDst &&
204223 " Expected linked alloca commands" );
205224 assert (AllocaCmdSrc->MIsActive &&
@@ -215,9 +234,9 @@ static Command *insertMapUnmapForLinkedCmds(AllocaCommandBase *AllocaCmdSrc,
215234 return UnMapCmd;
216235 }
217236
218- MapMemObject *MapCmd =
219- new MapMemObject ( AllocaCmdSrc, *AllocaCmdSrc->getRequirement (),
220- &AllocaCmdDst->MMemAllocation , AllocaCmdSrc->getQueue ());
237+ MapMemObject *MapCmd = new MapMemObject (
238+ AllocaCmdSrc, *AllocaCmdSrc->getRequirement (),
239+ &AllocaCmdDst->MMemAllocation , AllocaCmdSrc->getQueue (), MapMode );
221240
222241 std::swap (AllocaCmdSrc->MIsActive , AllocaCmdDst->MIsActive );
223242
@@ -274,7 +293,12 @@ Command *Scheduler::GraphBuilder::insertMemoryMove(MemObjRecord *Record,
274293 Command *NewCmd = nullptr ;
275294
276295 if (AllocaCmdSrc->MLinkedAllocaCmd == AllocaCmdDst) {
277- NewCmd = insertMapUnmapForLinkedCmds (AllocaCmdSrc, AllocaCmdDst);
296+ // Map write only as read-write
297+ access::mode MapMode = Req->MAccessMode ;
298+ if (MapMode == access::mode::write)
299+ MapMode = access::mode::read_write;
300+ NewCmd = insertMapUnmapForLinkedCmds (AllocaCmdSrc, AllocaCmdDst, MapMode);
301+ Record->MHostAccess = MapMode;
278302 } else {
279303
280304 // Full copy of buffer is needed to avoid loss of data that may be caused
@@ -295,6 +319,43 @@ Command *Scheduler::GraphBuilder::insertMemoryMove(MemObjRecord *Record,
295319 return NewCmd;
296320}
297321
322+ Command *Scheduler::GraphBuilder::remapMemoryObject (
323+ MemObjRecord *Record, Requirement *Req, AllocaCommandBase *HostAllocaCmd) {
324+ assert (HostAllocaCmd->getQueue ()->is_host () &&
325+ " Host alloca command expected" );
326+ assert (HostAllocaCmd->MIsActive && " Active alloca command expected" );
327+
328+ AllocaCommandBase *LinkedAllocaCmd = HostAllocaCmd->MLinkedAllocaCmd ;
329+ assert (LinkedAllocaCmd && " Linked alloca command expected" );
330+
331+ std::set<Command *> Deps = findDepsForReq (Record, Req, Record->MCurContext );
332+
333+ UnMapMemObject *UnMapCmd = new UnMapMemObject (
334+ LinkedAllocaCmd, *LinkedAllocaCmd->getRequirement (),
335+ &HostAllocaCmd->MMemAllocation , LinkedAllocaCmd->getQueue ());
336+
337+ // Map write only as read-write
338+ access::mode MapMode = Req->MAccessMode ;
339+ if (MapMode == access::mode::write)
340+ MapMode = access::mode::read_write;
341+ MapMemObject *MapCmd = new MapMemObject (
342+ LinkedAllocaCmd, *LinkedAllocaCmd->getRequirement (),
343+ &HostAllocaCmd->MMemAllocation , LinkedAllocaCmd->getQueue (), MapMode);
344+
345+ for (Command *Dep : Deps) {
346+ UnMapCmd->addDep (DepDesc{Dep, UnMapCmd->getRequirement (), LinkedAllocaCmd});
347+ Dep->addUser (UnMapCmd);
348+ }
349+
350+ MapCmd->addDep (DepDesc{UnMapCmd, MapCmd->getRequirement (), HostAllocaCmd});
351+ UnMapCmd->addUser (MapCmd);
352+
353+ updateLeaves (Deps, Record, access::mode::read_write);
354+ addNodeToLeaves (Record, MapCmd, access::mode::read_write);
355+ Record->MHostAccess = MapMode;
356+ return MapCmd;
357+ }
358+
298359// The function adds copy operation of the up to date'st memory to the memory
299360// pointed by Req.
300361Command *Scheduler::GraphBuilder::addCopyBack (Requirement *Req) {
@@ -349,8 +410,11 @@ Command *Scheduler::GraphBuilder::addHostAccessor(Requirement *Req,
349410 AllocaCommandBase *HostAllocaCmd =
350411 getOrCreateAllocaForReq (Record, Req, HostQueue);
351412
352- if (!sameCtx (HostAllocaCmd->getQueue ()->getContextImplPtr (),
353- Record->MCurContext ))
413+ if (sameCtx (HostAllocaCmd->getQueue ()->getContextImplPtr (),
414+ Record->MCurContext )) {
415+ if (!isAccessModeAllowed (Req->MAccessMode , Record->MHostAccess ))
416+ remapMemoryObject (Record, Req, HostAllocaCmd);
417+ } else
354418 insertMemoryMove (Record, Req, HostQueue);
355419
356420 Command *UpdateHostAccCmd = insertUpdateHostReqCmd (Record, Req, HostQueue);
@@ -600,7 +664,13 @@ Scheduler::GraphBuilder::addCG(std::unique_ptr<detail::CG> CommandGroup,
600664 AllocaCommandBase *AllocaCmd = getOrCreateAllocaForReq (Record, Req, Queue);
601665 // If there is alloca command we need to check if the latest memory is in
602666 // required context.
603- if (!sameCtx (Queue->getContextImplPtr (), Record->MCurContext )) {
667+ if (sameCtx (Queue->getContextImplPtr (), Record->MCurContext )) {
668+ // If the memory is already in the required host context, check if the
669+ // required access mode is valid, remap if not.
670+ if (Record->MCurContext ->is_host () &&
671+ !isAccessModeAllowed (Req->MAccessMode , Record->MHostAccess ))
672+ remapMemoryObject (Record, Req, AllocaCmd);
673+ } else {
604674 // Cannot directly copy memory from OpenCL device to OpenCL device -
605675 // create two copies: device->host and host->device.
606676 if (!Queue->is_host () && !Record->MCurContext ->is_host ())
0 commit comments