Skip to content

Commit 24a6403

Browse files
Add unittest that tests all overloads.
1 parent abe59c4 commit 24a6403

File tree

2 files changed

+108
-0
lines changed

2 files changed

+108
-0
lines changed

mlir/unittests/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_unittest(MLIRIRTests
88
OperationSupportTest.cpp
99
PatternMatchTest.cpp
1010
ShapedTypeTest.cpp
11+
SymbolTableTest.cpp
1112
TypeTest.cpp
1213
OpPropertiesTest.cpp
1314

mlir/unittests/IR/SymbolTableTest.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
//===- SymbolTableTest.cpp - SymbolTable unit tests -----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#include "mlir/IR/SymbolTable.h"
9+
#include "../../test/lib/Dialect/Test/TestDialect.h"
10+
#include "mlir/IR/Verifier.h"
11+
#include "mlir/Interfaces/CallInterfaces.h"
12+
#include "mlir/Interfaces/FunctionInterfaces.h"
13+
#include "mlir/Parser/Parser.h"
14+
15+
#include "gtest/gtest.h"
16+
17+
using namespace mlir;
18+
19+
namespace {
20+
TEST(SymbolTableTest, ReplaceAllSymbolUses) {
21+
MLIRContext context;
22+
context.getOrLoadDialect<test::TestDialect>();
23+
24+
auto testReplaceAllSymbolUses = [&](auto replaceFn) {
25+
const static llvm::StringLiteral input = R"MLIR(
26+
module {
27+
test.conversion_func_op private @foo() {
28+
"test.conversion_call_op"() { callee=@bar } : () -> ()
29+
"test.return"() : () -> ()
30+
}
31+
test.conversion_func_op private @bar()
32+
}
33+
)MLIR";
34+
35+
// Set up IR and find func ops.
36+
OwningOpRef<Operation *> module = parseSourceString(input, &context);
37+
SymbolTable symbolTable(module.get());
38+
auto ops = module->getRegion(0).getBlocks().front().getOperations().begin();
39+
auto fooOp = cast<FunctionOpInterface>(ops++);
40+
auto barOp = cast<FunctionOpInterface>(ops++);
41+
ASSERT_EQ(fooOp.getNameAttr(), "foo");
42+
ASSERT_EQ(barOp.getNameAttr(), "bar");
43+
44+
// Call test function that does symbol replacement.
45+
LogicalResult res = replaceFn(symbolTable, module.get(), fooOp, barOp);
46+
ASSERT_TRUE(succeeded(res));
47+
ASSERT_TRUE(succeeded(verify(module.get())));
48+
49+
// Check that it got renamed.
50+
bool calleeFound = false;
51+
fooOp->walk([&](CallOpInterface callOp) {
52+
StringAttr callee = callOp.getCallableForCallee()
53+
.dyn_cast<SymbolRefAttr>()
54+
.getLeafReference();
55+
EXPECT_EQ(callee, "baz");
56+
calleeFound = true;
57+
});
58+
EXPECT_TRUE(calleeFound);
59+
};
60+
61+
// Symbol as `Operation *`, rename within module.
62+
testReplaceAllSymbolUses(
63+
[&](auto symbolTable, auto module, auto fooOp, auto barOp) {
64+
return symbolTable.replaceAllSymbolUses(
65+
barOp, StringAttr::get(&context, "baz"), module);
66+
});
67+
68+
// Symbol as `StringAttr`, rename within module.
69+
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
70+
auto barOp) {
71+
return symbolTable.replaceAllSymbolUses(StringAttr::get(&context, "bar"),
72+
StringAttr::get(&context, "baz"),
73+
module);
74+
});
75+
76+
// Symbol as `Operation *`, rename within module body.
77+
testReplaceAllSymbolUses(
78+
[&](auto symbolTable, auto module, auto fooOp, auto barOp) {
79+
return symbolTable.replaceAllSymbolUses(
80+
barOp, StringAttr::get(&context, "baz"), &module->getRegion(0));
81+
});
82+
83+
// Symbol as `StringAttr`, rename within module body.
84+
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
85+
auto barOp) {
86+
return symbolTable.replaceAllSymbolUses(StringAttr::get(&context, "bar"),
87+
StringAttr::get(&context, "baz"),
88+
&module->getRegion(0));
89+
});
90+
91+
// Symbol as `Operation *`, rename within function.
92+
testReplaceAllSymbolUses(
93+
[&](auto symbolTable, auto module, auto fooOp, auto barOp) {
94+
return symbolTable.replaceAllSymbolUses(
95+
barOp, StringAttr::get(&context, "baz"), fooOp);
96+
});
97+
98+
// Symbol as `StringAttr`, rename within function.
99+
testReplaceAllSymbolUses([&](auto symbolTable, auto module, auto fooOp,
100+
auto barOp) {
101+
return symbolTable.replaceAllSymbolUses(StringAttr::get(&context, "bar"),
102+
StringAttr::get(&context, "baz"),
103+
fooOp);
104+
});
105+
}
106+
107+
} // namespace

0 commit comments

Comments
 (0)