Skip to content

Commit 06aecdb

Browse files
authored
[MLIR][SCF] Verify number of regions in scf.reduce (#171450)
This patch adds `ReduceOp::verifyRegions` to ensure that the number of reduction regions equals the number of operands (`getReductions().size() == getOperands().size()`). Additionally, `ParallelOp::verify` is updated to gracefully handle cases where the number of reduce operands differs from the initial values, preventing verification logic crashes and relying on `ReduceOp` to report structural inconsistencies. Fixes: #118768
1 parent 9b6b52b commit 06aecdb

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3152,6 +3152,9 @@ LogicalResult ParallelOp::verify() {
31523152
return emitOpError() << "expects number of results: " << resultsSize
31533153
<< " to be the same as number of initial values: "
31543154
<< initValsSize;
3155+
if (reduceOp.getNumOperands() != initValsSize)
3156+
// Delegate error reporting to ReduceOp
3157+
return success();
31553158

31563159
// Check that the types of the results and reductions are the same.
31573160
for (int64_t i = 0; i < static_cast<int64_t>(reductionsSize); ++i) {
@@ -3454,6 +3457,11 @@ void ReduceOp::build(OpBuilder &builder, OperationState &result,
34543457
}
34553458

34563459
LogicalResult ReduceOp::verifyRegions() {
3460+
if (getReductions().size() != getOperands().size())
3461+
return emitOpError() << "expects number of reduction regions: "
3462+
<< getReductions().size()
3463+
<< " to be the same as number of reduction operands: "
3464+
<< getOperands().size();
34573465
// The region of a ReduceOp has two arguments of the same type as its
34583466
// corresponding operand.
34593467
for (int64_t i = 0, e = getReductions().size(); i < e; ++i) {

mlir/test/Dialect/SCF/invalid.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,37 @@ func.func @parallel_different_types_of_results_and_reduces(
274274

275275
// -----
276276

277+
// The scf.parallel operation requires the number of operands in the terminator
278+
// (scf.reduce) to match the number of initial values provided to the loop.
279+
func.func @invalid_reduce_too_few_regions() {
280+
%c0 = arith.constant 0 : index
281+
%c1 = arith.constant 1 : index
282+
scf.parallel (%arg1) = (%c0) to (%c1) step (%c1) {
283+
// expected-error @+1 {{expects number of reduction regions: 0 to be the same as number of reduction operands: 1}}
284+
scf.reduce(%c1 : index)
285+
}
286+
return
287+
}
288+
289+
// -----
290+
291+
// The scf.parallel operation requires the number of operands in the terminator
292+
// (scf.reduce) to match the number of initial values provided to the loop.
293+
func.func @invalid_reduce_too_many_regions() {
294+
%c0 = arith.constant 0 : index
295+
%c1 = arith.constant 1 : index
296+
%0 = scf.parallel (%i0) = (%c0) to (%c1) step (%c1) init (%c0) -> (index) {
297+
// expected-error @+1 {{expects number of reduction regions: 1 to be the same as number of reduction operands: 0}}
298+
scf.reduce {
299+
^bb0(%lhs : index, %rhs : index):
300+
scf.reduce.return %lhs : index
301+
}
302+
}
303+
return
304+
}
305+
306+
// -----
307+
277308
func.func @top_level_reduce(%arg0 : f32) {
278309
// expected-error@+1 {{expects parent op 'scf.parallel'}}
279310
scf.reduce(%arg0 : f32) {

0 commit comments

Comments
 (0)