@@ -796,16 +796,17 @@ llvm::Function *CodeGenFunction::GenerateOpenMPCapturedStmtFunction(
796796 CapturedStmtInfo &&
797797 " CapturedStmtInfo should be set when generating the captured function" );
798798 const CapturedDecl *CD = S.getCapturedDecl ();
799+
799800 // Build the argument list.
800- // AMDGCN does not generate wrapper kernels properly, fails to launch kernel.
801- bool NeedWrapperFunction = !CGM.getTriple ().isAMDGCN () &&
802- (getDebugInfo () && CGM.getCodeGenOpts ().hasReducedDebugInfo ());
803- FunctionArgList Args;
804- llvm::MapVector<const Decl *, std::pair<const VarDecl *, Address>> LocalAddrs;
805- llvm::DenseMap<const Decl *, std::pair<const Expr *, llvm::Value *>> VLASizes;
801+ FunctionArgList Args, WrapperArgs;
802+ llvm::MapVector<const Decl *, std::pair<const VarDecl *, Address>> LocalAddrs,
803+ WrapperLocalAddrs;
804+ llvm::DenseMap<const Decl *, std::pair<const Expr *, llvm::Value *>> VLASizes,
805+ WrapperVLASizes;
806806 SmallString<256 > Buffer;
807807 llvm::raw_svector_ostream Out (Buffer);
808808 Out << CapturedStmtInfo->getHelperName ();
809+
809810 bool isKernel = (Out.str ().find (" __omp_offloading_" ) != std::string::npos);
810811
811812 // For host codegen, we need to determine now whether Xteam reduction is used
@@ -834,22 +835,40 @@ llvm::Function *CodeGenFunction::GenerateOpenMPCapturedStmtFunction(
834835 }
835836 }
836837
837- if (NeedWrapperFunction)
838+ // AMDGCN does not generate wrapper kernels properly, fails to launch kernel.
839+ // Xteam reduction does not use wrapper kernels.
840+ bool NeedWrapperFunction =
841+ !CGM.getTriple ().isAMDGCN () && !isXteamKernel &&
842+ (getDebugInfo () && CGM.getCodeGenOpts ().hasReducedDebugInfo ());
843+
844+ CodeGenFunction WrapperCGF (CGM, /* suppressNewContext=*/ true );
845+ llvm::Function *WrapperF = nullptr ;
846+ if (NeedWrapperFunction) {
847+ // Emit the final kernel early to allow attributes to be added by the
848+ // OpenMPI-IR-Builder.
849+ FunctionOptions WrapperFO (&S, /* UIntPtrCastRequired=*/ true ,
850+ /* RegisterCastedArgsOnly=*/ true ,
851+ CapturedStmtInfo->getHelperName (), Loc);
852+ WrapperCGF.CapturedStmtInfo = CapturedStmtInfo;
853+ WrapperF = emitOutlinedFunctionPrologue (WrapperCGF, D, Args, LocalAddrs,
854+ VLASizes, WrapperCGF.CXXThisValue ,
855+ WrapperFO, isKernel, isXteamKernel);
838856 Out << " _debug__" ;
857+ }
839858 FunctionOptions FO (&S, !NeedWrapperFunction, /* RegisterCastedArgsOnly=*/ false ,
840859 Out.str (), Loc);
841- llvm::Function *F =
842- emitOutlinedFunctionPrologue ( *this , D, Args, LocalAddrs, VLASizes ,
843- CXXThisValue, FO, isKernel, isXteamKernel);
860+ llvm::Function *F = emitOutlinedFunctionPrologue (
861+ *this , D, WrapperArgs, WrapperLocalAddrs, WrapperVLASizes, CXXThisValue ,
862+ FO, isKernel, isXteamKernel);
844863 CodeGenFunction::OMPPrivateScope LocalScope (*this );
845- for (const auto &LocalAddrPair : LocalAddrs ) {
864+ for (const auto &LocalAddrPair : WrapperLocalAddrs ) {
846865 if (LocalAddrPair.second .first ) {
847866 LocalScope.addPrivate (LocalAddrPair.second .first ,
848867 LocalAddrPair.second .second );
849868 }
850869 }
851870 (void )LocalScope.Privatize ();
852- for (const auto &VLASizePair : VLASizes )
871+ for (const auto &VLASizePair : WrapperVLASizes )
853872 VLASizeMap[VLASizePair.second .first ] = VLASizePair.second .second ;
854873 PGO.assignRegionCounters (GlobalDecl (CD), F);
855874
@@ -861,16 +880,16 @@ llvm::Function *CodeGenFunction::GenerateOpenMPCapturedStmtFunction(
861880 EmitOptKernel (
862881 D, FStmt,
863882 llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP, Loc,
864- /* Args =*/ nullptr );
883+ /* WrapperArgs =*/ nullptr );
865884 else
866885 EmitOptKernel (
867886 D, FStmt,
868887 llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD_BIG_JUMP_LOOP,
869- Loc, /* Args =*/ nullptr );
888+ Loc, /* WrapperArgs =*/ nullptr );
870889 } else if (CGM.getLangOpts ().OpenMPIsTargetDevice && isXteamKernel) {
871890 EmitOptKernel (D, FStmt,
872891 llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_XTEAM_RED,
873- Loc, &Args );
892+ Loc, &WrapperArgs );
874893 } else {
875894 CapturedStmtInfo->EmitBody (*this , CD->getBody ());
876895 }
@@ -880,22 +899,9 @@ llvm::Function *CodeGenFunction::GenerateOpenMPCapturedStmtFunction(
880899 if (!NeedWrapperFunction)
881900 return F;
882901
883- FunctionOptions WrapperFO (&S, /* UIntPtrCastRequired=*/ true ,
884- /* RegisterCastedArgsOnly=*/ true ,
885- CapturedStmtInfo->getHelperName (), Loc);
886- CodeGenFunction WrapperCGF (CGM, /* suppressNewContext=*/ true );
887- WrapperCGF.CapturedStmtInfo = CapturedStmtInfo;
888- Args.clear ();
889- LocalAddrs.clear ();
890- VLASizes.clear ();
891- SmallString<256 > Buffer2;
892- llvm::raw_svector_ostream Out2 (Buffer2);
893- Out2 << CapturedStmtInfo->getHelperName ();
894- isKernel = (Out2.str ().find (" __omp_offloading_" ) != std::string::npos);
895-
896- llvm::Function *WrapperF = emitOutlinedFunctionPrologue (
897- WrapperCGF, D, Args, LocalAddrs, VLASizes, WrapperCGF.CXXThisValue ,
898- WrapperFO, isKernel, isXteamKernel);
902+ // Reverse the order.
903+ WrapperF->removeFromParent ();
904+ F->getParent ()->getFunctionList ().insertAfter (F->getIterator (), WrapperF);
899905
900906 llvm::SmallVector<llvm::Value *, 4 > CallArgs;
901907 auto *PI = F->arg_begin ();
0 commit comments