Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenMPIRBuilder] Remove wrapper function in createTask, createTeams #67723

Merged
merged 10 commits into from
Oct 9, 2023

Conversation

shraiysh
Copy link
Member

@shraiysh shraiysh commented Sep 28, 2023

This patch removes the wrapper function in OpenMPIRBuilder::createTask and OpenMPIRBuilder.createTeams. The outlined function is directly of the form that is expected by the runtime library calls. This patch also adds a utility function to help add fake values and their uses, which will be deleted in finalization callbacks.

Why we needed wrappers earlier?
Before the post outline callbacks are executed, the IR has the following structure:

define @func() {
  ;...
  call void @outlined_fn(ptr %data)
  ;...
}
define void @outlined_fn(ptr %data)

OpenMP offloading expects a specific signature for the outlined function in a runtime call. For example, __kmpc_fork_teams expects the following signature:

define @outlined_fn(ptr %global.tid, ptr %data)

As there is no way to change a function's arguments after it has been created, a wrapper function with the expected signature is created that calls the outlined function inside it.

How we are handling it now?
To handle this in the current patch, we create a "fake" global tid and add a "fake" use for it in the to-be-outlined region. We need to create these fake values so the outliner sees it as something it needs to pass to the outlined function. We also tell the outliner to exclude this global tid value from the aggregate data argument, so it comes as a separate argument in the beginning. This way, we are able to directly get the outlined function in the expected format. This is inspired by the way createParallel handles outlining (using fake values and then deleting them later). Tasks are handled with a similar approach. This simplifies the generated code and the code to do this itself also becomes simpler (because we no longer have to construct a new function).

This patch removes the wrapper function in
`OpenMPIRBuilder::createTask`. The outlined function is directly of the
form that is expected by the runtime library calls. This also fixes the
global thread ID argument, which should be used whenever
`kmpc_global_thread_num()` is called inside the outlined function.
@llvmbot
Copy link
Member

llvmbot commented Sep 28, 2023

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-flang-openmp

Changes

This patch removes the wrapper function in OpenMPIRBuilder::createTask. The outlined function is directly of the form that is expected by the runtime library calls. This also fixes the global thread ID argument, which should be used whenever kmpc_global_thread_num() is called inside the outlined function.


Patch is 22.51 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67723.diff

