Skip to content

Commit 2006064

Browse files
author
Ewan Crawford
committed
[CUDA][HIP] Improve command-buffer sync points
In the CUDA/HIP adapters we assume that there is always a return sync-point passed by the user. However, this is not required so we should check that the return value is non-null before assignment.
1 parent c5d2175 commit 2006064

File tree

2 files changed

+169
-224
lines changed

2 files changed

+169
-224
lines changed

source/adapters/cuda/command_buffer.cpp

Lines changed: 78 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,11 @@ static ur_result_t enqueueCommandBufferFillHelper(
161161
const CUmemorytype_enum DstType, const void *Pattern, size_t PatternSize,
162162
size_t Size, uint32_t NumSyncPointsInWaitList,
163163
const ur_exp_command_buffer_sync_point_t *SyncPointWaitList,
164-
ur_exp_command_buffer_sync_point_t *SyncPoint) {
164+
ur_exp_command_buffer_sync_point_t *RetSyncPoint) {
165165
ur_result_t Result = UR_RESULT_SUCCESS;
166166
std::vector<CUgraphNode> DepsList;
167-
UR_CALL(getNodesFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
168-
SyncPointWaitList, DepsList),
169-
Result);
167+
UR_CHECK_ERROR(getNodesFromSyncPoints(CommandBuffer, NumSyncPointsInWaitList,
168+
SyncPointWaitList, DepsList));
170169

171170
try {
172171
const size_t N = Size / PatternSize;
@@ -209,9 +208,11 @@ static ur_result_t enqueueCommandBufferFillHelper(
209208
CommandBuffer->Device->getNativeContext()));
210209

211210
// Get sync point and register the cuNode with it.
212-
*SyncPoint =
211+
auto SyncPoint =
213212
CommandBuffer->addSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
214-
213+
if (RetSyncPoint) {
214+
*RetSyncPoint = SyncPoint;
215+
}
215216
} else {
216217
// CUDA has no memset functions that allow setting values more than 4
217218
// bytes. UR API lets you pass an arbitrary "pattern" to the buffer
@@ -241,9 +242,13 @@ static ur_result_t enqueueCommandBufferFillHelper(
241242
CommandBuffer->Device->getNativeContext()));
242243

243244
// Get sync point and register the cuNode with it.
244-
*SyncPoint = CommandBuffer->addSyncPoint(
245+
auto SyncPoint = CommandBuffer->addSyncPoint(
245246
std::make_shared<CUgraphNode>(GraphNodeFirst));
246247

248+
if (RetSyncPoint) {
249+
*RetSyncPoint = SyncPoint;
250+
}
251+
247252
DepsList.clear();
248253
DepsList.push_back(GraphNodeFirst);
249254

@@ -274,8 +279,11 @@ static ur_result_t enqueueCommandBufferFillHelper(
274279

275280
GraphNodePtr = std::make_shared<CUgraphNode>(GraphNode);
276281
// Get sync point and register the cuNode with it.
277-
*SyncPoint = CommandBuffer->addSyncPoint(GraphNodePtr);
282+
auto SyncPoint = CommandBuffer->addSyncPoint(GraphNodePtr);
278283

284+
if (RetSyncPoint) {
285+
*RetSyncPoint = SyncPoint;
286+
}
279287
DepsList.clear();
280288
DepsList.push_back(*GraphNodePtr.get());
281289
}
@@ -372,14 +380,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
372380
CUgraphNode GraphNode;
373381

374382
std::vector<CUgraphNode> DepsList;
375-
376-
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
377-
pSyncPointWaitList, DepsList),
378-
Result);
379-
380-
if (Result != UR_RESULT_SUCCESS) {
381-
return Result;
382-
}
383+
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
384+
pSyncPointWaitList, DepsList));
383385

384386
if (*pGlobalWorkSize == 0) {
385387
try {
@@ -388,8 +390,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
388390
DepsList.data(), DepsList.size()));
389391

390392
// Get sync point and register the cuNode with it.
391-
*pSyncPoint = hCommandBuffer->addSyncPoint(
393+
auto SyncPoint = hCommandBuffer->addSyncPoint(
392394
std::make_shared<CUgraphNode>(GraphNode));
395+
if (pSyncPoint) {
396+
*pSyncPoint = SyncPoint;
397+
}
393398
} catch (ur_result_t Err) {
394399
Result = Err;
395400
}
@@ -435,8 +440,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
435440

436441
// Get sync point and register the cuNode with it.
437442
auto NodeSP = std::make_shared<CUgraphNode>(GraphNode);
443+
auto SyncPoint = hCommandBuffer->addSyncPoint(NodeSP);
438444
if (pSyncPoint) {
439-
*pSyncPoint = hCommandBuffer->addSyncPoint(NodeSP);
445+
*pSyncPoint = SyncPoint;
440446
}
441447

