@@ -700,7 +700,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelCreate(
700700 ZeKernelDesc.pKernelName = KernelName;
701701
702702 ze_kernel_handle_t ZeKernel;
703- ZE2UR_CALL (zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel));
703+ auto ZeResult =
704+ ZE_CALL_NOCHECK (zeKernelCreate, (ZeModule, &ZeKernelDesc, &ZeKernel));
705+ // Gracefully handle the case that kernel create fails.
706+ if (ZeResult != ZE_RESULT_SUCCESS) {
707+ delete *RetKernel;
708+ *RetKernel = nullptr ;
709+ return ze2urResult (ZeResult);
710+ }
704711
705712 auto ZeDevice = It.first ;
706713
@@ -754,20 +761,29 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgValue(
754761 PArgValue = nullptr ;
755762 }
756763
764+ if (ArgIndex > Kernel->ZeKernelProperties ->numKernelArgs - 1 ) {
765+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
766+ }
767+
757768 std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
769+ ze_result_t ZeResult = ZE_RESULT_SUCCESS;
758770 if (Kernel->ZeKernelMap .empty ()) {
759771 auto ZeKernel = Kernel->ZeKernel ;
760- ZE2UR_CALL (zeKernelSetArgumentValue,
761- (ZeKernel, ArgIndex, ArgSize, PArgValue));
772+ ZeResult = ZE_CALL_NOCHECK (zeKernelSetArgumentValue,
773+ (ZeKernel, ArgIndex, ArgSize, PArgValue));
762774 } else {
763775 for (auto It : Kernel->ZeKernelMap ) {
764776 auto ZeKernel = It.second ;
765- ZE2UR_CALL (zeKernelSetArgumentValue,
766- (ZeKernel, ArgIndex, ArgSize, PArgValue));
777+ ZeResult = ZE_CALL_NOCHECK (zeKernelSetArgumentValue,
778+ (ZeKernel, ArgIndex, ArgSize, PArgValue));
767779 }
768780 }
769781
770- return UR_RESULT_SUCCESS;
782+ if (ZeResult == ZE_RESULT_ERROR_INVALID_ARGUMENT) {
783+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_SIZE;
784+ }
785+
786+ return ze2urResult (ZeResult);
771787}
772788
773789UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgLocal (
@@ -816,6 +832,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelGetInfo(
816832 } catch (...) {
817833 return UR_RESULT_ERROR_UNKNOWN;
818834 }
835+ case UR_KERNEL_INFO_NUM_REGS:
819836 case UR_KERNEL_INFO_NUM_ARGS:
820837 return ReturnValue (uint32_t {Kernel->ZeKernelProperties ->numKernelArgs });
821838 case UR_KERNEL_INFO_REFERENCE_COUNT:
@@ -1066,6 +1083,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgSampler(
10661083) {
10671084 std::ignore = Properties;
10681085 std::scoped_lock<ur_shared_mutex> Guard (Kernel->Mutex );
1086+ if (ArgIndex > Kernel->ZeKernelProperties ->numKernelArgs - 1 ) {
1087+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
1088+ }
10691089 ZE2UR_CALL (zeKernelSetArgumentValue, (Kernel->ZeKernel , ArgIndex,
10701090 sizeof (void *), &ArgValue->ZeSampler ));
10711091
@@ -1085,6 +1105,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgMemObj(
10851105 // The ArgValue may be a NULL pointer in which case a NULL value is used for
10861106 // the kernel argument declared as a pointer to global or constant memory.
10871107
1108+ if (ArgIndex > Kernel->ZeKernelProperties ->numKernelArgs - 1 ) {
1109+ return UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX;
1110+ }
1111+
10881112 ur_mem_handle_t_ *UrMem = ur_cast<ur_mem_handle_t_ *>(ArgValue);
10891113
10901114 ur_mem_handle_t_::access_mode_t UrAccessMode = ur_mem_handle_t_::read_write;
0 commit comments