Skip to content

Commit 78a81f8

Browse files
committed
improvement to canonical forOp lowering + added/improved new array/for tests
1 parent a81b1e9 commit 78a81f8

File tree

7 files changed

+155
-86
lines changed

7 files changed

+155
-86
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,18 @@
1414
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1515
#include "mlir/Dialect/SCF/IR/SCF.h"
1616
#include "mlir/IR/Builders.h"
17+
#include "mlir/IR/BuiltinOps.h"
1718
#include "mlir/IR/Location.h"
1819
#include "mlir/IR/ValueRange.h"
1920
#include "mlir/Pass/PassManager.h"
21+
#include "mlir/Support/LLVM.h"
2022
#include "mlir/Support/LogicalResult.h"
2123
#include "mlir/Transforms/DialectConversion.h"
2224
#include "clang/CIR/Dialect/IR/CIRDialect.h"
2325
#include "clang/CIR/Dialect/IR/CIRTypes.h"
2426
#include "clang/CIR/LowerToMLIR.h"
2527
#include "llvm/ADT/TypeSwitch.h"
28+
#include "llvm/IR/Module.h"
2629

2730
using namespace cir;
2831
using namespace llvm;
@@ -252,6 +255,14 @@ void SCFLoop::analysis() {
252255
if (!canonical)
253256
return;
254257

258+
// If the IV is defined before the forOp (i.e. outside the surrounding
259+
// cir.scope) this is not a canonical loop as the IV would not have the
260+
// correct value after the forOp
261+
if (ivAddr.getDefiningOp()->getBlock() != forOp->getBlock()) {
262+
canonical = false;
263+
return;
264+
}
265+
255266
cmpOp = findCmpOp();
256267
if (!cmpOp) {
257268
canonical = false;
@@ -310,17 +321,14 @@ void SCFLoop::transferToSCFForOp() {
310321
}
311322
return mlir::WalkResult::advance();
312323
});
313-
// If the IV was declared in the for op all uses have been replaced by the
314-
// scf.IV and we can remove the alloca + initial store
324+
325+
// All uses have been replaced by the scf.IV and we can remove the alloca + initial store operations
315326

316327
// The operations before the loop have been transferred to MLIR.
317-
// So we need to go through getRemappedValue to find the value.
328+
// So we need to go through getRemappedValue to find the operations.
318329
auto remapAddr = rewriter->getRemappedValue(ivAddr);
319-
// If IV has more uses than the use in the initial store op keep it
320-
if (!remapAddr || !remapAddr.hasOneUse())
321-
return;
322-
323-
// otherwise remove the alloca + initial store op
330+
331+
// Since this is a canonical loop we can remove the alloca + initial store op
324332
rewriter->eraseOp(remapAddr.getDefiningOp());
325333
rewriter->eraseOp(*remapAddr.user_begin());
326334
}

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -207,50 +207,44 @@ static bool findBaseAndIndices(mlir::Value addr, mlir::Value &base,
207207
return true;
208208
}
209209

