Skip to content

Commit

Permalink
[mlir][bytecode] Avoid recording null arglocs & realloc opnames.
Browse files Browse the repository at this point in the history
For block arg locs a common case is no/uknown location (where the producer
signifies they don't care about blockarg location). Also avoid needing to
dynamically resize opnames during parsing.

Assumed to be post lazy loading change, so chose version 3.

Differential Revision: https://reviews.llvm.org/D151038
  • Loading branch information
jpienaar committed May 25, 2023
1 parent 12ccc59 commit 1826fad
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 13 deletions.
5 changes: 3 additions & 2 deletions mlir/docs/BytecodeFormat.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ dialects that were also referenced.
dialect_section {
numDialects: varint,
dialectNames: varint[],
numTotalOpNames: varint,
opNames: op_name_group[]
}
Expand Down Expand Up @@ -444,8 +445,8 @@ block_arguments {
}
block_argument {
typeIndex: varint,
location: varint
typeAndLocation: varint, // (type << 1) | (hasLocation)
location: varint? // Optional, else unknown location
}
```

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Bytecode/Encoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ enum {
kMinSupportedVersion = 0,

/// The current bytecode version.
kVersion = 3,
kVersion = 4,

/// An arbitrary value used to fill alignment padding.
kAlignmentByte = 0xCB,
Expand Down
30 changes: 25 additions & 5 deletions mlir/lib/Bytecode/Reader/BytecodeReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1603,6 +1603,14 @@ BytecodeReader::Impl::parseDialectSection(ArrayRef<uint8_t> sectionData) {
opNames.emplace_back(dialect, opName);
return success();
};
// Avoid re-allocation in bytecode version > 3 where the number of ops are
// known.
if (version > 3) {
uint64_t numOps;
if (failed(sectionReader.parseVarInt(numOps)))
return failure();
opNames.reserve(numOps);
}
while (!sectionReader.empty())
if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName)))
return failure();
Expand Down Expand Up @@ -2175,13 +2183,25 @@ LogicalResult BytecodeReader::Impl::parseBlockArguments(EncodingReader &reader,
argTypes.reserve(numArgs);
argLocs.reserve(numArgs);

Location unknownLoc = UnknownLoc::get(config.getContext());
while (numArgs--) {
Type argType;
LocationAttr argLoc;
if (failed(parseType(reader, argType)) ||
failed(parseAttribute(reader, argLoc)))
return failure();

LocationAttr argLoc = unknownLoc;
if (version > 3) {
// Parse the type with hasLoc flag to determine if it has type.
uint64_t typeIdx;
bool hasLoc;
if (failed(reader.parseVarIntWithFlag(typeIdx, hasLoc)) ||
!(argType = attrTypeReader.resolveType(typeIdx)))
return failure();
if (hasLoc && failed(parseAttribute(reader, argLoc)))
return failure();
} else {
// All args has type and location.
if (failed(parseType(reader, argType)) ||
failed(parseAttribute(reader, argLoc)))
return failure();
}
argTypes.push_back(argType);
argLocs.push_back(argLoc);
}
Expand Down
17 changes: 14 additions & 3 deletions mlir/lib/Bytecode/Writer/BytecodeWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,9 @@ void BytecodeWriter::writeDialectSection(EncodingEmitter &emitter) {
std::move(versionEmitter));
}

if (config.bytecodeVersion > 3)
dialectEmitter.emitVarInt(size(numberingState.getOpNames()));

// Emit the referenced operation names grouped by dialect.
auto emitOpName = [&](OpNameNumbering &name) {
dialectEmitter.emitVarInt(stringSection.insert(name.name.stripDialect()));
Expand Down Expand Up @@ -670,8 +673,16 @@ void BytecodeWriter::writeBlock(EncodingEmitter &emitter, Block *block) {
if (hasArgs) {
emitter.emitVarInt(args.size());
for (BlockArgument arg : args) {
emitter.emitVarInt(numberingState.getNumber(arg.getType()));
emitter.emitVarInt(numberingState.getNumber(arg.getLoc()));
Location argLoc = arg.getLoc();
if (config.bytecodeVersion > 3) {
emitter.emitVarIntWithFlag(numberingState.getNumber(arg.getType()),
!isa<UnknownLoc>(argLoc));
if (!isa<UnknownLoc>(argLoc))
emitter.emitVarInt(numberingState.getNumber(argLoc));
} else {
emitter.emitVarInt(numberingState.getNumber(arg.getType()));
emitter.emitVarInt(numberingState.getNumber(argLoc));
}
}
if (config.bytecodeVersion > 2) {
uint64_t maskOffset = emitter.size();
Expand Down Expand Up @@ -755,7 +766,7 @@ void BytecodeWriter::writeOp(EncodingEmitter &emitter, Operation *op) {

for (Region &region : op->getRegions()) {
// If the region is not isolated from above, or we are emitting bytecode
// targetting version <2, we don't use a section.
// targeting version <2, we don't use a section.
if (!isIsolatedFromAbove || config.bytecodeVersion < 2) {
writeRegion(emitter, &region);
continue;
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Bytecode/general.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
}
"bytecode.branch"()[^secondBlock] : () -> ()

^secondBlock(%arg1: i32, %arg2: !bytecode.int, %arg3: !pdl.operation):
^secondBlock(%arg1: i32 loc(unknown), %arg2: !bytecode.int, %arg3: !pdl.operation loc(unknown)):
"bytecode.regions"() ({
"bytecode.operands"(%arg1, %arg2, %arg3) : (i32, !bytecode.int, !pdl.operation) -> ()
"bytecode.return"() : () -> ()
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Bytecode/invalid/invalid-structure.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
//===--------------------------------------------------------------------===//

// RUN: not mlir-opt %S/invalid-structure-version.mlirbc 2>&1 | FileCheck %s --check-prefix=VERSION
// VERSION: bytecode version 127 is newer than the current version 3
// VERSION: bytecode version 127 is newer than the current version

//===--------------------------------------------------------------------===//
// Producer
Expand Down

0 comments on commit 1826fad

Please sign in to comment.