3 Files Affected:

  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+49-57)
  • (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+34-22)
  • (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+16-35)
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 9c70d384e55db2b..54012b488c6b671 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -35,6 +35,7 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
@@ -1496,6 +1497,14 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
                             InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
                             bool Tied, Value *Final, Value *IfCondition,
                             SmallVector<DependData> Dependencies) {
+  // We create a temporary i32 value that will represent the global tid after
+  // outlining.
+  SmallVector<Instruction *, 4> ToBeDeleted;
+  Builder.restoreIP(AllocaIP);
+  AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
+  LoadInst *TID = Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use");
+  ToBeDeleted.append({TID, TIDAddr});
+
   if (!updateToLocation(Loc))
     return InsertPointTy();
 
@@ -1523,41 +1532,27 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
   BasicBlock *TaskAllocaBB =
       splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
 
+  // Fake use of TID
+  Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
+  BinaryOperator *AddInst =
+      dyn_cast<BinaryOperator>(Builder.CreateAdd(TID, Builder.getInt32(10)));
+  ToBeDeleted.push_back(AddInst);
+
   OutlineInfo OI;
   OI.EntryBB = TaskAllocaBB;
   OI.OuterAllocaBB = AllocaIP.getBlock();
   OI.ExitBB = TaskExitBB;
-  OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition,
-                      Dependencies](Function &OutlinedFn) {
-    // The input IR here looks like the following-
-    // ```
-    // func @current_fn() {
-    //   outlined_fn(%args)
-    // }
-    // func @outlined_fn(%args) { ... }
-    // ```
-    //
-    // This is changed to the following-
-    //
-    // ```
-    // func @current_fn() {
-    //   runtime_call(..., wrapper_fn, ...)
-    // }
-    // func @wrapper_fn(..., %args) {
-    //   outlined_fn(%args)
-    // }
-    // func @outlined_fn(%args) { ... }
-    // ```
-
-    // The stale call instruction will be replaced with a new call instruction
-    // for runtime call with a wrapper function.
+  OI.ExcludeArgsFromAggregate = {TID};
+  OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
+                      TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) {
+    // Replace the Stale CI by appropriate RTL function call.
     assert(OutlinedFn.getNumUses() == 1 &&
            "there must be a single user for the outlined function");
     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
 
     // HasShareds is true if any variables are captured in the outlined region,
     // false otherwise.
-    bool HasShareds = StaleCI->arg_size() > 0;
+    bool HasShareds = StaleCI->arg_size() > 1;
     Builder.SetInsertPoint(StaleCI);
 
     // Gather the arguments for emitting the runtime call for
@@ -1595,7 +1590,7 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
     Value *SharedsSize = Builder.getInt64(0);
     if (HasShareds) {
       AllocaInst *ArgStructAlloca =
-          dyn_cast<AllocaInst>(StaleCI->getArgOperand(0));
+          dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
       assert(ArgStructAlloca &&
              "Unable to find the alloca instruction corresponding to arguments "
              "for extracted function");
@@ -1606,31 +1601,17 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
       SharedsSize =
           Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
     }
-
-    // Argument - task_entry (the wrapper function)
-    // If the outlined function has some captured variables (i.e. HasShareds is
-    // true), then the wrapper function will have an additional argument (the
-    // struct containing captured variables). Otherwise, no such argument will
-    // be present.
-    SmallVector<Type *> WrapperArgTys{Builder.getInt32Ty()};
-    if (HasShareds)
-      WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType());
-    FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
-        (Twine(OutlinedFn.getName()) + ".wrapper").str(),
-        FunctionType::get(Builder.getInt32Ty(), WrapperArgTys, false));
-    Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
-
     // Emit the @__kmpc_omp_task_alloc runtime call
     // The runtime call returns a pointer to an area where the task captured
     // variables must be copied before the task is run (TaskData)
     CallInst *TaskData = Builder.CreateCall(
         TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
                       /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
-                      /*task_func=*/WrapperFunc});
+                      /*task_func=*/&OutlinedFn});
 
     // Copy the arguments for outlined function
     if (HasShareds) {
-      Value *Shareds = StaleCI->getArgOperand(0);
+      Value *Shareds = StaleCI->getArgOperand(1);
       Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
       Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
       Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
@@ -1697,10 +1678,9 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
     if (IfCondition) {
       // `SplitBlockAndInsertIfThenElse` requires the block to have a
       // terminator.
-      BasicBlock *NewBasicBlock =
-          splitBB(Builder, /*CreateBranch=*/true, "if.end");
+      splitBB(Builder, /*CreateBranch=*/true, "if.end");
       Instruction *IfTerminator =
-          NewBasicBlock->getSinglePredecessor()->getTerminator();
+          Builder.GetInsertPoint()->getParent()->getTerminator();
       Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
       Builder.SetInsertPoint(IfTerminator);
       SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI,
@@ -1711,10 +1691,12 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
       Function *TaskCompleteFn =
           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
       Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
+      CallInst *CI = nullptr;
       if (HasShareds)
-        Builder.CreateCall(WrapperFunc, {ThreadID, TaskData});
+        CI = Builder.CreateCall(&OutlinedFn, {ThreadID, TaskData});
       else
-        Builder.CreateCall(WrapperFunc, {ThreadID});
+        CI = Builder.CreateCall(&OutlinedFn, {ThreadID});
+      CI->setDebugLoc(StaleCI->getDebugLoc());
       Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
       Builder.SetInsertPoint(ThenTI);
     }
@@ -1736,18 +1718,28 @@ OpenMPIRBuilder::createTask(const LocationDescription &Loc,
 
     StaleCI->eraseFromParent();
 
-    // Emit the body for wrapper function
-    BasicBlock *WrapperEntryBB =
-        BasicBlock::Create(M.getContext(), "", WrapperFunc);
-    Builder.SetInsertPoint(WrapperEntryBB);
+    Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
     if (HasShareds) {
-      llvm::Value *Shareds =
-          Builder.CreateLoad(VoidPtr, WrapperFunc->getArg(1));
-      Builder.CreateCall(&OutlinedFn, {Shareds});
-    } else {
-      Builder.CreateCall(&OutlinedFn);
+      LoadInst *Shareds = Builder.CreateLoad(VoidPtr, OutlinedFn.getArg(1));
+      OutlinedFn.getArg(1)->replaceUsesWithIf(
+          Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; });
+    }
+
+    // Replace kmpc_global_thread_num() calls with the global thread id
+    // argument.
+    OutlinedFn.getArg(0)->setName("global.tid");
+    FunctionCallee TIDRTLFn =
+        getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
+    for (Instruction &Inst : instructions(OutlinedFn)) {
+      CallInst *CI = dyn_cast<CallInst>(&Inst);
+      if (!CI)
+        continue;
+      if (CI->getCalledFunction() == TIDRTLFn.getCallee())
+        CI->replaceAllUsesWith(OutlinedFn.getArg(0));
     }
-    Builder.CreateRet(Builder.getInt32(0));
+
+    for (Instruction *I : ToBeDeleted)
+      I->eraseFromParent();
   };
 
   addOutlineInfo(std::move(OI));
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index fd524f6067ee0ea..643b34270c01693 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -5486,25 +5486,28 @@ TEST_F(OpenMPIRBuilderTest, CreateTask) {
             24); // 64-bit pointer + 128-bit integer
 
   // Verify Wrapper function
