Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[mlir][flang][openmp] Rework parallel reduction operations
Browse files Browse the repository at this point in the history
This patch reworks the way that parallel reduction operations function to better
match the expected semantics from the OpenMP specification. Previously specific
omp.reduction operations were used inside the region, meaning that the reduction
only applied when the correct operation was used, whereas the specification
states that any change to the variable inside the region should be taken into
account for the reduction.

The new semantics create a private reduction variable as a block argument which
should be used normally for all operations on that variable in the region; this
private variable is then combined with the others into the shared variable. This
way no special omp.reduction operations are needed inside the region.

This patch only makes the change for the `parallel` operation, the change for
the `wsloop` operation will be in a separate patch.
DavidTruby authored and kiranchandramohan committed Feb 12, 2024
1 parent ffab5a0 commit 9e6aa2f
Showing 8 changed files with 211 additions and 63 deletions.
78 changes: 60 additions & 18 deletions flang/lib/Lower/OpenMP.cpp
Original file line number Diff line number Diff line change
@@ -621,10 +621,12 @@ class ClauseProcessor {
llvm::SmallVectorImpl<mlir::Location> *mapSymLocs = nullptr,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
*mapSymbols = nullptr) const;
bool processReduction(
mlir::Location currentLocation,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const;
bool
processReduction(mlir::Location currentLocation,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
*reductionSymbols = nullptr) const;
bool processSectionsReduction(mlir::Location currentLocation) const;
bool processTo(llvm::SmallVectorImpl<DeclareTargetCapturePair> &result) const;
bool
@@ -1079,12 +1081,14 @@ class ReductionProcessor {

/// Creates a reduction declaration and associates it with an OpenMP block
/// directive.
static void addReductionDecl(
mlir::Location currentLocation,
Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpReductionClause &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) {
static void
addReductionDecl(mlir::Location currentLocation,
Fortran::lower::AbstractConverter &converter,
const Fortran::parser::OmpReductionClause &reduction,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *>
*reductionSymbols = nullptr) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
mlir::omp::ReductionDeclareOp decl;
const auto &redOperator{
@@ -1114,6 +1118,8 @@ class ReductionProcessor {
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
if (reductionSymbols)
reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
symVal = declOp.getBase();
@@ -1148,6 +1154,8 @@ class ReductionProcessor {
if (const auto *name{
Fortran::parser::Unwrap<Fortran::parser::Name>(ompObject)}) {
if (const Fortran::semantics::Symbol * symbol{name->symbol}) {
if (reductionSymbols)
reductionSymbols->push_back(symbol);
mlir::Value symVal = converter.getSymbolAddress(*symbol);
if (auto declOp = symVal.getDefiningOp<hlfir::DeclareOp>())
symVal = declOp.getBase();
@@ -1948,13 +1956,16 @@ bool ClauseProcessor::processMap(
bool ClauseProcessor::processReduction(
mlir::Location currentLocation,
llvm::SmallVectorImpl<mlir::Value> &reductionVars,
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols) const {
llvm::SmallVectorImpl<mlir::Attribute> &reductionDeclSymbols,
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *reductionSymbols)
const {
return findRepeatableClause<ClauseTy::Reduction>(
[&](const ClauseTy::Reduction *reductionClause,
const Fortran::parser::CharBlock &) {
ReductionProcessor rp;
rp.addReductionDecl(currentLocation, converter, reductionClause->v,
reductionVars, reductionDeclSymbols);
reductionVars, reductionDeclSymbols,
reductionSymbols);
});
}

@@ -2304,6 +2315,14 @@ struct OpWithBodyGenInfo {
return *this;
}

OpWithBodyGenInfo &
setReductions(llvm::SmallVector<const Fortran::semantics::Symbol *> *value1,
llvm::SmallVector<mlir::Type> *value2) {
reductionSymbols = value1;
reductionTypes = value2;
return *this;
}

OpWithBodyGenInfo &setGenRegionEntryCb(GenOMPRegionEntryCBFn value) {
genRegionEntryCB = value;
return *this;
@@ -2323,6 +2342,11 @@ struct OpWithBodyGenInfo {
const Fortran::parser::OmpClauseList *clauses = nullptr;
/// [in] if provided, processes the construct's data-sharing attributes.
DataSharingProcessor *dsp = nullptr;
/// [in] if provided, list of reduction symbols
llvm::SmallVector<const Fortran::semantics::Symbol *> *reductionSymbols =
nullptr;
/// [in] if provided, list of reduction types
llvm::SmallVector<mlir::Type> *reductionTypes = nullptr;
/// [in] if provided, emits the op's region entry. Otherwise, an emtpy block
/// is created in the region.
GenOMPRegionEntryCBFn genRegionEntryCB = nullptr;
@@ -2567,6 +2591,7 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Value> allocateOperands, allocatorOperands,
reductionVars;
llvm::SmallVector<mlir::Attribute> reductionDeclSymbols;
llvm::SmallVector<const Fortran::semantics::Symbol *> reductionSymbols;

ClauseProcessor cp(converter, clauseList);
cp.processIf(Fortran::parser::OmpIfClause::DirectiveNameModifier::Parallel,
@@ -2576,13 +2601,33 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
cp.processDefault();
cp.processAllocate(allocatorOperands, allocateOperands);
if (!outerCombined)
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols,
&reductionSymbols);

llvm::SmallVector<mlir::Type> reductionTypes;
reductionTypes.reserve(reductionVars.size());
llvm::transform(reductionVars, std::back_inserter(reductionTypes),
[](mlir::Value v) { return v.getType(); });

auto reductionCallback = [&](mlir::Operation *op) {
llvm::SmallVector<mlir::Location> locs(reductionVars.size(),
currentLocation);
auto block = converter.getFirOpBuilder().createBlock(&op->getRegion(0), {},
reductionTypes, locs);
for (auto [arg, prv] :
llvm::zip_equal(reductionSymbols, block->getArguments())) {
converter.bindSymbol(*arg, prv);
}
return reductionSymbols;
};

return genOpWithBody<mlir::omp::ParallelOp>(
OpWithBodyGenInfo(converter, currentLocation, eval)
.setGenNested(genNested)
.setOuterCombined(outerCombined)
.setClauses(&clauseList),
.setClauses(&clauseList)
.setReductions(&reductionSymbols, &reductionTypes)
.setGenRegionEntryCb(reductionCallback),
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
numThreadsClauseOperand, allocateOperands, allocatorOperands,
reductionVars,
@@ -3634,10 +3679,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
break;
}

if (singleDirective) {
genOpenMPReduction(converter, beginClauseList);
if (singleDirective)
return;
}

// Codegen for combined directives
bool combinedDirective = false;
@@ -3673,7 +3716,6 @@ genOMP(Fortran::lower::AbstractConverter &converter,
")");

genNestedEvaluations(converter, eval);
genOpenMPReduction(converter, beginClauseList);
}

static void
26 changes: 17 additions & 9 deletions flang/test/Lower/OpenMP/FIR/parallel-reduction-add.f90
Original file line number Diff line number Diff line change
@@ -27,9 +27,11 @@
!CHECK: %[[IREF:.*]] = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFsimple_int_addEi"}
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>) {
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV:.+]] : !fir.ref<i32>) {
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<i32>
!CHECK: %[[I_INCR:.+]] = arith.constant 1 : i32
!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]]
!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<i32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
@@ -48,9 +50,11 @@ subroutine simple_int_add
!CHECK: %[[RREF:.*]] = fir.alloca f32 {bindc_name = "r", uniq_name = "_QFsimple_real_addEr"}
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[RREF]] -> %[[PRV:.+]] : !fir.ref<f32>) {
!CHECK: %[[LPRV:.+]] = fir.load %[[PRV]] : !fir.ref<f32>
!CHECK: %[[R_INCR:.+]] = arith.constant 1.500000e+00 : f32
!CHECK: %[[RES]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
!CHECK: fir.store %[[RES]] to %[[PRV]] : !fir.ref<f32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
@@ -72,11 +76,15 @@ subroutine simple_real_add
!CHECK: fir.store %[[R_START]] to %[[RREF]] : !fir.ref<f32>
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: fir.store %[[I_START]] to %[[IREF]] : !fir.ref<i32>
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[IREF]] : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[RREF]] : !fir.ref<f32>) {
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[IREF]] -> %[[PRV0:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[RREF]] -> %[[PRV1:.+]] : !fir.ref<f32>) {
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
!CHECK: omp.reduction %[[R_INCR]], %[[RREF]] : f32, !fir.ref<f32>
!CHECK: %[[LPRV1:.+]] = fir.load %[[PRV1]] : !fir.ref<f32>
!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[LPRV1]] {{.*}} : f32
!CHECK: fir.store %[[RES1]] to %[[PRV1]]
!CHECK: %[[LPRV0:.+]] = fir.load %[[PRV0]] : !fir.ref<i32>
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
!CHECK: omp.reduction %[[I_INCR]], %[[IREF]] : i32, !fir.ref<i32>
!CHECK: %[[RES0:.+]] = arith.addi %[[LPRV0]], %[[I_INCR]]
!CHECK: fir.store %[[RES0]] to %[[PRV0]]
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
26 changes: 19 additions & 7 deletions flang/test/Lower/OpenMP/parallel-reduction-add.f90
Original file line number Diff line number Diff line change
@@ -28,9 +28,12 @@
!CHECK: %[[I_DECL:.*]]:2 = hlfir.declare %[[IREF]] {uniq_name = "_QFsimple_int_addEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>) {
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<i32>) {
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[I_INCR:.*]] = arith.constant 1 : i32
!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: %[[RES:.+]] = arith.addi %[[LPRV]], %[[I_INCR]] : i32
!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
@@ -50,9 +53,12 @@ subroutine simple_int_add
!CHECK: %[[R_DECL:.*]]:2 = hlfir.declare %[[RREF]] {uniq_name = "_QFsimple_real_addEr"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
!CHECK: %[[R_START:.*]] = arith.constant 0.000000e+00 : f32
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
!CHECK: omp.parallel reduction(@[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[PRV:.+]] : !fir.ref<f32>) {
!CHECK: %[[P_DECL:.+]]:2 = hlfir.declare %[[PRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
!CHECK: %[[LPRV:.+]] = fir.load %[[P_DECL]]#0 : !fir.ref<f32>
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: %[[RES:.+]] = arith.addf %[[LPRV]], %[[R_INCR]] {{.*}} : f32
!CHECK: hlfir.assign %[[RES]] to %[[P_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
@@ -76,11 +82,17 @@ subroutine simple_real_add
!CHECK: hlfir.assign %[[R_START]] to %[[R_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: %[[I_START:.*]] = arith.constant 0 : i32
!CHECK: hlfir.assign %[[I_START]] to %[[I_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] -> %[[I_DECL]]#0 : !fir.ref<i32>, @[[RED_F32_NAME]] -> %[[R_DECL]]#0 : !fir.ref<f32>) {
!CHECK: omp.parallel reduction(@[[RED_I32_NAME]] %[[I_DECL]]#0 -> %[[IPRV:.+]] : !fir.ref<i32>, @[[RED_F32_NAME]] %[[R_DECL]]#0 -> %[[RPRV:.+]] : !fir.ref<f32>) {
!CHECK: %[[IP_DECL:.+]]:2 = hlfir.declare %[[IPRV]] {{.*}} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
!CHECK: %[[RP_DECL:.+]]:2 = hlfir.declare %[[RPRV]] {{.*}} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
!CHECK: %[[R_INCR:.*]] = arith.constant 1.500000e+00 : f32
!CHECK: omp.reduction %[[R_INCR]], %[[R_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: %[[R_LPRV:.+]] = fir.load %[[RP_DECL]]#0 : !fir.ref<f32>
!CHECK: %[[RES1:.+]] = arith.addf %[[R_INCR]], %[[R_LPRV]] {{.*}} : f32
!CHECK: hlfir.assign %[[RES1]] to %[[RP_DECL]]#0 : f32, !fir.ref<f32>
!CHECK: %[[I_LPRV:.+]] = fir.load %[[IP_DECL]]#0 : !fir.ref<i32>
!CHECK: %[[I_INCR:.*]] = arith.constant 3 : i32
!CHECK: omp.reduction %[[I_INCR]], %[[I_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: %[[RES0:.+]] = arith.addi %[[I_LPRV]], %[[I_INCR]] : i32
!CHECK: hlfir.assign %[[RES0]] to %[[IP_DECL]]#0 : i32, !fir.ref<i32>
!CHECK: omp.terminator
!CHECK: }
!CHECK: return
9 changes: 3 additions & 6 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
@@ -191,19 +191,16 @@ def ParallelOp : OpenMP_Op<"parallel", [
unsigned getNumReductionVars() { return getReductionVars().size(); }
}];
let assemblyFormat = [{
oilist( `reduction` `(`
custom<ReductionVarList>(
$reduction_vars, type($reduction_vars), $reductions
) `)`
| `if` `(` $if_expr_var `:` type($if_expr_var) `)`
oilist(
`if` `(` $if_expr_var `:` type($if_expr_var) `)`
| `num_threads` `(` $num_threads_var `:` type($num_threads_var) `)`
| `allocate` `(`
custom<AllocateAndAllocator>(
$allocate_vars, type($allocate_vars),
$allocators_vars, type($allocators_vars)
) `)`
| `proc_bind` `(` custom<ClauseAttr>($proc_bind_val) `)`
) $region attr-dict
) custom<ParallelRegion>($region, $reduction_vars, type($reduction_vars), $reductions) attr-dict
}];
let hasVerifier = 1;
}
68 changes: 68 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@
#include "mlir/Interfaces/FoldInterfaces.h"

#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/STLForwardCompat.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
@@ -34,6 +35,7 @@
#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
#include "mlir/Support/LogicalResult.h"

using namespace mlir;
using namespace mlir::omp;
@@ -427,6 +429,71 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
// Parser, printer and verifier for ReductionVarList
//===----------------------------------------------------------------------===//

ParseResult
parseReductionClause(OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols,
SmallVectorImpl<OpAsmParser::Argument> &privates) {
if (failed(parser.parseOptionalKeyword("reduction")))
return failure();

SmallVector<SymbolRefAttr> reductionVec;

if (failed(
parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, [&]() {
if (parser.parseAttribute(reductionVec.emplace_back()) ||
parser.parseOperand(operands.emplace_back()) ||
parser.parseArrow() ||
parser.parseArgument(privates.emplace_back()) ||
parser.parseColonType(types.emplace_back()))
return failure();
return success();
})))
return failure();

for (auto [prv, type] : llvm::zip_equal(privates, types)) {
prv.type = type;
}
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
reductionSymbols = ArrayAttr::get(parser.getContext(), reductions);
return success();
}

static void printReductionClause(OpAsmPrinter &p, Operation *op, Region &region,
ValueRange operands, TypeRange types,
ArrayAttr reductionSymbols) {
p << "reduction(";
llvm::interleaveComma(llvm::zip_equal(reductionSymbols, operands,
region.front().getArguments(), types),
p, [&p](auto t) {
auto [sym, op, arg, type] = t;
p << sym << " " << op << " -> " << arg << " : "
<< type;
});
p << ") ";
}

static ParseResult
parseParallelRegion(OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
SmallVectorImpl<Type> &types, ArrayAttr &reductionSymbols) {

llvm::SmallVector<OpAsmParser::Argument> privates;
if (succeeded(parseReductionClause(parser, region, operands, types,
reductionSymbols, privates)))
return parser.parseRegion(region, privates);

return parser.parseRegion(region);
}

static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
ValueRange operands, TypeRange types,
ArrayAttr reductionSymbols) {
if (reductionSymbols)
printReductionClause(p, op, region, operands, types, reductionSymbols);
p.printRegion(region, /*printEntryBlockArgs=*/false);
}

/// reduction-entry-list ::= reduction-entry
/// | reduction-entry-list `,` reduction-entry
/// reduction-entry ::= symbol-ref `->` ssa-id `:` type
@@ -1114,6 +1181,7 @@ parseLoopControl(OpAsmParser &parser, Region &region,
loopVarTypes = SmallVector<Type>(ivs.size(), loopVarType);
for (auto &iv : ivs)
iv.type = loopVarType;

return parser.parseRegion(region, ivs);
}

Loading

0 comments on commit 9e6aa2f

Please sign in to comment.