Skip to content

Commit 468d3b1

Browse files
authored
[flang][openacc][NFC] Simplify lowering of recipe (#68836)
Refactor some of the lowering in the reduction and firstprivate recipe to avoid duplicated code.
1 parent e32cde6 commit 468d3b1

File tree

1 file changed

+74
-101
lines changed

1 file changed

+74
-101
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 74 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ bool isConstantBound(mlir::acc::DataBoundsOp &op) {
463463
}
464464

465465
/// Return true iff all the bounds are expressed with constant values.
466-
bool areAllBoundConstant(llvm::SmallVector<mlir::Value> &bounds) {
466+
bool areAllBoundConstant(const llvm::SmallVector<mlir::Value> &bounds) {
467467
for (auto bound : bounds) {
468468
auto dataBound =
469469
mlir::dyn_cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
@@ -474,27 +474,6 @@ bool areAllBoundConstant(llvm::SmallVector<mlir::Value> &bounds) {
474474
return true;
475475
}
476476

477-
static fir::ShapeOp
478-
genShapeFromBounds(mlir::Location loc, fir::FirOpBuilder &builder,
479-
const llvm::SmallVector<mlir::Value> &args) {
480-
assert(args.size() % 3 == 0 && "Triplets must be a multiple of 3");
481-
llvm::SmallVector<mlir::Value> extents;
482-
mlir::Type idxTy = builder.getIndexType();
483-
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
484-
mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
485-
for (unsigned i = 0; i < args.size(); i += 3) {
486-
mlir::Value s1 =
487-
builder.create<mlir::arith::SubIOp>(loc, args[i + 1], args[0]);
488-
mlir::Value s2 = builder.create<mlir::arith::AddIOp>(loc, s1, one);
489-
mlir::Value s3 = builder.create<mlir::arith::DivSIOp>(loc, s2, args[i + 2]);
490-
mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
491-
loc, mlir::arith::CmpIPredicate::sgt, s3, zero);
492-
mlir::Value ext = builder.create<mlir::arith::SelectOp>(loc, cmp, s3, zero);
493-
extents.push_back(ext);
494-
}
495-
return builder.create<fir::ShapeOp>(loc, extents);
496-
}
497-
498477
static llvm::SmallVector<mlir::Value>
499478
genConstantBounds(fir::FirOpBuilder &builder, mlir::Location loc,
500479
mlir::acc::DataBoundsOp &dataBound) {
@@ -520,6 +499,63 @@ genConstantBounds(fir::FirOpBuilder &builder, mlir::Location loc,
520499
return {lb, ub, step};
521500
}
522501

502+
static fir::ShapeOp genShapeFromBoundsOrArgs(
503+
mlir::Location loc, fir::FirOpBuilder &builder, fir::SequenceType seqTy,
504+
const llvm::SmallVector<mlir::Value> &bounds, mlir::ValueRange arguments) {
505+
llvm::SmallVector<mlir::Value> args;
506+
if (areAllBoundConstant(bounds)) {
507+
for (auto bound : llvm::reverse(bounds)) {
508+
auto dataBound =
509+
mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
510+
args.append(genConstantBounds(builder, loc, dataBound));
511+
}
512+
} else {
513+
assert(((arguments.size() - 2) / 3 == seqTy.getDimension()) &&
514+
"Expect 3 block arguments per dimension");
515+
for (auto arg : arguments.drop_front(2))
516+
args.push_back(arg);
517+
}
518+
519+
assert(args.size() % 3 == 0 && "Triplets must be a multiple of 3");
520+
llvm::SmallVector<mlir::Value> extents;
521+
mlir::Type idxTy = builder.getIndexType();
522+
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
523+
mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
524+
for (unsigned i = 0; i < args.size(); i += 3) {
525+
mlir::Value s1 =
526+
builder.create<mlir::arith::SubIOp>(loc, args[i + 1], args[0]);
527+
mlir::Value s2 = builder.create<mlir::arith::AddIOp>(loc, s1, one);
528+
mlir::Value s3 = builder.create<mlir::arith::DivSIOp>(loc, s2, args[i + 2]);
529+
mlir::Value cmp = builder.create<mlir::arith::CmpIOp>(
530+
loc, mlir::arith::CmpIPredicate::sgt, s3, zero);
531+
mlir::Value ext = builder.create<mlir::arith::SelectOp>(loc, cmp, s3, zero);
532+
extents.push_back(ext);
533+
}
534+
return builder.create<fir::ShapeOp>(loc, extents);
535+
}
536+
537+
static hlfir::DesignateOp::Subscripts
538+
getSubscriptsFromArgs(mlir::ValueRange args) {
539+
hlfir::DesignateOp::Subscripts triplets;
540+
for (unsigned i = 2; i < args.size(); i += 3)
541+
triplets.emplace_back(
542+
hlfir::DesignateOp::Triplet{args[i], args[i + 1], args[i + 2]});
543+
return triplets;
544+
}
545+
546+
static hlfir::Entity genDesignateWithTriplets(
547+
fir::FirOpBuilder &builder, mlir::Location loc, hlfir::Entity &entity,
548+
hlfir::DesignateOp::Subscripts &triplets, mlir::Value shape) {
549+
llvm::SmallVector<mlir::Value> lenParams;
550+
hlfir::genLengthParameters(loc, builder, entity, lenParams);
551+
auto designate = builder.create<hlfir::DesignateOp>(
552+
loc, entity.getBase().getType(), entity, /*component=*/"",
553+
/*componentShape=*/mlir::Value{}, triplets,
554+
/*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt, shape,
555+
lenParams);
556+
return hlfir::Entity{designate.getResult()};
557+
}
558+
523559
mlir::acc::FirstprivateRecipeOp Fortran::lower::createOrGetFirstprivateRecipe(
524560
mlir::OpBuilder &builder, llvm::StringRef recipeName, mlir::Location loc,
525561
mlir::Type ty, llvm::SmallVector<mlir::Value> &bounds) {
@@ -600,47 +636,16 @@ mlir::acc::FirstprivateRecipeOp Fortran::lower::createOrGetFirstprivateRecipe(
600636
if (!seqTy)
601637
TODO(loc, "Unsupported boxed type in OpenACC firstprivate");
602638

603-
if (allConstantBound) {
604-
for (auto bound : llvm::reverse(bounds)) {
605-
auto dataBound =
606-
mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
607-
tripletArgs.append(genConstantBounds(firBuilder, loc, dataBound));
608-
}
609-
} else {
610-
assert(((recipe.getCopyRegion().getArguments().size() - 2) / 3 ==
611-
seqTy.getDimension()) &&
612-
"Expect 3 block arguments per dimension");
613-
for (auto arg : recipe.getCopyRegion().getArguments().drop_front(2))
614-
tripletArgs.push_back(arg);
615-
}
616-
auto shape = genShapeFromBounds(loc, firBuilder, tripletArgs);
617-
hlfir::DesignateOp::Subscripts triplets;
618-
for (unsigned i = 2; i < recipe.getCopyRegion().getArguments().size();
619-
i += 3)
620-
triplets.emplace_back(hlfir::DesignateOp::Triplet{
621-
recipe.getCopyRegion().getArgument(i),
622-
recipe.getCopyRegion().getArgument(i + 1),
623-
recipe.getCopyRegion().getArgument(i + 2)});
624-
625-
llvm::SmallVector<mlir::Value> lenParamsLeft;
639+
auto shape = genShapeFromBoundsOrArgs(
640+
loc, firBuilder, seqTy, bounds, recipe.getCopyRegion().getArguments());
641+
hlfir::DesignateOp::Subscripts triplets =
642+
getSubscriptsFromArgs(recipe.getCopyRegion().getArguments());
626643
auto leftEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(0)};
627-
hlfir::genLengthParameters(loc, firBuilder, leftEntity, lenParamsLeft);
628-
auto leftDesignate = firBuilder.create<hlfir::DesignateOp>(
629-
loc, leftEntity.getBase().getType(), leftEntity, /*component=*/"",
630-
/*componentShape=*/mlir::Value{}, triplets,
631-
/*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
632-
shape, lenParamsLeft);
633-
auto left = hlfir::Entity{leftDesignate.getResult()};
634-
635-
llvm::SmallVector<mlir::Value> lenParamsRight;
644+
auto left =
645+
genDesignateWithTriplets(firBuilder, loc, leftEntity, triplets, shape);
636646
auto rightEntity = hlfir::Entity{recipe.getCopyRegion().getArgument(1)};
637-
hlfir::genLengthParameters(loc, firBuilder, rightEntity, lenParamsRight);
638-
auto rightDesignate = firBuilder.create<hlfir::DesignateOp>(
639-
loc, rightEntity.getBase().getType(), rightEntity, /*component=*/"",
640-
/*componentShape=*/mlir::Value{}, triplets,
641-
/*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
642-
shape, lenParamsRight);
643-
auto right = hlfir::Entity{rightDesignate.getResult()};
647+
auto right =
648+
genDesignateWithTriplets(firBuilder, loc, rightEntity, triplets, shape);
644649
firBuilder.create<hlfir::AssignOp>(loc, left, right);
645650
}
646651

@@ -1110,48 +1115,16 @@ static void genCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
11101115
if (!seqTy)
11111116
TODO(loc, "Unsupported boxed type in OpenACC reduction");
11121117

1113-
if (allConstantBound) {
1114-
for (auto bound : llvm::reverse(bounds)) {
1115-
auto dataBound =
1116-
mlir::cast<mlir::acc::DataBoundsOp>(bound.getDefiningOp());
1117-
tripletArgs.append(genConstantBounds(builder, loc, dataBound));
1118-
}
1119-
} else {
1120-
assert(((recipe.getCombinerRegion().getArguments().size() - 2) / 3 ==
1121-
seqTy.getDimension()) &&
1122-
"Expect 3 block arguments per dimension");
1123-
for (auto arg : recipe.getCombinerRegion().getArguments().drop_front(2))
1124-
tripletArgs.push_back(arg);
1125-
}
1126-
auto shape = genShapeFromBounds(loc, builder, tripletArgs);
1127-
1128-
hlfir::DesignateOp::Subscripts triplets;
1129-
for (unsigned i = 2; i < recipe.getCombinerRegion().getArguments().size();
1130-
i += 3)
1131-
triplets.emplace_back(hlfir::DesignateOp::Triplet{
1132-
recipe.getCombinerRegion().getArgument(i),
1133-
recipe.getCombinerRegion().getArgument(i + 1),
1134-
recipe.getCombinerRegion().getArgument(i + 2)});
1135-
1136-
llvm::SmallVector<mlir::Value> lenParamsLeft;
1118+
auto shape = genShapeFromBoundsOrArgs(
1119+
loc, builder, seqTy, bounds, recipe.getCombinerRegion().getArguments());
1120+
hlfir::DesignateOp::Subscripts triplets =
1121+
getSubscriptsFromArgs(recipe.getCombinerRegion().getArguments());
11371122
auto leftEntity = hlfir::Entity{value1};
1138-
hlfir::genLengthParameters(loc, builder, leftEntity, lenParamsLeft);
1139-
auto leftDesignate = builder.create<hlfir::DesignateOp>(
1140-
loc, value1.getType(), leftEntity, /*component=*/"",
1141-
/*componentShape=*/mlir::Value{}, triplets,
1142-
/*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
1143-
shape, lenParamsLeft);
1144-
auto left = hlfir::Entity{leftDesignate.getResult()};
1145-
1146-
llvm::SmallVector<mlir::Value> lenParamsRight;
1123+
auto left =
1124+
genDesignateWithTriplets(builder, loc, leftEntity, triplets, shape);
11471125
auto rightEntity = hlfir::Entity{value2};
1148-
hlfir::genLengthParameters(loc, builder, rightEntity, lenParamsRight);
1149-
auto rightDesignate = builder.create<hlfir::DesignateOp>(
1150-
loc, value2.getType(), rightEntity, /*component=*/"",
1151-
/*componentShape=*/mlir::Value{}, triplets,
1152-
/*substring=*/mlir::ValueRange{}, /*complexPartAttr=*/std::nullopt,
1153-
shape, lenParamsRight);
1154-
auto right = hlfir::Entity{rightDesignate.getResult()};
1126+
auto right =
1127+
genDesignateWithTriplets(builder, loc, rightEntity, triplets, shape);
11551128

11561129
llvm::SmallVector<mlir::Value, 1> typeParams;
11571130
auto genKernel = [&builder, &loc, op, seqTy, &left, &right](

0 commit comments

Comments
 (0)