210-
// For memref.reinterpret_cast has multiple users, erasing the operation
211-
// after the last load or store been generated.
210+
// If the memref.reinterpret_cast has multiple users (i.e the original
211+
// cir.ptr_stride op has multiple users), only erase the operation after the
212+
// last load or store has been generated.
212213
static void eraseIfSafe(mlir::Value oldAddr, mlir::Value newAddr,
213214
SmallVector<mlir::Operation *> &eraseList,
214215
mlir::ConversionPatternRewriter &rewriter) {
215-
newAddr.getDefiningOp()->getParentOfType<mlir::ModuleOp>()->dump();
216-
oldAddr.dump();
217-
newAddr.dump();
218216

219217
unsigned oldUsedNum =
220218
std::distance(oldAddr.getUses().begin(), oldAddr.getUses().end());
221219
unsigned newUsedNum = 0;
220+
// Count the uses of the newAddr (the result of the original base alloca) in
221+
// load/store ops using an forwarded offset from the current
222+
// memref.reinterpret_cast op
222223
for (auto *user : newAddr.getUsers()) {
223-
user->dump();
224224
if (auto loadOpUser = mlir::dyn_cast_or_null<mlir::memref::LoadOp>(*user)) {
225-
if (auto strideVal = loadOpUser.getIndices()[0]) {
226-
strideVal.dump();
227-
mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(eraseList.back())
228-
.getOffsets()[0]
229-
.dump();
225+
if (!loadOpUser.getIndices().empty()) {
226+
auto strideVal = loadOpUser.getIndices()[0];
230227
if (strideVal ==
231228
mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(eraseList.back())
232229
.getOffsets()[0])
233230
++newUsedNum;
234231
}
235232
} else if (auto storeOpUser =
236233
mlir::dyn_cast_or_null<mlir::memref::StoreOp>(*user)) {
237-
if (auto strideVal = storeOpUser.getIndices()[0]) {
238-
strideVal.dump();
239-
mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(eraseList.back())
240-
.getOffsets()[0]
241-
.dump();
234+
if (!storeOpUser.getIndices().empty()) {
235+
auto strideVal = storeOpUser.getIndices()[0];
242236
if (strideVal ==
243237
mlir::dyn_cast<mlir::memref::ReinterpretCastOp>(eraseList.back())
244238
.getOffsets()[0])
245239
++newUsedNum;
246240
}
247241
}
248242
}
243+
// If all load/store ops using forwarded offsets from the current
244+
// memref.reinterpret_cast ops erase the memref.reinterpret_cast ops
249245
if (oldUsedNum == newUsedNum) {
250-
for (auto op : eraseList) {
251-
op->dump();
246+
for (auto op : eraseList)
252247
rewriter.eraseOp(op);
253-
}
254248
}
255249
}
256250

@@ -269,7 +263,6 @@ class CIRLoadOpLowering : public mlir::OpConversionPattern<cir::LoadOp> {
269263
rewriter)) {
270264
newLoad = rewriter.create<mlir::memref::LoadOp>(
271265
op.getLoc(), base, indices, op.getIsNontemporal());
272-
newLoad->dump();
273266
eraseIfSafe(op.getAddr(), base, eraseList, rewriter);
274267
} else
275268
newLoad = rewriter.create<mlir::memref::LoadOp>(
@@ -788,8 +781,6 @@ class CIRScopeOpLowering : public mlir::OpConversionPattern<cir::ScopeOp> {
788781
return mlir::success();
789782
}
790783

791-
// TODO: evaluate if a different mlir core dialect op is better suited for
792-
// this
793784
for (auto &block : scopeOp.getScopeRegion()) {
794785
rewriter.setInsertionPointToEnd(&block);
795786
auto *terminator = block.getTerminator();
@@ -1484,13 +1475,9 @@ mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp theModule,
14841475
pm.addPass(createConvertCIRToMLIRPass());
14851476

14861477
auto result = !mlir::failed(pm.run(theModule));
1487-
if (!result) {
1488-
// just for debugging purposes
1489-
// TODO: remove before creating a PR
1490-
theModule->dump();
1478+
if (!result)
14911479
report_fatal_error(
14921480
"The pass manager failed to lower CIR to MLIR standard dialects!");
1493-
}
14941481
// Now that we ran all the lowering passes, verify the final output.
14951482
if (theModule.verify().failed())
14961483
report_fatal_error(

clang/test/CIR/Lowering/ThroughMLIR/array.c

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,32 @@ int test_array2() {
2929
int a[3][4];
3030
return a[1][2];
3131
}
32+
33+
int test_array3() {
34+
// CIR-LABEL: cir.func {{.*}} @test_array3()
35+
// CIR: %[[ARRAY:.*]] = cir.alloca !cir.array<!s32i x 3>, !cir.ptr<!cir.array<!s32i x 3>>, ["a"] {alignment = 4 : i64}
36+
// CIR: %[[PTRDECAY1:.*]] = cir.cast(array_to_ptrdecay, %[[ARRAY]] : !cir.ptr<!cir.array<!s32i x 3>>), !cir.ptr<!s32i>
37+
// CIR: %[[PTRSTRIDE1:.*]] = cir.ptr_stride(%[[PTRDECAY1]] : !cir.ptr<!s32i>, {{.*}} : !s32i), !cir.ptr<!s32i>
38+
// CIR: {{.*}} = cir.load align(4) %[[PTRSTRIDE1]] : !cir.ptr<!s32i>, !s32i
39+
// CIR: %[[PTRDECAY2:.*]] = cir.cast(array_to_ptrdecay, %[[ARRAY]] : !cir.ptr<!cir.array<!s32i x 3>>), !cir.ptr<!s32i>
40+
// CIR: %[[PTRSTRIDE2:.*]] = cir.ptr_stride(%[[PTRDECAY2]] : !cir.ptr<!s32i>, {{.*}} : !s32i), !cir.ptr<!s32i>
41+
// CIR: %{{.*}} = cir.load align(4) %[[PTRSTRIDE2]] : !cir.ptr<!s32i>, !s32i
42+
// CIR: cir.store align(4) {{.*}}, %[[PTRSTRIDE2]] : !s32i, !cir.ptr<!s32i>
43+
// CIR: %[[PTRDECAY3:.*]] = cir.cast(array_to_ptrdecay, %[[ARRAY]] : !cir.ptr<!cir.array<!s32i x 3>>), !cir.ptr<!s32i>
44+
// CIR: %[[PTRSTRIDE3:.*]] = cir.ptr_stride(%[[PTRDECAY3]] : !cir.ptr<!s32i>, {{.*}} : !s32i), !cir.ptr<!s32i>
45+
// CIR: %{{.*}} = cir.load align(4) %[[PTRSTRIDE3]] : !cir.ptr<!s32i>, !s32i
46+
47+
// MLIR-LABEL: func @test_array3
48+
// MLIR: %{{.*}} = memref.alloca() {alignment = 4 : i64} : memref<i32>
49+
// MLIR: %[[ARRAY:.*]] = memref.alloca() {alignment = 4 : i64} : memref<3xi32>
50+
// MLIR: %[[IDX1:.*]] = arith.index_cast %{{.*}} : i32 to index
51+
// MLIR: %{{.*}} = memref.load %[[ARRAY]][%[[IDX1]]] : memref<3xi32>
52+
// MLIR: %[[IDX2:.*]] = arith.index_cast %{{.*}} : i32 to index
53+
// MLIR: %{{.*}} = memref.load %[[ARRAY]][%[[IDX2]]] : memref<3xi32>
54+
// MLIR: memref.store %{{.*}}, %[[ARRAY]][%[[IDX2]]] : memref<3xi32>
55+
// MLIR: %[[IDX3:.*]] = arith.index_cast %{{.*}} : i32 to index
56+
// MLIR: %{{.*}} = memref.load %[[ARRAY]][%[[IDX3]]] : memref<3xi32>
57+
int a[3];
58+
a[0] += a[2];
59+
return a[1];
60+
}

clang/test/CIR/Lowering/ThroughMLIR/for-reject-1.cpp

Lines changed: 0 additions & 24 deletions
This file was deleted.

clang/test/CIR/Lowering/ThroughMLIR/for-reject-2.cpp

Lines changed: 0 additions & 25 deletions
This file was deleted.
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
4+
void f() {}
5+
6+
void reject_test1() {
7+
for (int i = 0; i < 100; i++, f());
8+
// CHECK: %[[ALLOCA:.+]] = memref.alloca
9+
// CHECK: %[[ZERO:.+]] = arith.constant 0
10+
// CHECK: memref.store %[[ZERO]], %[[ALLOCA]]
11+
// CHECK: %[[HUNDRED:.+]] = arith.constant 100
12+
// CHECK: scf.while : () -> () {
13+
// CHECK: %[[TMP:.+]] = memref.load %[[ALLOCA]]
14+
// CHECK: %[[TMP1:.+]] = arith.cmpi slt, %0, %[[HUNDRED]]
15+
// CHECK: scf.condition(%[[TMP1]])
16+
// CHECK: } do {
17+
// CHECK: %[[TMP2:.+]] = memref.load %[[ALLOCA]]
18+
// CHECK: %[[ONE:.+]] = arith.constant 1
19+
// CHECK: %[[TMP3:.+]] = arith.addi %[[TMP2]], %[[ONE]]
20+
// CHECK: memref.store %[[TMP3]], %[[ALLOCA]]
21+
// CHECK: func.call @_Z1fv()
22+
// CHECK: scf.yield
23+
// CHECK: }
24+
}
25+
26+
void reject_test2() {
27+
for (int i = 0; i < 100; i++, i++);
28+
// CHECK: %[[ALLOCA:.+]] = memref.alloca
29+
// CHECK: %[[ZERO:.+]] = arith.constant 0
30+
// CHECK: memref.store %[[ZERO]], %[[ALLOCA]]
31+
// CHECK: %[[HUNDRED:.+]] = arith.constant 100
32+
// CHECK: scf.while : () -> () {
33+
// CHECK: %[[TMP:.+]] = memref.load %[[ALLOCA]]
34+
// CHECK: %[[TMP2:.+]] = arith.cmpi slt, %[[TMP]], %[[HUNDRED]]
35+
// CHECK: scf.condition(%[[TMP2]])
36+
// CHECK: } do {
37+
// CHECK: %[[TMP3:.+]] = memref.load %[[ALLOCA]]
38+
// CHECK: %[[ONE:.+]] = arith.constant 1
39+
// CHECK: %[[ADD:.+]] = arith.addi %[[TMP3]], %[[ONE]]
40+
// CHECK: memref.store %[[ADD]], %[[ALLOCA]]
41+
// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOCA]]
42+
// CHECK: %[[ONE2:.+]] = arith.constant 1
43+
// CHECK: %[[ADD2:.+]] = arith.addi %[[LOAD]], %[[ONE2]]
44+
// CHECK: memref.store %[[ADD2]], %[[ALLOCA]]
45+
// CHECK: scf.yield
46+
// CHECK: }
47+
}
48+
49+
void reject_test3() {
50+
int i;
51+
for (i = 0; i < 100; i++);
52+
i += 10;
53+
// CHECK: %[[ALLOCA:.+]] = memref.alloca()
54+
// CHECK: memref.alloca_scope {
55+
// CHECK: %[[ZERO:.+]] = arith.constant 0
56+
// CHECK: memref.store %[[ZERO]], %[[ALLOCA]]
57+
// CHECK: %[[HUNDRED:.+]] = arith.constant 100
58+
// CHECK: scf.while : () -> () {
59+
// CHECK: %[[TMP:.+]] = memref.load %[[ALLOCA]]
60+
// CHECK: %[[TMP2:.+]] = arith.cmpi slt, %[[TMP]], %[[HUNDRED]]
61+
// CHECK: scf.condition(%[[TMP2]])
62+
// CHECK: } do {
63+
// CHECK: %[[TMP3:.+]] = memref.load %[[ALLOCA]]
64+
// CHECK: %[[ONE:.+]] = arith.constant 1
65+
// CHECK: %[[ADD:.+]] = arith.addi %[[TMP3]], %[[ONE]]
66+
// CHECK: memref.store %[[ADD]], %[[ALLOCA]]
67+
// CHECK: scf.yield
68+
// CHECK: }
69+
// CHECK: }
70+
// CHECK: %[[TEN:.+]] = arith.constant 10
71+
// CHECK: %[[TMP4:.+]] = memref.load %[[ALLOCA]]
72+
// CHECK: %[[TMP5:.+]] = arith.addi %[[TMP4]], %[[TEN]]
73+
// CHECK: memref.store %[[TMP5]], %[[ALLOCA]]
74+
}

0 commit comments

Comments
 (0)