@@ -239,76 +239,53 @@ void SPIRVToOCL20Base::visitCallSPIRVAtomicCmpExchg(CallInst *CI) {
239239}
240240
241241void SPIRVToOCL20Base::visitCallSPIRVEnqueueKernel (CallInst *CI, Op OC) {
242- assert (CI->getCalledFunction () && " Unexpected indirect call" );
243- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
244- Instruction *PInsertBefore = CI;
245-
246- mutateCallInstOCL (
247- M, CI,
248- [=](CallInst *, std::vector<Value *> &Args) {
249- bool HasVaargs = Args.size () > 10 ;
250- bool HasEvents = true ;
251- Value *EventRet = Args[5 ];
252- if (isa<ConstantPointerNull>(EventRet)) {
253- Value *NumEvents = Args[3 ];
254- if (isa<ConstantInt>(NumEvents)) {
255- ConstantInt *NE = cast<ConstantInt>(NumEvents);
256- HasEvents = NE->getZExtValue () != 0 ;
257- }
258- }
259-
260- Value *Invoke = Args[6 ];
261- auto *Int8PtrTyGen = Type::getInt8PtrTy (*Ctx, SPIRAS_Generic);
262- Args[6 ] = CastInst::CreatePointerBitCastOrAddrSpaceCast (
263- Invoke, Int8PtrTyGen, " " , PInsertBefore);
264-
265- // Don't remove arguments immediately, just mark them as removed with
266- // nullptr, and remove them at the end of processing. It allows for
267- // easier understanding of which argument is going to be removed.
268- auto MarkAsRemoved = [&Args](size_t Start, size_t End) {
269- assert (Start <= End);
270- for (size_t I = Start; I < End; I++)
271- Args[I] = nullptr ;
272- };
273-
274- if (!HasEvents) {
275- // Mark arguments at indices 3 (Num Events), 4 (Wait Events), 5 (Ret
276- // Event) as removed.
277- MarkAsRemoved (3 , 6 );
278- }
279-
280- if (!HasVaargs) {
281- // Mark arguments at indices 8 (Param Size), 9 (Param Align) as
282- // removed.
283- MarkAsRemoved (8 , 10 );
284- } else {
285- // GEP to array of sizes of local arguments
286- Value *GEP = Args[10 ];
287- size_t NumLocalArgs = Args.size () - 10 ;
288-
289- // Mark all SPIRV-specific arguments as removed
290- MarkAsRemoved (8 , Args.size ());
291-
292- Type *Int32Ty = Type::getInt32Ty (*Ctx);
293- Args[8 ] = ConstantInt::get (Int32Ty, NumLocalArgs);
294- Args[9 ] = GEP;
295- }
296-
297- Args.erase (std::remove (Args.begin (), Args.end (), nullptr ), Args.end ());
298-
299- std::string FName = " " ;
300- if (!HasVaargs && !HasEvents)
301- FName = " __enqueue_kernel_basic" ;
302- else if (!HasVaargs && HasEvents)
303- FName = " __enqueue_kernel_basic_events" ;
304- else if (HasVaargs && !HasEvents)
305- FName = " __enqueue_kernel_varargs" ;
306- else
307- FName = " __enqueue_kernel_events_varargs" ;
308-
309- return FName;
310- },
311- &Attrs);
242+ bool HasVaargs = CI->arg_size () > 10 ;
243+ bool HasEvents = true ;
244+ Value *EventRet = CI->getArgOperand (5 );
245+ if (isa<ConstantPointerNull>(EventRet)) {
246+ Value *NumEvents = CI->getArgOperand (3 );
247+ if (isa<ConstantInt>(NumEvents)) {
248+ ConstantInt *NE = cast<ConstantInt>(NumEvents);
249+ HasEvents = NE->getZExtValue () != 0 ;
250+ }
251+ }
252+
253+ StringRef FName = " " ;
254+ if (!HasVaargs && !HasEvents)
255+ FName = " __enqueue_kernel_basic" ;
256+ else if (!HasVaargs && HasEvents)
257+ FName = " __enqueue_kernel_basic_events" ;
258+ else if (HasVaargs && !HasEvents)
259+ FName = " __enqueue_kernel_varargs" ;
260+ else
261+ FName = " __enqueue_kernel_events_varargs" ;
262+
263+ auto Mutator = mutateCallInst (CI, FName.str ());
264+ Mutator.mapArg (6 , [=](IRBuilder<> &Builder, Value *Invoke) {
265+ Value *Replace = CastInst::CreatePointerBitCastOrAddrSpaceCast (
266+ Invoke, Builder.getInt8PtrTy (SPIRAS_Generic), " " , CI);
267+ return std::pair<Value *, Type *>(Replace, Builder.getInt8Ty ());
268+ });
269+
270+ if (!HasVaargs) {
271+ // Remove arguments at indices 8 (Param Size), 9 (Param Align)
272+ Mutator.removeArgs (8 , 2 );
273+ } else {
274+ // GEP to array of sizes of local arguments
275+ Mutator.moveArg (10 , 8 );
276+ Type *Int32Ty = Type::getInt32Ty (*Ctx);
277+ size_t NumLocalArgs = Mutator.arg_size () - 10 ;
278+ Mutator.insertArg (8 , ConstantInt::get (Int32Ty, NumLocalArgs));
279+
280+ // Mark all SPIRV-specific arguments as removed
281+ Mutator.removeArgs (10 , Mutator.arg_size () - 10 );
282+ }
283+
284+ if (!HasEvents) {
285+ // Remove arguments at indices 3 (Num Events), 4 (Wait Events), 5 (Ret
286+ // Event).
287+ Mutator.removeArgs (3 , 3 );
288+ }
312289}
313290
314291} // namespace SPIRV
0 commit comments