diff --git a/mlir/include/mlir/IR/StorageUniquerSupport.h b/mlir/include/mlir/IR/StorageUniquerSupport.h index fb64f15162df..183d3f37ed3a 100644 --- a/mlir/include/mlir/IR/StorageUniquerSupport.h +++ b/mlir/include/mlir/IR/StorageUniquerSupport.h @@ -200,7 +200,7 @@ class StorageUserBase : public BaseT, public Traits... { // If the construction invariants fail then we return a null attribute. if (failed(ConcreteT::verify(emitErrorFn, args...))) return ConcreteT(); - return UniquerT::template get(ctx, args...); + return UniquerT::template get(ctx, std::forward(args)...); } /// Get an instance of the concrete type from a void pointer. diff --git a/mlir/test/lib/Dialect/Test/TestAttrDefs.td b/mlir/test/lib/Dialect/Test/TestAttrDefs.td index 2c5acb1b99a4..64ccd8b36c5b 100644 --- a/mlir/test/lib/Dialect/Test/TestAttrDefs.td +++ b/mlir/test/lib/Dialect/Test/TestAttrDefs.td @@ -331,6 +331,7 @@ def TestCopyCount : Test_Attr<"TestCopyCount"> { let mnemonic = "copy_count"; let parameters = (ins TestParamCopyCount:$copy_count); let assemblyFormat = "`<` $copy_count `>`"; + let genVerifyDecl = 1; } def TestConditionalAliasAttr : Test_Attr<"TestConditionalAlias"> { diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index 32fef18261ce..c2ea88bbe45c 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -177,6 +177,16 @@ static void printTrueFalse(AsmPrinter &p, std::optional result) { p << (*result ? "true" : "false"); } +//===----------------------------------------------------------------------===// +// TestCopyCountAttr Implementation +//===----------------------------------------------------------------------===// + +LogicalResult TestCopyCountAttr::verify( + llvm::function_ref<::mlir::InFlightDiagnostic()> /*emitError*/, + CopyCount /*copy_count*/) { + return success(); +} + //===----------------------------------------------------------------------===// // CopyCountAttr Implementation //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/attrdefs.td b/mlir/test/mlir-tblgen/attrdefs.td index 35d2c49619ee..8bd4af6ee73b 100644 --- a/mlir/test/mlir-tblgen/attrdefs.td +++ b/mlir/test/mlir-tblgen/attrdefs.td @@ -115,6 +115,11 @@ def B_CompoundAttrA : TestAttr<"CompoundA"> { // DEF: return new (allocator.allocate()) // DEF-SAME: CompoundAAttrStorage(std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner)); +// DEF: CompoundAAttr CompoundAAttr::getChecked( +// DEF-SAME: int widthOfSomething, ::test::SimpleTypeA exampleTdType, ::llvm::APFloat apFloat, ::llvm::ArrayRef dims, ::mlir::Type inner +// DEF-SAME: ) +// DEF-NEXT: return Base::getChecked(emitError, context, std::move(widthOfSomething), std::move(exampleTdType), std::move(apFloat), std::move(dims), std::move(inner)); + // DEF: ::mlir::Type CompoundAAttr::getInner() const { // DEF-NEXT: return getImpl()->inner; } diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index 8cc831441810..37968d863954 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -393,7 +393,7 @@ void DefGen::emitCheckedBuilder() { MethodBody &body = m->body().indent(); auto scope = body.scope("return Base::getChecked(emitError, context", ");"); for (const auto ¶m : params) - body << ", " << param.getName(); + body << ", std::move(" << param.getName() << ")"; } static SmallVector diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index e72bfe9d82e7..9f8abdb5096e 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -462,8 +462,9 @@ TEST(SubElementTest, Nested) { {strAttr, trueAttr, falseAttr, boolArrayAttr, dictAttr})); } -// Test how many times we call copy-ctor when building an attribute. -TEST(CopyCountAttr, CopyCount) { +// Test how many times we call copy-ctor when building an attribute with the +// 'get' method. +TEST(CopyCountAttr, CopyCountGet) { MLIRContext context; context.loadDialect(); @@ -483,6 +484,23 @@ TEST(CopyCountAttr, CopyCount) { #endif } +// Test how many times we call copy-ctor when building an attribute with the +// 'getChecked' method. +TEST(CopyCountAttr, CopyCountGetChecked) { + MLIRContext context; + context.loadDialect(); + test::CopyCount::counter = 0; + test::CopyCount copyCount("hello"); + auto loc = UnknownLoc::get(&context); + test::TestCopyCountAttr::getChecked(loc, &context, std::move(copyCount)); + int counter1 = test::CopyCount::counter; + test::CopyCount::counter = 0; + test::TestCopyCountAttr::getChecked(loc, &context, std::move(copyCount)); + // One verification requires a copy. + EXPECT_EQ(counter1, 1); + EXPECT_EQ(test::CopyCount::counter, 1); +} + // Test stripped printing using test dialect attribute. TEST(CopyCountAttr, PrintStripped) { MLIRContext context;