442448
auto NewCommand = new ur_exp_command_buffer_command_handle_t_{
@@ -464,9 +470,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp(
464470
ur_result_t Result = UR_RESULT_SUCCESS;
465471
CUgraphNode GraphNode;
466472
std::vector<CUgraphNode> DepsList;
467-
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
468-
pSyncPointWaitList, DepsList),
469-
Result);
473+
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
474+
pSyncPointWaitList, DepsList));
470475

471476
if (Result != UR_RESULT_SUCCESS) {
472477
return Result;
@@ -482,8 +487,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMMemcpyExp(
482487
&NodeParams, hCommandBuffer->Device->getNativeContext()));
483488

484489
// Get sync point and register the cuNode with it.
485-
*pSyncPoint =
490+
auto SyncPoint =
486491
hCommandBuffer->addSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
492+
if (pSyncPoint) {
493+
*pSyncPoint = SyncPoint;
494+
}
487495
} catch (ur_result_t Err) {
488496
Result = Err;
489497
}
@@ -505,13 +513,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
505513
UR_ASSERT(size + srcOffset <= std::get<BufferMem>(hSrcMem->Mem).getSize(),
506514
UR_RESULT_ERROR_INVALID_SIZE);
507515

508-
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
509-
pSyncPointWaitList, DepsList),
510-
Result);
511-
512-
if (Result != UR_RESULT_SUCCESS) {
513-
return Result;
514-
}
516+
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
517+
pSyncPointWaitList, DepsList));
515518

516519
try {
517520
auto Src = std::get<BufferMem>(hSrcMem->Mem)
@@ -528,8 +531,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
528531
&NodeParams, hCommandBuffer->Device->getNativeContext()));
529532

530533
// Get sync point and register the cuNode with it.
531-
*pSyncPoint =
534+
auto SyncPoint =
532535
hCommandBuffer->addSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
536+
if (pSyncPoint) {
537+
*pSyncPoint = SyncPoint;
538+
}
533539
} catch (ur_result_t Err) {
534540
Result = Err;
535541
}
@@ -547,13 +553,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
547553
ur_result_t Result = UR_RESULT_SUCCESS;
548554
CUgraphNode GraphNode;
549555
std::vector<CUgraphNode> DepsList;
550-
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
551-
pSyncPointWaitList, DepsList),
552-
Result);
553-
554-
if (Result != UR_RESULT_SUCCESS) {
555-
return Result;
556-
}
556+
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
557+
pSyncPointWaitList, DepsList));
557558

558559
try {
559560
auto SrcPtr =
@@ -571,8 +572,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
571572
&NodeParams, hCommandBuffer->Device->getNativeContext()));
572573

573574
// Get sync point and register the cuNode with it.
574-
*pSyncPoint =
575+
auto SyncPoint =
575576
hCommandBuffer->addSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
577+
if (pSyncPoint) {
578+
*pSyncPoint = SyncPoint;
579+
}
576580
} catch (ur_result_t Err) {
577581
Result = Err;
578582
}
@@ -589,13 +593,8 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp(
589593
ur_result_t Result = UR_RESULT_SUCCESS;
590594
CUgraphNode GraphNode;
591595
std::vector<CUgraphNode> DepsList;
592-
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
593-
pSyncPointWaitList, DepsList),
594-
Result);
595-
596-
if (Result != UR_RESULT_SUCCESS) {
597-
return Result;
598-
}
596+
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
597+
pSyncPointWaitList, DepsList));
599598

600599
try {
601600
auto Dst = std::get<BufferMem>(hBuffer->Mem)
@@ -610,8 +609,11 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteExp(
610609
&NodeParams, hCommandBuffer->Device->getNativeContext()));
611610

612611
// Get sync point and register the cuNode with it.
613-
*pSyncPoint =
612+
auto SyncPoint =
614613
hCommandBuffer->addSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
614+
if (pSyncPoint) {
615+
*pSyncPoint = SyncPoint;
616+
}
615617
} catch (ur_result_t Err) {
616618
Result = Err;
617619
}
@@ -627,9 +629,8 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp(
627629
ur_result_t Result = UR_RESULT_SUCCESS;
628630
CUgraphNode GraphNode;
629631
std::vector<CUgraphNode> DepsList;
630-
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
631-
pSyncPointWaitList, DepsList),
632-
Result);
632+
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
633+
pSyncPointWaitList, DepsList));
633634

