@@ -593,9 +593,9 @@ void OCLToSPIRVBase::visitCallAtomicLegacy(CallInst *CI, StringRef MangledName,
593593 PostOps.push_back (OCLLegacyAtomicMemOrder);
594594 PostOps.push_back (OCLLegacyAtomicMemScope);
595595
596- Info.PostProc = [=](std::vector<Value *> &Ops ) {
596+ Info.PostProc = [=](BuiltinCallMutator &Mutator ) {
597597 for (auto &I : PostOps) {
598- Ops. push_back (addInt32 (I));
598+ Mutator. appendArg (addInt32 (I));
599599 }
600600 };
601601 transAtomicBuiltin (CI, Info);
@@ -637,9 +637,9 @@ void OCLToSPIRVBase::visitCallAtomicCpp11(CallInst *CI, StringRef MangledName,
637637
638638 OCLBuiltinTransInfo Info;
639639 Info.UniqName = std::string (" atomic_" ) + NewStem;
640- Info.PostProc = [=](std::vector<Value *> &Ops ) {
640+ Info.PostProc = [=](BuiltinCallMutator &Mutator ) {
641641 for (auto &I : PostOps) {
642- Ops. push_back (addInt32 (I));
642+ Mutator. appendArg (addInt32 (I));
643643 }
644644 };
645645
@@ -648,72 +648,65 @@ void OCLToSPIRVBase::visitCallAtomicCpp11(CallInst *CI, StringRef MangledName,
648648
649649void OCLToSPIRVBase::transAtomicBuiltin (CallInst *CI,
650650 OCLBuiltinTransInfo &Info) {
651- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
652- mutateCallInstSPIRV (
653- M, CI,
654- [=](CallInst *CI, std::vector<Value *> &Args) -> std::string {
655- Info.PostProc (Args);
656- // Order of args in OCL20:
657- // object, 0-2 other args, 1-2 order, scope
658- const size_t NumOrder =
659- getAtomicBuiltinNumMemoryOrderArgs (Info.UniqName );
660- const size_t ArgsCount = Args.size ();
661- const size_t ScopeIdx = ArgsCount - 1 ;
662- const size_t OrderIdx = ScopeIdx - NumOrder;
663-
664- Args[ScopeIdx] =
665- transOCLMemScopeIntoSPIRVScope (Args[ScopeIdx], OCLMS_device, CI);
666-
667- for (size_t I = 0 ; I < NumOrder; ++I) {
668- Args[OrderIdx + I] = transOCLMemOrderIntoSPIRVMemorySemantics (
669- Args[OrderIdx + I], OCLMO_seq_cst, CI);
670- }
671- // Order of args in SPIR-V:
672- // object, scope, 1-2 order, 0-2 other args
673- std::swap (Args[1 ], Args[ScopeIdx]);
674- if (OrderIdx > 2 ) {
675- // For atomic_compare_exchange the swap above puts Comparator/Expected
676- // argument just where it should be, so don't move the last argument
677- // then.
678- int Offset =
679- Info.UniqName .find (" atomic_compare_exchange" ) == 0 ? 1 : 0 ;
680- std::rotate (Args.begin () + 2 , Args.begin () + OrderIdx,
681- Args.end () - Offset);
682- }
683- llvm::Type *AtomicBuiltinsReturnType =
684- CI->getCalledFunction ()->getReturnType ();
685- auto IsFPType = [](llvm::Type *ReturnType) {
686- return ReturnType->isHalfTy () || ReturnType->isFloatTy () ||
687- ReturnType->isDoubleTy ();
688- };
689- auto SPIRVFunctionName =
690- getSPIRVFuncName (OCLSPIRVBuiltinMap::map (Info.UniqName ));
691- if (!IsFPType (AtomicBuiltinsReturnType))
692- return SPIRVFunctionName;
693- // Translate FP-typed atomic builtins. Currently we only need to
694- // translate atomic_fetch_[add, sub, max, min] and atomic_fetch_[add,
695- // sub, max, min]_explicit to related float instructions.
696- // Translate atomic_fetch_sub to OpAtomicFAddEXT with negative value
697- // operand
698- auto SPIRFunctionNameForFloatAtomics =
699- llvm::StringSwitch<std::string>(SPIRVFunctionName)
700- .Case (" __spirv_AtomicIAdd" , " __spirv_AtomicFAddEXT" )
701- .Case (" __spirv_AtomicISub" , " __spirv_AtomicFAddEXT" )
702- .Case (" __spirv_AtomicSMax" , " __spirv_AtomicFMaxEXT" )
703- .Case (" __spirv_AtomicSMin" , " __spirv_AtomicFMinEXT" )
704- .Default (" others" );
705- if (SPIRVFunctionName == " __spirv_AtomicISub" ) {
706- IRBuilder<> IRB (CI);
707- // Set float operand to its negation
708- CI->setOperand (1 , IRB.CreateFNeg (CI->getArgOperand (1 )));
709- // Update Args which is used to generate new call
710- Args.back () = CI->getArgOperand (1 );
711- }
712- return SPIRFunctionNameForFloatAtomics == " others"
713- ? SPIRVFunctionName
714- : SPIRFunctionNameForFloatAtomics;
715- },
716- &Attrs);
651+ llvm::Type *AtomicBuiltinsReturnType = CI->getType ();
652+ auto SPIRVFunctionName =
653+ getSPIRVFuncName (OCLSPIRVBuiltinMap::map (Info.UniqName ));
654+ bool NeedsNegate = false ;
655+ if (AtomicBuiltinsReturnType->isFloatingPointTy ()) {
656+ // Translate FP-typed atomic builtins. Currently we only need to
657+ // translate atomic_fetch_[add, sub, max, min] and atomic_fetch_[add,
658+ // sub, max, min]_explicit to related float instructions.
659+ // Translate atomic_fetch_sub to OpAtomicFAddEXT with negative value
660+ // operand
661+ auto SPIRFunctionNameForFloatAtomics =
662+ llvm::StringSwitch<std::string>(SPIRVFunctionName)
663+ .Case (" __spirv_AtomicIAdd" , " __spirv_AtomicFAddEXT" )
664+ .Case (" __spirv_AtomicISub" , " __spirv_AtomicFAddEXT" )
665+ .Case (" __spirv_AtomicSMax" , " __spirv_AtomicFMaxEXT" )
666+ .Case (" __spirv_AtomicSMin" , " __spirv_AtomicFMinEXT" )
667+ .Default (" others" );
668+ if (SPIRVFunctionName == " __spirv_AtomicISub" ) {
669+ NeedsNegate = true ;
670+ }
671+ if (SPIRFunctionNameForFloatAtomics != " others" )
672+ SPIRVFunctionName = SPIRFunctionNameForFloatAtomics;
673+ }
674+
675+ auto Mutator = mutateCallInst (CI, SPIRVFunctionName);
676+ Info.PostProc (Mutator);
677+ // Order of args in OCL20:
678+ // object, 0-2 other args, 1-2 order, scope
679+ const size_t NumOrder = getAtomicBuiltinNumMemoryOrderArgs (Info.UniqName );
680+ const size_t ArgsCount = Mutator.arg_size ();
681+ const size_t ScopeIdx = ArgsCount - 1 ;
682+ const size_t OrderIdx = ScopeIdx - NumOrder;
683+
684+ if (NeedsNegate) {
685+ Mutator.mapArg (1 , [=](Value *V) {
686+ IRBuilder<> IRB (CI);
687+ return IRB.CreateFNeg (V);
688+ });
689+ }
690+ Mutator.mapArg (ScopeIdx, [=](Value *V) {
691+ return transOCLMemScopeIntoSPIRVScope (V, OCLMS_device, CI);
692+ });
693+ for (size_t I = 0 ; I < NumOrder; ++I) {
694+ Mutator.mapArg (OrderIdx + I, [=](Value *V) {
695+ return transOCLMemOrderIntoSPIRVMemorySemantics (V, OCLMO_seq_cst, CI);
696+ });
697+ }
698+
699+ // Order of args in SPIR-V:
700+ // object, scope, 1-2 order, 0-2 other args
701+ for (size_t I = 0 ; I < NumOrder; ++I) {
702+ Mutator.moveArg (OrderIdx + I, I + 1 );
703+ }
704+ Mutator.moveArg (ScopeIdx, 1 );
705+ if (Info.UniqName .find (" atomic_compare_exchange" ) == 0 ) {
706+ // For atomic_compare_exchange, the two "other args" are in the opposite
707+ // order from the SPIR-V order. Swap these two arguments.
708+ Mutator.moveArg (Mutator.arg_size () - 1 , Mutator.arg_size () - 2 );
709+ }
717710}
718711
719712void OCLToSPIRVBase::visitCallBarrier (CallInst *CI) {
@@ -871,24 +864,29 @@ void OCLToSPIRVBase::visitCallGroupBuiltin(CallInst *CI,
871864 if (HasBoolReturnType)
872865 Info.RetTy = Type::getInt1Ty (*Ctx);
873866 Info.UniqName = DemangledName;
874- Info.PostProc = [=](std::vector<Value *> &Ops ) {
867+ Info.PostProc = [=](BuiltinCallMutator &Mutator ) {
875868 if (HasBoolArg) {
876- IRBuilder<> IRB (CI);
877- Ops[0 ] =
878- IRB.CreateICmpNE (Ops[0 ], ConstantInt::get (Type::getInt32Ty (*Ctx), 0 ));
869+ Mutator.mapArg (0 , [&](Value *V) {
870+ IRBuilder<> IRB (CI);
871+ return IRB.CreateICmpNE (V, IRB.getInt32 (0 ));
872+ });
879873 }
880- size_t E = Ops. size ();
874+ size_t E = Mutator. arg_size ();
881875 if (DemangledName == " group_broadcast" && E > 2 ) {
882876 assert (E == 3 || E == 4 );
877+ std::vector<Value *> Ops = getArguments (CI);
883878 makeVector (CI, Ops, std::make_pair (Ops.begin () + 1 , Ops.end ()));
879+ while (Mutator.arg_size () > 1 )
880+ Mutator.removeArg (1 );
881+ Mutator.appendArg (Ops.back ());
884882 }
885- Ops.insert (Ops.begin (), Consts.begin (), Consts.end ());
883+ for (unsigned I = 0 ; I < Consts.size (); I++)
884+ Mutator.insertArg (I, Consts[I]);
886885 };
887886 transBuiltin (CI, Info);
888887}
889888
890889void OCLToSPIRVBase::transBuiltin (CallInst *CI, OCLBuiltinTransInfo &Info) {
891- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
892890 Op OC = OpNop;
893891 unsigned ExtOp = ~0U ;
894892 SPIRVBuiltinVariableKind BVKind = BuiltInMax;
@@ -918,31 +916,18 @@ void OCLToSPIRVBase::transBuiltin(CallInst *CI, OCLBuiltinTransInfo &Info) {
918916 Info.UniqName = getSPIRVFuncName (BVKind);
919917 } else
920918 return ;
921- if (!Info.RetTy )
922- mutateCallInstSPIRV (
923- M, CI,
924- [=](CallInst *, std::vector<Value *> &Args) {
925- Info.PostProc (Args);
926- return Info.UniqName + Info.Postfix ;
927- },
928- &Attrs);
929- else
930- mutateCallInstSPIRV (
931- M, CI,
932- [=](CallInst *, std::vector<Value *> &Args, Type *&RetTy) {
933- Info.PostProc (Args);
934- RetTy = Info.RetTy ;
935- return Info.UniqName + Info.Postfix ;
936- },
937- [=](CallInst *NewCI) -> Instruction * {
938- if (NewCI->getType ()->isIntegerTy () && CI->getType ()->isIntegerTy ())
939- return CastInst::CreateIntegerCast (NewCI, CI->getType (),
940- Info.IsRetSigned , " " , CI);
919+ auto Mutator = mutateCallInst (CI, Info.UniqName + Info.Postfix );
920+ Info.PostProc (Mutator);
921+ if (Info.RetTy ) {
922+ Type *OldRetTy = CI->getType ();
923+ Mutator.changeReturnType (
924+ Info.RetTy , [&](IRBuilder<> &Builder, CallInst *NewCI) {
925+ if (Info.RetTy ->isIntegerTy () && OldRetTy->isIntegerTy ())
926+ return Builder.CreateIntCast (NewCI, OldRetTy, Info.IsRetSigned );
941927 else
942- return CastInst::CreatePointerBitCastOrAddrSpaceCast (
943- NewCI, CI->getType (), " " , CI);
944- },
945- &Attrs);
928+ return Builder.CreatePointerBitCastOrAddrSpaceCast (NewCI, OldRetTy);
929+ });
930+ }
946931}
947932
948933void OCLToSPIRVBase::visitCallReadImageMSAA (CallInst *CI,
@@ -1122,27 +1107,25 @@ void OCLToSPIRVBase::visitCallReadWriteImage(CallInst *CI,
11221107 Info.UniqName = kOCLBuiltinName ::ReadImage;
11231108 unsigned ImgOpMask = getImageSignZeroExt (DemangledName);
11241109 if (ImgOpMask) {
1125- Info.PostProc = [&](std::vector<Value *> &Args ) {
1126- Args. push_back (getInt32 (M, ImgOpMask));
1110+ Info.PostProc = [&](BuiltinCallMutator &Mutator ) {
1111+ Mutator. appendArg (getInt32 (M, ImgOpMask));
11271112 };
11281113 }
11291114 }
11301115
11311116 if (DemangledName.find (kOCLBuiltinName ::WriteImage) == 0 ) {
11321117 Info.UniqName = kOCLBuiltinName ::WriteImage;
1133- Info.PostProc = [&](std::vector<Value *> &Args ) {
1118+ Info.PostProc = [&](BuiltinCallMutator &Mutator ) {
11341119 unsigned ImgOpMask = getImageSignZeroExt (DemangledName);
1135- unsigned ImgOpMaskInsIndex = Args. size ();
1136- if (Args. size () == 4 ) // write with lod
1120+ unsigned ImgOpMaskInsIndex = Mutator. arg_size ();
1121+ if (Mutator. arg_size () == 4 ) // write with lod
11371122 {
1138- auto Lod = Args[2 ];
1139- Args.erase (Args.begin () + 2 );
11401123 ImgOpMask |= ImageOperandsMask::ImageOperandsLodMask;
1141- ImgOpMaskInsIndex = Args. size () ;
1142- Args. push_back (Lod );
1124+ ImgOpMaskInsIndex = Mutator. arg_size () - 1 ;
1125+ Mutator. moveArg ( 2 , Mutator. arg_size () - 1 );
11431126 }
11441127 if (ImgOpMask) {
1145- Args. insert (Args. begin () + ImgOpMaskInsIndex, getInt32 (M, ImgOpMask));
1128+ Mutator. insertArg ( ImgOpMaskInsIndex, getInt32 (M, ImgOpMask));
11461129 }
11471130 };
11481131 }
@@ -1159,11 +1142,14 @@ void OCLToSPIRVBase::visitCallToAddr(CallInst *CI, StringRef DemangledName) {
11591142 SPIRAddrSpaceCapitalizedNameMap::map (AddrSpace);
11601143 auto StorageClass = addInt32 (SPIRSPIRVAddrSpaceMap::map (AddrSpace));
11611144 Info.RetTy = getInt8PtrTy (cast<PointerType>(CI->getType ()));
1162- Info.PostProc = [=](std::vector<Value *> &Ops) {
1163- auto P = Ops.back ();
1164- Ops.pop_back ();
1165- Ops.push_back (castToInt8Ptr (P, CI));
1166- Ops.push_back (StorageClass);
1145+ Info.PostProc = [=](BuiltinCallMutator &Mutator) {
1146+ Mutator
1147+ .mapArg (Mutator.arg_size () - 1 ,
1148+ [&](Value *V) {
1149+ return std::pair<Value *, Type *>(
1150+ castToInt8Ptr (V, CI), Type::getInt8Ty (V->getContext ()));
1151+ })
1152+ .appendArg (StorageClass);
11671153 };
11681154 transBuiltin (CI, Info);
11691155}
@@ -1216,8 +1202,9 @@ void OCLToSPIRVBase::visitCallVecLoadStore(CallInst *CI, StringRef MangledName,
12161202 if (DemangledName.find (kOCLBuiltinName ::VLoadPrefix) == 0 )
12171203 Info.Postfix =
12181204 std::string (kSPIRVPostfix ::ExtDivider) + getPostfixForReturnType (CI);
1219- Info.PostProc = [=](std::vector<Value *> &Ops) {
1220- Ops.insert (Ops.end (), Consts.begin (), Consts.end ());
1205+ Info.PostProc = [=](BuiltinCallMutator &Mutator) {
1206+ for (auto *Value : Consts)
1207+ Mutator.appendArg (Value);
12211208 };
12221209 transBuiltin (CI, Info);
12231210}
@@ -1514,9 +1501,8 @@ void OCLToSPIRVBase::visitCallKernelQuery(CallInst *CI,
15141501
15151502// Add postfix to overloaded intel subgroup block read/write builtins
15161503// so new functions can be distinguished.
1517- static void processSubgroupBlockReadWriteINTEL (CallInst *CI,
1518- OCLBuiltinTransInfo &Info,
1519- const Type *DataTy, Module *M) {
1504+ void OCLToSPIRVBase::processSubgroupBlockReadWriteINTEL (
1505+ CallInst *CI, OCLBuiltinTransInfo &Info, const Type *DataTy) {
15201506 unsigned VectorNumElements = 1 ;
15211507 if (auto *VecTy = dyn_cast<FixedVectorType>(DataTy))
15221508 VectorNumElements = VecTy->getNumElements ();
@@ -1525,14 +1511,7 @@ static void processSubgroupBlockReadWriteINTEL(CallInst *CI,
15251511 Info.Postfix +=
15261512 getIntelSubgroupBlockDataPostfix (ElementBitSize, VectorNumElements);
15271513 assert (CI->getCalledFunction () && " Unexpected indirect call" );
1528- AttributeList Attrs = CI->getCalledFunction ()->getAttributes ();
1529- mutateCallInstSPIRV (
1530- M, CI,
1531- [&Info](CallInst *, std::vector<Value *> &Args) {
1532- Info.PostProc (Args);
1533- return Info.UniqName + Info.Postfix ;
1534- },
1535- &Attrs);
1514+ mutateCallInst (CI, Info.UniqName + Info.Postfix );
15361515}
15371516
15381517// The intel_sub_group_block_read built-ins are overloaded to support both
@@ -1548,7 +1527,7 @@ void OCLToSPIRVBase::visitSubgroupBlockReadINTEL(CallInst *CI) {
15481527 else
15491528 Info.UniqName = getSPIRVFuncName (spv::OpSubgroupBlockReadINTEL);
15501529 Type *DataTy = CI->getType ();
1551- processSubgroupBlockReadWriteINTEL (CI, Info, DataTy, M );
1530+ processSubgroupBlockReadWriteINTEL (CI, Info, DataTy);
15521531}
15531532
15541533// The intel_sub_group_block_write built-ins are similarly overloaded to support
@@ -1566,7 +1545,7 @@ void OCLToSPIRVBase::visitSubgroupBlockWriteINTEL(CallInst *CI) {
15661545 " Intel subgroup block write should have arguments" );
15671546 unsigned DataArg = CI->arg_size () - 1 ;
15681547 Type *DataTy = CI->getArgOperand (DataArg)->getType ();
1569- processSubgroupBlockReadWriteINTEL (CI, Info, DataTy, M );
1548+ processSubgroupBlockReadWriteINTEL (CI, Info, DataTy);
15701549}
15711550
15721551void OCLToSPIRVBase::visitSubgroupImageMediaBlockINTEL (
0 commit comments