|
13 | 13 |
|
14 | 14 | #include "mlir/Dialect/Vector/IR/VectorOps.h" |
15 | 15 |
|
| 16 | +#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" |
16 | 17 | #include "mlir/Dialect/Arith/IR/Arith.h" |
17 | 18 | #include "mlir/Dialect/Arith/Utils/Utils.h" |
18 | 19 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
|
30 | 31 | #include "mlir/IR/OpImplementation.h" |
31 | 32 | #include "mlir/IR/PatternMatch.h" |
32 | 33 | #include "mlir/IR/TypeUtilities.h" |
| 34 | +#include "mlir/Interfaces/ValueBoundsOpInterface.h" |
33 | 35 | #include "mlir/Support/LLVM.h" |
34 | 36 | #include "llvm/ADT/ArrayRef.h" |
35 | 37 | #include "llvm/ADT/STLExtras.h" |
@@ -168,39 +170,76 @@ bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write, |
168 | 170 | } |
169 | 171 |
|
170 | 172 | bool mlir::vector::isDisjointTransferIndices( |
171 | | - VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) { |
| 173 | + VectorTransferOpInterface transferA, VectorTransferOpInterface transferB, |
| 174 | + bool testDynamicValueUsingBounds) { |
172 | 175 | // For simplicity only look at transfer of same type. |
173 | 176 | if (transferA.getVectorType() != transferB.getVectorType()) |
174 | 177 | return false; |
175 | 178 | unsigned rankOffset = transferA.getLeadingShapedRank(); |
176 | 179 | for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) { |
177 | | - auto indexA = getConstantIntValue(transferA.indices()[i]); |
178 | | - auto indexB = getConstantIntValue(transferB.indices()[i]); |
179 | | - // If any of the indices are dynamic we cannot prove anything. |
180 | | - if (!indexA.has_value() || !indexB.has_value()) |
181 | | - continue; |
| 180 | + Value indexA = transferA.indices()[i]; |
| 181 | + Value indexB = transferB.indices()[i]; |
| 182 | + std::optional<int64_t> cstIndexA = getConstantIntValue(indexA); |
| 183 | + std::optional<int64_t> cstIndexB = getConstantIntValue(indexB); |
182 | 184 |
|
183 | 185 | if (i < rankOffset) { |
184 | 186 | // For leading dimensions, if we can prove that index are different we |
185 | 187 | // know we are accessing disjoint slices. |
186 | | - if (*indexA != *indexB) |
187 | | - return true; |
| 188 | + if (cstIndexA.has_value() && cstIndexB.has_value()) { |
| 189 | + if (*cstIndexA != *cstIndexB) |
| 190 | + return true; |
| 191 | + continue; |
| 192 | + } |
| 193 | + if (testDynamicValueUsingBounds) { |
| 194 | + // First try to see if we can fully compose and simplify the affine |
| 195 | + // expression as a fast track. |
| 196 | + FailureOr<uint64_t> delta = |
| 197 | + affine::fullyComposeAndComputeConstantDelta(indexA, indexB); |
| 198 | + if (succeeded(delta) && *delta != 0) |
| 199 | + return true; |
| 200 | + |
| 201 | + FailureOr<bool> testEqual = |
| 202 | + ValueBoundsConstraintSet::areEqual(indexA, indexB); |
| 203 | + if (succeeded(testEqual) && !testEqual.value()) |
| 204 | + return true; |
| 205 | + } |
188 | 206 | } else { |
189 | 207 | // For this dimension, we slice a part of the memref we need to make sure |
190 | 208 | // the intervals accessed don't overlap. |
191 | | - int64_t distance = std::abs(*indexA - *indexB); |
192 | | - if (distance >= transferA.getVectorType().getDimSize(i - rankOffset)) |
193 | | - return true; |
| 209 | + int64_t vectorDim = transferA.getVectorType().getDimSize(i - rankOffset); |
| 210 | + if (cstIndexA.has_value() && cstIndexB.has_value()) { |
| 211 | + int64_t distance = std::abs(*cstIndexA - *cstIndexB); |
| 212 | + if (distance >= vectorDim) |
| 213 | + return true; |
| 214 | + continue; |
| 215 | + } |
| 216 | + if (testDynamicValueUsingBounds) { |
| 217 | + // First try to see if we can fully compose and simplify the affine |
| 218 | + // expression as a fast track. |
| 219 | + FailureOr<int64_t> delta = |
| 220 | + affine::fullyComposeAndComputeConstantDelta(indexA, indexB); |
| 221 | + if (succeeded(delta) && std::abs(*delta) >= vectorDim) |
| 222 | + return true; |
| 223 | + |
| 224 | + FailureOr<int64_t> computeDelta = |
| 225 | + ValueBoundsConstraintSet::computeConstantDelta(indexA, indexB); |
| 226 | + if (succeeded(computeDelta)) { |
| 227 | + if (std::abs(computeDelta.value()) >= vectorDim) |
| 228 | + return true; |
| 229 | + } |
| 230 | + } |
194 | 231 | } |
195 | 232 | } |
196 | 233 | return false; |
197 | 234 | } |
198 | 235 |
|
199 | 236 | bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA, |
200 | | - VectorTransferOpInterface transferB) { |
| 237 | + VectorTransferOpInterface transferB, |
| 238 | + bool testDynamicValueUsingBounds) { |
201 | 239 | if (transferA.source() != transferB.source()) |
202 | 240 | return false; |
203 | | - return isDisjointTransferIndices(transferA, transferB); |
| 241 | + return isDisjointTransferIndices(transferA, transferB, |
| 242 | + testDynamicValueUsingBounds); |
204 | 243 | } |
205 | 244 |
|
206 | 245 | // Helper to iterate over n-D vector slice elements. Calculate the next |
|
0 commit comments