-
Notifications
You must be signed in to change notification settings - Fork 12.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OpenMPIRBuilder] Remove wrapper function in createTask
, createTeams
#67723
Conversation
This patch removes the wrapper function in `OpenMPIRBuilder::createTask`. The outlined function is directly of the form that is expected by the runtime library calls. This also fixes the global thread ID argument, which should be used whenever `kmpc_global_thread_num()` is called inside the outlined function.
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-flang-openmp ChangesThis patch removes the wrapper function in Patch is 22.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67723.diff 3 Files Affected:
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 9c70d384e55db2b..54012b488c6b671 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -35,6 +35,7 @@
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Metadata.h"
@@ -1496,6 +1497,14 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
bool Tied, Value *Final, Value *IfCondition,
SmallVector<DependData> Dependencies) {
+ // We create a temporary i32 value that will represent the global tid after
+ // outlining.
+ SmallVector<Instruction *, 4> ToBeDeleted;
+ Builder.restoreIP(AllocaIP);
+ AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
+ LoadInst *TID = Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use");
+ ToBeDeleted.append({TID, TIDAddr});
+
if (!updateToLocation(Loc))
return InsertPointTy();
@@ -1523,41 +1532,27 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
BasicBlock *TaskAllocaBB =
splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
+ // Fake use of TID
+ Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
+ BinaryOperator *AddInst =
+ dyn_cast<BinaryOperator>(Builder.CreateAdd(TID, Builder.getInt32(10)));
+ ToBeDeleted.push_back(AddInst);
+
OutlineInfo OI;
OI.EntryBB = TaskAllocaBB;
OI.OuterAllocaBB = AllocaIP.getBlock();
OI.ExitBB = TaskExitBB;
- OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition,
- Dependencies](Function &OutlinedFn) {
- // The input IR here looks like the following-
- // ```
- // func @current_fn() {
- // outlined_fn(%args)
- // }
- // func @outlined_fn(%args) { ... }
- // ```
- //
- // This is changed to the following-
- //
- // ```
- // func @current_fn() {
- // runtime_call(..., wrapper_fn, ...)
- // }
- // func @wrapper_fn(..., %args) {
- // outlined_fn(%args)
- // }
- // func @outlined_fn(%args) { ... }
- // ```
-
- // The stale call instruction will be replaced with a new call instruction
- // for runtime call with a wrapper function.
+ OI.ExcludeArgsFromAggregate = {TID};
+ OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
+ TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) {
+ // Replace the Stale CI by appropriate RTL function call.
assert(OutlinedFn.getNumUses() == 1 &&
"there must be a single user for the outlined function");
CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
// HasShareds is true if any variables are captured in the outlined region,
// false otherwise.
- bool HasShareds = StaleCI->arg_size() > 0;
+ bool HasShareds = StaleCI->arg_size() > 1;
Builder.SetInsertPoint(StaleCI);
// Gather the arguments for emitting the runtime call for
@@ -1595,7 +1590,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
Value *SharedsSize = Builder.getInt64(0);
if (HasShareds) {
AllocaInst *ArgStructAlloca =
- dyn_cast<AllocaInst>(StaleCI->getArgOperand(0));
+ dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
assert(ArgStructAlloca &&
"Unable to find the alloca instruction corresponding to arguments "
"for extracted function");
@@ -1606,31 +1601,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
SharedsSize =
Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
}
-
- // Argument - task_entry (the wrapper function)
- // If the outlined function has some captured variables (i.e. HasShareds is
- // true), then the wrapper function will have an additional argument (the
- // struct containing captured variables). Otherwise, no such argument will
- // be present.
- SmallVector<Type *> WrapperArgTys{Builder.getInt32Ty()};
- if (HasShareds)
- WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType());
- FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
- (Twine(OutlinedFn.getName()) + ".wrapper").str(),
- FunctionType::get(Builder.getInt32Ty(), WrapperArgTys, false));
- Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
-
// Emit the @__kmpc_omp_task_alloc runtime call
// The runtime call returns a pointer to an area where the task captured
// variables must be copied before the task is run (TaskData)
CallInst *TaskData = Builder.CreateCall(
TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
/*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
- /*task_func=*/WrapperFunc});
+ /*task_func=*/&OutlinedFn});
// Copy the arguments for outlined function
if (HasShareds) {
- Value *Shareds = StaleCI->getArgOperand(0);
+ Value *Shareds = StaleCI->getArgOperand(1);
Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
@@ -1697,10 +1678,9 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
if (IfCondition) {
// `SplitBlockAndInsertIfThenElse` requires the block to have a
// terminator.
- BasicBlock *NewBasicBlock =
- splitBB(Builder, /*CreateBranch=*/true, "if.end");
+ splitBB(Builder, /*CreateBranch=*/true, "if.end");
Instruction *IfTerminator =
- NewBasicBlock->getSinglePredecessor()->getTerminator();
+ Builder.GetInsertPoint()->getParent()->getTerminator();
Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
Builder.SetInsertPoint(IfTerminator);
SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI,
@@ -1711,10 +1691,12 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
Function *TaskCompleteFn =
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
+ CallInst *CI = nullptr;
if (HasShareds)
- Builder.CreateCall(WrapperFunc, {ThreadID, TaskData});
+ CI = Builder.CreateCall(&OutlinedFn, {ThreadID, TaskData});
else
- Builder.CreateCall(WrapperFunc, {ThreadID});
+ CI = Builder.CreateCall(&OutlinedFn, {ThreadID});
+ CI->setDebugLoc(StaleCI->getDebugLoc());
Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
Builder.SetInsertPoint(ThenTI);
}
@@ -1736,18 +1718,28 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
StaleCI->eraseFromParent();
- // Emit the body for wrapper function
- BasicBlock *WrapperEntryBB =
- BasicBlock::Create(M.getContext(), "", WrapperFunc);
- Builder.SetInsertPoint(WrapperEntryBB);
+ Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
if (HasShareds) {
- llvm::Value *Shareds =
- Builder.CreateLoad(VoidPtr, WrapperFunc->getArg(1));
- Builder.CreateCall(&OutlinedFn, {Shareds});
- } else {
- Builder.CreateCall(&OutlinedFn);
+ LoadInst *Shareds = Builder.CreateLoad(VoidPtr, OutlinedFn.getArg(1));
+ OutlinedFn.getArg(1)->replaceUsesWithIf(
+ Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; });
+ }
+
+ // Replace kmpc_global_thread_num() calls with the global thread id
+ // argument.
+ OutlinedFn.getArg(0)->setName("global.tid");
+ FunctionCallee TIDRTLFn =
+ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
+ for (Instruction &Inst : instructions(OutlinedFn)) {
+ CallInst *CI = dyn_cast<CallInst>(&Inst);
+ if (!CI)
+ continue;
+ if (CI->getCalledFunction() == TIDRTLFn.getCallee())
+ CI->replaceAllUsesWith(OutlinedFn.getArg(0));
}
- Builder.CreateRet(Builder.getInt32(0));
+
+ for (Instruction *I : ToBeDeleted)
+ I->eraseFromParent();
};
addOutlineInfo(std::move(OI));
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index fd524f6067ee0ea..643b34270c01693 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5486,25 +5486,28 @@ TEST_F(OpenMPIRBuilderTest, CreateTask) {
24); // 64-bit pointer + 128-bit integer
// Verify Wrapper function
- Function *WrapperFunc =
+ Function *OutlinedFn =
dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
- ASSERT_NE(WrapperFunc, nullptr);
+ ASSERT_NE(OutlinedFn, nullptr);
- LoadInst *SharedsLoad = dyn_cast<LoadInst>(WrapperFunc->begin()->begin());
+ LoadInst *SharedsLoad = dyn_cast<LoadInst>(OutlinedFn->begin()->begin());
ASSERT_NE(SharedsLoad, nullptr);
- EXPECT_EQ(SharedsLoad->getPointerOperand(), WrapperFunc->getArg(1));
-
- EXPECT_FALSE(WrapperFunc->isDeclaration());
- CallInst *OutlinedFnCall =
- dyn_cast<CallInst>(++WrapperFunc->begin()->begin());
- ASSERT_NE(OutlinedFnCall, nullptr);
- EXPECT_EQ(WrapperFunc->getArg(0)->getType(), Builder.getInt32Ty());
- EXPECT_EQ(OutlinedFnCall->getArgOperand(0),
- WrapperFunc->getArg(1)->uses().begin()->getUser());
+ EXPECT_EQ(SharedsLoad->getPointerOperand(), OutlinedFn->getArg(1));
+
+ EXPECT_FALSE(OutlinedFn->isDeclaration());
+ EXPECT_EQ(OutlinedFn->getArg(0)->getType(), Builder.getInt32Ty());
+
+ // Verify that the data argument is used only once, and that too in the load
+ // instruction that is then used for accessing shared data.
+ Value *DataPtr = OutlinedFn->getArg(1);
+ EXPECT_EQ(DataPtr->getNumUses(), 1);
+ EXPECT_TRUE(isa<LoadInst>(DataPtr->uses().begin()->getUser()));
+ Value *Data = DataPtr->uses().begin()->getUser();
+ EXPECT_TRUE(all_of(Data->uses(), [](Use &U) {
+ return isa<GetElementPtrInst>(U.getUser());
+ }));
// Verify the presence of `trunc` and `icmp` instructions in Outlined function
- Function *OutlinedFn = OutlinedFnCall->getCalledFunction();
- ASSERT_NE(OutlinedFn, nullptr);
EXPECT_TRUE(any_of(instructions(OutlinedFn),
[](Instruction &inst) { return isa<TruncInst>(&inst); }));
EXPECT_TRUE(any_of(instructions(OutlinedFn),
@@ -5547,6 +5550,14 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskNoArgs) {
Builder.CreateRetVoid();
EXPECT_FALSE(verifyModule(*M, &errs()));
+
+ // Check that the outlined function has only one argument.
+ CallInst *TaskAllocCall = dyn_cast<CallInst>(
+ OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
+ ->user_back());
+ Function *OutlinedFn = dyn_cast<Function>(TaskAllocCall->getArgOperand(5));
+ ASSERT_NE(OutlinedFn, nullptr);
+ ASSERT_EQ(OutlinedFn->arg_size(), 1);
}
TEST_F(OpenMPIRBuilderTest, CreateTaskUntied) {
@@ -5658,8 +5669,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) {
F->setName("func");
IRBuilder<> Builder(BB);
auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
- IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+ IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
Builder.SetInsertPoint(BodyBB);
Value *Final = Builder.CreateICmp(
CmpInst::Predicate::ICMP_EQ, F->getArg(0),
@@ -5711,8 +5722,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) {
F->setName("func");
IRBuilder<> Builder(BB);
auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
- IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+ IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
Builder.SetInsertPoint(BodyBB);
Value *IfCondition = Builder.CreateICmp(
CmpInst::Predicate::ICMP_EQ, F->getArg(0),
@@ -5758,15 +5769,16 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) {
->user_back());
ASSERT_NE(TaskBeginIfCall, nullptr);
ASSERT_NE(TaskCompleteCall, nullptr);
- Function *WrapperFunc =
+ Function *OulinedFn =
dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
- ASSERT_NE(WrapperFunc, nullptr);
- CallInst *WrapperFuncCall = dyn_cast<CallInst>(WrapperFunc->user_back());
- ASSERT_NE(WrapperFuncCall, nullptr);
+ ASSERT_NE(OulinedFn, nullptr);
+ CallInst *OulinedFnCall = dyn_cast<CallInst>(OulinedFn->user_back());
+ ASSERT_NE(OulinedFnCall, nullptr);
EXPECT_EQ(TaskBeginIfCall->getParent(),
IfConditionBranchInst->getSuccessor(1));
- EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), WrapperFuncCall);
- EXPECT_EQ(WrapperFuncCall->getNextNonDebugInstruction(), TaskCompleteCall);
+
+ EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), OulinedFnCall);
+ EXPECT_EQ(OulinedFnCall->getNextNonDebugInstruction(), TaskCompleteCall);
}
TEST_F(OpenMPIRBuilderTest, CreateTaskgroup) {
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 28b0113a19d61b8..2cd561cb021075f 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2209,7 +2209,7 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
// CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40,
- // CHECK-SAME: i64 0, ptr @[[wrapper_fn:.+]])
+ // CHECK-SAME: i64 0, ptr @[[outlined_fn:.+]])
// CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
omp.task {
%n = llvm.mlir.constant(1 : i64) : i64
@@ -2222,7 +2222,7 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
llvm.return
}
-// CHECK: define internal void @[[outlined_fn:.+]]()
+// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]])
// CHECK: task.alloca{{.*}}:
// CHECK: br label %[[task_body:[^, ]+]]
// CHECK: [[task_body]]:
@@ -2236,12 +2236,6 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
// CHECK: [[exit_stub]]:
// CHECK: ret void
-
-// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}) {
-// CHECK: call void @[[outlined_fn]]()
-// CHECK: ret i32 0
-// CHECK: }
-
// -----
// CHECK-LABEL: define void @omp_task_with_deps
@@ -2259,7 +2253,7 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) {
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
// CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40,
- // CHECK-SAME: i64 0, ptr @[[wrapper_fn:.+]])
+ // CHECK-SAME: i64 0, ptr @[[outlined_fn:.+]])
// CHECK: call i32 @__kmpc_omp_task_with_deps(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]], {{.*}})
omp.task depend(taskdependin -> %zaddr : !llvm.ptr<i32>) {
%n = llvm.mlir.constant(1 : i64) : i64
@@ -2272,7 +2266,7 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) {
llvm.return
}
-// CHECK: define internal void @[[outlined_fn:.+]]()
+// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]])
// CHECK: task.alloca{{.*}}:
// CHECK: br label %[[task_body:[^, ]+]]
// CHECK: [[task_body]]:
@@ -2286,11 +2280,6 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) {
// CHECK: [[exit_stub]]:
// CHECK: ret void
-// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}) {
-// CHECK: call void @[[outlined_fn]]()
-// CHECK: ret i32 0
-// CHECK: }
-
// -----
// CHECK-LABEL: define void @omp_task
@@ -2304,7 +2293,7 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
// CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 16,
- // CHECK-SAME: ptr @[[wrapper_fn:.+]])
+ // CHECK-SAME: ptr @[[outlined_fn:.+]])
// CHECK: %[[shareds:.+]] = load ptr, ptr %[[task_data]]
// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr {{.+}} %[[shareds]], ptr {{.+}}, i64 16, i1 false)
// CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
@@ -2321,8 +2310,9 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
}
}
-// CHECK: define internal void @[[outlined_fn:.+]](ptr %[[task_data:.+]])
+// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]], ptr %[[task_data:.+]])
// CHECK: task.alloca{{.*}}:
+// CHECK: %[[shareds:.+]] = load ptr, ptr %[[task_data]]
// CHECK: br label %[[task_body:[^, ]+]]
// CHECK: [[task_body]]:
// CHECK: br label %[[task_region:[^, ]+]]
@@ -2333,13 +2323,6 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
// CHECK: [[exit_stub]]:
// CHECK: ret void
-
-// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}, ptr %[[task_data:.+]]) {
-// CHECK: %[[shareds:.+]] = load ptr, ptr %1, align 8
-// CHECK: call void @[[outlined_fn]](ptr %[[shareds]])
-// CHECK: ret i32 0
-// CHECK: }
-
// -----
llvm.func @par_task_(%arg0: !llvm.ptr<i32> {fir.bindc_name = "a"}) {
@@ -2355,14 +2338,12 @@ llvm.func @par_task_(%arg0: !llvm.ptr<i32> {fir.bindc_name = "a"}) {
}
// CHECK-LABEL: @par_task_
-// CHECK: %[[TASK_ALLOC:.*]] = call ptr @__kmpc_omp_task_alloc({{.*}}ptr @par_task_..omp_par.wrapper)
+// CHECK: %[[TASK_ALLOC:.*]] = call ptr @__kmpc_omp_task_alloc({{.*}}ptr @[[task_outlined_fn:.+]])
// CHECK: call i32 @__kmpc_omp_task({{.*}}, ptr %[[TASK_ALLOC]])
-// CHECK-LABEL: define internal void @par_task_..omp_par
+// CHECK: define internal void @[[task_outlined_fn]]
// CHECK: %[[ARG_ALLOC:.*]] = alloca { ptr }, align 8
-// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @par_task_..omp_par..omp_par, ptr %[[ARG_ALLOC]])
-// CHECK: define internal void @par_task_..omp_par..omp_par
-// CHECK: define i32 @par_task_..omp_par.wrapper
-// CHECK: call void @par_task_..omp_par
+// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @[[parallel_outlined_fn:.+]], ptr %[[ARG_ALLOC]])
+// CHECK: define internal void @[[parallel_outlined_fn]]
// -----
llvm.func @foo() -> ()
@@ -2432,7 +2413,7 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
// CHECK: br label %[[codeRepl:[^,]+]]
// CHECK: [[codeRepl]]:
// CHECK: %[[omp_global_thread_num_t1:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
-// CHECK: %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 40, i64 0, ptr @omp_taskgroup_task..omp_par.wrapper)
+// CHECK: %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 40, i64 0, ptr @[[outlined_task_fn:.+]])
// CHECK: %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], ptr %[[t1_alloc]])
// CHECK: br label %[[task_exit:[^,]+]]
// CHECK: [[task_exit]]:
@@ -2445,7 +2426,7 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
// CHECK: %[[gep3:.+]] = getelementptr { i32, i32, ptr }, ptr %[[structArg]], i32 0, i32 2
// CHECK: store ptr %[[zaddr]], ptr %[[gep3]], align 8
// CHECK: %[[omp_global_thread_num_t2:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
-// CHECK: %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 40, i64 16, ptr @omp_taskgroup_task..omp_par.1.wrapper)
+// CHECK: %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 40, i64 16, ptr @[[outlined_task_fn:.+]])
// CHECK: %[[shareds:.+]] = load ptr, ptr %[[t2_alloc]]
// CHECK: call void @llvm.memcpy.p0.p0.i64(ptr align ...
[truncated]
|
createTask
createTask
, createTeams
Ping for review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you expand the summary to describe the changes.
What are the changes required to remove the wrapper function (Why was it required in the first place?)
Why are the fake vals necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed summary. LG.
// The stale call instruction will be replaced with a new call instruction | ||
// for runtime call with a wrapper function. | ||
// Add the thread ID argument. | ||
std::stack<Instruction *> ToBeDeleted; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not a SmallVector, like we use everywhere else? We can reasonably guess the size to avoid dynamic allocations.
|
||
while (!ToBeDeleted.empty()) { | ||
ToBeDeleted.top()->eraseFromParent(); | ||
ToBeDeleted.pop(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no need to pop anything, we have this code in other places already, why do we need to come up with new and exciting ways to do the same thing?
for (auto *TBD : ToBeDeleted)
TBD->eraseFromParent
|
||
while (!ToBeDeleted.empty()) { | ||
ToBeDeleted.top()->eraseFromParent(); | ||
ToBeDeleted.pop(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same
This patch removes the wrapper function in
OpenMPIRBuilder::createTask
andOpenMPIRBuilder.createTeams
. The outlined function is directly of the form that is expected by the runtime library calls. This patch also adds a utility function to help add fake values and their uses, which will be deleted in finalization callbacks.Why we needed wrappers earlier?
Before the post outline callbacks are executed, the IR has the following structure:
OpenMP offloading expects a specific signature for the outlined function in a runtime call. For example,
__kmpc_fork_teams
expects the following signature:As there is no way to change a function's arguments after it has been created, a wrapper function with the expected signature is created that calls the outlined function inside it.
How we are handling it now?
To handle this in the current patch, we create a "fake" global tid and add a "fake" use for it in the to-be-outlined region. We need to create these fake values so the outliner sees it as something it needs to pass to the outlined function. We also tell the outliner to exclude this global tid value from the aggregate
data
argument, so it comes as a separate argument in the beginning. This way, we are able to directly get the outlined function in the expected format. This is inspired by the waycreateParallel
handles outlining (using fake values and then deleting them later). Tasks are handled with a similar approach. This simplifies the generated code and the code to do this itself also becomes simpler (because we no longer have to construct a new function).