634635
if (Result != UR_RESULT_SUCCESS) {
635636
return Result;
@@ -648,8 +649,11 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadExp(
648649
&NodeParams, hCommandBuffer->Device->getNativeContext()));
649650

650651
// Get sync point and register the cuNode with it.
651-
*pSyncPoint =
652+
auto SyncPoint =
652653
hCommandBuffer->addSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
654+
if (pSyncPoint) {
655+
*pSyncPoint = SyncPoint;
656+
}
653657
} catch (ur_result_t Err) {
654658
Result = Err;
655659
}
@@ -668,9 +672,8 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp(
668672
ur_result_t Result = UR_RESULT_SUCCESS;
669673
CUgraphNode GraphNode;
670674
std::vector<CUgraphNode> DepsList;
671-
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
672-
pSyncPointWaitList, DepsList),
673-
Result);
675+
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
676+
pSyncPointWaitList, DepsList));
674677

675678
if (Result != UR_RESULT_SUCCESS) {
676679
return Result;
@@ -691,8 +694,11 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferWriteRectExp(
691694
&NodeParams, hCommandBuffer->Device->getNativeContext()));
692695

693696
// Get sync point and register the cuNode with it.
694-
*pSyncPoint =
697+
auto SyncPoint =
695698
hCommandBuffer->addSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
699+
if (pSyncPoint) {
700+
*pSyncPoint = SyncPoint;
701+
}
696702
} catch (ur_result_t Err) {
697703
Result = Err;
698704
}
@@ -711,13 +717,8 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
711717
ur_result_t Result = UR_RESULT_SUCCESS;
712718
CUgraphNode GraphNode;
713719
std::vector<CUgraphNode> DepsList;
714-
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
715-
pSyncPointWaitList, DepsList),
716-
Result);
717-
718-
if (Result != UR_RESULT_SUCCESS) {
719-
return Result;
720-
}
720+
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
721+
pSyncPointWaitList, DepsList));
721722

722723
try {
723724
auto SrcPtr =
@@ -734,8 +735,11 @@ ur_result_t UR_APICALL urCommandBufferAppendMemBufferReadRectExp(
734735
&NodeParams, hCommandBuffer->Device->getNativeContext()));
735736

736737
// Get sync point and register the cuNode with it.
737-
*pSyncPoint =
738+
auto SyncPoint =
738739
hCommandBuffer->addSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
740+
if (pSyncPoint) {
741+
*pSyncPoint = SyncPoint;
742+
}
739743
} catch (ur_result_t Err) {
740744
Result = Err;
741745
}
@@ -754,18 +758,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMPrefetchExp(
754758
CUgraphNode GraphNode;
755759

756760
std::vector<CUgraphNode> DepsList;
757-
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
758-
pSyncPointWaitList, DepsList),
759-
Result);
761+
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
762+
pSyncPointWaitList, DepsList));
760763

761764
try {
762765
// Add an empty node to preserve dependencies.
763766
UR_CHECK_ERROR(cuGraphAddEmptyNode(&GraphNode, hCommandBuffer->CudaGraph,
764767
DepsList.data(), DepsList.size()));
765768

766769
// Get sync point and register the cuNode with it.
767-
*pSyncPoint =
770+
auto SyncPoint =
768771
hCommandBuffer->addSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
772+
if (pSyncPoint) {
773+
*pSyncPoint = SyncPoint;
774+
}
769775

770776
setErrorMessage("Prefetch hint ignored and replaced with empty node as "
771777
"prefetch is not supported by CUDA Graph backend",
@@ -789,18 +795,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendUSMAdviseExp(
789795
CUgraphNode GraphNode;
790796

791797
std::vector<CUgraphNode> DepsList;
792-
UR_CALL(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
793-
pSyncPointWaitList, DepsList),
794-
Result);
798+
UR_CHECK_ERROR(getNodesFromSyncPoints(hCommandBuffer, numSyncPointsInWaitList,
799+
pSyncPointWaitList, DepsList));
795800

796801
try {
797802
// Add an empty node to preserve dependencies.
798803
UR_CHECK_ERROR(cuGraphAddEmptyNode(&GraphNode, hCommandBuffer->CudaGraph,
799804
DepsList.data(), DepsList.size()));
800805

801806
// Get sync point and register the cuNode with it.
802-
*pSyncPoint =
807+
auto SyncPoint =
803808
hCommandBuffer->addSyncPoint(std::make_shared<CUgraphNode>(GraphNode));
809+
if (pSyncPoint) {
810+
*pSyncPoint = SyncPoint;
811+
}
804812

805813
setErrorMessage("Memory advice ignored and replaced with empty node as "
806814
"memory advice is not supported by CUDA Graph backend",

0 commit comments

Comments
 (0)