-  Function *WrapperFunc =
+  Function *OutlinedFn =
       dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
-  ASSERT_NE(WrapperFunc, nullptr);
+  ASSERT_NE(OutlinedFn, nullptr);
 
-  LoadInst *SharedsLoad = dyn_cast<LoadInst>(WrapperFunc->begin()->begin());
+  LoadInst *SharedsLoad = dyn_cast<LoadInst>(OutlinedFn->begin()->begin());
   ASSERT_NE(SharedsLoad, nullptr);
-  EXPECT_EQ(SharedsLoad->getPointerOperand(), WrapperFunc->getArg(1));
-
-  EXPECT_FALSE(WrapperFunc->isDeclaration());
-  CallInst *OutlinedFnCall =
-      dyn_cast<CallInst>(++WrapperFunc->begin()->begin());
-  ASSERT_NE(OutlinedFnCall, nullptr);
-  EXPECT_EQ(WrapperFunc->getArg(0)->getType(), Builder.getInt32Ty());
-  EXPECT_EQ(OutlinedFnCall->getArgOperand(0),
-            WrapperFunc->getArg(1)->uses().begin()->getUser());
+  EXPECT_EQ(SharedsLoad->getPointerOperand(), OutlinedFn->getArg(1));
+
+  EXPECT_FALSE(OutlinedFn->isDeclaration());
+  EXPECT_EQ(OutlinedFn->getArg(0)->getType(), Builder.getInt32Ty());
+
+  // Verify that the data argument is used only once, and that too in the load
+  // instruction that is then used for accessing shared data.
+  Value *DataPtr = OutlinedFn->getArg(1);
+  EXPECT_EQ(DataPtr->getNumUses(), 1);
+  EXPECT_TRUE(isa<LoadInst>(DataPtr->uses().begin()->getUser()));
+  Value *Data = DataPtr->uses().begin()->getUser();
+  EXPECT_TRUE(all_of(Data->uses(), [](Use &U) {
+    return isa<GetElementPtrInst>(U.getUser());
+  }));
 
   // Verify the presence of `trunc` and `icmp` instructions in Outlined function
