Skip to content

Commit

Permalink
Move getAsmBlockArgumentNames from OpAsmDialectInterface to OpAsmOpIn…
Browse files Browse the repository at this point in the history
…terface

This method is more suitable as an opinterface: it seems intrinsic to
individual instances of the operation instead of the dialect.
Also remove the restriction on the interface being applicable to the entry block only.

Differential Revision: https://reviews.llvm.org/D116018
  • Loading branch information
joker-eph committed Dec 20, 2021
1 parent 9c11e95 commit 7f9e9c7
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 37 deletions.
13 changes: 12 additions & 1 deletion mlir/include/mlir/IR/OpAsmInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,18 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> {
}],
"void", "getAsmResultNames",
(ins "::mlir::OpAsmSetValueNameFn":$setNameFn),
"", ";"
"", "return;"
>,
InterfaceMethod<[{
Get a special name to use when printing the block arguments for a region
immediately nested under this operation.
}],
"void", "getAsmBlockArgumentNames",
(ins
"::mlir::Region&":$region,
"::mlir::OpAsmSetValueNameFn":$setNameFn
),
"", "return;"
>,
StaticInterfaceMethod<[{
Return the default dialect used when printing/parsing operations in
Expand Down
5 changes: 0 additions & 5 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -1348,11 +1348,6 @@ class OpAsmDialectInterface
/// OpAsmInterface.td#getAsmResultNames for usage details and documentation.
virtual void getAsmResultNames(Operation *op,
OpAsmSetValueNameFn setNameFn) const {}

/// Get a special name to use when printing the entry block arguments of the
/// region contained by an operation in this dialect.
virtual void getAsmBlockArgumentNames(Block *block,
OpAsmSetValueNameFn setNameFn) const {}
};
} // namespace mlir

Expand Down
30 changes: 15 additions & 15 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,20 @@ void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
}

void SSANameState::numberValuesInRegion(Region &region) {
auto setBlockArgNameFn = [&](Value arg, StringRef name) {
assert(!valueIDs.count(arg) && "arg numbered multiple times");
assert(arg.cast<BlockArgument>().getOwner()->getParent() == &region &&
"arg not defined in current region");
setValueName(arg, name);
};

if (!printerFlags.shouldPrintGenericOpForm()) {
if (Operation *op = region.getParentOp()) {
if (auto asmInterface = dyn_cast<OpAsmOpInterface>(op))
asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn);
}
}

// Number the values within this region in a breadth-first order.
unsigned nextBlockID = 0;
for (auto &block : region) {
Expand All @@ -1017,23 +1031,9 @@ void SSANameState::numberValuesInRegion(Region &region) {
}

void SSANameState::numberValuesInBlock(Block &block) {
auto setArgNameFn = [&](Value arg, StringRef name) {
assert(!valueIDs.count(arg) && "arg numbered multiple times");
assert(arg.cast<BlockArgument>().getOwner() == &block &&
"arg not defined in 'block'");
setValueName(arg, name);
};

bool isEntryBlock = block.isEntryBlock();
if (isEntryBlock && !printerFlags.shouldPrintGenericOpForm()) {
if (auto *op = block.getParentOp()) {
if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect()))
asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
}
}

// Number the block arguments. We give entry block arguments a special name
// 'arg'.
bool isEntryBlock = block.isEntryBlock();
SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
llvm::raw_svector_ostream specialName(specialNameBuffer);
for (auto arg : block.getArguments()) {
Expand Down
27 changes: 13 additions & 14 deletions mlir/test/lib/Dialect/Test/TestDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,6 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
setNameFn(asmOp, "result");
}

void getAsmBlockArgumentNames(Block *block,
OpAsmSetValueNameFn setNameFn) const final {
auto op = block->getParentOp();
auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
if (!arrayAttr)
return;
auto args = block->getArguments();
auto e = std::min(arrayAttr.size(), args.size());
for (unsigned i = 0; i < e; ++i) {
if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
setNameFn(args[i], strAttr.getValue());
}
}
};

struct TestDialectFoldInterface : public DialectFoldInterface {
Expand Down Expand Up @@ -848,6 +834,19 @@ static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
return parser.parseRegion(*body, ivsInfo, argTypes);
}

void PolyForOp::getAsmBlockArgumentNames(Region &region,
OpAsmSetValueNameFn setNameFn) {
auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
if (!arrayAttr)
return;
auto args = getRegion().front().getArguments();
auto e = std::min(arrayAttr.size(), args.size());
for (unsigned i = 0; i < e; ++i) {
if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
setNameFn(args[i], strAttr.getValue());
}
}

//===----------------------------------------------------------------------===//
// Test removing op with inner ops.
//===----------------------------------------------------------------------===//
Expand Down
7 changes: 5 additions & 2 deletions mlir/test/lib/Dialect/Test/TestOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1667,13 +1667,16 @@ def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region",
let printer = [{ return ::print(p, *this); }];
}

def PolyForOp : TEST_Op<"polyfor">
def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]>
{
let summary = "polyfor operation";
let description = [{
Test op with multiple region arguments, each argument of index type.
}];

let extraClassDeclaration = [{
void getAsmBlockArgumentNames(mlir::Region &region,
mlir::OpAsmSetValueNameFn setNameFn);
}];
let regions = (region SizedRegion<1>:$region);
let parser = [{ return ::parse$cppClass(parser, result); }];
}
Expand Down

0 comments on commit 7f9e9c7

Please sign in to comment.