diff --git a/mlir/include/mlir/IR/OpAsmInterface.td b/mlir/include/mlir/IR/OpAsmInterface.td index 50b98b04dbee..b49e12ea9a85 100644 --- a/mlir/include/mlir/IR/OpAsmInterface.td +++ b/mlir/include/mlir/IR/OpAsmInterface.td @@ -49,7 +49,18 @@ def OpAsmOpInterface : OpInterface<"OpAsmOpInterface"> { }], "void", "getAsmResultNames", (ins "::mlir::OpAsmSetValueNameFn":$setNameFn), - "", ";" + "", "return;" + >, + InterfaceMethod<[{ + Get a special name to use when printing the block arguments for a region + immediately nested under this operation. + }], + "void", "getAsmBlockArgumentNames", + (ins + "::mlir::Region&":$region, + "::mlir::OpAsmSetValueNameFn":$setNameFn + ), + "", "return;" >, StaticInterfaceMethod<[{ Return the default dialect used when printing/parsing operations in diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 4b763f9efe36..2cd09abb1dc2 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -1348,11 +1348,6 @@ class OpAsmDialectInterface /// OpAsmInterface.td#getAsmResultNames for usage details and documentation. virtual void getAsmResultNames(Operation *op, OpAsmSetValueNameFn setNameFn) const {} - - /// Get a special name to use when printing the entry block arguments of the - /// region contained by an operation in this dialect. - virtual void getAsmBlockArgumentNames(Block *block, - OpAsmSetValueNameFn setNameFn) const {} }; } // namespace mlir diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index fe214f2aca2a..fe6893d2147f 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1006,6 +1006,20 @@ void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { } void SSANameState::numberValuesInRegion(Region ®ion) { + auto setBlockArgNameFn = [&](Value arg, StringRef name) { + assert(!valueIDs.count(arg) && "arg numbered multiple times"); + assert(arg.cast().getOwner()->getParent() == ®ion && + "arg not defined in current region"); + setValueName(arg, name); + }; + + if (!printerFlags.shouldPrintGenericOpForm()) { + if (Operation *op = region.getParentOp()) { + if (auto asmInterface = dyn_cast(op)) + asmInterface.getAsmBlockArgumentNames(region, setBlockArgNameFn); + } + } + // Number the values within this region in a breadth-first order. unsigned nextBlockID = 0; for (auto &block : region) { @@ -1017,23 +1031,9 @@ void SSANameState::numberValuesInRegion(Region ®ion) { } void SSANameState::numberValuesInBlock(Block &block) { - auto setArgNameFn = [&](Value arg, StringRef name) { - assert(!valueIDs.count(arg) && "arg numbered multiple times"); - assert(arg.cast().getOwner() == &block && - "arg not defined in 'block'"); - setValueName(arg, name); - }; - - bool isEntryBlock = block.isEntryBlock(); - if (isEntryBlock && !printerFlags.shouldPrintGenericOpForm()) { - if (auto *op = block.getParentOp()) { - if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect())) - asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn); - } - } - // Number the block arguments. We give entry block arguments a special name // 'arg'. + bool isEntryBlock = block.isEntryBlock(); SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : ""); llvm::raw_svector_ostream specialName(specialNameBuffer); for (auto arg : block.getArguments()) { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index aae98baf1c26..b46aaf106997 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -105,20 +105,6 @@ struct TestOpAsmInterface : public OpAsmDialectInterface { if (auto asmOp = dyn_cast(op)) setNameFn(asmOp, "result"); } - - void getAsmBlockArgumentNames(Block *block, - OpAsmSetValueNameFn setNameFn) const final { - auto op = block->getParentOp(); - auto arrayAttr = op->getAttrOfType("arg_names"); - if (!arrayAttr) - return; - auto args = block->getArguments(); - auto e = std::min(arrayAttr.size(), args.size()); - for (unsigned i = 0; i < e; ++i) { - if (auto strAttr = arrayAttr[i].dyn_cast()) - setNameFn(args[i], strAttr.getValue()); - } - } }; struct TestDialectFoldInterface : public DialectFoldInterface { @@ -848,6 +834,19 @@ static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { return parser.parseRegion(*body, ivsInfo, argTypes); } +void PolyForOp::getAsmBlockArgumentNames(Region ®ion, + OpAsmSetValueNameFn setNameFn) { + auto arrayAttr = getOperation()->getAttrOfType("arg_names"); + if (!arrayAttr) + return; + auto args = getRegion().front().getArguments(); + auto e = std::min(arrayAttr.size(), args.size()); + for (unsigned i = 0; i < e; ++i) { + if (auto strAttr = arrayAttr[i].dyn_cast()) + setNameFn(args[i], strAttr.getValue()); + } +} + //===----------------------------------------------------------------------===// // Test removing op with inner ops. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 7dca8165b0d2..80b568c743b0 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1667,13 +1667,16 @@ def PrettyPrintedRegionOp : TEST_Op<"pretty_printed_region", let printer = [{ return ::print(p, *this); }]; } -def PolyForOp : TEST_Op<"polyfor"> +def PolyForOp : TEST_Op<"polyfor", [OpAsmOpInterface]> { let summary = "polyfor operation"; let description = [{ Test op with multiple region arguments, each argument of index type. }]; - + let extraClassDeclaration = [{ + void getAsmBlockArgumentNames(mlir::Region ®ion, + mlir::OpAsmSetValueNameFn setNameFn); + }]; let regions = (region SizedRegion<1>:$region); let parser = [{ return ::parse$cppClass(parser, result); }]; }