diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp index 410a6bffd345e..496a7b036e65d 100644 --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -17,19 +17,36 @@ namespace mlir { namespace scf { namespace { +static AffineExpr getTripCountExpr(OpFoldResult lb, OpFoldResult ub, + OpFoldResult step, + ValueBoundsConstraintSet &cstr) { + AffineExpr lbExpr = cstr.getExpr(lb); + AffineExpr ubExpr = cstr.getExpr(ub); + AffineExpr stepExpr = cstr.getExpr(step); + AffineExpr tripCountExpr = + AffineExpr(ubExpr - lbExpr).ceilDiv(stepExpr); // (ub - lb) / step + return tripCountExpr; +} + +static void populateIVBounds(OpFoldResult lb, OpFoldResult ub, + OpFoldResult step, Value iv, + ValueBoundsConstraintSet &cstr) { + cstr.bound(iv) >= cstr.getExpr(lb); + cstr.bound(iv) < cstr.getExpr(ub); + // iv <= lb + ((ub-lb)/step - 1) * step + // This bound does not replace the `iv < ub` constraint mentioned above, + // since constraints involving the multiplication of two constraint set + // dimensions are not supported. + AffineExpr tripCountMinusOne = + getTripCountExpr(lb, ub, step, cstr) - cstr.getExpr(1); + AffineExpr computedUpperBound = + cstr.getExpr(lb) + AffineExpr(tripCountMinusOne * cstr.getExpr(step)); + cstr.bound(iv) <= computedUpperBound; +} + struct ForOpInterface : public ValueBoundsOpInterface::ExternalModel { - static AffineExpr getTripCountExpr(scf::ForOp forOp, - ValueBoundsConstraintSet &cstr) { - AffineExpr lbExpr = cstr.getExpr(forOp.getLowerBound()); - AffineExpr ubExpr = cstr.getExpr(forOp.getUpperBound()); - AffineExpr stepExpr = cstr.getExpr(forOp.getStep()); - AffineExpr tripCountExpr = - AffineExpr(ubExpr - lbExpr).ceilDiv(stepExpr); // (ub - lb) / step - return tripCountExpr; - } - /// Populate bounds of values/dimensions for iter_args/OpResults. If the /// value/dimension size does not change in an iteration, we can deduce that /// it the same as the initial value/dimension. @@ -87,7 +104,8 @@ struct ForOpInterface // `value` is result of `forOp`, we can prove that: // %result == %init_arg + trip_count * (%yielded_value - %iter_arg). // Where trip_count is (ub - lb) / step. - AffineExpr tripCountExpr = getTripCountExpr(forOp, cstr); + AffineExpr tripCountExpr = getTripCountExpr( + forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), cstr); AffineExpr oneIterAdvanceExpr = cstr.getExpr(yieldedValue) - cstr.getExpr(iterArg); cstr.bound(value) == @@ -99,19 +117,8 @@ struct ForOpInterface auto forOp = cast(op); if (value == forOp.getInductionVar()) { - cstr.bound(value) >= forOp.getLowerBound(); - cstr.bound(value) < forOp.getUpperBound(); - // iv <= lb + ((ub-lb)/step - 1) * step - // This bound does not replace the `iv < ub` constraint mentioned above, - // since constraints involving the multiplication of two constraint set - // dimensions are not supported. - AffineExpr tripCountMinusOne = - getTripCountExpr(forOp, cstr) - cstr.getExpr(1); - AffineExpr computedUpperBound = - cstr.getExpr(forOp.getLowerBound()) + - AffineExpr(tripCountMinusOne * cstr.getExpr(forOp.getStep())); - cstr.bound(value) <= computedUpperBound; - return; + return populateIVBounds(forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), value, cstr); } // Handle iter_args and OpResults. @@ -141,11 +148,9 @@ struct ForallOpInterface assert(blockArg.getArgNumber() < forallOp.getInductionVars().size() && "expected index value to be an induction var"); int64_t idx = blockArg.getArgNumber(); - // TODO: Take into account step size. - AffineExpr lb = cstr.getExpr(forallOp.getMixedLowerBound()[idx]); - AffineExpr ub = cstr.getExpr(forallOp.getMixedUpperBound()[idx]); - cstr.bound(value) >= lb; - cstr.bound(value) < ub; + return populateIVBounds(forallOp.getMixedLowerBound()[idx], + forallOp.getMixedUpperBound()[idx], + forallOp.getMixedStep()[idx], value, cstr); } void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, diff --git a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir index 339d97df001c5..60fe96d52d20b 100644 --- a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir @@ -379,3 +379,12 @@ func.func @scf_for_result_infer_dynamic_init_big_step(%i : index) { "test.compare"(%0, %7) {cmp = "LE"} : (index, index) -> () return } + +func.func @scf_forall_computed_upper_bound(%x: index) { + %c6 = arith.constant 6 : index + scf.forall (%iv) = (0) to (8) step (3) { + // expected-remark @below{{true}} + "test.compare"(%iv, %c6) {cmp = "LE"} : (index, index) -> () + } + return +}