diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 1659437e1eb24..dd367b5922558 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -27,6 +27,7 @@ #include #include +#include #include #include #include @@ -830,6 +831,23 @@ namespace { /// This class provides support for reading attribute and type entries from the /// bytecode. Attribute and Type entries are read lazily on demand, so we use /// this reader to manage when to actually parse them from the bytecode. +/// +/// The parsing of attributes & types are generally recursive, this can lead to +/// stack overflows for deeply nested structures, so we track a few extra pieces +/// of information to avoid this: +/// +/// - `depth`: The current depth while parsing nested attributes. We defer on +/// parsing deeply nested attributes to avoid potential stack overflows. The +/// deferred parsing is achieved by reporting a failure when parsing a nested +/// attribute/type and registering the index of the encountered attribute/type +/// in the deferred parsing worklist. Hence, a failure with deffered entry +/// does not constitute a failure, it also requires that folks return on +/// first failure rather than attempting additional parses. +/// - `deferredWorklist`: A list of attribute/type indices that we could not +/// parse due to hitting the depth limit. The worklist is used to capture the +/// indices of attributes/types that need to be parsed/reparsed when we hit +/// the depth limit. This enables moving the tracking of what needs to be +/// parsed to the heap. class AttrTypeReader { /// This class represents a single attribute or type entry. template @@ -863,12 +881,34 @@ class AttrTypeReader { ArrayRef sectionData, ArrayRef offsetSectionData); + LogicalResult readAttribute(uint64_t index, Attribute &result, + uint64_t depth = 0) { + return readEntry(attributes, index, result, "attribute", depth); + } + + LogicalResult readType(uint64_t index, Type &result, uint64_t depth = 0) { + return readEntry(types, index, result, "type", depth); + } + /// Resolve the attribute or type at the given index. Returns nullptr on /// failure. - Attribute resolveAttribute(size_t index) { - return resolveEntry(attributes, index, "Attribute"); + Attribute resolveAttribute(size_t index, uint64_t depth = 0) { + return resolveEntry(attributes, index, "Attribute", depth); + } + Type resolveType(size_t index, uint64_t depth = 0) { + return resolveEntry(types, index, "Type", depth); + } + + Attribute getAttributeOrSentinel(size_t index) { + if (index >= attributes.size()) + return nullptr; + return attributes[index].entry; + } + Type getTypeOrSentinel(size_t index) { + if (index >= types.size()) + return nullptr; + return types[index].entry; } - Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } /// Parse a reference to an attribute or type using the given reader. LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) { @@ -909,23 +949,33 @@ class AttrTypeReader { llvm::getTypeName(), ", but got: ", baseResult); } + /// Add an index to the deferred worklist for re-parsing. + void addDeferredParsing(uint64_t index) { deferredWorklist.push_back(index); } + private: /// Resolve the given entry at `index`. template - T resolveEntry(SmallVectorImpl> &entries, size_t index, - StringRef entryType); + T resolveEntry(SmallVectorImpl> &entries, uint64_t index, + StringRef entryType, uint64_t depth = 0); - /// Parse an entry using the given reader that was encoded using the textual - /// assembly format. + /// Read the entry at the given index, returning failure if the entry is not + /// yet resolved. template - LogicalResult parseAsmEntry(T &result, EncodingReader &reader, - StringRef entryType); + LogicalResult readEntry(SmallVectorImpl> &entries, uint64_t index, + T &result, StringRef entryType, uint64_t depth); /// Parse an entry using the given reader that was encoded using a custom /// bytecode format. template LogicalResult parseCustomEntry(Entry &entry, EncodingReader &reader, - StringRef entryType); + StringRef entryType, uint64_t index, + uint64_t depth); + + /// Parse an entry using the given reader that was encoded using the textual + /// assembly format. + template + LogicalResult parseAsmEntry(T &result, EncodingReader &reader, + StringRef entryType); /// The string section reader used to resolve string references when parsing /// custom encoded attribute/type entries. @@ -951,6 +1001,10 @@ class AttrTypeReader { /// Reference to the parser configuration. const ParserConfig &parserConfig; + + /// Worklist for deferred attribute/type parsing. This is used to handle + /// deeply nested structures like CallSiteLoc iteratively. + std::vector deferredWorklist; }; class DialectReader : public DialectBytecodeReader { @@ -959,10 +1013,11 @@ class DialectReader : public DialectBytecodeReader { const StringSectionReader &stringReader, const ResourceSectionReader &resourceReader, const llvm::StringMap &dialectsMap, - EncodingReader &reader, uint64_t &bytecodeVersion) + EncodingReader &reader, uint64_t &bytecodeVersion, + uint64_t depth = 0) : attrTypeReader(attrTypeReader), stringReader(stringReader), resourceReader(resourceReader), dialectsMap(dialectsMap), - reader(reader), bytecodeVersion(bytecodeVersion) {} + reader(reader), bytecodeVersion(bytecodeVersion), depth(depth) {} InFlightDiagnostic emitError(const Twine &msg) const override { return reader.emitError(msg); @@ -998,14 +1053,40 @@ class DialectReader : public DialectBytecodeReader { // IR //===--------------------------------------------------------------------===// + /// The maximum depth to eagerly parse nested attributes/types before + /// deferring. + static constexpr uint64_t maxAttrTypeDepth = 5; + LogicalResult readAttribute(Attribute &result) override { - return attrTypeReader.parseAttribute(reader, result); + uint64_t index; + if (failed(reader.parseVarInt(index))) + return failure(); + if (depth > maxAttrTypeDepth) { + if (Attribute attr = attrTypeReader.getAttributeOrSentinel(index)) { + result = attr; + return success(); + } + attrTypeReader.addDeferredParsing(index); + return failure(); + } + return attrTypeReader.readAttribute(index, result, depth + 1); } LogicalResult readOptionalAttribute(Attribute &result) override { return attrTypeReader.parseOptionalAttribute(reader, result); } LogicalResult readType(Type &result) override { - return attrTypeReader.parseType(reader, result); + uint64_t index; + if (failed(reader.parseVarInt(index))) + return failure(); + if (depth > maxAttrTypeDepth) { + if (Type type = attrTypeReader.getTypeOrSentinel(index)) { + result = type; + return success(); + } + attrTypeReader.addDeferredParsing(index); + return failure(); + } + return attrTypeReader.readType(index, result, depth + 1); } FailureOr readResourceHandle() override { @@ -1095,6 +1176,7 @@ class DialectReader : public DialectBytecodeReader { const llvm::StringMap &dialectsMap; EncodingReader &reader; uint64_t &bytecodeVersion; + uint64_t depth; }; /// Wraps the properties section and handles reading properties out of it. @@ -1239,68 +1321,110 @@ LogicalResult AttrTypeReader::initialize( template T AttrTypeReader::resolveEntry(SmallVectorImpl> &entries, size_t index, - StringRef entryType) { + StringRef entryType, uint64_t depth) { if (index >= entries.size()) { emitError(fileLoc) << "invalid " << entryType << " index: " << index; return {}; } - // If the entry has already been resolved, there is nothing left to do. - Entry &entry = entries[index]; - if (entry.entry) - return entry.entry; + // Fast path: Try direct parsing without worklist overhead. This handles the + // common case where there are no deferred dependencies. + assert(deferredWorklist.empty()); + T result; + if (succeeded(readEntry(entries, index, result, entryType, depth))) { + assert(deferredWorklist.empty()); + return result; + } + if (deferredWorklist.empty()) { + // Failed with no deferred entries is error. + return T(); + } - // Parse the entry. - EncodingReader reader(entry.data, fileLoc); + // Slow path: Use worklist to handle deferred dependencies. Use a deque to + // iteratively resolve entries with dependencies. + // - Pop from front to process + // - Push new dependencies to front (depth-first) + // - Move failed entries to back (retry after dependencies) + std::deque worklist; + llvm::DenseSet inWorklist; - // Parse based on how the entry was encoded. - if (entry.hasCustomEncoding) { - if (failed(parseCustomEntry(entry, reader, entryType))) - return T(); - } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) { - return T(); + // Add the original index and any dependencies from the fast path attempt. + worklist.push_back(index); + inWorklist.insert(index); + for (uint64_t idx : llvm::reverse(deferredWorklist)) { + if (inWorklist.insert(idx).second) + worklist.push_front(idx); } - if (!reader.empty()) { - reader.emitError("unexpected trailing bytes after " + entryType + " entry"); - return T(); + while (!worklist.empty()) { + size_t currentIndex = worklist.front(); + worklist.pop_front(); + + // Clear the deferred worklist before parsing to capture any new entries. + deferredWorklist.clear(); + + T result; + if (succeeded(readEntry(entries, currentIndex, result, entryType, depth))) { + inWorklist.erase(currentIndex); + continue; + } + + if (deferredWorklist.empty()) { + // Parsing failed with no deferred entries which implies an error. + return T(); + } + + // Move this entry to the back to retry after dependencies. + worklist.push_back(currentIndex); + + // Add dependencies to the front (in reverse so they maintain order). + for (uint64_t idx : llvm::reverse(deferredWorklist)) { + if (inWorklist.insert(idx).second) + worklist.push_front(idx); + } + deferredWorklist.clear(); } - return entry.entry; + return entries[index].entry; } template -LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, - StringRef entryType) { - StringRef asmStr; - if (failed(reader.parseNullTerminatedString(asmStr))) - return failure(); +LogicalResult AttrTypeReader::readEntry(SmallVectorImpl> &entries, + uint64_t index, T &result, + StringRef entryType, uint64_t depth) { + if (index >= entries.size()) + return emitError(fileLoc) << "invalid " << entryType << " index: " << index; - // Invoke the MLIR assembly parser to parse the entry text. - size_t numRead = 0; - MLIRContext *context = fileLoc->getContext(); - if constexpr (std::is_same_v) - result = - ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); - else - result = ::parseAttribute(asmStr, context, Type(), &numRead, - /*isKnownNullTerminated=*/true); - if (!result) + // If the entry has already been resolved, return it. + Entry &entry = entries[index]; + if (entry.entry) { + result = entry.entry; + return success(); + } + + // If the entry hasn't been resolved, try to parse it. + EncodingReader reader(entry.data, fileLoc); + LogicalResult parseResult = + entry.hasCustomEncoding + ? parseCustomEntry(entry, reader, entryType, index, depth) + : parseAsmEntry(entry.entry, reader, entryType); + if (failed(parseResult)) return failure(); - // Ensure there weren't dangling characters after the entry. - if (numRead != asmStr.size()) { - return reader.emitError("trailing characters found after ", entryType, - " assembly format: ", asmStr.drop_front(numRead)); - } + if (!reader.empty()) + return reader.emitError("unexpected trailing bytes after " + entryType + + " entry"); + + result = entry.entry; return success(); } template LogicalResult AttrTypeReader::parseCustomEntry(Entry &entry, EncodingReader &reader, - StringRef entryType) { + StringRef entryType, + uint64_t index, uint64_t depth) { DialectReader dialectReader(*this, stringReader, resourceReader, dialectsMap, - reader, bytecodeVersion); + reader, bytecodeVersion, depth); if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) return failure(); @@ -1350,6 +1474,33 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry &entry, return success(!!entry.entry); } +template +LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, + StringRef entryType) { + StringRef asmStr; + if (failed(reader.parseNullTerminatedString(asmStr))) + return failure(); + + // Invoke the MLIR assembly parser to parse the entry text. + size_t numRead = 0; + MLIRContext *context = fileLoc->getContext(); + if constexpr (std::is_same_v) + result = + ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); + else + result = ::parseAttribute(asmStr, context, Type(), &numRead, + /*isKnownNullTerminated=*/true); + if (!result) + return failure(); + + // Ensure there weren't dangling characters after the entry. + if (numRead != asmStr.size()) { + return reader.emitError("trailing characters found after ", entryType, + " assembly format: ", asmStr.drop_front(numRead)); + } + return success(); +} + //===----------------------------------------------------------------------===// // Bytecode Reader //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/Bytecode/BytecodeTest.cpp b/mlir/unittests/Bytecode/BytecodeTest.cpp index d7b442f6832d0..30e7ed9b6cb7e 100644 --- a/mlir/unittests/Bytecode/BytecodeTest.cpp +++ b/mlir/unittests/Bytecode/BytecodeTest.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/OwningOpRef.h" #include "mlir/Parser/Parser.h" +#include "mlir/IR/BuiltinOps.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Alignment.h" #include "llvm/Support/Endian.h" @@ -228,3 +229,39 @@ TEST(Bytecode, OpWithoutProperties) { EXPECT_TRUE(OperationEquivalence::computeHash(op.get()) == OperationEquivalence::computeHash(roundtripped)); } + +TEST(Bytecode, DeepCallSiteLoc) { + MLIRContext context; + ParserConfig config(&context); + + // Create a deep CallSiteLoc chain to test iterative parsing. + Location baseLoc = FileLineColLoc::get(&context, "test.mlir", 1, 1); + Location loc = baseLoc; + constexpr int kDepth = 1000; + for (int i = 0; i < kDepth; ++i) { + loc = CallSiteLoc::get(loc, baseLoc); + } + + // Create a simple module with the deep location. + Builder builder(&context); + OwningOpRef module = + ModuleOp::create(loc, /*attributes=*/std::nullopt); + ASSERT_TRUE(module); + + // Write to bytecode. + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), os))); + + // Parse it back using the bytecode reader. + std::unique_ptr block = std::make_unique(); + ASSERT_TRUE(succeeded(readBytecodeFile( + llvm::MemoryBufferRef(bytecode, "string-buffer"), block.get(), config))); + + // Verify we got the roundtripped module. + ASSERT_FALSE(block->empty()); + Operation *roundTripped = &block->front(); + + // Verify the location matches. + EXPECT_EQ(module.get()->getLoc(), roundTripped->getLoc()); +}