Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][flang][openmp] Rework wsloop reduction operations #80019

Merged
merged 4 commits into from
Feb 13, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
[mlir][flang][openmp] Rework wsloop reduction operations
This patch reworks the way that wsloop reduction operations function to better
match the expected semantics from the OpenMP specification, following the rework
of parallel reductions.

The new semantics create a private reduction variable as a block argument which
should be used normally for all operations on that variable in the region; this
private variable is then combined with the others into the shared variable. This
way no special omp.reduction operations are needed inside the region. These
block arguments follow the loop control block arguments.
DavidTruby authored and kiranchandramohan committed Feb 13, 2024
commit b2fb948fe3ea1f4273aab271639b944e50d667b8
18 changes: 15 additions & 3 deletions flang/lib/Lower/OpenMP.cpp
Original file line number Diff line number Diff line change
@@ -2366,6 +2366,12 @@ static void createBodyOfOp(Op &op, OpWithBodyGenInfo &info) {
return undef.getDefiningOp();
};

llvm::SmallVector<mlir::Type> blockArgTypes;
llvm::SmallVector<mlir::Location> blockArgLocs;
blockArgTypes.reserve(loopArgs.size() + reductionArgs.size());
blockArgLocs.reserve(blockArgTypes.size());
mlir::Block *entryBlock;

// If an argument for the region is provided then create the block with that
// argument. Also update the symbol's address with the mlir argument value.
// e.g. For loops the argument is the induction variable. And all further
@@ -3429,6 +3435,7 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Value> linearVars, linearStepVars;
llvm::SmallVector<const Fortran::semantics::Symbol *> iv;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
mlir::omp::ClauseScheduleKindAttr scheduleValClauseOperand;
mlir::UnitAttr nowaitClauseOperand, scheduleSimdClauseOperand;
@@ -3440,7 +3447,8 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
cp.processCollapse(loc, eval, lowerBound, upperBound, step, iv,
loopVarTypeSize);
cp.processScheduleChunk(stmtCtx, scheduleChunkClauseOperand);
cp.processReduction(loc, reductionVars, reductionDeclSymbols);
cp.processReduction(loc, reductionVars, reductionDeclSymbols,
&reductionSymbols);
cp.processTODO<Fortran::parser::OmpClause::Linear,
Fortran::parser::OmpClause::Order>(loc, ompDirective);

@@ -3488,6 +3496,11 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
return genLoopVars(op, converter, loc, iv);
};

//llvm::SmallVector<mlir::Type> reductionTypes;
//reductionTypes.reserve(reductionVars.size());
//llvm::transform(reductionVars, std::back_inserter(reductionTypes),
// [](mlir::Value v) { return v.getType(); });

createBodyOfOp<mlir::omp::WsLoopOp>(
wsLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
.setClauses(&beginClauseList)
@@ -3594,12 +3607,11 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
// 2.9.3.1 SIMD construct
createSimdLoop(converter, semaCtx, eval, ompDirective, loopOpClauseList,
currentLocation);
genOpenMPReduction(converter, loopOpClauseList);
} else {
createWsLoop(converter, semaCtx, eval, ompDirective, loopOpClauseList,
endClauseList, currentLocation);
}

genOpenMPReduction(converter, semaCtx, loopOpClauseList);
}

static void
20 changes: 16 additions & 4 deletions flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
Original file line number Diff line number Diff line change
@@ -701,10 +701,17 @@ func.func @_QPsb() {
// CHECK-SAME: %[[ARRAY_REF:.*]]: !llvm.ptr
// CHECK: %[[RED_ACCUMULATOR:.*]] = llvm.alloca %2 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
// CHECK: omp.parallel {
// CHECK: omp.wsloop reduction(@[[EQV_REDUCTION]] -> %[[RED_ACCUMULATOR]] : !llvm.ptr) for
// CHECK: omp.wsloop reduction(@[[EQV_REDUCTION]] %[[RED_ACCUMULATOR]] -> %[[PRV:.+]] : !llvm.ptr) for
// CHECK: %[[ARRAY_ELEM_REF:.*]] = llvm.getelementptr %[[ARRAY_REF]][0, %{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr
// CHECK: %[[ARRAY_ELEM:.*]] = llvm.load %[[ARRAY_ELEM_REF]] : !llvm.ptr -> i32
// CHECK: omp.reduction %[[ARRAY_ELEM]], %[[RED_ACCUMULATOR]] : i32, !llvm.ptr
// CHECK: %[[LPRV:.+]] = llvm.load %[[PRV]] : !llvm.ptr -> i32
// CHECK: %[[ZERO_1:.*]] = llvm.mlir.constant(0 : i64) : i32
// CHECK: %[[ARGVAL_1:.*]] = llvm.icmp "ne" %[[LPRV]], %[[ZERO_1]] : i32
// CHECK: %[[ZERO_2:.*]] = llvm.mlir.constant(0 : i64) : i32
// CHECK: %[[ARGVAL_2:.*]] = llvm.icmp "ne" %[[ARRAY_ELEM]], %[[ZERO_2]] : i32
// CHECK: %[[RES:.*]] = llvm.icmp "eq" %[[ARGVAL_2]], %[[ARGVAL_1]] : i1
// CHECK: %[[RES_EXT:.*]] = llvm.zext %[[RES]] : i1 to i32
// CHECK: llvm.store %[[RES_EXT]], %[[PRV]] : i32, !llvm.ptr
// CHECK: omp.yield
// CHECK: omp.terminator
// CHECK: llvm.return
@@ -733,15 +740,20 @@ func.func @_QPsimple_reduction(%arg0: !fir.ref<!fir.array<100x!fir.logical<4>>>
%c1_i32 = arith.constant 1 : i32
%c100_i32 = arith.constant 100 : i32
%c1_i32_0 = arith.constant 1 : i32
omp.wsloop reduction(@eqv_reduction -> %1 : !fir.ref<!fir.logical<4>>) for (%arg1) : i32 = (%c1_i32) to (%c100_i32) inclusive step (%c1_i32_0) {
omp.wsloop reduction(@eqv_reduction %1 -> %prv : !fir.ref<!fir.logical<4>>) for (%arg1) : i32 = (%c1_i32) to (%c100_i32) inclusive step (%c1_i32_0) {
fir.store %arg1 to %3 : !fir.ref<i32>
%4 = fir.load %3 : !fir.ref<i32>
%5 = fir.convert %4 : (i32) -> i64
%c1_i64 = arith.constant 1 : i64
%6 = arith.subi %5, %c1_i64 : i64
%7 = fir.coordinate_of %arg0, %6 : (!fir.ref<!fir.array<100x!fir.logical<4>>>, i64) -> !fir.ref<!fir.logical<4>>
%8 = fir.load %7 : !fir.ref<!fir.logical<4>>
omp.reduction %8, %1 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
%lprv = fir.load %prv : !fir.ref<!fir.logical<4>>
%lprv1 = fir.convert %lprv : (!fir.logical<4>) -> i1
%9 = fir.convert %8 : (!fir.logical<4>) -> i1
%10 = arith.cmpi eq, %9, %lprv1 : i1
%11 = fir.convert %10 : (i1) -> !fir.logical<4>
fir.store %11 to %prv : !fir.ref<!fir.logical<4>>
omp.yield
}
omp.terminator
Loading