@@ -340,6 +340,42 @@ BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
340340 return splitBB (Builder, CreateBranch, Old->getName () + Suffix);
341341}
342342
343+ // This function creates a fake integer value and a fake use for the integer
344+ // value. It returns the fake value created. This is useful in modeling the
345+ // extra arguments to the outlined functions.
346+ Value *createFakeIntVal (IRBuilder<> &Builder,
347+ OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
348+ std::stack<Instruction *> &ToBeDeleted,
349+ OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
350+ const Twine &Name = " " , bool AsPtr = true ) {
351+ Builder.restoreIP (OuterAllocaIP);
352+ Instruction *FakeVal;
353+ AllocaInst *FakeValAddr =
354+ Builder.CreateAlloca (Builder.getInt32Ty (), nullptr , Name + " .addr" );
355+ ToBeDeleted.push (FakeValAddr);
356+
357+ if (AsPtr) {
358+ FakeVal = FakeValAddr;
359+ } else {
360+ FakeVal =
361+ Builder.CreateLoad (Builder.getInt32Ty (), FakeValAddr, Name + " .val" );
362+ ToBeDeleted.push (FakeVal);
363+ }
364+
365+ // Generate a fake use of this value
366+ Builder.restoreIP (InnerAllocaIP);
367+ Instruction *UseFakeVal;
368+ if (AsPtr) {
369+ UseFakeVal =
370+ Builder.CreateLoad (Builder.getInt32Ty (), FakeVal, Name + " .use" );
371+ } else {
372+ UseFakeVal =
373+ cast<BinaryOperator>(Builder.CreateAdd (FakeVal, Builder.getInt32 (10 )));
374+ }
375+ ToBeDeleted.push (UseFakeVal);
376+ return FakeVal;
377+ }
378+
343379// ===----------------------------------------------------------------------===//
344380// OpenMPIRBuilderConfig
345381// ===----------------------------------------------------------------------===//
@@ -1496,6 +1532,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
14961532 InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
14971533 bool Tied, Value *Final, Value *IfCondition,
14981534 SmallVector<DependData> Dependencies) {
1535+
14991536 if (!updateToLocation (Loc))
15001537 return InsertPointTy ();
15011538
@@ -1523,41 +1560,31 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
15231560 BasicBlock *TaskAllocaBB =
15241561 splitBB (Builder, /* CreateBranch=*/ true , " task.alloca" );
15251562
1563+ InsertPointTy TaskAllocaIP =
1564+ InsertPointTy (TaskAllocaBB, TaskAllocaBB->begin ());
1565+ InsertPointTy TaskBodyIP = InsertPointTy (TaskBodyBB, TaskBodyBB->begin ());
1566+ BodyGenCB (TaskAllocaIP, TaskBodyIP);
1567+
15261568 OutlineInfo OI;
15271569 OI.EntryBB = TaskAllocaBB;
15281570 OI.OuterAllocaBB = AllocaIP.getBlock ();
15291571 OI.ExitBB = TaskExitBB;
1530- OI.PostOutlineCB = [this , Ident, Tied, Final, IfCondition,
1531- Dependencies](Function &OutlinedFn) {
1532- // The input IR here looks like the following-
1533- // ```
1534- // func @current_fn() {
1535- // outlined_fn(%args)
1536- // }
1537- // func @outlined_fn(%args) { ... }
1538- // ```
1539- //
1540- // This is changed to the following-
1541- //
1542- // ```
1543- // func @current_fn() {
1544- // runtime_call(..., wrapper_fn, ...)
1545- // }
1546- // func @wrapper_fn(..., %args) {
1547- // outlined_fn(%args)
1548- // }
1549- // func @outlined_fn(%args) { ... }
1550- // ```
15511572
1552- // The stale call instruction will be replaced with a new call instruction
1553- // for runtime call with a wrapper function.
1573+ // Add the thread ID argument.
1574+ std::stack<Instruction *> ToBeDeleted;
1575+ OI.ExcludeArgsFromAggregate .push_back (createFakeIntVal (
1576+ Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, " global.tid" , false ));
1577+
1578+ OI.PostOutlineCB = [this , Ident, Tied, Final, IfCondition, Dependencies,
1579+ TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable {
1580+ // Replace the Stale CI by appropriate RTL function call.
15541581 assert (OutlinedFn.getNumUses () == 1 &&
15551582 " there must be a single user for the outlined function" );
15561583 CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back ());
15571584
15581585 // HasShareds is true if any variables are captured in the outlined region,
15591586 // false otherwise.
1560- bool HasShareds = StaleCI->arg_size () > 0 ;
1587+ bool HasShareds = StaleCI->arg_size () > 1 ;
15611588 Builder.SetInsertPoint (StaleCI);
15621589
15631590 // Gather the arguments for emitting the runtime call for
@@ -1595,7 +1622,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
15951622 Value *SharedsSize = Builder.getInt64 (0 );
15961623 if (HasShareds) {
15971624 AllocaInst *ArgStructAlloca =
1598- dyn_cast<AllocaInst>(StaleCI->getArgOperand (0 ));
1625+ dyn_cast<AllocaInst>(StaleCI->getArgOperand (1 ));
15991626 assert (ArgStructAlloca &&
16001627 " Unable to find the alloca instruction corresponding to arguments "
16011628 " for extracted function" );
@@ -1606,31 +1633,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
16061633 SharedsSize =
16071634 Builder.getInt64 (M.getDataLayout ().getTypeStoreSize (ArgStructType));
16081635 }
1609-
1610- // Argument - task_entry (the wrapper function)
1611- // If the outlined function has some captured variables (i.e. HasShareds is
1612- // true), then the wrapper function will have an additional argument (the
1613- // struct containing captured variables). Otherwise, no such argument will
1614- // be present.
1615- SmallVector<Type *> WrapperArgTys{Builder.getInt32Ty ()};
1616- if (HasShareds)
1617- WrapperArgTys.push_back (OutlinedFn.getArg (0 )->getType ());
1618- FunctionCallee WrapperFuncVal = M.getOrInsertFunction (
1619- (Twine (OutlinedFn.getName ()) + " .wrapper" ).str (),
1620- FunctionType::get (Builder.getInt32Ty (), WrapperArgTys, false ));
1621- Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee ());
1622-
16231636 // Emit the @__kmpc_omp_task_alloc runtime call
16241637 // The runtime call returns a pointer to an area where the task captured
16251638 // variables must be copied before the task is run (TaskData)
16261639 CallInst *TaskData = Builder.CreateCall (
16271640 TaskAllocFn, {/* loc_ref=*/ Ident, /* gtid=*/ ThreadID, /* flags=*/ Flags,
16281641 /* sizeof_task=*/ TaskSize, /* sizeof_shared=*/ SharedsSize,
1629- /* task_func=*/ WrapperFunc });
1642+ /* task_func=*/ &OutlinedFn });
16301643
16311644 // Copy the arguments for outlined function
16321645 if (HasShareds) {
1633- Value *Shareds = StaleCI->getArgOperand (0 );
1646+ Value *Shareds = StaleCI->getArgOperand (1 );
16341647 Align Alignment = TaskData->getPointerAlignment (M.getDataLayout ());
16351648 Value *TaskShareds = Builder.CreateLoad (VoidPtr, TaskData);
16361649 Builder.CreateMemCpy (TaskShareds, Alignment, Shareds, Alignment,
@@ -1689,18 +1702,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
16891702 // br label %exit
16901703 // else:
16911704 // call @__kmpc_omp_task_begin_if0(...)
1692- // call @wrapper_fn (...)
1705+ // call @outlined_fn (...)
16931706 // call @__kmpc_omp_task_complete_if0(...)
16941707 // br label %exit
16951708 // exit:
16961709 // ...
16971710 if (IfCondition) {
16981711 // `SplitBlockAndInsertIfThenElse` requires the block to have a
16991712 // terminator.
1700- BasicBlock *NewBasicBlock =
1701- splitBB (Builder, /* CreateBranch=*/ true , " if.end" );
1713+ splitBB (Builder, /* CreateBranch=*/ true , " if.end" );
17021714 Instruction *IfTerminator =
1703- NewBasicBlock-> getSinglePredecessor ()->getTerminator ();
1715+ Builder. GetInsertPoint ()-> getParent ()->getTerminator ();
17041716 Instruction *ThenTI = IfTerminator, *ElseTI = nullptr ;
17051717 Builder.SetInsertPoint (IfTerminator);
17061718 SplitBlockAndInsertIfThenElse (IfCondition, IfTerminator, &ThenTI,
@@ -1711,10 +1723,12 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
17111723 Function *TaskCompleteFn =
17121724 getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_omp_task_complete_if0);
17131725 Builder.CreateCall (TaskBeginFn, {Ident, ThreadID, TaskData});
1726+ CallInst *CI = nullptr ;
17141727 if (HasShareds)
1715- Builder.CreateCall (WrapperFunc , {ThreadID, TaskData});
1728+ CI = Builder.CreateCall (&OutlinedFn , {ThreadID, TaskData});
17161729 else
1717- Builder.CreateCall (WrapperFunc, {ThreadID});
1730+ CI = Builder.CreateCall (&OutlinedFn, {ThreadID});
1731+ CI->setDebugLoc (StaleCI->getDebugLoc ());
17181732 Builder.CreateCall (TaskCompleteFn, {Ident, ThreadID, TaskData});
17191733 Builder.SetInsertPoint (ThenTI);
17201734 }
@@ -1736,26 +1750,20 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
17361750
17371751 StaleCI->eraseFromParent ();
17381752
1739- // Emit the body for wrapper function
1740- BasicBlock *WrapperEntryBB =
1741- BasicBlock::Create (M.getContext (), " " , WrapperFunc);
1742- Builder.SetInsertPoint (WrapperEntryBB);
1753+ Builder.SetInsertPoint (TaskAllocaBB, TaskAllocaBB->begin ());
17431754 if (HasShareds) {
1744- llvm::Value *Shareds =
1745- Builder.CreateLoad (VoidPtr, WrapperFunc->getArg (1 ));
1746- Builder.CreateCall (&OutlinedFn, {Shareds});
1747- } else {
1748- Builder.CreateCall (&OutlinedFn);
1755+ LoadInst *Shareds = Builder.CreateLoad (VoidPtr, OutlinedFn.getArg (1 ));
1756+ OutlinedFn.getArg (1 )->replaceUsesWithIf (
1757+ Shareds, [Shareds](Use &U) { return U.getUser () != Shareds; });
1758+ }
1759+
1760+ while (!ToBeDeleted.empty ()) {
1761+ ToBeDeleted.top ()->eraseFromParent ();
1762+ ToBeDeleted.pop ();
17491763 }
1750- Builder.CreateRet (Builder.getInt32 (0 ));
17511764 };
17521765
17531766 addOutlineInfo (std::move (OI));
1754-
1755- InsertPointTy TaskAllocaIP =
1756- InsertPointTy (TaskAllocaBB, TaskAllocaBB->begin ());
1757- InsertPointTy TaskBodyIP = InsertPointTy (TaskBodyBB, TaskBodyBB->begin ());
1758- BodyGenCB (TaskAllocaIP, TaskBodyIP);
17591767 Builder.SetInsertPoint (TaskExitBB, TaskExitBB->begin ());
17601768
17611769 return Builder.saveIP ();
@@ -5763,84 +5771,63 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
57635771 BasicBlock *AllocaBB =
57645772 splitBB (Builder, /* CreateBranch=*/ true , " teams.alloca" );
57655773
5774+ // Generate the body of teams.
5775+ InsertPointTy AllocaIP (AllocaBB, AllocaBB->begin ());
5776+ InsertPointTy CodeGenIP (BodyBB, BodyBB->begin ());
5777+ BodyGenCB (AllocaIP, CodeGenIP);
5778+
57665779 OutlineInfo OI;
57675780 OI.EntryBB = AllocaBB;
57685781 OI.ExitBB = ExitBB;
57695782 OI.OuterAllocaBB = &OuterAllocaBB;
5770- OI.PostOutlineCB = [this , Ident](Function &OutlinedFn) {
5771- // The input IR here looks like the following-
5772- // ```
5773- // func @current_fn() {
5774- // outlined_fn(%args)
5775- // }
5776- // func @outlined_fn(%args) { ... }
5777- // ```
5778- //
5779- // This is changed to the following-
5780- //
5781- // ```
5782- // func @current_fn() {
5783- // runtime_call(..., wrapper_fn, ...)
5784- // }
5785- // func @wrapper_fn(..., %args) {
5786- // outlined_fn(%args)
5787- // }
5788- // func @outlined_fn(%args) { ... }
5789- // ```
57905783
5784+ // Insert fake values for global tid and bound tid.
5785+ std::stack<Instruction *> ToBeDeleted;
5786+ InsertPointTy OuterAllocaIP (&OuterAllocaBB, OuterAllocaBB.begin ());
5787+ OI.ExcludeArgsFromAggregate .push_back (createFakeIntVal (
5788+ Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, " gid" , true ));
5789+ OI.ExcludeArgsFromAggregate .push_back (createFakeIntVal (
5790+ Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, " tid" , true ));
5791+
5792+ OI.PostOutlineCB = [this , Ident, ToBeDeleted](Function &OutlinedFn) mutable {
57915793 // The stale call instruction will be replaced with a new call instruction
5792- // for runtime call with a wrapper function.
5794+ // for runtime call with the outlined function.
57935795
57945796 assert (OutlinedFn.getNumUses () == 1 &&
57955797 " there must be a single user for the outlined function" );
57965798 CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back ());
5799+ ToBeDeleted.push (StaleCI);
5800+
5801+ assert ((OutlinedFn.arg_size () == 2 || OutlinedFn.arg_size () == 3 ) &&
5802+ " Outlined function must have two or three arguments only" );
5803+
5804+ bool HasShared = OutlinedFn.arg_size () == 3 ;
57975805
5798- // Create the wrapper function.
5799- SmallVector<Type *> WrapperArgTys{Builder.getPtrTy (), Builder.getPtrTy ()};
5800- for (auto &Arg : OutlinedFn.args ())
5801- WrapperArgTys.push_back (Arg.getType ());
5802- FunctionCallee WrapperFuncVal = M.getOrInsertFunction (
5803- (Twine (OutlinedFn.getName ()) + " .teams" ).str (),
5804- FunctionType::get (Builder.getVoidTy (), WrapperArgTys, false ));
5805- Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee ());
5806- WrapperFunc->getArg (0 )->setName (" global_tid" );
5807- WrapperFunc->getArg (1 )->setName (" bound_tid" );
5808- if (WrapperFunc->arg_size () > 2 )
5809- WrapperFunc->getArg (2 )->setName (" data" );
5810-
5811- // Emit the body of the wrapper function - just a call to outlined function
5812- // and return statement.
5813- BasicBlock *WrapperEntryBB =
5814- BasicBlock::Create (M.getContext (), " entrybb" , WrapperFunc);
5815- Builder.SetInsertPoint (WrapperEntryBB);
5816- SmallVector<Value *> Args;
5817- for (size_t ArgIndex = 2 ; ArgIndex < WrapperFunc->arg_size (); ArgIndex++)
5818- Args.push_back (WrapperFunc->getArg (ArgIndex));
5819- Builder.CreateCall (&OutlinedFn, Args);
5820- Builder.CreateRetVoid ();
5821-
5822- OutlinedFn.addFnAttr (Attribute::AttrKind::AlwaysInline);
5806+ OutlinedFn.getArg (0 )->setName (" global.tid.ptr" );
5807+ OutlinedFn.getArg (1 )->setName (" bound.tid.ptr" );
5808+ if (HasShared)
5809+ OutlinedFn.getArg (2 )->setName (" data" );
58235810
58245811 // Call to the runtime function for teams in the current function.
58255812 assert (StaleCI && " Error while outlining - no CallInst user found for the "
58265813 " outlined function." );
58275814 Builder.SetInsertPoint (StaleCI);
5828- Args = {Ident, Builder.getInt32 (StaleCI->arg_size ()), WrapperFunc};
5829- for (Use &Arg : StaleCI->args ())
5830- Args.push_back (Arg);
5815+ SmallVector<Value *> Args = {
5816+ Ident, Builder.getInt32 (StaleCI->arg_size () - 2 ), &OutlinedFn};
5817+ if (HasShared)
5818+ Args.push_back (StaleCI->getArgOperand (2 ));
58315819 Builder.CreateCall (getOrCreateRuntimeFunctionPtr (
58325820 omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
58335821 Args);
5834- StaleCI->eraseFromParent ();
5822+
5823+ while (!ToBeDeleted.empty ()) {
5824+ ToBeDeleted.top ()->eraseFromParent ();
5825+ ToBeDeleted.pop ();
5826+ }
58355827 };
58365828
58375829 addOutlineInfo (std::move (OI));
58385830
5839- // Generate the body of teams.
5840- InsertPointTy AllocaIP (AllocaBB, AllocaBB->begin ());
5841- InsertPointTy CodeGenIP (BodyBB, BodyBB->begin ());
5842- BodyGenCB (AllocaIP, CodeGenIP);
5843-
58445831 Builder.SetInsertPoint (ExitBB, ExitBB->begin ());
58455832
58465833 return Builder.saveIP ();
0 commit comments