Skip to content

Commit d84ba0c

Browse files
committed
[mlir][bytecode] Add support for deferred attribute/type parsing.
Add ability to defer parsing and re-enqueueing oneself. This enables changing CallSiteLoc parsing to not recurse as deeply: previously this could fail (especially on large inputs in debug mode). Chose an arbitrary depth for now.
1 parent 028bfa2 commit d84ba0c

File tree

2 files changed

+224
-53
lines changed

2 files changed

+224
-53
lines changed

mlir/lib/Bytecode/Reader/BytecodeReader.cpp

Lines changed: 187 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#include <cstddef>
2929
#include <cstdint>
30+
#include <deque>
3031
#include <list>
3132
#include <memory>
3233
#include <numeric>
@@ -863,12 +864,34 @@ class AttrTypeReader {
863864
ArrayRef<uint8_t> sectionData,
864865
ArrayRef<uint8_t> offsetSectionData);
865866

867+
LogicalResult readAttribute(uint64_t index, Attribute &result,
868+
uint64_t depth = 0) {
869+
return readEntry(attributes, index, result, "attribute", depth);
870+
}
871+
872+
LogicalResult readType(uint64_t index, Type &result, uint64_t depth = 0) {
873+
return readEntry(types, index, result, "type", depth);
874+
}
875+
866876
/// Resolve the attribute or type at the given index. Returns nullptr on
867877
/// failure.
868-
Attribute resolveAttribute(size_t index) {
869-
return resolveEntry(attributes, index, "Attribute");
878+
Attribute resolveAttribute(size_t index, uint64_t depth = 0) {
879+
return resolveEntry(attributes, index, "Attribute", depth);
880+
}
881+
Type resolveType(size_t index, uint64_t depth = 0) {
882+
return resolveEntry(types, index, "Type", depth);
883+
}
884+
885+
Attribute getAttributeOrSentinel(size_t index) {
886+
if (index >= attributes.size())
887+
return nullptr;
888+
return attributes[index].entry;
889+
}
890+
Type getTypeOrSentinel(size_t index) {
891+
if (index >= types.size())
892+
return nullptr;
893+
return types[index].entry;
870894
}
871-
Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); }
872895

873896
/// Parse a reference to an attribute or type using the given reader.
874897
LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) {
@@ -909,23 +932,33 @@ class AttrTypeReader {
909932
llvm::getTypeName<T>(), ", but got: ", baseResult);
910933
}
911934

935+
/// Add an index to the deferred worklist for re-parsing.
936+
void addDeferredParsing(uint64_t index) { deferredWorklist.push_back(index); }
937+
912938
private:
913939
/// Resolve the given entry at `index`.
914940
template <typename T>
915-
T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
916-
StringRef entryType);
941+
T resolveEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
942+
StringRef entryType, uint64_t depth = 0);
917943

918-
/// Parse an entry using the given reader that was encoded using the textual
919-
/// assembly format.
944+
/// Read the entry at the given index, returning failure if the entry is not
945+
/// yet resolved.
920946
template <typename T>
921-
LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
922-
StringRef entryType);
947+
LogicalResult readEntry(SmallVectorImpl<Entry<T>> &entries, uint64_t index,
948+
T &result, StringRef entryType, uint64_t depth);
923949

924950
/// Parse an entry using the given reader that was encoded using a custom
925951
/// bytecode format.
926952
template <typename T>
927953
LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader,
928-
StringRef entryType);
954+
StringRef entryType, uint64_t index,
955+
uint64_t depth);
956+
957+
/// Parse an entry using the given reader that was encoded using the textual
958+
/// assembly format.
959+
template <typename T>
960+
LogicalResult parseAsmEntry(T &result, EncodingReader &reader,
961+
StringRef entryType);
929962

930963
/// The string section reader used to resolve string references when parsing
931964
/// custom encoded attribute/type entries.
@@ -951,6 +984,10 @@ class AttrTypeReader {
951984

952985
/// Reference to the parser configuration.
953986
const ParserConfig &parserConfig;
987+
988+
/// Worklist for deferred attribute/type parsing. This is used to handle
989+
/// deeply nested structures like CallSiteLoc iteratively.
990+
std::vector<uint64_t> deferredWorklist;
954991
};
955992

