@@ -4033,7 +4033,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeams) {
40334033 };
40344034
40354035 OpenMPIRBuilder::LocationDescription Loc ({Builder.saveIP (), DL});
4036- Builder.restoreIP (OMPBuilder.createTeams (Builder, BodyGenCB));
4036+ Builder.restoreIP (OMPBuilder.createTeams (
4037+ Builder, BodyGenCB, /* NumTeamsLower=*/ nullptr , /* NumTeamsUpper=*/ nullptr ,
4038+ /* ThreadLimit=*/ nullptr , /* IfExpr=*/ nullptr ));
40374039
40384040 OMPBuilder.finalize ();
40394041 Builder.CreateRetVoid ();
@@ -4095,7 +4097,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithThreadLimit) {
40954097 Builder.restoreIP (OMPBuilder.createTeams (/* =*/ Builder, BodyGenCB,
40964098 /* NumTeamsLower=*/ nullptr ,
40974099 /* NumTeamsUpper=*/ nullptr ,
4098- /* ThreadLimit=*/ F->arg_begin ()));
4100+ /* ThreadLimit=*/ F->arg_begin (),
4101+ /* IfExpr=*/ nullptr ));
40994102
41004103 Builder.CreateRetVoid ();
41014104 OMPBuilder.finalize ();
@@ -4144,7 +4147,9 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsUpper) {
41444147 // `num_teams`
41454148 Builder.restoreIP (OMPBuilder.createTeams (Builder, BodyGenCB,
41464149 /* NumTeamsLower=*/ nullptr ,
4147- /* NumTeamsUpper=*/ F->arg_begin ()));
4150+ /* NumTeamsUpper=*/ F->arg_begin (),
4151+ /* ThreadLimit=*/ nullptr ,
4152+ /* IfExpr=*/ nullptr ));
41484153
41494154 Builder.CreateRetVoid ();
41504155 OMPBuilder.finalize ();
@@ -4197,7 +4202,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsBoth) {
41974202 // `F` already has an integer argument, so we use that as upper bound to
41984203 // `num_teams`
41994204 Builder.restoreIP (
4200- OMPBuilder.createTeams (Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper));
4205+ OMPBuilder.createTeams (Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper,
4206+ /* ThreadLimit=*/ nullptr , /* IfExpr=*/ nullptr ));
42014207
42024208 Builder.CreateRetVoid ();
42034209 OMPBuilder.finalize ();
@@ -4255,8 +4261,8 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
42554261 };
42564262
42574263 OpenMPIRBuilder::LocationDescription Loc ({Builder.saveIP (), DL});
4258- Builder.restoreIP (OMPBuilder.createTeams (Builder, BodyGenCB, NumTeamsLower,
4259- NumTeamsUpper, ThreadLimit));
4264+ Builder.restoreIP (OMPBuilder.createTeams (
4265+ Builder, BodyGenCB, NumTeamsLower, NumTeamsUpper, ThreadLimit, nullptr ));
42604266
42614267 Builder.CreateRetVoid ();
42624268 OMPBuilder.finalize ();
@@ -4284,6 +4290,134 @@ TEST_F(OpenMPIRBuilderTest, CreateTeamsWithNumTeamsAndThreadLimit) {
42844290 OMPBuilder.getOrCreateRuntimeFunctionPtr (OMPRTL___kmpc_fork_teams));
42854291}
42864292
4293+ TEST_F (OpenMPIRBuilderTest, CreateTeamsWithIfCondition) {
4294+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
4295+ OpenMPIRBuilder OMPBuilder (*M);
4296+ OMPBuilder.initialize ();
4297+ F->setName (" func" );
4298+ IRBuilder<> &Builder = OMPBuilder.Builder ;
4299+ Builder.SetInsertPoint (BB);
4300+
4301+ Value *IfExpr = Builder.CreateLoad (Builder.getInt1Ty (),
4302+ Builder.CreateAlloca (Builder.getInt1Ty ()));
4303+
4304+ Function *FakeFunction =
4305+ Function::Create (FunctionType::get (Builder.getVoidTy (), false ),
4306+ GlobalValue::ExternalLinkage, " fakeFunction" , M.get ());
4307+
4308+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4309+ Builder.restoreIP (CodeGenIP);
4310+ Builder.CreateCall (FakeFunction, {});
4311+ };
4312+
4313+ // `F` already has an integer argument, so we use that as upper bound to
4314+ // `num_teams`
4315+ Builder.restoreIP (OMPBuilder.createTeams (
4316+ Builder, BodyGenCB, /* NumTeamsLower=*/ nullptr , /* NumTeamsUpper=*/ nullptr ,
4317+ /* ThreadLimit=*/ nullptr , IfExpr));
4318+
4319+ Builder.CreateRetVoid ();
4320+ OMPBuilder.finalize ();
4321+
4322+ ASSERT_FALSE (verifyModule (*M));
4323+
4324+ CallInst *PushNumTeamsCallInst =
4325+ findSingleCall (F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
4326+ ASSERT_NE (PushNumTeamsCallInst, nullptr );
4327+ Value *NumTeamsLower = PushNumTeamsCallInst->getArgOperand (2 );
4328+ Value *NumTeamsUpper = PushNumTeamsCallInst->getArgOperand (3 );
4329+ Value *ThreadLimit = PushNumTeamsCallInst->getArgOperand (4 );
4330+
4331+ // Check the lower_bound
4332+ ASSERT_NE (NumTeamsLower, nullptr );
4333+ SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLower);
4334+ ASSERT_NE (NumTeamsLowerSelectInst, nullptr );
4335+ EXPECT_EQ (NumTeamsLowerSelectInst->getCondition (), IfExpr);
4336+ EXPECT_EQ (NumTeamsLowerSelectInst->getTrueValue (), Builder.getInt32 (0 ));
4337+ EXPECT_EQ (NumTeamsLowerSelectInst->getFalseValue (), Builder.getInt32 (1 ));
4338+
4339+ // Check the upper_bound
4340+ ASSERT_NE (NumTeamsUpper, nullptr );
4341+ SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpper);
4342+ ASSERT_NE (NumTeamsUpperSelectInst, nullptr );
4343+ EXPECT_EQ (NumTeamsUpperSelectInst->getCondition (), IfExpr);
4344+ EXPECT_EQ (NumTeamsUpperSelectInst->getTrueValue (), Builder.getInt32 (0 ));
4345+ EXPECT_EQ (NumTeamsUpperSelectInst->getFalseValue (), Builder.getInt32 (1 ));
4346+
4347+ // Check thread_limit
4348+ EXPECT_EQ (ThreadLimit, Builder.getInt32 (0 ));
4349+ }
4350+
4351+ TEST_F (OpenMPIRBuilderTest, CreateTeamsWithIfConditionAndNumTeams) {
4352+ using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
4353+ OpenMPIRBuilder OMPBuilder (*M);
4354+ OMPBuilder.initialize ();
4355+ F->setName (" func" );
4356+ IRBuilder<> &Builder = OMPBuilder.Builder ;
4357+ Builder.SetInsertPoint (BB);
4358+
4359+ Value *IfExpr = Builder.CreateLoad (
4360+ Builder.getInt32Ty (), Builder.CreateAlloca (Builder.getInt32Ty ()));
4361+ Value *NumTeamsLower = Builder.CreateAdd (F->arg_begin (), Builder.getInt32 (5 ));
4362+ Value *NumTeamsUpper =
4363+ Builder.CreateAdd (F->arg_begin (), Builder.getInt32 (10 ));
4364+ Value *ThreadLimit = Builder.CreateAdd (F->arg_begin (), Builder.getInt32 (20 ));
4365+
4366+ Function *FakeFunction =
4367+ Function::Create (FunctionType::get (Builder.getVoidTy (), false ),
4368+ GlobalValue::ExternalLinkage, " fakeFunction" , M.get ());
4369+
4370+ auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
4371+ Builder.restoreIP (CodeGenIP);
4372+ Builder.CreateCall (FakeFunction, {});
4373+ };
4374+
4375+ // `F` already has an integer argument, so we use that as upper bound to
4376+ // `num_teams`
4377+ Builder.restoreIP (OMPBuilder.createTeams (Builder, BodyGenCB, NumTeamsLower,
4378+ NumTeamsUpper, ThreadLimit, IfExpr));
4379+
4380+ Builder.CreateRetVoid ();
4381+ OMPBuilder.finalize ();
4382+
4383+ ASSERT_FALSE (verifyModule (*M));
4384+
4385+ CallInst *PushNumTeamsCallInst =
4386+ findSingleCall (F, OMPRTL___kmpc_push_num_teams_51, OMPBuilder);
4387+ ASSERT_NE (PushNumTeamsCallInst, nullptr );
4388+ Value *NumTeamsLowerArg = PushNumTeamsCallInst->getArgOperand (2 );
4389+ Value *NumTeamsUpperArg = PushNumTeamsCallInst->getArgOperand (3 );
4390+ Value *ThreadLimitArg = PushNumTeamsCallInst->getArgOperand (4 );
4391+
4392+ // Get the boolean conversion of if expression
4393+ ASSERT_EQ (IfExpr->getNumUses (), 1U );
4394+ User *IfExprInst = IfExpr->user_back ();
4395+ ICmpInst *IfExprCmpInst = dyn_cast<ICmpInst>(IfExprInst);
4396+ ASSERT_NE (IfExprCmpInst, nullptr );
4397+ EXPECT_EQ (IfExprCmpInst->getPredicate (), ICmpInst::Predicate::ICMP_NE);
4398+ EXPECT_EQ (IfExprCmpInst->getOperand (0 ), IfExpr);
4399+ EXPECT_EQ (IfExprCmpInst->getOperand (1 ), Builder.getInt32 (0 ));
4400+
4401+ // Check the lower_bound
4402+ ASSERT_NE (NumTeamsLowerArg, nullptr );
4403+ SelectInst *NumTeamsLowerSelectInst = dyn_cast<SelectInst>(NumTeamsLowerArg);
4404+ ASSERT_NE (NumTeamsLowerSelectInst, nullptr );
4405+ EXPECT_EQ (NumTeamsLowerSelectInst->getCondition (), IfExprCmpInst);
4406+ EXPECT_EQ (NumTeamsLowerSelectInst->getTrueValue (), NumTeamsLower);
4407+ EXPECT_EQ (NumTeamsLowerSelectInst->getFalseValue (), Builder.getInt32 (1 ));
4408+
4409+ // Check the upper_bound
4410+ ASSERT_NE (NumTeamsUpperArg, nullptr );
4411+ SelectInst *NumTeamsUpperSelectInst = dyn_cast<SelectInst>(NumTeamsUpperArg);
4412+ ASSERT_NE (NumTeamsUpperSelectInst, nullptr );
4413+ EXPECT_EQ (NumTeamsUpperSelectInst->getCondition (), IfExprCmpInst);
4414+ EXPECT_EQ (NumTeamsUpperSelectInst->getTrueValue (), NumTeamsUpper);
4415+ EXPECT_EQ (NumTeamsUpperSelectInst->getFalseValue (), Builder.getInt32 (1 ));
4416+
4417+ // Check thread_limit
4418+ EXPECT_EQ (ThreadLimitArg, ThreadLimit);
4419+ }
4420+
42874421// / Returns the single instruction of InstTy type in BB that uses the value V.
42884422// / If there is more than one such instruction, returns null.
42894423template <typename InstTy>
0 commit comments