@@ -17,19 +17,36 @@ namespace mlir {
1717namespace scf {
1818namespace {
1919
20+ static AffineExpr getTripCountExpr (OpFoldResult lb, OpFoldResult ub,
21+ OpFoldResult step,
22+ ValueBoundsConstraintSet &cstr) {
23+ AffineExpr lbExpr = cstr.getExpr (lb);
24+ AffineExpr ubExpr = cstr.getExpr (ub);
25+ AffineExpr stepExpr = cstr.getExpr (step);
26+ AffineExpr tripCountExpr =
27+ AffineExpr (ubExpr - lbExpr).ceilDiv (stepExpr); // (ub - lb) / step
28+ return tripCountExpr;
29+ }
30+
31+ static void populateIVBounds (OpFoldResult lb, OpFoldResult ub,
32+ OpFoldResult step, Value iv,
33+ ValueBoundsConstraintSet &cstr) {
34+ cstr.bound (iv) >= cstr.getExpr (lb);
35+ cstr.bound (iv) < cstr.getExpr (ub);
36+ // iv <= lb + ((ub-lb)/step - 1) * step
37+ // This bound does not replace the `iv < ub` constraint mentioned above,
38+ // since constraints involving the multiplication of two constraint set
39+ // dimensions are not supported.
40+ AffineExpr tripCountMinusOne =
41+ getTripCountExpr (lb, ub, step, cstr) - cstr.getExpr (1 );
42+ AffineExpr computedUpperBound =
43+ cstr.getExpr (lb) + AffineExpr (tripCountMinusOne * cstr.getExpr (step));
44+ cstr.bound (iv) <= computedUpperBound;
45+ }
46+
2047struct ForOpInterface
2148 : public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
2249
23- static AffineExpr getTripCountExpr (scf::ForOp forOp,
24- ValueBoundsConstraintSet &cstr) {
25- AffineExpr lbExpr = cstr.getExpr (forOp.getLowerBound ());
26- AffineExpr ubExpr = cstr.getExpr (forOp.getUpperBound ());
27- AffineExpr stepExpr = cstr.getExpr (forOp.getStep ());
28- AffineExpr tripCountExpr =
29- AffineExpr (ubExpr - lbExpr).ceilDiv (stepExpr); // (ub - lb) / step
30- return tripCountExpr;
31- }
32-
3350 // / Populate bounds of values/dimensions for iter_args/OpResults. If the
3451 // / value/dimension size does not change in an iteration, we can deduce that
3552 // / it the same as the initial value/dimension.
@@ -87,7 +104,8 @@ struct ForOpInterface
87104 // `value` is result of `forOp`, we can prove that:
88105 // %result == %init_arg + trip_count * (%yielded_value - %iter_arg).
89106 // Where trip_count is (ub - lb) / step.
90- AffineExpr tripCountExpr = getTripCountExpr (forOp, cstr);
107+ AffineExpr tripCountExpr = getTripCountExpr (
108+ forOp.getLowerBound (), forOp.getUpperBound (), forOp.getStep (), cstr);
91109 AffineExpr oneIterAdvanceExpr =
92110 cstr.getExpr (yieldedValue) - cstr.getExpr (iterArg);
93111 cstr.bound (value) ==
@@ -99,19 +117,8 @@ struct ForOpInterface
99117 auto forOp = cast<ForOp>(op);
100118
101119 if (value == forOp.getInductionVar ()) {
102- cstr.bound (value) >= forOp.getLowerBound ();
103- cstr.bound (value) < forOp.getUpperBound ();
104- // iv <= lb + ((ub-lb)/step - 1) * step
105- // This bound does not replace the `iv < ub` constraint mentioned above,
106- // since constraints involving the multiplication of two constraint set
107- // dimensions are not supported.
108- AffineExpr tripCountMinusOne =
109- getTripCountExpr (forOp, cstr) - cstr.getExpr (1 );
110- AffineExpr computedUpperBound =
111- cstr.getExpr (forOp.getLowerBound ()) +
112- AffineExpr (tripCountMinusOne * cstr.getExpr (forOp.getStep ()));
113- cstr.bound (value) <= computedUpperBound;
114- return ;
120+ return populateIVBounds (forOp.getLowerBound (), forOp.getUpperBound (),
121+ forOp.getStep (), value, cstr);
115122 }
116123
117124 // Handle iter_args and OpResults.
@@ -141,11 +148,9 @@ struct ForallOpInterface
141148 assert (blockArg.getArgNumber () < forallOp.getInductionVars ().size () &&
142149 " expected index value to be an induction var" );
143150 int64_t idx = blockArg.getArgNumber ();
144- // TODO: Take into account step size.
145- AffineExpr lb = cstr.getExpr (forallOp.getMixedLowerBound ()[idx]);
146- AffineExpr ub = cstr.getExpr (forallOp.getMixedUpperBound ()[idx]);
147- cstr.bound (value) >= lb;
148- cstr.bound (value) < ub;
151+ return populateIVBounds (forallOp.getMixedLowerBound ()[idx],
152+ forallOp.getMixedUpperBound ()[idx],
153+ forallOp.getMixedStep ()[idx], value, cstr);
149154 }
150155
151156 void populateBoundsForShapedValueDim (Operation *op, Value value, int64_t dim,
0 commit comments