|
10 | 10 | #include "mlir/Dialect/SCF/IR/SCF.h" |
11 | 11 | #include "mlir/IR/Diagnostics.h" |
12 | 12 | #include "mlir/IR/MLIRContext.h" |
| 13 | +#include "mlir/IR/OwningOpRef.h" |
13 | 14 | #include "gtest/gtest.h" |
14 | 15 |
|
15 | 16 | using namespace mlir; |
@@ -55,35 +56,50 @@ class SCFLoopLikeTest : public ::testing::Test { |
55 | 56 | }; |
56 | 57 |
|
57 | 58 | TEST_F(SCFLoopLikeTest, queryUnidimensionalLooplikes) { |
58 | | - Value lb = b.create<arith::ConstantIndexOp>(loc, 0); |
59 | | - Value ub = b.create<arith::ConstantIndexOp>(loc, 10); |
60 | | - Value step = b.create<arith::ConstantIndexOp>(loc, 2); |
| 59 | + OwningOpRef<arith::ConstantIndexOp> lb = |
| 60 | + b.create<arith::ConstantIndexOp>(loc, 0); |
| 61 | + OwningOpRef<arith::ConstantIndexOp> ub = |
| 62 | + b.create<arith::ConstantIndexOp>(loc, 10); |
| 63 | + OwningOpRef<arith::ConstantIndexOp> step = |
| 64 | + b.create<arith::ConstantIndexOp>(loc, 2); |
61 | 65 |
|
62 | | - auto forOp = b.create<scf::ForOp>(loc, lb, ub, step); |
63 | | - checkUnidimensional(forOp); |
| 66 | + OwningOpRef<scf::ForOp> forOp = |
| 67 | + b.create<scf::ForOp>(loc, lb.get(), ub.get(), step.get()); |
| 68 | + checkUnidimensional(forOp.get()); |
64 | 69 |
|
65 | | - auto forallOp = b.create<scf::ForallOp>( |
66 | | - loc, ArrayRef<OpFoldResult>(lb), ArrayRef<OpFoldResult>(ub), |
67 | | - ArrayRef<OpFoldResult>(step), ValueRange(), std::nullopt); |
68 | | - checkUnidimensional(forallOp); |
| 70 | + OwningOpRef<scf::ForallOp> forallOp = b.create<scf::ForallOp>( |
| 71 | + loc, ArrayRef<OpFoldResult>(static_cast<Value>(lb.get())), |
| 72 | + ArrayRef<OpFoldResult>(static_cast<Value>(ub.get())), |
| 73 | + ArrayRef<OpFoldResult>(static_cast<Value>(step.get())), ValueRange(), |
| 74 | + std::nullopt); |
| 75 | + checkUnidimensional(forallOp.get()); |
69 | 76 |
|
70 | | - auto parallelOp = b.create<scf::ParallelOp>( |
71 | | - loc, ValueRange(lb), ValueRange(ub), ValueRange(step), ValueRange()); |
72 | | - checkUnidimensional(parallelOp); |
| 77 | + OwningOpRef<scf::ParallelOp> parallelOp = |
| 78 | + b.create<scf::ParallelOp>(loc, ValueRange(lb.get()), ValueRange(ub.get()), |
| 79 | + ValueRange(step.get()), ValueRange()); |
| 80 | + checkUnidimensional(parallelOp.get()); |
73 | 81 | } |
74 | 82 |
|
75 | 83 | TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) { |
76 | | - Value lb = b.create<arith::ConstantIndexOp>(loc, 0); |
77 | | - Value ub = b.create<arith::ConstantIndexOp>(loc, 10); |
78 | | - Value step = b.create<arith::ConstantIndexOp>(loc, 2); |
| 84 | + OwningOpRef<arith::ConstantIndexOp> lb = |
| 85 | + b.create<arith::ConstantIndexOp>(loc, 0); |
| 86 | + OwningOpRef<arith::ConstantIndexOp> ub = |
| 87 | + b.create<arith::ConstantIndexOp>(loc, 10); |
| 88 | + OwningOpRef<arith::ConstantIndexOp> step = |
| 89 | + b.create<arith::ConstantIndexOp>(loc, 2); |
| 90 | + auto lbValue = static_cast<Value>(lb.get()); |
| 91 | + auto ubValue = static_cast<Value>(ub.get()); |
| 92 | + auto stepValue = static_cast<Value>(step.get()); |
79 | 93 |
|
80 | | - auto forallOp = b.create<scf::ForallOp>( |
81 | | - loc, ArrayRef<OpFoldResult>({lb, lb}), ArrayRef<OpFoldResult>({ub, ub}), |
82 | | - ArrayRef<OpFoldResult>({step, step}), ValueRange(), std::nullopt); |
83 | | - checkMultidimensional(forallOp); |
| 94 | + OwningOpRef<scf::ForallOp> forallOp = |
| 95 | + b.create<scf::ForallOp>(loc, ArrayRef<OpFoldResult>({lbValue, lbValue}), |
| 96 | + ArrayRef<OpFoldResult>({ubValue, ubValue}), |
| 97 | + ArrayRef<OpFoldResult>({stepValue, stepValue}), |
| 98 | + ValueRange(), std::nullopt); |
| 99 | + checkMultidimensional(forallOp.get()); |
84 | 100 |
|
85 | | - auto parallelOp = |
86 | | - b.create<scf::ParallelOp>(loc, ValueRange({lb, lb}), ValueRange({ub, ub}), |
87 | | - ValueRange({step, step}), ValueRange()); |
88 | | - checkMultidimensional(parallelOp); |
| 101 | + OwningOpRef<scf::ParallelOp> parallelOp = b.create<scf::ParallelOp>( |
| 102 | + loc, ValueRange({lbValue, lbValue}), ValueRange({ubValue, ubValue}), |
| 103 | + ValueRange({stepValue, stepValue}), ValueRange()); |
| 104 | + checkMultidimensional(parallelOp.get()); |
89 | 105 | } |
0 commit comments