-  Function *OutlinedFn = OutlinedFnCall->getCalledFunction();
-  ASSERT_NE(OutlinedFn, nullptr);
   EXPECT_TRUE(any_of(instructions(OutlinedFn),
                      [](Instruction &inst) { return isa<TruncInst>(&inst); }));
   EXPECT_TRUE(any_of(instructions(OutlinedFn),
@@ -5547,6 +5550,14 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskNoArgs) {
   Builder.CreateRetVoid();
 
   EXPECT_FALSE(verifyModule(*M, &errs()));
+
+  // Check that the outlined function has only one argument.
+  CallInst *TaskAllocCall = dyn_cast<CallInst>(
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc)
+          ->user_back());
+  Function *OutlinedFn = dyn_cast<Function>(TaskAllocCall->getArgOperand(5));
+  ASSERT_NE(OutlinedFn, nullptr);
+  ASSERT_EQ(OutlinedFn->arg_size(), 1);
 }
 
 TEST_F(OpenMPIRBuilderTest, CreateTaskUntied) {
@@ -5658,8 +5669,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskFinal) {
   F->setName("func");
   IRBuilder<> Builder(BB);
   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
-  IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
   BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+  IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
   Builder.SetInsertPoint(BodyBB);
   Value *Final = Builder.CreateICmp(
       CmpInst::Predicate::ICMP_EQ, F->getArg(0),
@@ -5711,8 +5722,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) {
   F->setName("func");
   IRBuilder<> Builder(BB);
   auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
-  IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
   BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "alloca.split");
+  IRBuilderBase::InsertPoint AllocaIP = Builder.saveIP();
   Builder.SetInsertPoint(BodyBB);
   Value *IfCondition = Builder.CreateICmp(
       CmpInst::Predicate::ICMP_EQ, F->getArg(0),
@@ -5758,15 +5769,16 @@ TEST_F(OpenMPIRBuilderTest, CreateTaskIfCondition) {
           ->user_back());
   ASSERT_NE(TaskBeginIfCall, nullptr);
   ASSERT_NE(TaskCompleteCall, nullptr);
-  Function *WrapperFunc =
+  Function *OulinedFn =
       dyn_cast<Function>(TaskAllocCall->getArgOperand(5)->stripPointerCasts());
-  ASSERT_NE(WrapperFunc, nullptr);
-  CallInst *WrapperFuncCall = dyn_cast<CallInst>(WrapperFunc->user_back());
-  ASSERT_NE(WrapperFuncCall, nullptr);
+  ASSERT_NE(OulinedFn, nullptr);
+  CallInst *OulinedFnCall = dyn_cast<CallInst>(OulinedFn->user_back());
+  ASSERT_NE(OulinedFnCall, nullptr);
   EXPECT_EQ(TaskBeginIfCall->getParent(),
             IfConditionBranchInst->getSuccessor(1));
-  EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), WrapperFuncCall);
-  EXPECT_EQ(WrapperFuncCall->getNextNonDebugInstruction(), TaskCompleteCall);
+
+  EXPECT_EQ(TaskBeginIfCall->getNextNonDebugInstruction(), OulinedFnCall);
+  EXPECT_EQ(OulinedFnCall->getNextNonDebugInstruction(), TaskCompleteCall);
 }
 
 TEST_F(OpenMPIRBuilderTest, CreateTaskgroup) {
diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
index 28b0113a19d61b8..2cd561cb021075f 100644
--- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir
@@ -2209,7 +2209,7 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
   // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
   // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
   // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40,
-  // CHECK-SAME:  i64 0, ptr @[[wrapper_fn:.+]])
+  // CHECK-SAME:  i64 0, ptr @[[outlined_fn:.+]])
   // CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
   omp.task {
     %n = llvm.mlir.constant(1 : i64) : i64
@@ -2222,7 +2222,7 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
   llvm.return
 }
 
-// CHECK: define internal void @[[outlined_fn:.+]]()
+// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]])
 // CHECK: task.alloca{{.*}}:
 // CHECK:   br label %[[task_body:[^, ]+]]
 // CHECK: [[task_body]]:
@@ -2236,12 +2236,6 @@ llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
 // CHECK: [[exit_stub]]:
 // CHECK:   ret void
 
-
-// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}) {
-// CHECK:   call void @[[outlined_fn]]()
-// CHECK:   ret i32 0
-// CHECK: }
-
 // -----
 
 // CHECK-LABEL: define void @omp_task_with_deps
@@ -2259,7 +2253,7 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) {
   // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
   // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
   // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40,
-  // CHECK-SAME:  i64 0, ptr @[[wrapper_fn:.+]])
+  // CHECK-SAME:  i64 0, ptr @[[outlined_fn:.+]])
   // CHECK: call i32 @__kmpc_omp_task_with_deps(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]], {{.*}})
   omp.task depend(taskdependin -> %zaddr : !llvm.ptr<i32>) {
     %n = llvm.mlir.constant(1 : i64) : i64
@@ -2272,7 +2266,7 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) {
   llvm.return
 }
 
-// CHECK: define internal void @[[outlined_fn:.+]]()
+// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]])
 // CHECK: task.alloca{{.*}}:
 // CHECK:   br label %[[task_body:[^, ]+]]
 // CHECK: [[task_body]]:
@@ -2286,11 +2280,6 @@ llvm.func @omp_task_with_deps(%zaddr: !llvm.ptr<i32>) {
 // CHECK: [[exit_stub]]:
 // CHECK:   ret void
 
-// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}) {
-// CHECK:   call void @[[outlined_fn]]()
-// CHECK:   ret i32 0
-// CHECK: }
-
 // -----
 
 // CHECK-LABEL: define void @omp_task
@@ -2304,7 +2293,7 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
     // CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
     // CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
     // CHECK-SAME: (ptr @{{.+}}, i32 %[[omp_global_thread_num]], i32 1, i64 40, i64 16,
-    // CHECK-SAME: ptr @[[wrapper_fn:.+]])
+    // CHECK-SAME: ptr @[[outlined_fn:.+]])
     // CHECK: %[[shareds:.+]] = load ptr, ptr %[[task_data]]
     // CHECK: call void @llvm.memcpy.p0.p0.i64(ptr {{.+}} %[[shareds]], ptr {{.+}}, i64 16, i1 false)
     // CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
@@ -2321,8 +2310,9 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
   }
 }
 
