Skip to content

[flang][HLFIR] fix FORALL issue 120190 #120236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -660,10 +660,7 @@ OrderedAssignmentRewriter::generateYieldedEntity(
return castIfNeeded(loc, builder, {maskedValue, std::nullopt}, castToType);
}

assert(region.hasOneBlock() && "region must contain one block");
auto oldYield = getYield(region);
mlir::Block::OpListType &ops = region.back().getOperations();

// Inside Forall, scalars that do not depend on forall indices can be hoisted
// here because their evaluation is required to only call pure procedures, and
// if they depend on a variable previously assigned to in a forall assignment,
Expand All @@ -674,24 +671,24 @@ OrderedAssignmentRewriter::generateYieldedEntity(
bool hoistComputation = false;
if (fir::isa_trivial(oldYield.getEntity().getType()) &&
!constructStack.empty()) {
hoistComputation = true;
for (mlir::Operation &op : ops)
if (llvm::any_of(op.getOperands(), [](mlir::Value value) {
return isForallIndex(value);
})) {
hoistComputation = false;
break;
}
mlir::WalkResult walkResult =
region.walk([&](mlir::Operation *op) -> mlir::WalkResult {
if (llvm::any_of(op->getOperands(), [](mlir::Value value) {
return isForallIndex(value);
}))
return mlir::WalkResult::interrupt();
return mlir::WalkResult::advance();
});
hoistComputation = !walkResult.wasInterrupted();
}
auto insertionPoint = builder.saveInsertionPoint();
if (hoistComputation)
builder.setInsertionPoint(constructStack[0]);

// Clone all operations except the final hlfir.yield.
assert(!ops.empty() && "yield block cannot be empty");
auto end = ops.end();
for (auto opIt = ops.begin(); std::next(opIt) != end; ++opIt)
(void)builder.clone(*opIt, mapper);
assert(region.hasOneBlock() && "region must contain one block");
for (auto &op : region.back().without_terminator())
(void)builder.clone(op, mapper);
// Get the value for the yielded entity, it may be the result of an operation
// that was cloned, or it may be the same as the previous value if the yield
// operand was created before the ordered assignment tree.
Expand Down
64 changes: 64 additions & 0 deletions flang/test/HLFIR/order_assignments/forall-issue120190.fir
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Regression test for https://github.com/llvm/llvm-project/issues/120190
// Verify that hlfir.forall lowering does not try hoisting mask evaluation
// that refer to the forall index inside nested regions only.
// RUN: fir-opt %s --lower-hlfir-ordered-assignments | FileCheck %s

func.func @issue120190(%array: !fir.ref<!fir.array<100xf32>>, %cdt: i1) {
%cst = arith.constant 0.000000e+00 : f32
%c1 = arith.constant 1 : i64
%c50 = arith.constant 50 : i64
%c100 = arith.constant 100 : i64
hlfir.forall lb {
hlfir.yield %c1 : i64
} ub {
hlfir.yield %c100 : i64
} (%forall_index: i64) {
hlfir.forall_mask {
%mask = fir.if %cdt -> i1 {
// Reference to %forall_index is not directly in
// hlfir.forall_mask region, but is nested.
%res = arith.cmpi slt, %forall_index, %c50 : i64
fir.result %res : i1
} else {
%res = arith.cmpi sgt, %forall_index, %c50 : i64
fir.result %res : i1
}
hlfir.yield %mask : i1
} do {
hlfir.region_assign {
hlfir.yield %cst : f32
} to {
%6 = hlfir.designate %array (%forall_index) : (!fir.ref<!fir.array<100xf32>>, i64) -> !fir.ref<f32>
hlfir.yield %6 : !fir.ref<f32>
}
}
}
return
}

// CHECK-LABEL: func.func @issue120190(
// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<100xf32>>,
// CHECK-SAME: %[[VAL_1:.*]]: i1) {
// CHECK: %[[VAL_2:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i64
// CHECK: %[[VAL_4:.*]] = arith.constant 50 : i64
// CHECK: %[[VAL_5:.*]] = arith.constant 100 : i64
// CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_3]] : (i64) -> index
// CHECK: %[[VAL_7:.*]] = fir.convert %[[VAL_5]] : (i64) -> index
// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
// CHECK: fir.do_loop %[[VAL_9:.*]] = %[[VAL_6]] to %[[VAL_7]] step %[[VAL_8]] {
// CHECK: %[[VAL_10:.*]] = fir.convert %[[VAL_9]] : (index) -> i64
// CHECK: %[[VAL_11:.*]] = fir.if %[[VAL_1]] -> (i1) {
// CHECK: %[[VAL_12:.*]] = arith.cmpi slt, %[[VAL_10]], %[[VAL_4]] : i64
// CHECK: fir.result %[[VAL_12]] : i1
// CHECK: } else {
// CHECK: %[[VAL_13:.*]] = arith.cmpi sgt, %[[VAL_10]], %[[VAL_4]] : i64
// CHECK: fir.result %[[VAL_13]] : i1
// CHECK: }
// CHECK: fir.if %[[VAL_11]] {
// CHECK: %[[VAL_14:.*]] = hlfir.designate %[[VAL_0]] (%[[VAL_10]]) : (!fir.ref<!fir.array<100xf32>>, i64) -> !fir.ref<f32>
// CHECK: hlfir.assign %[[VAL_2]] to %[[VAL_14]] : f32, !fir.ref<f32>
// CHECK: }
// CHECK: }
// CHECK: return
// CHECK: }
Loading