956993
class DialectReader : public DialectBytecodeReader {
@@ -959,10 +996,11 @@ class DialectReader : public DialectBytecodeReader {
959996
const StringSectionReader &stringReader,
960997
const ResourceSectionReader &resourceReader,
961998
const llvm::StringMap<BytecodeDialect *> &dialectsMap,
962-
EncodingReader &reader, uint64_t &bytecodeVersion)
999+
EncodingReader &reader, uint64_t &bytecodeVersion,
1000+
uint64_t depth = 0)
9631001
: attrTypeReader(attrTypeReader), stringReader(stringReader),
9641002
resourceReader(resourceReader), dialectsMap(dialectsMap),
965-
reader(reader), bytecodeVersion(bytecodeVersion) {}
1003+
reader(reader), bytecodeVersion(bytecodeVersion), depth(depth) {}
9661004

9671005
InFlightDiagnostic emitError(const Twine &msg) const override {
9681006
return reader.emitError(msg);
@@ -998,14 +1036,40 @@ class DialectReader : public DialectBytecodeReader {
9981036
// IR
9991037
//===--------------------------------------------------------------------===//
10001038

1039+
/// The maximum depth to eagerly parse nested attributes/types before
1040+
/// deferring.
1041+
static constexpr uint64_t maxAttrTypeDepth = 5;
1042+
10011043
LogicalResult readAttribute(Attribute &result) override {
1002-
return attrTypeReader.parseAttribute(reader, result);
1044+
uint64_t index;
1045+
if (failed(reader.parseVarInt(index)))
1046+
return failure();
1047+
if (depth > maxAttrTypeDepth) {
1048+
if (Attribute attr = attrTypeReader.getAttributeOrSentinel(index)) {
1049+
result = attr;
1050+
return success();
1051+
}
1052+
attrTypeReader.addDeferredParsing(index);
1053+
return failure();
1054+
}
1055+
return attrTypeReader.readAttribute(index, result, depth + 1);
10031056
}
10041057
LogicalResult readOptionalAttribute(Attribute &result) override {
10051058
return attrTypeReader.parseOptionalAttribute(reader, result);
10061059
}
10071060
LogicalResult readType(Type &result) override {
1008-
return attrTypeReader.parseType(reader, result);
1061+
uint64_t index;
1062+
if (failed(reader.parseVarInt(index)))
1063+
return failure();
1064+
if (depth > maxAttrTypeDepth) {
1065+
if (Type type = attrTypeReader.getTypeOrSentinel(index)) {
1066+
result = type;
1067+
return success();
1068+
}
1069+
attrTypeReader.addDeferredParsing(index);
1070+
return failure();
1071+
}
1072+
return attrTypeReader.readType(index, result, depth + 1);
10091073
}
10101074

10111075
FailureOr<AsmDialectResourceHandle> readResourceHandle() override {
@@ -1095,6 +1159,7 @@ class DialectReader : public DialectBytecodeReader {
10951159
const llvm::StringMap<BytecodeDialect *> &dialectsMap;
10961160
EncodingReader &reader;
10971161
uint64_t &bytecodeVersion;
1162+
uint64_t depth;
10981163
};
10991164

11001165
/// Wraps the properties section and handles reading properties out of it.
@@ -1239,68 +1304,110 @@ LogicalResult AttrTypeReader::initialize(
12391304

12401305
template <typename T>
12411306
T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index,
1242-
StringRef entryType) {
1307+
StringRef entryType, uint64_t depth) {
12431308
if (index >= entries.size()) {
12441309
emitError(fileLoc) << "invalid " << entryType << " index: " << index;
12451310
return {};
12461311
}
12471312

1248-
// If the entry has already been resolved, there is nothing left to do.
1249-
Entry<T> &entry = entries[index];
1250-
if (entry.entry)
1251-
return entry.entry;
1313+
// Fast path: Try direct parsing without worklist overhead.
1314+
// This handles the common case where there are no deferred dependencies.
1315+
deferredWorklist.clear();
1316+
T result;
1317+
if (succeeded(readEntry(entries, index, result, entryType, depth))) {
1318+
assert(deferredWorklist.empty());
1319+
return result;
1320+
}
1321+
if (deferredWorklist.empty()) {
1322+
// Failed with no deferred entries is error.
1323+
return T();
1324+
}
12521325

1253-
// Parse the entry.
1254-
EncodingReader reader(entry.data, fileLoc);
1326+
// Slow path: Use worklist to handle deferred dependencies. Use a deque to
1327+
// iteratively resolve entries with dependencies.
1328+
// - Pop from front to process
1329+
// - Push new dependencies to front (depth-first)
1330+
// - Move failed entries to back (retry after dependencies)
1331+
std::deque<size_t> worklist;
1332+
llvm::DenseSet<size_t> inWorklist;
12551333

1256-
// Parse based on how the entry was encoded.
1257-
if (entry.hasCustomEncoding) {
1258-
if (failed(parseCustomEntry(entry, reader, entryType)))
1259-
return T();
1260-
} else if (failed(parseAsmEntry(entry.entry, reader, entryType))) {
1261-
return T();
1334+
// Add the original index and any dependencies from the fast path attempt.
1335+
worklist.push_back(index);
1336+
inWorklist.insert(index);
1337+
for (uint64_t idx : llvm::reverse(deferredWorklist)) {
1338+
if (inWorklist.insert(idx).second)
1339+
worklist.push_front(idx);
12621340
}
12631341

1264-
if (!reader.empty()) {
1265-
reader.emitError("unexpected trailing bytes after " + entryType + " entry");
1266-
return T();
1342+
while (!worklist.empty()) {
1343+
size_t currentIndex = worklist.front();
1344+
worklist.pop_front();
1345+
1346+
// Clear the deferred worklist before parsing to capture any new entries.
1347+
deferredWorklist.clear();
1348+
1349+
T result;
1350+
if (succeeded(readEntry(entries, currentIndex, result, entryType, depth))) {
1351+
inWorklist.erase(currentIndex);
1352+
continue;
1353+
}
1354+
1355+
if (deferredWorklist.empty()) {
1356+
// Parsing failed with no deferred entries which implies an error.
1357+
return T();
1358+
}
1359+
1360+
// Move this entry to the back to retry after dependencies.
1361+
worklist.push_back(currentIndex);
1362+
1363+
// Add dependencies to the front (in reverse so they maintain order).
1364+
for (uint64_t idx : llvm::reverse(deferredWorklist)) {
1365+
if (inWorklist.insert(idx).second)
1366+
worklist.push_front(idx);
1367+
}
1368+
deferredWorklist.clear();
12671369
}
1268-
return entry.entry;
1370+
return entries[index].entry;
12691371
}
12701372

12711373
template <typename T>
1272-
LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
1273-
StringRef entryType) {
1274-
StringRef asmStr;
1275-
if (failed(reader.parseNullTerminatedString(asmStr)))
1276-
return failure();
1374+
LogicalResult AttrTypeReader::readEntry(SmallVectorImpl<Entry<T>> &entries,
1375+
uint64_t index, T &result,
1376+
StringRef entryType, uint64_t depth) {
1377+
if (index >= entries.size())
1378+
return emitError(fileLoc) << "invalid " << entryType << " index: " << index;
12771379

1278-
// Invoke the MLIR assembly parser to parse the entry text.
1279-
size_t numRead = 0;
1280-
MLIRContext *context = fileLoc->getContext();
1281-
if constexpr (std::is_same_v<T, Type>)
1282-
result =
1283-
::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
1284-
else
1285-
result = ::parseAttribute(asmStr, context, Type(), &numRead,
1286-
/*isKnownNullTerminated=*/true);
1287-
if (!result)
1380+
// If the entry has already been resolved, return it.
1381+
Entry<T> &entry = entries[index];
1382+
if (entry.entry) {
1383+
result = entry.entry;
1384+
return success();
1385+
}
1386+
1387+
// If the entry hasn't been resolved, try to parse it.
1388+
EncodingReader reader(entry.data, fileLoc);
1389+
LogicalResult parseResult =
1390+
entry.hasCustomEncoding
1391+
? parseCustomEntry(entry, reader, entryType, index, depth)
1392+
: parseAsmEntry(entry.entry, reader, entryType);
1393+
if (failed(parseResult))
12881394
return failure();
12891395

1290-
// Ensure there weren't dangling characters after the entry.
1291-
if (numRead != asmStr.size()) {
1292-
return reader.emitError("trailing characters found after ", entryType,
1293-
" assembly format: ", asmStr.drop_front(numRead));
1294-
}
1396+
if (!reader.empty())
1397+
return reader.emitError("unexpected trailing bytes after " + entryType +
1398+
" entry");
1399+
1400+
result = entry.entry;
12951401
return success();
12961402
}
12971403

12981404
template <typename T>
12991405
LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
13001406
EncodingReader &reader,
1301-
StringRef entryType) {
1407+
StringRef entryType,
1408+
uint64_t index, uint64_t depth) {
13021409
DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap,
1303-
reader, bytecodeVersion);
1410+
reader, bytecodeVersion, depth);
13041411
if (failed(entry.dialect->load(dialectReader, fileLoc.getContext())))
13051412
return failure();
13061413

@@ -1350,6 +1457,33 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry,
13501457
return success(!!entry.entry);
13511458
}
13521459

1460+
template <typename T>
1461+
LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader,
1462+
StringRef entryType) {
1463+
StringRef asmStr;
1464+
if (failed(reader.parseNullTerminatedString(asmStr)))
1465+
return failure();
1466+
1467+
// Invoke the MLIR assembly parser to parse the entry text.
1468+
size_t numRead = 0;
1469+
MLIRContext *context = fileLoc->getContext();
1470+
if constexpr (std::is_same_v<T, Type>)
1471+
result =
1472+
::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true);
1473+
else
1474+
result = ::parseAttribute(asmStr, context, Type(), &numRead,
1475+
/*isKnownNullTerminated=*/true);
1476+
if (!result)
1477+
return failure();
1478+
1479+
// Ensure there weren't dangling characters after the entry.
1480+
if (numRead != asmStr.size()) {
1481+
return reader.emitError("trailing characters found after ", entryType,
1482+
" assembly format: ", asmStr.drop_front(numRead));
1483+
}
1484+
return success();
1485+
}
1486+
13531487
//===----------------------------------------------------------------------===//
13541488
// Bytecode Reader
13551489
//===----------------------------------------------------------------------===//

mlir/unittests/Bytecode/BytecodeTest.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/IR/OwningOpRef.h"
1616
#include "mlir/Parser/Parser.h"
1717

18+
#include "mlir/IR/BuiltinOps.h"
1819
#include "llvm/ADT/StringRef.h"
1920
#include "llvm/Support/Alignment.h"
2021
#include "llvm/Support/Endian.h"
@@ -228,3 +229,39 @@ TEST(Bytecode, OpWithoutProperties) {
228229
EXPECT_TRUE(OperationEquivalence::computeHash(op.get()) ==
229230
OperationEquivalence::computeHash(roundtripped));
230231
}
232+
233+
TEST(Bytecode, DeepCallSiteLoc) {
234+
MLIRContext context;
235+
ParserConfig config(&context);
236+
237+
// Create a deep CallSiteLoc chain to test iterative parsing.
238+
Location baseLoc = FileLineColLoc::get(&context, "test.mlir", 1, 1);
239+
Location loc = baseLoc;
240+
constexpr int kDepth = 1000;
241+
for (int i = 0; i < kDepth; ++i) {
242+
loc = CallSiteLoc::get(loc, baseLoc);
243+
}
244+
245+
// Create a simple module with the deep location.
246+
Builder builder(&context);
247+
OwningOpRef<ModuleOp> module =
248+
ModuleOp::create(loc, /*attributes=*/std::nullopt);
249+
ASSERT_TRUE(module);
250+
251+
// Write to bytecode.
252+
std::string bytecode;
253+
llvm::raw_string_ostream os(bytecode);
254+
ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), os)));
255+
256+
// Parse it back using the bytecode reader.
257+
std::unique_ptr<Block> block = std::make_unique<Block>();
258+
ASSERT_TRUE(succeeded(readBytecodeFile(
259+
llvm::MemoryBufferRef(bytecode, "string-buffer"), block.get(), config)));
260+
261+
// Verify we got the roundtripped module.
262+
ASSERT_FALSE(block->empty());
263+
Operation *roundTripped = &block->front();
264+
265+
// Verify the location matches.
266+
EXPECT_EQ(module.get()->getLoc(), roundTripped->getLoc());
267+
}

0 commit comments

Comments
 (0)