-// CHECK: define internal void @[[outlined_fn:.+]](ptr %[[task_data:.+]])
+// CHECK: define internal void @[[outlined_fn]](i32 %[[global_tid:[^ ,]+]], ptr %[[task_data:.+]])
 // CHECK: task.alloca{{.*}}:
+// CHECK:   %[[shareds:.+]] = load ptr, ptr %[[task_data]]
 // CHECK:   br label %[[task_body:[^, ]+]]
 // CHECK: [[task_body]]:
 // CHECK:   br label %[[task_region:[^, ]+]]
@@ -2333,13 +2323,6 @@ module attributes {llvm.target_triple = "x86_64-unknown-linux-gnu"} {
 // CHECK: [[exit_stub]]:
 // CHECK:   ret void
 
-
-// CHECK: define i32 @[[wrapper_fn]](i32 %{{.+}}, ptr %[[task_data:.+]]) {
-// CHECK:  %[[shareds:.+]] = load ptr, ptr %1, align 8
-// CHECK:   call void @[[outlined_fn]](ptr %[[shareds]])
-// CHECK:   ret i32 0
-// CHECK: }
-
 // -----
 
 llvm.func @par_task_(%arg0: !llvm.ptr<i32> {fir.bindc_name = "a"}) {
@@ -2355,14 +2338,12 @@ llvm.func @par_task_(%arg0: !llvm.ptr<i32> {fir.bindc_name = "a"}) {
 }
 
 // CHECK-LABEL: @par_task_
-// CHECK: %[[TASK_ALLOC:.*]] = call ptr @__kmpc_omp_task_alloc({{.*}}ptr @par_task_..omp_par.wrapper)
+// CHECK: %[[TASK_ALLOC:.*]] = call ptr @__kmpc_omp_task_alloc({{.*}}ptr @[[task_outlined_fn:.+]])
 // CHECK: call i32 @__kmpc_omp_task({{.*}}, ptr %[[TASK_ALLOC]])
-// CHECK-LABEL: define internal void @par_task_..omp_par
+// CHECK: define internal void @[[task_outlined_fn]]
 // CHECK: %[[ARG_ALLOC:.*]] = alloca { ptr }, align 8
-// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @par_task_..omp_par..omp_par, ptr %[[ARG_ALLOC]])
-// CHECK: define internal void @par_task_..omp_par..omp_par
-// CHECK: define i32 @par_task_..omp_par.wrapper
-// CHECK: call void @par_task_..omp_par
+// CHECK: call void ({{.*}}) @__kmpc_fork_call({{.*}}, ptr @[[parallel_outlined_fn:.+]], ptr %[[ARG_ALLOC]])
+// CHECK: define internal void @[[parallel_outlined_fn]]
 // -----
 
 llvm.func @foo() -> ()
@@ -2432,7 +2413,7 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
 // CHECK:         br label %[[codeRepl:[^,]+]]
 // CHECK:       [[codeRepl]]:
 // CHECK:         %[[omp_global_thread_num_t1:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
-// CHECK:         %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 40, i64 0, ptr @omp_taskgroup_task..omp_par.wrapper)
+// CHECK:         %[[t1_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], i32 1, i64 40, i64 0, ptr @[[outlined_task_fn:.+]])
 // CHECK:         %{{.+}} = call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num_t1]], ptr %[[t1_alloc]])
 // CHECK:         br label %[[task_exit:[^,]+]]
 // CHECK:       [[task_exit]]:
@@ -2445,7 +2426,7 @@ llvm.func @omp_taskgroup_task(%x: i32, %y: i32, %zaddr: !llvm.ptr<i32>) {
 // CHECK:         %[[gep3:.+]] = getelementptr { i32, i32, ptr }, ptr %[[structArg]], i32 0, i32 2
 // CHECK:         store ptr %[[zaddr]], ptr %[[gep3]], align 8
 // CHECK:         %[[omp_global_thread_num_t2:.+]] = call i32 @__kmpc_global_thread_num(ptr @{{.+}})
-// CHECK:         %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 40, i64 16, ptr @omp_taskgroup_task..omp_par.1.wrapper)
+// CHECK:         %[[t2_alloc:.+]] = call ptr @__kmpc_omp_task_alloc(ptr @{{.+}}, i32 %[[omp_global_thread_num_t2]], i32 1, i64 40, i64 16, ptr @[[outlined_task_fn:.+]])
 // CHECK:         %[[shareds:.+]] = load ptr, ptr %[[t2_alloc]]
 // CHECK:         call void @llvm.memcpy.p0.p0.i64(ptr align ...
[truncated]

@shraiysh shraiysh requested a review from jsjodin September 28, 2023 20:17
@shraiysh shraiysh changed the title [OpenMPIRBuilder] Remove wrapper function in createTask [OpenMPIRBuilder] Remove wrapper function in createTask, createTeams Oct 2, 2023
@shraiysh
Copy link
Member Author

shraiysh commented Oct 3, 2023

Ping for review!

Copy link
Contributor

@kiranchandramohan kiranchandramohan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you expand the summary to describe the changes.
What are the changes required to remove the wrapper function (Why was it required in the first place?)
Why are the fake vals necessary?

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp Outdated Show resolved Hide resolved
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp Show resolved Hide resolved
llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@kiranchandramohan kiranchandramohan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed summary. LG.

@shraiysh shraiysh merged commit 9050b27 into llvm:main Oct 9, 2023
2 checks passed
// The stale call instruction will be replaced with a new call instruction
// for runtime call with a wrapper function.
// Add the thread ID argument.
std::stack<Instruction *> ToBeDeleted;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not a SmallVector, like we use everywhere else? We can reasonably guess the size to avoid dynamic allocations.


while (!ToBeDeleted.empty()) {
ToBeDeleted.top()->eraseFromParent();
ToBeDeleted.pop();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to pop anything, we have this code in other places already, why do we need to come up with new and exciting ways to do the same thing?

for (auto *TBD : ToBeDeleted)
  TBD->eraseFromParent


while (!ToBeDeleted.empty()) {
ToBeDeleted.top()->eraseFromParent();
ToBeDeleted.pop();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants