diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp index ca052392f2f5f..65592a5c5d698 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp @@ -12,6 +12,7 @@ #include "TestDenseDataFlowAnalysis.h" #include "TestDialect.h" +#include "TestOps.h" #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/DenseAnalysis.h" diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp index 29480f5ad63ee..3f9ce2dc0bc50 100644 --- a/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseForwardDataFlowAnalysis.cpp @@ -12,6 +12,7 @@ #include "TestDenseDataFlowAnalysis.h" #include "TestDialect.h" +#include "TestOps.h" #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/DenseAnalysis.h" diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp index 5e17779660f39..f878a262512ee 100644 --- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp +++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertCallOp.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestOps.h" #include "TestTypes.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" diff --git a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp index 3c4067b35d8e5..cc1af59c5e15b 100644 --- a/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp +++ b/mlir/test/lib/Conversion/OneToNTypeConversion/TestOneToNTypeConversionPass.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestOps.h" #include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp index b098a5a23fd31..34513cd418e4c 100644 --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestOps.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Affine/Transforms/Transforms.h" diff --git a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp index 84f45b3160319..56f309f150ca5 100644 --- a/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp +++ b/mlir/test/lib/Dialect/DLTI/TestDataLayoutQuery.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "TestDialect.h" +#include "TestOps.h" #include "mlir/Analysis/DataLayoutAnalysis.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/IR/BuiltinAttributes.h" diff --git a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp index 10aba733bd569..0d7dce2240f4c 100644 --- a/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp +++ b/mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/DecomposeCallGraphTypes.h" #include "mlir/IR/Builders.h" diff --git a/mlir/test/lib/Dialect/Test/CMakeLists.txt b/mlir/test/lib/Dialect/Test/CMakeLists.txt index d246c0492a3bd..f63e4d330e6ac 100644 --- a/mlir/test/lib/Dialect/Test/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Test/CMakeLists.txt @@ -47,7 +47,10 @@ add_public_tablegen_target(MLIRTestOpsSyntaxIncGen) add_mlir_library(MLIRTestDialect TestAttributes.cpp TestDialect.cpp + TestFormatUtils.cpp TestInterfaces.cpp + TestOpDefs.cpp + TestOps.cpp TestPatterns.cpp TestTraits.cpp TestTypes.cpp diff --git a/mlir/test/lib/Dialect/Test/TestAttributes.cpp b/mlir/test/lib/Dialect/Test/TestAttributes.cpp index d41d495c38e55..2cc051e664bee 100644 --- a/mlir/test/lib/Dialect/Test/TestAttributes.cpp +++ b/mlir/test/lib/Dialect/Test/TestAttributes.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/Types.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" #include "llvm/Support/ErrorHandling.h" @@ -244,7 +245,7 @@ static void printConditionalAlias(AsmPrinter &p, StringAttr value) { //===----------------------------------------------------------------------===// #include "TestAttrInterfaces.cpp.inc" - +#include "TestOpEnums.cpp.inc" #define GET_ATTRDEF_CLASSES #include "TestAttrDefs.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index a23ed89c4b04d..77fd7e61bd3a0 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -7,8 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" -#include "TestAttributes.h" -#include "TestInterfaces.h" +#include "TestOps.h" #include "TestTypes.h" #include "mlir/Bytecode/BytecodeImplementation.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -39,17 +38,85 @@ #include "llvm/Support/Base64.h" #include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Interfaces/FoldInterfaces.h" +#include "mlir/Reducer/ReductionPatternInterface.h" +#include "mlir/Transforms/InliningUtils.h" #include #include #include -// Include this before the using namespace lines below to -// test that we don't have namespace dependencies. +// Include this before the using namespace lines below to test that we don't +// have namespace dependencies. #include "TestOpsDialect.cpp.inc" using namespace mlir; using namespace test; +//===----------------------------------------------------------------------===// +// PropertiesWithCustomPrint +//===----------------------------------------------------------------------===// + +LogicalResult +test::setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, + Attribute attr, + function_ref emitError) { + DictionaryAttr dict = dyn_cast(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set TestProperties"; + return failure(); + } + auto label = dict.getAs("label"); + if (!label) { + emitError() << "expected StringAttr for key `label`"; + return failure(); + } + auto valueAttr = dict.getAs("value"); + if (!valueAttr) { + emitError() << "expected IntegerAttr for key `value`"; + return failure(); + } + + prop.label = std::make_shared(label.getValue()); + prop.value = valueAttr.getValue().getSExtValue(); + return success(); +} + +DictionaryAttr +test::getPropertiesAsAttribute(MLIRContext *ctx, + const PropertiesWithCustomPrint &prop) { + SmallVector attrs; + Builder b{ctx}; + attrs.push_back(b.getNamedAttr("label", b.getStringAttr(*prop.label))); + attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value))); + return b.getDictionaryAttr(attrs); +} + +llvm::hash_code test::computeHash(const PropertiesWithCustomPrint &prop) { + return llvm::hash_combine(prop.value, StringRef(*prop.label)); +} + +void test::customPrintProperties(OpAsmPrinter &p, + const PropertiesWithCustomPrint &prop) { + p.printKeywordOrString(*prop.label); + p << " is " << prop.value; +} + +ParseResult test::customParseProperties(OpAsmParser &parser, + PropertiesWithCustomPrint &prop) { + std::string label; + if (parser.parseKeywordOrString(&label) || parser.parseKeyword("is") || + parser.parseInteger(prop.value)) + return failure(); + prop.label = std::make_shared(std::move(label)); + return success(); +} + +//===----------------------------------------------------------------------===// +// MyPropStruct +//===----------------------------------------------------------------------===// + Attribute MyPropStruct::asAttribute(MLIRContext *ctx) const { return StringAttr::get(ctx, content); } @@ -70,8 +137,8 @@ llvm::hash_code MyPropStruct::hash() const { return hash_value(StringRef(content)); } -static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader, - MyPropStruct &prop) { +LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader, + MyPropStruct &prop) { StringRef str; if (failed(reader.readString(str))) return failure(); @@ -79,13 +146,71 @@ static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader, return success(); } -static void writeToMlirBytecode(::mlir::DialectBytecodeWriter &writer, - MyPropStruct &prop) { +void test::writeToMlirBytecode(DialectBytecodeWriter &writer, + MyPropStruct &prop) { writer.writeOwnedString(prop.content); } -static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader, - MutableArrayRef prop) { +//===----------------------------------------------------------------------===// +// VersionedProperties +//===----------------------------------------------------------------------===// + +LogicalResult +test::setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr, + function_ref emitError) { + DictionaryAttr dict = dyn_cast(attr); + if (!dict) { + emitError() << "expected DictionaryAttr to set VersionedProperties"; + return failure(); + } + auto value1Attr = dict.getAs("value1"); + if (!value1Attr) { + emitError() << "expected IntegerAttr for key `value1`"; + return failure(); + } + auto value2Attr = dict.getAs("value2"); + if (!value2Attr) { + emitError() << "expected IntegerAttr for key `value2`"; + return failure(); + } + + prop.value1 = value1Attr.getValue().getSExtValue(); + prop.value2 = value2Attr.getValue().getSExtValue(); + return success(); +} + +DictionaryAttr test::getPropertiesAsAttribute(MLIRContext *ctx, + const VersionedProperties &prop) { + SmallVector attrs; + Builder b{ctx}; + attrs.push_back(b.getNamedAttr("value1", b.getI32IntegerAttr(prop.value1))); + attrs.push_back(b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2))); + return b.getDictionaryAttr(attrs); +} + +llvm::hash_code test::computeHash(const VersionedProperties &prop) { + return llvm::hash_combine(prop.value1, prop.value2); +} + +void test::customPrintProperties(OpAsmPrinter &p, + const VersionedProperties &prop) { + p << prop.value1 << " | " << prop.value2; +} + +ParseResult test::customParseProperties(OpAsmParser &parser, + VersionedProperties &prop) { + if (parser.parseInteger(prop.value1) || parser.parseVerticalBar() || + parser.parseInteger(prop.value2)) + return failure(); + return success(); +} + +//===----------------------------------------------------------------------===// +// Bytecode Support +//===----------------------------------------------------------------------===// + +LogicalResult test::readFromMlirBytecode(DialectBytecodeReader &reader, + MutableArrayRef prop) { uint64_t size; if (failed(reader.readVarInt(size))) return failure(); @@ -101,45 +226,13 @@ static LogicalResult readFromMlirBytecode(DialectBytecodeReader &reader, return success(); } -static void writeToMlirBytecode(::mlir::DialectBytecodeWriter &writer, - ArrayRef prop) { +void test::writeToMlirBytecode(DialectBytecodeWriter &writer, + ArrayRef prop) { writer.writeVarInt(prop.size()); for (auto elt : prop) writer.writeVarInt(elt); } -static LogicalResult -setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr, - function_ref emitError); -static DictionaryAttr -getPropertiesAsAttribute(MLIRContext *ctx, - const PropertiesWithCustomPrint &prop); -static llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop); -static void customPrintProperties(OpAsmPrinter &p, - const PropertiesWithCustomPrint &prop); -static ParseResult customParseProperties(OpAsmParser &parser, - PropertiesWithCustomPrint &prop); -static LogicalResult -setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr, - function_ref emitError); -static DictionaryAttr getPropertiesAsAttribute(MLIRContext *ctx, - const VersionedProperties &prop); -static llvm::hash_code computeHash(const VersionedProperties &prop); -static void customPrintProperties(OpAsmPrinter &p, - const VersionedProperties &prop); -static ParseResult customParseProperties(OpAsmParser &parser, - VersionedProperties &prop); -static ParseResult -parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, - SmallVectorImpl> &caseRegions); - -static void printSwitchCases(OpAsmPrinter &p, Operation *op, - DenseI64ArrayAttr cases, RegionRange caseRegions); - -void test::registerTestDialect(DialectRegistry ®istry) { - registry.insert(); -} - //===----------------------------------------------------------------------===// // Dynamic operations //===----------------------------------------------------------------------===// @@ -196,9 +289,20 @@ getDynamicCustomParserPrinterOp(TestDialect *dialect) { // TestDialect //===----------------------------------------------------------------------===// -static void testSideEffectOpGetEffect( +void test::registerTestDialect(DialectRegistry ®istry) { + registry.insert(); +} + +void test::testSideEffectOpGetEffect( Operation *op, - SmallVectorImpl> &effects); + SmallVectorImpl> + &effects) { + auto effectsAttr = op->getAttrOfType("effect_parameter"); + if (!effectsAttr) + return; + + effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); +} // This is the implementation of a dialect fallback for `TestEffectOpInterface`. struct TestOpEffectInterfaceFallback @@ -318,57 +422,6 @@ TestDialect::getOperationPrinter(Operation *op) const { return {}; } -//===----------------------------------------------------------------------===// -// TypedAttrOp -//===----------------------------------------------------------------------===// - -/// Parse an attribute with a given type. -static ParseResult parseAttrElideType(AsmParser &parser, TypeAttr type, - Attribute &attr) { - return parser.parseAttribute(attr, type.getValue()); -} - -/// Print an attribute without its type. -static void printAttrElideType(AsmPrinter &printer, Operation *op, - TypeAttr type, Attribute attr) { - printer.printAttributeWithoutType(attr); -} - -//===----------------------------------------------------------------------===// -// TestBranchOp -//===----------------------------------------------------------------------===// - -SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { - assert(index == 0 && "invalid successor index"); - return SuccessorOperands(getTargetOperandsMutable()); -} - -//===----------------------------------------------------------------------===// -// TestProducingBranchOp -//===----------------------------------------------------------------------===// - -SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { - assert(index <= 1 && "invalid successor index"); - if (index == 1) - return SuccessorOperands(getFirstOperandsMutable()); - return SuccessorOperands(getSecondOperandsMutable()); -} - -//===----------------------------------------------------------------------===// -// TestProducingBranchOp -//===----------------------------------------------------------------------===// - -SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { - assert(index <= 1 && "invalid successor index"); - if (index == 0) - return SuccessorOperands(0, getSuccessOperandsMutable()); - return SuccessorOperands(1, getErrorOperandsMutable()); -} - -//===----------------------------------------------------------------------===// -// TestDialectCanonicalizerOp -//===----------------------------------------------------------------------===// - static LogicalResult dialectCanonicalizationPattern(TestDialectCanonicalizerOp op, PatternRewriter &rewriter) { @@ -381,1206 +434,3 @@ void TestDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { results.add(&dialectCanonicalizationPattern); } - -//===----------------------------------------------------------------------===// -// TestCallOp -//===----------------------------------------------------------------------===// - -LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - // Check that the callee attribute was specified. - auto fnAttr = (*this)->getAttrOfType("callee"); - if (!fnAttr) - return emitOpError("requires a 'callee' symbol reference attribute"); - if (!symbolTable.lookupNearestSymbolFrom(*this, fnAttr)) - return emitOpError() << "'" << fnAttr.getValue() - << "' does not reference a valid function"; - return success(); -} - -//===----------------------------------------------------------------------===// -// ConversionFuncOp -//===----------------------------------------------------------------------===// - -ParseResult ConversionFuncOp::parse(OpAsmParser &parser, - OperationState &result) { - auto buildFuncType = - [](Builder &builder, ArrayRef argTypes, ArrayRef results, - function_interface_impl::VariadicFlag, - std::string &) { return builder.getFunctionType(argTypes, results); }; - - return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType, - getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); -} - -void ConversionFuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp( - p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), - getArgAttrsAttrName(), getResAttrsAttrName()); -} - -//===----------------------------------------------------------------------===// -// TestFoldToCallOp -//===----------------------------------------------------------------------===// - -namespace { -struct FoldToCallOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(FoldToCallOp op, - PatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, TypeRange(), - op.getCalleeAttr(), ValueRange()); - return success(); - } -}; -} // namespace - -void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -//===----------------------------------------------------------------------===// -// Test IsolatedRegionOp - parse passthrough region arguments. -//===----------------------------------------------------------------------===// - -ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, - OperationState &result) { - // Parse the input operand. - OpAsmParser::Argument argInfo; - argInfo.type = parser.getBuilder().getIndexType(); - if (parser.parseOperand(argInfo.ssaName) || - parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands)) - return failure(); - - // Parse the body region, and reuse the operand info as the argument info. - Region *body = result.addRegion(); - return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true); -} - -void IsolatedRegionOp::print(OpAsmPrinter &p) { - p << ' '; - p.printOperand(getOperand()); - p.shadowRegionArgs(getRegion(), getOperand()); - p << ' '; - p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); -} - -//===----------------------------------------------------------------------===// -// Test SSACFGRegionOp -//===----------------------------------------------------------------------===// - -RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { - return RegionKind::SSACFG; -} - -//===----------------------------------------------------------------------===// -// Test GraphRegionOp -//===----------------------------------------------------------------------===// - -RegionKind GraphRegionOp::getRegionKind(unsigned index) { - return RegionKind::Graph; -} - -//===----------------------------------------------------------------------===// -// Test AffineScopeOp -//===----------------------------------------------------------------------===// - -ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { - // Parse the body region, and reuse the operand info as the argument info. - Region *body = result.addRegion(); - return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); -} - -void AffineScopeOp::print(OpAsmPrinter &p) { - p << " "; - p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); -} - -//===----------------------------------------------------------------------===// -// Test OptionalCustomAttrOp -//===----------------------------------------------------------------------===// - -static OptionalParseResult parseOptionalCustomParser(AsmParser &p, - IntegerAttr &result) { - if (succeeded(p.parseOptionalKeyword("foo"))) - return p.parseAttribute(result); - return {}; -} - -static void printOptionalCustomParser(AsmPrinter &p, Operation *, - IntegerAttr result) { - p << "foo "; - p.printAttribute(result); -} - -//===----------------------------------------------------------------------===// -// ReifyBoundOp -//===----------------------------------------------------------------------===// - -::mlir::presburger::BoundType ReifyBoundOp::getBoundType() { - if (getType() == "EQ") - return ::mlir::presburger::BoundType::EQ; - if (getType() == "LB") - return ::mlir::presburger::BoundType::LB; - if (getType() == "UB") - return ::mlir::presburger::BoundType::UB; - llvm_unreachable("invalid bound type"); -} - -LogicalResult ReifyBoundOp::verify() { - if (isa(getVar().getType())) { - if (!getDim().has_value()) - return emitOpError("expected 'dim' attribute for shaped type variable"); - } else if (getVar().getType().isIndex()) { - if (getDim().has_value()) - return emitOpError("unexpected 'dim' attribute for index variable"); - } else { - return emitOpError("expected index-typed variable or shape type variable"); - } - if (getConstant() && getScalable()) - return emitOpError("'scalable' and 'constant' are mutually exlusive"); - if (getScalable() != getVscaleMin().has_value()) - return emitOpError("expected 'vscale_min' if and only if 'scalable'"); - if (getScalable() != getVscaleMax().has_value()) - return emitOpError("expected 'vscale_min' if and only if 'scalable'"); - return success(); -} - -::mlir::ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() { - if (getDim().has_value()) - return ValueBoundsConstraintSet::Variable(getVar(), *getDim()); - return ValueBoundsConstraintSet::Variable(getVar()); -} - -::mlir::ValueBoundsConstraintSet::ComparisonOperator -CompareOp::getComparisonOperator() { - if (getCmp() == "EQ") - return ValueBoundsConstraintSet::ComparisonOperator::EQ; - if (getCmp() == "LT") - return ValueBoundsConstraintSet::ComparisonOperator::LT; - if (getCmp() == "LE") - return ValueBoundsConstraintSet::ComparisonOperator::LE; - if (getCmp() == "GT") - return ValueBoundsConstraintSet::ComparisonOperator::GT; - if (getCmp() == "GE") - return ValueBoundsConstraintSet::ComparisonOperator::GE; - llvm_unreachable("invalid comparison operator"); -} - -::mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() { - if (!getLhsMap()) - return ValueBoundsConstraintSet::Variable(getVarOperands()[0]); - SmallVector mapOperands( - getVarOperands().slice(0, getLhsMap()->getNumInputs())); - return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands); -} - -::mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() { - int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1; - if (!getRhsMap()) - return ValueBoundsConstraintSet::Variable( - getVarOperands()[rhsOperandsBegin]); - SmallVector mapOperands( - getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs())); - return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands); -} - -LogicalResult CompareOp::verify() { - if (getCompose() && (getLhsMap() || getRhsMap())) - return emitOpError( - "'compose' not supported when 'lhs_map' or 'rhs_map' is present"); - int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1; - expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1; - if (getVarOperands().size() != size_t(expectedNumOperands)) - return emitOpError("expected ") - << expectedNumOperands << " operands, but got " - << getVarOperands().size(); - return success(); -} - -//===----------------------------------------------------------------------===// -// Test removing op with inner ops. -//===----------------------------------------------------------------------===// - -namespace { -struct TestRemoveOpWithInnerOps - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } - - LogicalResult matchAndRewrite(TestOpWithRegionPattern op, - PatternRewriter &rewriter) const override { - rewriter.eraseOp(op); - return success(); - } -}; -} // namespace - -void TestOpWithRegionPattern::getCanonicalizationPatterns( - RewritePatternSet &results, MLIRContext *context) { - results.add(context); -} - -OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) { - return getOperand(); -} - -OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); } - -LogicalResult TestOpWithVariadicResultsAndFolder::fold( - FoldAdaptor adaptor, SmallVectorImpl &results) { - for (Value input : this->getOperands()) { - results.push_back(input); - } - return success(); -} - -OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) { - // Exercise the fact that an operation created with createOrFold should be - // allowed to access its parent block. - assert(getOperation()->getBlock() && - "expected that operation is not unlinked"); - - if (adaptor.getOp() && !getProperties().attr) { - // The folder adds "attr" if not present. - getProperties().attr = dyn_cast_or_null(adaptor.getOp()); - return getResult(); - } - return {}; -} - -OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) { - int64_t sum = 0; - if (auto value = dyn_cast_or_null(adaptor.getOp())) - sum += value.getValue().getSExtValue(); - - for (Attribute attr : adaptor.getVariadic()) - if (auto value = dyn_cast_or_null(attr)) - sum += 2 * value.getValue().getSExtValue(); - - for (ArrayRef attrs : adaptor.getVarOfVar()) - for (Attribute attr : attrs) - if (auto value = dyn_cast_or_null(attr)) - sum += 3 * value.getValue().getSExtValue(); - - sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end()); - - return IntegerAttr::get(getType(), sum); -} - -LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( - MLIRContext *, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - if (operands[0].getType() != operands[1].getType()) { - return emitOptionalError(location, "operand type mismatch ", - operands[0].getType(), " vs ", - operands[1].getType()); - } - inferredReturnTypes.assign({operands[0].getType()}); - return success(); -} - -LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes( - MLIRContext *, std::optional location, - OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor, - SmallVectorImpl &inferredReturnTypes) { - if (adaptor.getX().getType() != adaptor.getY().getType()) { - return emitOptionalError(location, "operand type mismatch ", - adaptor.getX().getType(), " vs ", - adaptor.getY().getType()); - } - inferredReturnTypes.assign({adaptor.getX().getType()}); - return success(); -} - -// TODO: We should be able to only define either inferReturnType or -// refineReturnType, currently only refineReturnType can be omitted. -LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes( - MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &returnTypes) { - returnTypes.clear(); - return OpWithRefineTypeInterfaceOp::refineReturnTypes( - context, location, operands, attributes, properties, regions, - returnTypes); -} - -LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes( - MLIRContext *, std::optional location, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &returnTypes) { - if (operands[0].getType() != operands[1].getType()) { - return emitOptionalError(location, "operand type mismatch ", - operands[0].getType(), " vs ", - operands[1].getType()); - } - // TODO: Add helper to make this more concise to write. - if (returnTypes.empty()) - returnTypes.resize(1, nullptr); - if (returnTypes[0] && returnTypes[0] != operands[0].getType()) - return emitOptionalError(location, - "required first operand and result to match"); - returnTypes[0] = operands[0].getType(); - return success(); -} - -LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( - MLIRContext *context, std::optional location, - ValueShapeRange operands, DictionaryAttr attributes, - OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnShapes) { - // Create return type consisting of the last element of the first operand. - auto operandType = operands.front().getType(); - auto sval = dyn_cast(operandType); - if (!sval) - return emitOptionalError(location, "only shaped type operands allowed"); - int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; - auto type = IntegerType::get(context, 17); - - Attribute encoding; - if (auto rankedTy = dyn_cast(sval)) - encoding = rankedTy.getEncoding(); - inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); - return success(); -} - -LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( - OpBuilder &builder, ValueRange operands, - llvm::SmallVectorImpl &shapes) { - shapes = SmallVector{ - builder.createOrFold(getLoc(), operands.front(), 0)}; - return success(); -} - -LogicalResult -OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents( - MLIRContext *context, std::optional location, - OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor, - SmallVectorImpl &inferredReturnShapes) { - // Create return type consisting of the last element of the first operand. - auto operandType = adaptor.getOperand1().getType(); - auto sval = dyn_cast(operandType); - if (!sval) - return emitOptionalError(location, "only shaped type operands allowed"); - int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; - auto type = IntegerType::get(context, 17); - - Attribute encoding; - if (auto rankedTy = dyn_cast(sval)) - encoding = rankedTy.getEncoding(); - inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); - return success(); -} - -LogicalResult -OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes( - OpBuilder &builder, ValueRange operands, - llvm::SmallVectorImpl &shapes) { - shapes = SmallVector{ - builder.createOrFold(getLoc(), operands.front(), 0)}; - return success(); -} - -LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( - OpBuilder &builder, ValueRange operands, - llvm::SmallVectorImpl &shapes) { - Location loc = getLoc(); - shapes.reserve(operands.size()); - for (Value operand : llvm::reverse(operands)) { - auto rank = cast(operand.getType()).getRank(); - auto currShape = llvm::to_vector<4>( - llvm::map_range(llvm::seq(0, rank), [&](int64_t dim) -> Value { - return builder.createOrFold(loc, operand, dim); - })); - shapes.push_back(builder.create( - getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), - currShape)); - } - return success(); -} - -LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( - OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { - Location loc = getLoc(); - shapes.reserve(getNumOperands()); - for (Value operand : llvm::reverse(getOperands())) { - auto tensorType = cast(operand.getType()); - auto currShape = llvm::to_vector<4>(llvm::map_range( - llvm::seq(0, tensorType.getRank()), - [&](int64_t dim) -> OpFoldResult { - return tensorType.isDynamicDim(dim) - ? static_cast( - builder.createOrFold(loc, operand, - dim)) - : static_cast( - builder.getIndexAttr(tensorType.getDimSize(dim))); - })); - shapes.emplace_back(std::move(currShape)); - } - return success(); -} - -LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes( - MLIRContext *context, std::optional, ValueRange operands, - DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, - SmallVectorImpl &inferredReturnTypes) { - - Adaptor adaptor(operands, attributes, properties, regions); - inferredReturnTypes.push_back(IntegerType::get( - context, adaptor.getLhs() + adaptor.getProperties().rhs)); - return success(); -} - -//===----------------------------------------------------------------------===// -// Test SideEffect interfaces -//===----------------------------------------------------------------------===// - -namespace { -/// A test resource for side effects. -struct TestResource : public SideEffects::Resource::Base { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) - - StringRef getName() final { return ""; } -}; -} // namespace - -static void testSideEffectOpGetEffect( - Operation *op, - SmallVectorImpl> - &effects) { - auto effectsAttr = op->getAttrOfType("effect_parameter"); - if (!effectsAttr) - return; - - effects.emplace_back(TestEffects::Concrete::get(), effectsAttr); -} - -void SideEffectOp::getEffects( - SmallVectorImpl &effects) { - // Check for an effects attribute on the op instance. - ArrayAttr effectsAttr = (*this)->getAttrOfType("effects"); - if (!effectsAttr) - return; - - // If there is one, it is an array of dictionary attributes that hold - // information on the effects of this operation. - for (Attribute element : effectsAttr) { - DictionaryAttr effectElement = cast(element); - - // Get the specific memory effect. - MemoryEffects::Effect *effect = - StringSwitch( - cast(effectElement.get("effect")).getValue()) - .Case("allocate", MemoryEffects::Allocate::get()) - .Case("free", MemoryEffects::Free::get()) - .Case("read", MemoryEffects::Read::get()) - .Case("write", MemoryEffects::Write::get()); - - // Check for a non-default resource to use. - SideEffects::Resource *resource = SideEffects::DefaultResource::get(); - if (effectElement.get("test_resource")) - resource = TestResource::get(); - - // Check for a result to affect. - if (effectElement.get("on_result")) - effects.emplace_back(effect, getResult(), resource); - else if (Attribute ref = effectElement.get("on_reference")) - effects.emplace_back(effect, cast(ref), resource); - else - effects.emplace_back(effect, resource); - } -} - -void SideEffectOp::getEffects( - SmallVectorImpl &effects) { - testSideEffectOpGetEffect(getOperation(), effects); -} - -//===----------------------------------------------------------------------===// -// StringAttrPrettyNameOp -//===----------------------------------------------------------------------===// - -// This op has fancy handling of its SSA result name. -ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, - OperationState &result) { - // Add the result types. - for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) - result.addTypes(parser.getBuilder().getIntegerType(32)); - - if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) - return failure(); - - // If the attribute dictionary contains no 'names' attribute, infer it from - // the SSA name (if specified). - bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { - return attr.getName() == "names"; - }); - - // If there was no name specified, check to see if there was a useful name - // specified in the asm file. - if (hadNames || parser.getNumResults() == 0) - return success(); - - SmallVector names; - auto *context = result.getContext(); - - for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { - auto resultName = parser.getResultName(i); - StringRef nameStr; - if (!resultName.first.empty() && !isdigit(resultName.first[0])) - nameStr = resultName.first; - - names.push_back(nameStr); - } - - auto namesAttr = parser.getBuilder().getStrArrayAttr(names); - result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); - return success(); -} - -void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { - // Note that we only need to print the "name" attribute if the asmprinter - // result name disagrees with it. This can happen in strange cases, e.g. - // when there are conflicts. - bool namesDisagree = getNames().size() != getNumResults(); - - SmallString<32> resultNameStr; - for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { - resultNameStr.clear(); - llvm::raw_svector_ostream tmpStream(resultNameStr); - p.printOperand(getResult(i), tmpStream); - - auto expectedName = dyn_cast(getNames()[i]); - if (!expectedName || - tmpStream.str().drop_front() != expectedName.getValue()) { - namesDisagree = true; - } - } - - if (namesDisagree) - p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); - else - p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"}); -} - -// We set the SSA name in the asm syntax to the contents of the name -// attribute. -void StringAttrPrettyNameOp::getAsmResultNames( - function_ref setNameFn) { - - auto value = getNames(); - for (size_t i = 0, e = value.size(); i != e; ++i) - if (auto str = dyn_cast(value[i])) - if (!str.getValue().empty()) - setNameFn(getResult(i), str.getValue()); -} - -void CustomResultsNameOp::getAsmResultNames( - function_ref setNameFn) { - ArrayAttr value = getNames(); - for (size_t i = 0, e = value.size(); i != e; ++i) - if (auto str = dyn_cast(value[i])) - if (!str.empty()) - setNameFn(getResult(i), str.getValue()); -} - -//===----------------------------------------------------------------------===// -// ResultTypeWithTraitOp -//===----------------------------------------------------------------------===// - -LogicalResult ResultTypeWithTraitOp::verify() { - if ((*this)->getResultTypes()[0].hasTrait()) - return success(); - return emitError("result type should have trait 'TestTypeTrait'"); -} - -//===----------------------------------------------------------------------===// -// AttrWithTraitOp -//===----------------------------------------------------------------------===// - -LogicalResult AttrWithTraitOp::verify() { - if (getAttr().hasTrait()) - return success(); - return emitError("'attr' attribute should have trait 'TestAttrTrait'"); -} - -//===----------------------------------------------------------------------===// -// RegionIfOp -//===----------------------------------------------------------------------===// - -void RegionIfOp::print(OpAsmPrinter &p) { - p << " "; - p.printOperands(getOperands()); - p << ": " << getOperandTypes(); - p.printArrowTypeList(getResultTypes()); - p << " then "; - p.printRegion(getThenRegion(), - /*printEntryBlockArgs=*/true, - /*printBlockTerminators=*/true); - p << " else "; - p.printRegion(getElseRegion(), - /*printEntryBlockArgs=*/true, - /*printBlockTerminators=*/true); - p << " join "; - p.printRegion(getJoinRegion(), - /*printEntryBlockArgs=*/true, - /*printBlockTerminators=*/true); -} - -ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { - SmallVector operandInfos; - SmallVector operandTypes; - - result.regions.reserve(3); - Region *thenRegion = result.addRegion(); - Region *elseRegion = result.addRegion(); - Region *joinRegion = result.addRegion(); - - // Parse operand, type and arrow type lists. - if (parser.parseOperandList(operandInfos) || - parser.parseColonTypeList(operandTypes) || - parser.parseArrowTypeList(result.types)) - return failure(); - - // Parse all attached regions. - if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || - parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || - parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) - return failure(); - - return parser.resolveOperands(operandInfos, operandTypes, - parser.getCurrentLocation(), result.operands); -} - -OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) && - "invalid region index"); - return getOperands(); -} - -void RegionIfOp::getSuccessorRegions( - RegionBranchPoint point, SmallVectorImpl ®ions) { - // We always branch to the join region. - if (!point.isParent()) { - if (point != getJoinRegion()) - regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); - else - regions.push_back(RegionSuccessor(getResults())); - return; - } - - // The then and else regions are the entry regions of this op. - regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); - regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); -} - -void RegionIfOp::getRegionInvocationBounds( - ArrayRef operands, - SmallVectorImpl &invocationBounds) { - // Each region is invoked at most once. - invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); -} - -//===----------------------------------------------------------------------===// -// AnyCondOp -//===----------------------------------------------------------------------===// - -void AnyCondOp::getSuccessorRegions(RegionBranchPoint point, - SmallVectorImpl ®ions) { - // The parent op branches into the only region, and the region branches back - // to the parent op. - if (point.isParent()) - regions.emplace_back(&getRegion()); - else - regions.emplace_back(getResults()); -} - -void AnyCondOp::getRegionInvocationBounds( - ArrayRef operands, - SmallVectorImpl &invocationBounds) { - invocationBounds.emplace_back(1, 1); -} - -//===----------------------------------------------------------------------===// -// LoopBlockOp -//===----------------------------------------------------------------------===// - -void LoopBlockOp::getSuccessorRegions( - RegionBranchPoint point, SmallVectorImpl ®ions) { - regions.emplace_back(&getBody(), getBody().getArguments()); - if (point.isParent()) - return; - - regions.emplace_back((*this)->getResults()); -} - -OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { - assert(point == getBody()); - return MutableOperandRange(getInitMutable()); -} - -//===----------------------------------------------------------------------===// -// LoopBlockTerminatorOp -//===----------------------------------------------------------------------===// - -MutableOperandRange -LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) { - if (point.isParent()) - return getExitArgMutable(); - return getNextIterArgMutable(); -} - -//===----------------------------------------------------------------------===// -// SwitchWithNoBreakOp -//===----------------------------------------------------------------------===// - -void TestNoTerminatorOp::getSuccessorRegions( - RegionBranchPoint point, SmallVectorImpl ®ions) {} - -//===----------------------------------------------------------------------===// -// SingleNoTerminatorCustomAsmOp -//===----------------------------------------------------------------------===// - -ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, - OperationState &state) { - Region *body = state.addRegion(); - if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) - return failure(); - return success(); -} - -void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { - printer.printRegion( - getRegion(), /*printEntryBlockArgs=*/false, - // This op has a single block without terminators. But explicitly mark - // as not printing block terminators for testing. - /*printBlockTerminators=*/false); -} - -//===----------------------------------------------------------------------===// -// TestVerifiersOp -//===----------------------------------------------------------------------===// - -LogicalResult TestVerifiersOp::verify() { - if (!getRegion().hasOneBlock()) - return emitOpError("`hasOneBlock` trait hasn't been verified"); - - Operation *definingOp = getInput().getDefiningOp(); - if (definingOp && failed(mlir::verify(definingOp))) - return emitOpError("operand hasn't been verified"); - - // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier - // loop. - mlir::emitRemark(getLoc(), "success run of verifier"); - - return success(); -} - -LogicalResult TestVerifiersOp::verifyRegions() { - if (!getRegion().hasOneBlock()) - return emitOpError("`hasOneBlock` trait hasn't been verified"); - - for (Block &block : getRegion()) - for (Operation &op : block) - if (failed(mlir::verify(&op))) - return emitOpError("nested op hasn't been verified"); - - // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier - // loop. - mlir::emitRemark(getLoc(), "success run of region verifier"); - - return success(); -} - -//===----------------------------------------------------------------------===// -// Test InferIntRangeInterface -//===----------------------------------------------------------------------===// - -void TestWithBoundsOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRanges) { - setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()}); -} - -ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser, - OperationState &result) { - if (parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - // Parse the input argument - OpAsmParser::Argument argInfo; - argInfo.type = parser.getBuilder().getIndexType(); - if (failed(parser.parseArgument(argInfo))) - return failure(); - - // Parse the body region, and reuse the operand info as the argument info. - Region *body = result.addRegion(); - return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false); -} - -void TestWithBoundsRegionOp::print(OpAsmPrinter &p) { - p.printOptionalAttrDict((*this)->getAttrs()); - p << ' '; - p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{}, - /*omitType=*/true); - p << ' '; - p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); -} - -void TestWithBoundsRegionOp::inferResultRanges( - ArrayRef argRanges, SetIntRangeFn setResultRanges) { - Value arg = getRegion().getArgument(0); - setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()}); -} - -void TestIncrementOp::inferResultRanges(ArrayRef argRanges, - SetIntRangeFn setResultRanges) { - const ConstantIntRanges &range = argRanges[0]; - APInt one(range.umin().getBitWidth(), 1); - setResultRanges(getResult(), - {range.umin().uadd_sat(one), range.umax().uadd_sat(one), - range.smin().sadd_sat(one), range.smax().sadd_sat(one)}); -} - -void TestReflectBoundsOp::inferResultRanges( - ArrayRef argRanges, SetIntRangeFn setResultRanges) { - const ConstantIntRanges &range = argRanges[0]; - MLIRContext *ctx = getContext(); - Builder b(ctx); - setUminAttr(b.getIndexAttr(range.umin().getZExtValue())); - setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue())); - setSminAttr(b.getIndexAttr(range.smin().getSExtValue())); - setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue())); - setResultRanges(getResult(), range); -} - -OpFoldResult ManualCppOpWithFold::fold(ArrayRef attributes) { - // Just a simple fold for testing purposes that reads an operands constant - // value and returns it. - if (!attributes.empty()) - return attributes.front(); - return nullptr; -} - -static LogicalResult -setPropertiesFromAttribute(PropertiesWithCustomPrint &prop, Attribute attr, - function_ref emitError) { - DictionaryAttr dict = dyn_cast(attr); - if (!dict) { - emitError() << "expected DictionaryAttr to set TestProperties"; - return failure(); - } - auto label = dict.getAs("label"); - if (!label) { - emitError() << "expected StringAttr for key `label`"; - return failure(); - } - auto valueAttr = dict.getAs("value"); - if (!valueAttr) { - emitError() << "expected IntegerAttr for key `value`"; - return failure(); - } - - prop.label = std::make_shared(label.getValue()); - prop.value = valueAttr.getValue().getSExtValue(); - return success(); -} - -static DictionaryAttr -getPropertiesAsAttribute(MLIRContext *ctx, - const PropertiesWithCustomPrint &prop) { - SmallVector attrs; - Builder b{ctx}; - attrs.push_back(b.getNamedAttr("label", b.getStringAttr(*prop.label))); - attrs.push_back(b.getNamedAttr("value", b.getI32IntegerAttr(prop.value))); - return b.getDictionaryAttr(attrs); -} - -static llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop) { - return llvm::hash_combine(prop.value, StringRef(*prop.label)); -} - -static void customPrintProperties(OpAsmPrinter &p, - const PropertiesWithCustomPrint &prop) { - p.printKeywordOrString(*prop.label); - p << " is " << prop.value; -} - -static ParseResult customParseProperties(OpAsmParser &parser, - PropertiesWithCustomPrint &prop) { - std::string label; - if (parser.parseKeywordOrString(&label) || parser.parseKeyword("is") || - parser.parseInteger(prop.value)) - return failure(); - prop.label = std::make_shared(std::move(label)); - return success(); -} - -static ParseResult -parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, - SmallVectorImpl> &caseRegions) { - SmallVector caseValues; - while (succeeded(p.parseOptionalKeyword("case"))) { - int64_t value; - Region ®ion = *caseRegions.emplace_back(std::make_unique()); - if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{})) - return failure(); - caseValues.push_back(value); - } - cases = p.getBuilder().getDenseI64ArrayAttr(caseValues); - return success(); -} - -static void printSwitchCases(OpAsmPrinter &p, Operation *op, - DenseI64ArrayAttr cases, RegionRange caseRegions) { - for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) { - p.printNewline(); - p << "case " << value << ' '; - p.printRegion(*region, /*printEntryBlockArgs=*/false); - } -} - -static LogicalResult -setPropertiesFromAttribute(VersionedProperties &prop, Attribute attr, - function_ref emitError) { - DictionaryAttr dict = dyn_cast(attr); - if (!dict) { - emitError() << "expected DictionaryAttr to set VersionedProperties"; - return failure(); - } - auto value1Attr = dict.getAs("value1"); - if (!value1Attr) { - emitError() << "expected IntegerAttr for key `value1`"; - return failure(); - } - auto value2Attr = dict.getAs("value2"); - if (!value2Attr) { - emitError() << "expected IntegerAttr for key `value2`"; - return failure(); - } - - prop.value1 = value1Attr.getValue().getSExtValue(); - prop.value2 = value2Attr.getValue().getSExtValue(); - return success(); -} - -static DictionaryAttr -getPropertiesAsAttribute(MLIRContext *ctx, const VersionedProperties &prop) { - SmallVector attrs; - Builder b{ctx}; - attrs.push_back(b.getNamedAttr("value1", b.getI32IntegerAttr(prop.value1))); - attrs.push_back(b.getNamedAttr("value2", b.getI32IntegerAttr(prop.value2))); - return b.getDictionaryAttr(attrs); -} - -static llvm::hash_code computeHash(const VersionedProperties &prop) { - return llvm::hash_combine(prop.value1, prop.value2); -} - -static void customPrintProperties(OpAsmPrinter &p, - const VersionedProperties &prop) { - p << prop.value1 << " | " << prop.value2; -} - -static ParseResult customParseProperties(OpAsmParser &parser, - VersionedProperties &prop) { - if (parser.parseInteger(prop.value1) || parser.parseVerticalBar() || - parser.parseInteger(prop.value2)) - return failure(); - return success(); -} - -static bool parseUsingPropertyInCustom(OpAsmParser &parser, int64_t value[3]) { - return parser.parseLSquare() || parser.parseInteger(value[0]) || - parser.parseComma() || parser.parseInteger(value[1]) || - parser.parseComma() || parser.parseInteger(value[2]) || - parser.parseRSquare(); -} - -static void printUsingPropertyInCustom(OpAsmPrinter &printer, Operation *op, - ArrayRef value) { - printer << '[' << value << ']'; -} - -static bool parseIntProperty(OpAsmParser &parser, int64_t &value) { - return failed(parser.parseInteger(value)); -} - -static void printIntProperty(OpAsmPrinter &printer, Operation *op, - int64_t value) { - printer << value; -} - -static bool parseSumProperty(OpAsmParser &parser, int64_t &second, - int64_t first) { - int64_t sum; - auto loc = parser.getCurrentLocation(); - if (parser.parseInteger(second) || parser.parseEqual() || - parser.parseInteger(sum)) - return true; - if (sum != second + first) { - parser.emitError(loc, "Expected sum to equal first + second"); - return true; - } - return false; -} - -static void printSumProperty(OpAsmPrinter &printer, Operation *op, - int64_t second, int64_t first) { - printer << second << " = " << (second + first); -} - -//===----------------------------------------------------------------------===// -// Tensor/Buffer Ops -//===----------------------------------------------------------------------===// - -void ReadBufferOp::getEffects( - SmallVectorImpl> - &effects) { - // The buffer operand is read. - effects.emplace_back(MemoryEffects::Read::get(), getBuffer(), - SideEffects::DefaultResource::get()); - // The buffer contents are dumped. - effects.emplace_back(MemoryEffects::Write::get(), - SideEffects::DefaultResource::get()); -} - -//===----------------------------------------------------------------------===// -// Test Dataflow -//===----------------------------------------------------------------------===// - -CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() { - return getCallee(); -} - -void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) { - setCalleeAttr(callee.get()); -} - -Operation::operand_range TestCallAndStoreOp::getArgOperands() { - return getCalleeOperands(); -} - -MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() { - return getCalleeOperandsMutable(); -} - -CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() { - return getCallee(); -} - -void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) { - setCalleeAttr(callee.get()); -} - -Operation::operand_range TestCallOnDeviceOp::getArgOperands() { - return getForwardedOperands(); -} - -MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() { - return getForwardedOperandsMutable(); -} - -void TestStoreWithARegion::getSuccessorRegions( - RegionBranchPoint point, SmallVectorImpl ®ions) { - if (point.isParent()) - regions.emplace_back(&getBody(), getBody().front().getArguments()); - else - regions.emplace_back(); -} - -void TestStoreWithALoopRegion::getSuccessorRegions( - RegionBranchPoint point, SmallVectorImpl ®ions) { - // Both the operation itself and the region may be branching into the body or - // back into the operation itself. It is possible for the operation not to - // enter the body. - regions.emplace_back( - RegionSuccessor(&getBody(), getBody().front().getArguments())); - regions.emplace_back(); -} - -LogicalResult -TestVersionedOpA::readProperties(::mlir::DialectBytecodeReader &reader, - ::mlir::OperationState &state) { - auto &prop = state.getOrAddProperties(); - if (::mlir::failed(reader.readAttribute(prop.dims))) - return ::mlir::failure(); - - // Check if we have a version. If not, assume we are parsing the current - // version. - auto maybeVersion = reader.getDialectVersion(); - if (succeeded(maybeVersion)) { - // If version is less than 2.0, there is no additional attribute to parse. - // We can materialize missing properties post parsing before verification. - const auto *version = - reinterpret_cast(*maybeVersion); - if ((version->major_ < 2)) { - return success(); - } - } - - if (::mlir::failed(reader.readAttribute(prop.modifier))) - return ::mlir::failure(); - return ::mlir::success(); -} - -void TestVersionedOpA::writeProperties(::mlir::DialectBytecodeWriter &writer) { - auto &prop = getProperties(); - writer.writeAttribute(prop.dims); - - auto maybeVersion = writer.getDialectVersion(); - if (succeeded(maybeVersion)) { - // If version is less than 2.0, there is no additional attribute to write. - const auto *version = - reinterpret_cast(*maybeVersion); - if ((version->major_ < 2)) { - llvm::outs() << "downgrading op properties...\n"; - return; - } - } - writer.writeAttribute(prop.modifier); -} - -::mlir::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode( - ::mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) { - uint64_t value1, value2 = 0; - if (failed(reader.readVarInt(value1))) - return failure(); - - // Check if we have a version. If not, assume we are parsing the current - // version. - auto maybeVersion = reader.getDialectVersion(); - bool needToParseAnotherInt = true; - if (succeeded(maybeVersion)) { - // If version is less than 2.0, there is no additional attribute to parse. - // We can materialize missing properties post parsing before verification. - const auto *version = - reinterpret_cast(*maybeVersion); - if ((version->major_ < 2)) - needToParseAnotherInt = false; - } - if (needToParseAnotherInt && failed(reader.readVarInt(value2))) - return failure(); - - prop.value1 = value1; - prop.value2 = value2; - return success(); -} - -void TestOpWithVersionedProperties::writeToMlirBytecode( - ::mlir::DialectBytecodeWriter &writer, - const test::VersionedProperties &prop) { - writer.writeVarInt(prop.value1); - writer.writeVarInt(prop.value2); -} - -#include "TestOpEnums.cpp.inc" -#include "TestOpInterfaces.cpp.inc" -#include "TestTypeInterfaces.cpp.inc" - -#define GET_OP_CLASSES -#include "TestOps.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h index d5b2fbeafc410..c05e15fc642a2 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -43,19 +43,18 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/ADT/SetVector.h" #include namespace mlir { -class DLTIDialect; class RewritePatternSet; -} // namespace mlir +} // end namespace mlir //===----------------------------------------------------------------------===// // TestDialect //===----------------------------------------------------------------------===// -#include "TestOpInterfaces.h.inc" #include "TestOpsDialect.h.inc" namespace test { @@ -75,49 +74,8 @@ struct TestDialectVersion : public mlir::DialectVersion { uint32_t minor_ = 0; }; -// Define some classes to exercises the Properties feature. - -struct PropertiesWithCustomPrint { - /// A shared_ptr to a const object is safe: it is equivalent to a value-based - /// member. Here the label will be deallocated when the last operation - /// refering to it is destroyed. However there is no pool-allocation: this is - /// offloaded to the client. - std::shared_ptr label; - int value; - bool operator==(const PropertiesWithCustomPrint &rhs) const { - return value == rhs.value && *label == *rhs.label; - } -}; -class MyPropStruct { -public: - std::string content; - // These three methods are invoked through the `MyStructProperty` wrapper - // defined in TestOps.td - mlir::Attribute asAttribute(mlir::MLIRContext *ctx) const; - static mlir::LogicalResult - setFromAttr(MyPropStruct &prop, mlir::Attribute attr, - llvm::function_ref emitError); - llvm::hash_code hash() const; - bool operator==(const MyPropStruct &rhs) const { - return content == rhs.content; - } -}; -struct VersionedProperties { - // For the sake of testing, assume that this object was associated to version - // 1.2 of the test dialect when having only one int value. In the current - // version 2.0, the property has two values. We also assume that the class is - // upgrade-able if value2 = 0. - int value1; - int value2; - bool operator==(const VersionedProperties &rhs) const { - return value1 == rhs.value1 && value2 == rhs.value2; - } -}; } // namespace test -#define GET_OP_CLASSES -#include "TestOps.h.inc" - namespace test { // Op deliberately defined in C++ code rather than ODS to test that C++ @@ -138,6 +96,10 @@ class ManualCppOpWithFold void registerTestDialect(::mlir::DialectRegistry ®istry); void populateTestReductionPatterns(::mlir::RewritePatternSet &patterns); +void testSideEffectOpGetEffect( + mlir::Operation *op, + llvm::SmallVectorImpl< + mlir::SideEffects::EffectInstance> &effects); } // namespace test #endif // MLIR_TESTDIALECT_H diff --git a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp index 66578b246afab..a3a8913d5964c 100644 --- a/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestOps.h" #include "mlir/Interfaces/FoldInterfaces.h" #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Transforms/InliningUtils.h" diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp new file mode 100644 index 0000000000000..6e75dd3932281 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.cpp @@ -0,0 +1,377 @@ +//===- TestFormatUtils.cpp - MLIR Test Dialect Assembly Format Utilities --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestFormatUtils.h" +#include "mlir/IR/Builders.h" + +using namespace mlir; +using namespace test; + +//===----------------------------------------------------------------------===// +// CustomDirectiveOperands +//===----------------------------------------------------------------------===// + +ParseResult test::parseCustomDirectiveOperands( + OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, + std::optional &optOperand, + SmallVectorImpl &varOperands) { + if (parser.parseOperand(operand)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + optOperand.emplace(); + if (parser.parseOperand(*optOperand)) + return failure(); + } + if (parser.parseArrow() || parser.parseLParen() || + parser.parseOperandList(varOperands) || parser.parseRParen()) + return failure(); + return success(); +} + +void test::printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *, + Value operand, Value optOperand, + OperandRange varOperands) { + printer << operand; + if (optOperand) + printer << ", " << optOperand; + printer << " -> (" << varOperands << ")"; +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveResults +//===----------------------------------------------------------------------===// + +ParseResult +test::parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType, + Type &optOperandType, + SmallVectorImpl &varOperandTypes) { + if (parser.parseColon()) + return failure(); + + if (parser.parseType(operandType)) + return failure(); + if (succeeded(parser.parseOptionalComma())) + if (parser.parseType(optOperandType)) + return failure(); + if (parser.parseArrow() || parser.parseLParen() || + parser.parseTypeList(varOperandTypes) || parser.parseRParen()) + return failure(); + return success(); +} + +void test::printCustomDirectiveResults(OpAsmPrinter &printer, Operation *, + Type operandType, Type optOperandType, + TypeRange varOperandTypes) { + printer << " : " << operandType; + if (optOperandType) + printer << ", " << optOperandType; + printer << " -> (" << varOperandTypes << ")"; +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveWithTypeRefs +//===----------------------------------------------------------------------===// + +ParseResult test::parseCustomDirectiveWithTypeRefs( + OpAsmParser &parser, Type operandType, Type optOperandType, + const SmallVectorImpl &varOperandTypes) { + if (parser.parseKeyword("type_refs_capture")) + return failure(); + + Type operandType2, optOperandType2; + SmallVector varOperandTypes2; + if (parseCustomDirectiveResults(parser, operandType2, optOperandType2, + varOperandTypes2)) + return failure(); + + if (operandType != operandType2 || optOperandType != optOperandType2 || + varOperandTypes != varOperandTypes2) + return failure(); + + return success(); +} + +void test::printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer, + Operation *op, Type operandType, + Type optOperandType, + TypeRange varOperandTypes) { + printer << " type_refs_capture "; + printCustomDirectiveResults(printer, op, operandType, optOperandType, + varOperandTypes); +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveOperandsAndTypes +//===----------------------------------------------------------------------===// + +ParseResult test::parseCustomDirectiveOperandsAndTypes( + OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand, + std::optional &optOperand, + SmallVectorImpl &varOperands, + Type &operandType, Type &optOperandType, + SmallVectorImpl &varOperandTypes) { + if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) || + parseCustomDirectiveResults(parser, operandType, optOperandType, + varOperandTypes)) + return failure(); + return success(); +} + +void test::printCustomDirectiveOperandsAndTypes( + OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand, + OperandRange varOperands, Type operandType, Type optOperandType, + TypeRange varOperandTypes) { + printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands); + printCustomDirectiveResults(printer, op, operandType, optOperandType, + varOperandTypes); +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveRegions +//===----------------------------------------------------------------------===// + +ParseResult test::parseCustomDirectiveRegions( + OpAsmParser &parser, Region ®ion, + SmallVectorImpl> &varRegions) { + if (parser.parseRegion(region)) + return failure(); + if (failed(parser.parseOptionalComma())) + return success(); + std::unique_ptr varRegion = std::make_unique(); + if (parser.parseRegion(*varRegion)) + return failure(); + varRegions.emplace_back(std::move(varRegion)); + return success(); +} + +void test::printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *, + Region ®ion, + MutableArrayRef varRegions) { + printer.printRegion(region); + if (!varRegions.empty()) { + printer << ", "; + for (Region ®ion : varRegions) + printer.printRegion(region); + } +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveSuccessors +//===----------------------------------------------------------------------===// + +ParseResult +test::parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor, + SmallVectorImpl &varSuccessors) { + if (parser.parseSuccessor(successor)) + return failure(); + if (failed(parser.parseOptionalComma())) + return success(); + Block *varSuccessor; + if (parser.parseSuccessor(varSuccessor)) + return failure(); + varSuccessors.append(2, varSuccessor); + return success(); +} + +void test::printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *, + Block *successor, + SuccessorRange varSuccessors) { + printer << successor; + if (!varSuccessors.empty()) + printer << ", " << varSuccessors.front(); +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveAttributes +//===----------------------------------------------------------------------===// + +ParseResult test::parseCustomDirectiveAttributes(OpAsmParser &parser, + IntegerAttr &attr, + IntegerAttr &optAttr) { + if (parser.parseAttribute(attr)) + return failure(); + if (succeeded(parser.parseOptionalComma())) { + if (parser.parseAttribute(optAttr)) + return failure(); + } + return success(); +} + +void test::printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *, + Attribute attribute, + Attribute optAttribute) { + printer << attribute; + if (optAttribute) + printer << ", " << optAttribute; +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveAttrDict +//===----------------------------------------------------------------------===// + +ParseResult test::parseCustomDirectiveAttrDict(OpAsmParser &parser, + NamedAttrList &attrs) { + return parser.parseOptionalAttrDict(attrs); +} + +void test::printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op, + DictionaryAttr attrs) { + printer.printOptionalAttrDict(attrs.getValue()); +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveOptionalOperandRef +//===----------------------------------------------------------------------===// + +ParseResult test::parseCustomDirectiveOptionalOperandRef( + OpAsmParser &parser, + std::optional &optOperand) { + int64_t operandCount = 0; + if (parser.parseInteger(operandCount)) + return failure(); + bool expectedOptionalOperand = operandCount == 0; + return success(expectedOptionalOperand != !!optOperand); +} + +void test::printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer, + Operation *op, + Value optOperand) { + printer << (optOperand ? "1" : "0"); +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveOptionalOperand +//===----------------------------------------------------------------------===// + +ParseResult test::parseCustomOptionalOperand( + OpAsmParser &parser, + std::optional &optOperand) { + if (succeeded(parser.parseOptionalLParen())) { + optOperand.emplace(); + if (parser.parseOperand(*optOperand) || parser.parseRParen()) + return failure(); + } + return success(); +} + +void test::printCustomOptionalOperand(OpAsmPrinter &printer, Operation *, + Value optOperand) { + if (optOperand) + printer << "(" << optOperand << ") "; +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveSwitchCases +//===----------------------------------------------------------------------===// + +ParseResult +test::parseSwitchCases(OpAsmParser &p, DenseI64ArrayAttr &cases, + SmallVectorImpl> &caseRegions) { + SmallVector caseValues; + while (succeeded(p.parseOptionalKeyword("case"))) { + int64_t value; + Region ®ion = *caseRegions.emplace_back(std::make_unique()); + if (p.parseInteger(value) || p.parseRegion(region, /*arguments=*/{})) + return failure(); + caseValues.push_back(value); + } + cases = p.getBuilder().getDenseI64ArrayAttr(caseValues); + return success(); +} + +void test::printSwitchCases(OpAsmPrinter &p, Operation *op, + DenseI64ArrayAttr cases, RegionRange caseRegions) { + for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) { + p.printNewline(); + p << "case " << value << ' '; + p.printRegion(*region, /*printEntryBlockArgs=*/false); + } +} + +//===----------------------------------------------------------------------===// +// CustomUsingPropertyInCustom +//===----------------------------------------------------------------------===// + +bool test::parseUsingPropertyInCustom(OpAsmParser &parser, int64_t value[3]) { + return parser.parseLSquare() || parser.parseInteger(value[0]) || + parser.parseComma() || parser.parseInteger(value[1]) || + parser.parseComma() || parser.parseInteger(value[2]) || + parser.parseRSquare(); +} + +void test::printUsingPropertyInCustom(OpAsmPrinter &printer, Operation *op, + ArrayRef value) { + printer << '[' << value << ']'; +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveIntProperty +//===----------------------------------------------------------------------===// + +bool test::parseIntProperty(OpAsmParser &parser, int64_t &value) { + return failed(parser.parseInteger(value)); +} + +void test::printIntProperty(OpAsmPrinter &printer, Operation *op, + int64_t value) { + printer << value; +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveSumProperty +//===----------------------------------------------------------------------===// + +bool test::parseSumProperty(OpAsmParser &parser, int64_t &second, + int64_t first) { + int64_t sum; + auto loc = parser.getCurrentLocation(); + if (parser.parseInteger(second) || parser.parseEqual() || + parser.parseInteger(sum)) + return true; + if (sum != second + first) { + parser.emitError(loc, "Expected sum to equal first + second"); + return true; + } + return false; +} + +void test::printSumProperty(OpAsmPrinter &printer, Operation *op, + int64_t second, int64_t first) { + printer << second << " = " << (second + first); +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveOptionalCustomParser +//===----------------------------------------------------------------------===// + +OptionalParseResult test::parseOptionalCustomParser(AsmParser &p, + IntegerAttr &result) { + if (succeeded(p.parseOptionalKeyword("foo"))) + return p.parseAttribute(result); + return {}; +} + +void test::printOptionalCustomParser(AsmPrinter &p, Operation *, + IntegerAttr result) { + p << "foo "; + p.printAttribute(result); +} + +//===----------------------------------------------------------------------===// +// CustomDirectiveAttrElideType +//===----------------------------------------------------------------------===// + +ParseResult test::parseAttrElideType(AsmParser &parser, TypeAttr type, + Attribute &attr) { + return parser.parseAttribute(attr, type.getValue()); +} + +void test::printAttrElideType(AsmPrinter &printer, Operation *op, TypeAttr type, + Attribute attr) { + printer.printAttributeWithoutType(attr); +} diff --git a/mlir/test/lib/Dialect/Test/TestFormatUtils.h b/mlir/test/lib/Dialect/Test/TestFormatUtils.h new file mode 100644 index 0000000000000..7e9cd834278e3 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestFormatUtils.h @@ -0,0 +1,211 @@ +//===- TestFormatUtils.h - MLIR Test Dialect Assembly Format Utilities ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TESTFORMATUTILS_H +#define MLIR_TESTFORMATUTILS_H + +#include "mlir/IR/OpImplementation.h" + +namespace test { + +//===----------------------------------------------------------------------===// +// CustomDirectiveOperands +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomDirectiveOperands( + mlir::OpAsmParser &parser, mlir::OpAsmParser::UnresolvedOperand &operand, + std::optional &optOperand, + llvm::SmallVectorImpl &varOperands); + +void printCustomDirectiveOperands(mlir::OpAsmPrinter &printer, + mlir::Operation *, mlir::Value operand, + mlir::Value optOperand, + mlir::OperandRange varOperands); + +//===----------------------------------------------------------------------===// +// CustomDirectiveResults +//===----------------------------------------------------------------------===// + +mlir::ParseResult +parseCustomDirectiveResults(mlir::OpAsmParser &parser, mlir::Type &operandType, + mlir::Type &optOperandType, + llvm::SmallVectorImpl &varOperandTypes); + +void printCustomDirectiveResults(mlir::OpAsmPrinter &printer, mlir::Operation *, + mlir::Type operandType, + mlir::Type optOperandType, + mlir::TypeRange varOperandTypes); + +//===----------------------------------------------------------------------===// +// CustomDirectiveWithTypeRefs +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomDirectiveWithTypeRefs( + mlir::OpAsmParser &parser, mlir::Type operandType, + mlir::Type optOperandType, + const llvm::SmallVectorImpl &varOperandTypes); + +void printCustomDirectiveWithTypeRefs(mlir::OpAsmPrinter &printer, + mlir::Operation *op, + mlir::Type operandType, + mlir::Type optOperandType, + mlir::TypeRange varOperandTypes); + +//===----------------------------------------------------------------------===// +// CustomDirectiveOperandsAndTypes +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomDirectiveOperandsAndTypes( + mlir::OpAsmParser &parser, mlir::OpAsmParser::UnresolvedOperand &operand, + std::optional &optOperand, + llvm::SmallVectorImpl &varOperands, + mlir::Type &operandType, mlir::Type &optOperandType, + llvm::SmallVectorImpl &varOperandTypes); + +void printCustomDirectiveOperandsAndTypes( + mlir::OpAsmPrinter &printer, mlir::Operation *op, mlir::Value operand, + mlir::Value optOperand, mlir::OperandRange varOperands, + mlir::Type operandType, mlir::Type optOperandType, + mlir::TypeRange varOperandTypes); + +//===----------------------------------------------------------------------===// +// CustomDirectiveRegions +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomDirectiveRegions( + mlir::OpAsmParser &parser, mlir::Region ®ion, + llvm::SmallVectorImpl> &varRegions); + +void printCustomDirectiveRegions( + mlir::OpAsmPrinter &printer, mlir::Operation *, mlir::Region ®ion, + llvm::MutableArrayRef varRegions); + +//===----------------------------------------------------------------------===// +// CustomDirectiveSuccessors +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomDirectiveSuccessors( + mlir::OpAsmParser &parser, mlir::Block *&successor, + llvm::SmallVectorImpl &varSuccessors); + +void printCustomDirectiveSuccessors(mlir::OpAsmPrinter &printer, + mlir::Operation *, mlir::Block *successor, + mlir::SuccessorRange varSuccessors); + +//===----------------------------------------------------------------------===// +// CustomDirectiveAttributes +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomDirectiveAttributes(mlir::OpAsmParser &parser, + mlir::IntegerAttr &attr, + mlir::IntegerAttr &optAttr); + +void printCustomDirectiveAttributes(mlir::OpAsmPrinter &printer, + mlir::Operation *, + mlir::Attribute attribute, + mlir::Attribute optAttribute); + +//===----------------------------------------------------------------------===// +// CustomDirectiveAttrDict +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomDirectiveAttrDict(mlir::OpAsmParser &parser, + mlir::NamedAttrList &attrs); + +void printCustomDirectiveAttrDict(mlir::OpAsmPrinter &printer, + mlir::Operation *op, + mlir::DictionaryAttr attrs); + +//===----------------------------------------------------------------------===// +// CustomDirectiveOptionalOperandRef +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomDirectiveOptionalOperandRef( + mlir::OpAsmParser &parser, + std::optional &optOperand); + +void printCustomDirectiveOptionalOperandRef(mlir::OpAsmPrinter &printer, + mlir::Operation *op, + mlir::Value optOperand); + +//===----------------------------------------------------------------------===// +// CustomDirectiveOptionalOperand +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseCustomOptionalOperand( + mlir::OpAsmParser &parser, + std::optional &optOperand); + +void printCustomOptionalOperand(mlir::OpAsmPrinter &printer, mlir::Operation *, + mlir::Value optOperand); + +//===----------------------------------------------------------------------===// +// CustomDirectiveSwitchCases +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseSwitchCases( + mlir::OpAsmParser &p, mlir::DenseI64ArrayAttr &cases, + llvm::SmallVectorImpl> &caseRegions); + +void printSwitchCases(mlir::OpAsmPrinter &p, mlir::Operation *op, + mlir::DenseI64ArrayAttr cases, + mlir::RegionRange caseRegions); + +//===----------------------------------------------------------------------===// +// CustomUsingPropertyInCustom +//===----------------------------------------------------------------------===// + +bool parseUsingPropertyInCustom(mlir::OpAsmParser &parser, int64_t value[3]); + +void printUsingPropertyInCustom(mlir::OpAsmPrinter &printer, + mlir::Operation *op, + llvm::ArrayRef value); + +//===----------------------------------------------------------------------===// +// CustomDirectiveIntProperty +//===----------------------------------------------------------------------===// + +bool parseIntProperty(mlir::OpAsmParser &parser, int64_t &value); + +void printIntProperty(mlir::OpAsmPrinter &printer, mlir::Operation *op, + int64_t value); + +//===----------------------------------------------------------------------===// +// CustomDirectiveSumProperty +//===----------------------------------------------------------------------===// + +bool parseSumProperty(mlir::OpAsmParser &parser, int64_t &second, + int64_t first); + +void printSumProperty(mlir::OpAsmPrinter &printer, mlir::Operation *op, + int64_t second, int64_t first); + +//===----------------------------------------------------------------------===// +// CustomDirectiveOptionalCustomParser +//===----------------------------------------------------------------------===// + +mlir::OptionalParseResult parseOptionalCustomParser(mlir::AsmParser &p, + mlir::IntegerAttr &result); + +void printOptionalCustomParser(mlir::AsmPrinter &p, mlir::Operation *, + mlir::IntegerAttr result); + +//===----------------------------------------------------------------------===// +// CustomDirectiveAttrElideType +//===----------------------------------------------------------------------===// + +mlir::ParseResult parseAttrElideType(mlir::AsmParser &parser, + mlir::TypeAttr type, + mlir::Attribute &attr); + +void printAttrElideType(mlir::AsmPrinter &printer, mlir::Operation *op, + mlir::TypeAttr type, mlir::Attribute attr); + +} // end namespace test + +#endif // MLIR_TESTFORMATUTILS_H diff --git a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp index 3673d62bea2c9..dc6413b25707e 100644 --- a/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp +++ b/mlir/test/lib/Dialect/Test/TestFromLLVMIRTranslation.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.cpp b/mlir/test/lib/Dialect/Test/TestInterfaces.cpp index 64ec82ecb24ff..14099bb4bb16b 100644 --- a/mlir/test/lib/Dialect/Test/TestInterfaces.cpp +++ b/mlir/test/lib/Dialect/Test/TestInterfaces.cpp @@ -6,3 +6,5 @@ bool mlir::TestEffects::Effect::classof( const mlir::SideEffects::Effect *effect) { return isa(effect); } + +#include "TestOpInterfaces.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestInterfaces.h b/mlir/test/lib/Dialect/Test/TestInterfaces.h index 3239584a93326..d58d1aafbe66c 100644 --- a/mlir/test/lib/Dialect/Test/TestInterfaces.h +++ b/mlir/test/lib/Dialect/Test/TestInterfaces.h @@ -34,4 +34,6 @@ struct Concrete : public Effect::Base {}; } // namespace TestEffects } // namespace mlir +#include "TestOpInterfaces.h.inc" + #endif // MLIR_TEST_LIB_DIALECT_TEST_TESTINTERFACES_H diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp new file mode 100644 index 0000000000000..7263774ca158e --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -0,0 +1,1161 @@ +//===- TestOpDefs.cpp - MLIR Test Dialect Operation Hooks -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "TestOps.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/FunctionImplementation.h" + +using namespace mlir; +using namespace test; + +//===----------------------------------------------------------------------===// +// TestBranchOp +//===----------------------------------------------------------------------===// + +SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) { + assert(index == 0 && "invalid successor index"); + return SuccessorOperands(getTargetOperandsMutable()); +} + +//===----------------------------------------------------------------------===// +// TestProducingBranchOp +//===----------------------------------------------------------------------===// + +SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) { + assert(index <= 1 && "invalid successor index"); + if (index == 1) + return SuccessorOperands(getFirstOperandsMutable()); + return SuccessorOperands(getSecondOperandsMutable()); +} + +//===----------------------------------------------------------------------===// +// TestInternalBranchOp +//===----------------------------------------------------------------------===// + +SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) { + assert(index <= 1 && "invalid successor index"); + if (index == 0) + return SuccessorOperands(0, getSuccessOperandsMutable()); + return SuccessorOperands(1, getErrorOperandsMutable()); +} + +//===----------------------------------------------------------------------===// +// TestCallOp +//===----------------------------------------------------------------------===// + +LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Check that the callee attribute was specified. + auto fnAttr = (*this)->getAttrOfType("callee"); + if (!fnAttr) + return emitOpError("requires a 'callee' symbol reference attribute"); + if (!symbolTable.lookupNearestSymbolFrom(*this, fnAttr)) + return emitOpError() << "'" << fnAttr.getValue() + << "' does not reference a valid function"; + return success(); +} + +//===----------------------------------------------------------------------===// +// FoldToCallOp +//===----------------------------------------------------------------------===// + +namespace { +struct FoldToCallOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(FoldToCallOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, TypeRange(), + op.getCalleeAttr(), ValueRange()); + return success(); + } +}; +} // namespace + +void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + +//===----------------------------------------------------------------------===// +// IsolatedRegionOp - test parsing passthrough operands +//===----------------------------------------------------------------------===// + +ParseResult IsolatedRegionOp::parse(OpAsmParser &parser, + OperationState &result) { + // Parse the input operand. + OpAsmParser::Argument argInfo; + argInfo.type = parser.getBuilder().getIndexType(); + if (parser.parseOperand(argInfo.ssaName) || + parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands)) + return failure(); + + // Parse the body region, and reuse the operand info as the argument info. + Region *body = result.addRegion(); + return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true); +} + +void IsolatedRegionOp::print(OpAsmPrinter &p) { + p << ' '; + p.printOperand(getOperand()); + p.shadowRegionArgs(getRegion(), getOperand()); + p << ' '; + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); +} + +//===----------------------------------------------------------------------===// +// SSACFGRegionOp +//===----------------------------------------------------------------------===// + +RegionKind SSACFGRegionOp::getRegionKind(unsigned index) { + return RegionKind::SSACFG; +} + +//===----------------------------------------------------------------------===// +// GraphRegionOp +//===----------------------------------------------------------------------===// + +RegionKind GraphRegionOp::getRegionKind(unsigned index) { + return RegionKind::Graph; +} + +//===----------------------------------------------------------------------===// +// AffineScopeOp +//===----------------------------------------------------------------------===// + +ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) { + // Parse the body region, and reuse the operand info as the argument info. + Region *body = result.addRegion(); + return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); +} + +void AffineScopeOp::print(OpAsmPrinter &p) { + p << " "; + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); +} + +//===----------------------------------------------------------------------===// +// TestRemoveOpWithInnerOps +//===----------------------------------------------------------------------===// + +namespace { +struct TestRemoveOpWithInnerOps + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + void initialize() { setDebugName("TestRemoveOpWithInnerOps"); } + + LogicalResult matchAndRewrite(TestOpWithRegionPattern op, + PatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// TestOpWithRegionPattern +//===----------------------------------------------------------------------===// + +void TestOpWithRegionPattern::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.add(context); +} + +//===----------------------------------------------------------------------===// +// TestOpWithRegionFold +//===----------------------------------------------------------------------===// + +OpFoldResult TestOpWithRegionFold::fold(FoldAdaptor adaptor) { + return getOperand(); +} + +//===----------------------------------------------------------------------===// +// TestOpConstant +//===----------------------------------------------------------------------===// + +OpFoldResult TestOpConstant::fold(FoldAdaptor adaptor) { return getValue(); } + +//===----------------------------------------------------------------------===// +// TestOpWithVariadicResultsAndFolder +//===----------------------------------------------------------------------===// + +LogicalResult TestOpWithVariadicResultsAndFolder::fold( + FoldAdaptor adaptor, SmallVectorImpl &results) { + for (Value input : this->getOperands()) { + results.push_back(input); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// TestOpInPlaceFold +//===----------------------------------------------------------------------===// + +OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) { + // Exercise the fact that an operation created with createOrFold should be + // allowed to access its parent block. + assert(getOperation()->getBlock() && + "expected that operation is not unlinked"); + + if (adaptor.getOp() && !getProperties().attr) { + // The folder adds "attr" if not present. + getProperties().attr = dyn_cast_or_null(adaptor.getOp()); + return getResult(); + } + return {}; +} + +//===----------------------------------------------------------------------===// +// OpWithInferTypeInterfaceOp +//===----------------------------------------------------------------------===// + +LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes( + MLIRContext *, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + if (operands[0].getType() != operands[1].getType()) { + return emitOptionalError(location, "operand type mismatch ", + operands[0].getType(), " vs ", + operands[1].getType()); + } + inferredReturnTypes.assign({operands[0].getType()}); + return success(); +} + +//===----------------------------------------------------------------------===// +// OpWithShapedTypeInferTypeInterfaceOp +//===----------------------------------------------------------------------===// + +LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents( + MLIRContext *context, std::optional location, + ValueShapeRange operands, DictionaryAttr attributes, + OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + // Create return type consisting of the last element of the first operand. + auto operandType = operands.front().getType(); + auto sval = dyn_cast(operandType); + if (!sval) + return emitOptionalError(location, "only shaped type operands allowed"); + int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; + auto type = IntegerType::get(context, 17); + + Attribute encoding; + if (auto rankedTy = dyn_cast(sval)) + encoding = rankedTy.getEncoding(); + inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); + return success(); +} + +LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes( + OpBuilder &builder, ValueRange operands, + llvm::SmallVectorImpl &shapes) { + shapes = SmallVector{ + builder.createOrFold(getLoc(), operands.front(), 0)}; + return success(); +} + +//===----------------------------------------------------------------------===// +// OpWithResultShapeInterfaceOp +//===----------------------------------------------------------------------===// + +LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes( + OpBuilder &builder, ValueRange operands, + llvm::SmallVectorImpl &shapes) { + Location loc = getLoc(); + shapes.reserve(operands.size()); + for (Value operand : llvm::reverse(operands)) { + auto rank = cast(operand.getType()).getRank(); + auto currShape = llvm::to_vector<4>( + llvm::map_range(llvm::seq(0, rank), [&](int64_t dim) -> Value { + return builder.createOrFold(loc, operand, dim); + })); + shapes.push_back(builder.create( + getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), + currShape)); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// OpWithResultShapePerDimInterfaceOp +//===----------------------------------------------------------------------===// + +LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes( + OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) { + Location loc = getLoc(); + shapes.reserve(getNumOperands()); + for (Value operand : llvm::reverse(getOperands())) { + auto tensorType = cast(operand.getType()); + auto currShape = llvm::to_vector<4>(llvm::map_range( + llvm::seq(0, tensorType.getRank()), + [&](int64_t dim) -> OpFoldResult { + return tensorType.isDynamicDim(dim) + ? static_cast( + builder.createOrFold(loc, operand, + dim)) + : static_cast( + builder.getIndexAttr(tensorType.getDimSize(dim))); + })); + shapes.emplace_back(std::move(currShape)); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// SideEffectOp +//===----------------------------------------------------------------------===// + +namespace { +/// A test resource for side effects. +struct TestResource : public SideEffects::Resource::Base { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource) + + StringRef getName() final { return ""; } +}; +} // namespace + +void SideEffectOp::getEffects( + SmallVectorImpl &effects) { + // Check for an effects attribute on the op instance. + ArrayAttr effectsAttr = (*this)->getAttrOfType("effects"); + if (!effectsAttr) + return; + + // If there is one, it is an array of dictionary attributes that hold + // information on the effects of this operation. + for (Attribute element : effectsAttr) { + DictionaryAttr effectElement = cast(element); + + // Get the specific memory effect. + MemoryEffects::Effect *effect = + StringSwitch( + cast(effectElement.get("effect")).getValue()) + .Case("allocate", MemoryEffects::Allocate::get()) + .Case("free", MemoryEffects::Free::get()) + .Case("read", MemoryEffects::Read::get()) + .Case("write", MemoryEffects::Write::get()); + + // Check for a non-default resource to use. + SideEffects::Resource *resource = SideEffects::DefaultResource::get(); + if (effectElement.get("test_resource")) + resource = TestResource::get(); + + // Check for a result to affect. + if (effectElement.get("on_result")) + effects.emplace_back(effect, getResult(), resource); + else if (Attribute ref = effectElement.get("on_reference")) + effects.emplace_back(effect, cast(ref), resource); + else + effects.emplace_back(effect, resource); + } +} + +void SideEffectOp::getEffects( + SmallVectorImpl &effects) { + testSideEffectOpGetEffect(getOperation(), effects); +} + +//===----------------------------------------------------------------------===// +// StringAttrPrettyNameOp +//===----------------------------------------------------------------------===// + +// This op has fancy handling of its SSA result name. +ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser, + OperationState &result) { + // Add the result types. + for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) + result.addTypes(parser.getBuilder().getIntegerType(32)); + + if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) + return failure(); + + // If the attribute dictionary contains no 'names' attribute, infer it from + // the SSA name (if specified). + bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) { + return attr.getName() == "names"; + }); + + // If there was no name specified, check to see if there was a useful name + // specified in the asm file. + if (hadNames || parser.getNumResults() == 0) + return success(); + + SmallVector names; + auto *context = result.getContext(); + + for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) { + auto resultName = parser.getResultName(i); + StringRef nameStr; + if (!resultName.first.empty() && !isdigit(resultName.first[0])) + nameStr = resultName.first; + + names.push_back(nameStr); + } + + auto namesAttr = parser.getBuilder().getStrArrayAttr(names); + result.attributes.push_back({StringAttr::get(context, "names"), namesAttr}); + return success(); +} + +void StringAttrPrettyNameOp::print(OpAsmPrinter &p) { + // Note that we only need to print the "name" attribute if the asmprinter + // result name disagrees with it. This can happen in strange cases, e.g. + // when there are conflicts. + bool namesDisagree = getNames().size() != getNumResults(); + + SmallString<32> resultNameStr; + for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) { + resultNameStr.clear(); + llvm::raw_svector_ostream tmpStream(resultNameStr); + p.printOperand(getResult(i), tmpStream); + + auto expectedName = dyn_cast(getNames()[i]); + if (!expectedName || + tmpStream.str().drop_front() != expectedName.getValue()) { + namesDisagree = true; + } + } + + if (namesDisagree) + p.printOptionalAttrDictWithKeyword((*this)->getAttrs()); + else + p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"}); +} + +// We set the SSA name in the asm syntax to the contents of the name +// attribute. +void StringAttrPrettyNameOp::getAsmResultNames( + function_ref setNameFn) { + + auto value = getNames(); + for (size_t i = 0, e = value.size(); i != e; ++i) + if (auto str = dyn_cast(value[i])) + if (!str.getValue().empty()) + setNameFn(getResult(i), str.getValue()); +} + +//===----------------------------------------------------------------------===// +// CustomResultsNameOp +//===----------------------------------------------------------------------===// + +void CustomResultsNameOp::getAsmResultNames( + function_ref setNameFn) { + ArrayAttr value = getNames(); + for (size_t i = 0, e = value.size(); i != e; ++i) + if (auto str = dyn_cast(value[i])) + if (!str.empty()) + setNameFn(getResult(i), str.getValue()); +} + +//===----------------------------------------------------------------------===// +// ResultTypeWithTraitOp +//===----------------------------------------------------------------------===// + +LogicalResult ResultTypeWithTraitOp::verify() { + if ((*this)->getResultTypes()[0].hasTrait()) + return success(); + return emitError("result type should have trait 'TestTypeTrait'"); +} + +//===----------------------------------------------------------------------===// +// AttrWithTraitOp +//===----------------------------------------------------------------------===// + +LogicalResult AttrWithTraitOp::verify() { + if (getAttr().hasTrait()) + return success(); + return emitError("'attr' attribute should have trait 'TestAttrTrait'"); +} + +//===----------------------------------------------------------------------===// +// RegionIfOp +//===----------------------------------------------------------------------===// + +void RegionIfOp::print(OpAsmPrinter &p) { + p << " "; + p.printOperands(getOperands()); + p << ": " << getOperandTypes(); + p.printArrowTypeList(getResultTypes()); + p << " then "; + p.printRegion(getThenRegion(), + /*printEntryBlockArgs=*/true, + /*printBlockTerminators=*/true); + p << " else "; + p.printRegion(getElseRegion(), + /*printEntryBlockArgs=*/true, + /*printBlockTerminators=*/true); + p << " join "; + p.printRegion(getJoinRegion(), + /*printEntryBlockArgs=*/true, + /*printBlockTerminators=*/true); +} + +ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) { + SmallVector operandInfos; + SmallVector operandTypes; + + result.regions.reserve(3); + Region *thenRegion = result.addRegion(); + Region *elseRegion = result.addRegion(); + Region *joinRegion = result.addRegion(); + + // Parse operand, type and arrow type lists. + if (parser.parseOperandList(operandInfos) || + parser.parseColonTypeList(operandTypes) || + parser.parseArrowTypeList(result.types)) + return failure(); + + // Parse all attached regions. + if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) || + parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) || + parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {})) + return failure(); + + return parser.resolveOperands(operandInfos, operandTypes, + parser.getCurrentLocation(), result.operands); +} + +OperandRange RegionIfOp::getEntrySuccessorOperands(RegionBranchPoint point) { + assert(llvm::is_contained({&getThenRegion(), &getElseRegion()}, point) && + "invalid region index"); + return getOperands(); +} + +void RegionIfOp::getSuccessorRegions( + RegionBranchPoint point, SmallVectorImpl ®ions) { + // We always branch to the join region. + if (!point.isParent()) { + if (point != getJoinRegion()) + regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs())); + else + regions.push_back(RegionSuccessor(getResults())); + return; + } + + // The then and else regions are the entry regions of this op. + regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs())); + regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs())); +} + +void RegionIfOp::getRegionInvocationBounds( + ArrayRef operands, + SmallVectorImpl &invocationBounds) { + // Each region is invoked at most once. + invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1}); +} + +//===----------------------------------------------------------------------===// +// AnyCondOp +//===----------------------------------------------------------------------===// + +void AnyCondOp::getSuccessorRegions(RegionBranchPoint point, + SmallVectorImpl ®ions) { + // The parent op branches into the only region, and the region branches back + // to the parent op. + if (point.isParent()) + regions.emplace_back(&getRegion()); + else + regions.emplace_back(getResults()); +} + +void AnyCondOp::getRegionInvocationBounds( + ArrayRef operands, + SmallVectorImpl &invocationBounds) { + invocationBounds.emplace_back(1, 1); +} + +//===----------------------------------------------------------------------===// +// SingleBlockImplicitTerminatorOp +//===----------------------------------------------------------------------===// + +/// Testing the correctness of some traits. +static_assert( + llvm::is_detected::value, + "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp"); +static_assert(OpTrait::hasSingleBlockImplicitTerminator< + SingleBlockImplicitTerminatorOp>::value, + "hasSingleBlockImplicitTerminator does not match " + "SingleBlockImplicitTerminatorOp"); + +//===----------------------------------------------------------------------===// +// SingleNoTerminatorCustomAsmOp +//===----------------------------------------------------------------------===// + +ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser, + OperationState &state) { + Region *body = state.addRegion(); + if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) + return failure(); + return success(); +} + +void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) { + printer.printRegion( + getRegion(), /*printEntryBlockArgs=*/false, + // This op has a single block without terminators. But explicitly mark + // as not printing block terminators for testing. + /*printBlockTerminators=*/false); +} + +//===----------------------------------------------------------------------===// +// TestVerifiersOp +//===----------------------------------------------------------------------===// + +LogicalResult TestVerifiersOp::verify() { + if (!getRegion().hasOneBlock()) + return emitOpError("`hasOneBlock` trait hasn't been verified"); + + Operation *definingOp = getInput().getDefiningOp(); + if (definingOp && failed(mlir::verify(definingOp))) + return emitOpError("operand hasn't been verified"); + + // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier + // loop. + mlir::emitRemark(getLoc(), "success run of verifier"); + + return success(); +} + +LogicalResult TestVerifiersOp::verifyRegions() { + if (!getRegion().hasOneBlock()) + return emitOpError("`hasOneBlock` trait hasn't been verified"); + + for (Block &block : getRegion()) + for (Operation &op : block) + if (failed(mlir::verify(&op))) + return emitOpError("nested op hasn't been verified"); + + // Avoid using `emitRemark(msg)` since that will trigger an infinite verifier + // loop. + mlir::emitRemark(getLoc(), "success run of region verifier"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// Test InferIntRangeInterface +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// TestWithBoundsOp + +void TestWithBoundsOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()}); +} + +//===----------------------------------------------------------------------===// +// TestWithBoundsRegionOp + +ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser, + OperationState &result) { + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // Parse the input argument + OpAsmParser::Argument argInfo; + argInfo.type = parser.getBuilder().getIndexType(); + if (failed(parser.parseArgument(argInfo))) + return failure(); + + // Parse the body region, and reuse the operand info as the argument info. + Region *body = result.addRegion(); + return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false); +} + +void TestWithBoundsRegionOp::print(OpAsmPrinter &p) { + p.printOptionalAttrDict((*this)->getAttrs()); + p << ' '; + p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{}, + /*omitType=*/true); + p << ' '; + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); +} + +void TestWithBoundsRegionOp::inferResultRanges( + ArrayRef argRanges, SetIntRangeFn setResultRanges) { + Value arg = getRegion().getArgument(0); + setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()}); +} + +//===----------------------------------------------------------------------===// +// TestIncrementOp + +void TestIncrementOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + const ConstantIntRanges &range = argRanges[0]; + APInt one(range.umin().getBitWidth(), 1); + setResultRanges(getResult(), + {range.umin().uadd_sat(one), range.umax().uadd_sat(one), + range.smin().sadd_sat(one), range.smax().sadd_sat(one)}); +} + +//===----------------------------------------------------------------------===// +// TestReflectBoundsOp + +void TestReflectBoundsOp::inferResultRanges( + ArrayRef argRanges, SetIntRangeFn setResultRanges) { + const ConstantIntRanges &range = argRanges[0]; + MLIRContext *ctx = getContext(); + Builder b(ctx); + setUminAttr(b.getIndexAttr(range.umin().getZExtValue())); + setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue())); + setSminAttr(b.getIndexAttr(range.smin().getSExtValue())); + setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue())); + setResultRanges(getResult(), range); +} + +//===----------------------------------------------------------------------===// +// ConversionFuncOp +//===----------------------------------------------------------------------===// + +ParseResult ConversionFuncOp::parse(OpAsmParser &parser, + OperationState &result) { + auto buildFuncType = + [](Builder &builder, ArrayRef argTypes, ArrayRef results, + function_interface_impl::VariadicFlag, + std::string &) { return builder.getFunctionType(argTypes, results); }; + + return function_interface_impl::parseFunctionOp( + parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); +} + +void ConversionFuncOp::print(OpAsmPrinter &p) { + function_interface_impl::printFunctionOp( + p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); +} + +//===----------------------------------------------------------------------===// +// ReifyBoundOp +//===----------------------------------------------------------------------===// + +mlir::presburger::BoundType ReifyBoundOp::getBoundType() { + if (getType() == "EQ") + return mlir::presburger::BoundType::EQ; + if (getType() == "LB") + return mlir::presburger::BoundType::LB; + if (getType() == "UB") + return mlir::presburger::BoundType::UB; + llvm_unreachable("invalid bound type"); +} + +LogicalResult ReifyBoundOp::verify() { + if (isa(getVar().getType())) { + if (!getDim().has_value()) + return emitOpError("expected 'dim' attribute for shaped type variable"); + } else if (getVar().getType().isIndex()) { + if (getDim().has_value()) + return emitOpError("unexpected 'dim' attribute for index variable"); + } else { + return emitOpError("expected index-typed variable or shape type variable"); + } + if (getConstant() && getScalable()) + return emitOpError("'scalable' and 'constant' are mutually exlusive"); + if (getScalable() != getVscaleMin().has_value()) + return emitOpError("expected 'vscale_min' if and only if 'scalable'"); + if (getScalable() != getVscaleMax().has_value()) + return emitOpError("expected 'vscale_min' if and only if 'scalable'"); + return success(); +} + +ValueBoundsConstraintSet::Variable ReifyBoundOp::getVariable() { + if (getDim().has_value()) + return ValueBoundsConstraintSet::Variable(getVar(), *getDim()); + return ValueBoundsConstraintSet::Variable(getVar()); +} + +//===----------------------------------------------------------------------===// +// CompareOp +//===----------------------------------------------------------------------===// + +ValueBoundsConstraintSet::ComparisonOperator +CompareOp::getComparisonOperator() { + if (getCmp() == "EQ") + return ValueBoundsConstraintSet::ComparisonOperator::EQ; + if (getCmp() == "LT") + return ValueBoundsConstraintSet::ComparisonOperator::LT; + if (getCmp() == "LE") + return ValueBoundsConstraintSet::ComparisonOperator::LE; + if (getCmp() == "GT") + return ValueBoundsConstraintSet::ComparisonOperator::GT; + if (getCmp() == "GE") + return ValueBoundsConstraintSet::ComparisonOperator::GE; + llvm_unreachable("invalid comparison operator"); +} + +mlir::ValueBoundsConstraintSet::Variable CompareOp::getLhs() { + if (!getLhsMap()) + return ValueBoundsConstraintSet::Variable(getVarOperands()[0]); + SmallVector mapOperands( + getVarOperands().slice(0, getLhsMap()->getNumInputs())); + return ValueBoundsConstraintSet::Variable(*getLhsMap(), mapOperands); +} + +mlir::ValueBoundsConstraintSet::Variable CompareOp::getRhs() { + int64_t rhsOperandsBegin = getLhsMap() ? getLhsMap()->getNumInputs() : 1; + if (!getRhsMap()) + return ValueBoundsConstraintSet::Variable( + getVarOperands()[rhsOperandsBegin]); + SmallVector mapOperands( + getVarOperands().slice(rhsOperandsBegin, getRhsMap()->getNumInputs())); + return ValueBoundsConstraintSet::Variable(*getRhsMap(), mapOperands); +} + +LogicalResult CompareOp::verify() { + if (getCompose() && (getLhsMap() || getRhsMap())) + return emitOpError( + "'compose' not supported when 'lhs_map' or 'rhs_map' is present"); + int64_t expectedNumOperands = getLhsMap() ? getLhsMap()->getNumInputs() : 1; + expectedNumOperands += getRhsMap() ? getRhsMap()->getNumInputs() : 1; + if (getVarOperands().size() != size_t(expectedNumOperands)) + return emitOpError("expected ") + << expectedNumOperands << " operands, but got " + << getVarOperands().size(); + return success(); +} + +//===----------------------------------------------------------------------===// +// TestOpFoldWithFoldAdaptor +//===----------------------------------------------------------------------===// + +OpFoldResult TestOpFoldWithFoldAdaptor::fold(FoldAdaptor adaptor) { + int64_t sum = 0; + if (auto value = dyn_cast_or_null(adaptor.getOp())) + sum += value.getValue().getSExtValue(); + + for (Attribute attr : adaptor.getVariadic()) + if (auto value = dyn_cast_or_null(attr)) + sum += 2 * value.getValue().getSExtValue(); + + for (ArrayRef attrs : adaptor.getVarOfVar()) + for (Attribute attr : attrs) + if (auto value = dyn_cast_or_null(attr)) + sum += 3 * value.getValue().getSExtValue(); + + sum += 4 * std::distance(adaptor.getBody().begin(), adaptor.getBody().end()); + + return IntegerAttr::get(getType(), sum); +} + +//===----------------------------------------------------------------------===// +// OpWithInferTypeAdaptorInterfaceOp +//===----------------------------------------------------------------------===// + +LogicalResult OpWithInferTypeAdaptorInterfaceOp::inferReturnTypes( + MLIRContext *, std::optional location, + OpWithInferTypeAdaptorInterfaceOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnTypes) { + if (adaptor.getX().getType() != adaptor.getY().getType()) { + return emitOptionalError(location, "operand type mismatch ", + adaptor.getX().getType(), " vs ", + adaptor.getY().getType()); + } + inferredReturnTypes.assign({adaptor.getX().getType()}); + return success(); +} + +//===----------------------------------------------------------------------===// +// OpWithRefineTypeInterfaceOp +//===----------------------------------------------------------------------===// + +// TODO: We should be able to only define either inferReturnType or +// refineReturnType, currently only refineReturnType can be omitted. +LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes( + MLIRContext *context, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &returnTypes) { + returnTypes.clear(); + return OpWithRefineTypeInterfaceOp::refineReturnTypes( + context, location, operands, attributes, properties, regions, + returnTypes); +} + +LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes( + MLIRContext *, std::optional location, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &returnTypes) { + if (operands[0].getType() != operands[1].getType()) { + return emitOptionalError(location, "operand type mismatch ", + operands[0].getType(), " vs ", + operands[1].getType()); + } + // TODO: Add helper to make this more concise to write. + if (returnTypes.empty()) + returnTypes.resize(1, nullptr); + if (returnTypes[0] && returnTypes[0] != operands[0].getType()) + return emitOptionalError(location, + "required first operand and result to match"); + returnTypes[0] = operands[0].getType(); + return success(); +} + +//===----------------------------------------------------------------------===// +// OpWithShapedTypeInferTypeAdaptorInterfaceOp +//===----------------------------------------------------------------------===// + +LogicalResult +OpWithShapedTypeInferTypeAdaptorInterfaceOp::inferReturnTypeComponents( + MLIRContext *context, std::optional location, + OpWithShapedTypeInferTypeAdaptorInterfaceOp::Adaptor adaptor, + SmallVectorImpl &inferredReturnShapes) { + // Create return type consisting of the last element of the first operand. + auto operandType = adaptor.getOperand1().getType(); + auto sval = dyn_cast(operandType); + if (!sval) + return emitOptionalError(location, "only shaped type operands allowed"); + int64_t dim = sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamic; + auto type = IntegerType::get(context, 17); + + Attribute encoding; + if (auto rankedTy = dyn_cast(sval)) + encoding = rankedTy.getEncoding(); + inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type, encoding)); + return success(); +} + +LogicalResult +OpWithShapedTypeInferTypeAdaptorInterfaceOp::reifyReturnTypeShapes( + OpBuilder &builder, ValueRange operands, + llvm::SmallVectorImpl &shapes) { + shapes = SmallVector{ + builder.createOrFold(getLoc(), operands.front(), 0)}; + return success(); +} + +//===----------------------------------------------------------------------===// +// TestOpWithPropertiesAndInferredType +//===----------------------------------------------------------------------===// + +LogicalResult TestOpWithPropertiesAndInferredType::inferReturnTypes( + MLIRContext *context, std::optional, ValueRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + + Adaptor adaptor(operands, attributes, properties, regions); + inferredReturnTypes.push_back(IntegerType::get( + context, adaptor.getLhs() + adaptor.getProperties().rhs)); + return success(); +} + +//===----------------------------------------------------------------------===// +// LoopBlockOp +//===----------------------------------------------------------------------===// + +void LoopBlockOp::getSuccessorRegions( + RegionBranchPoint point, SmallVectorImpl ®ions) { + regions.emplace_back(&getBody(), getBody().getArguments()); + if (point.isParent()) + return; + + regions.emplace_back((*this)->getResults()); +} + +OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { + assert(point == getBody()); + return MutableOperandRange(getInitMutable()); +} + +//===----------------------------------------------------------------------===// +// LoopBlockTerminatorOp +//===----------------------------------------------------------------------===// + +MutableOperandRange +LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionBranchPoint point) { + if (point.isParent()) + return getExitArgMutable(); + return getNextIterArgMutable(); +} + +//===----------------------------------------------------------------------===// +// SwitchWithNoBreakOp +//===----------------------------------------------------------------------===// + +void TestNoTerminatorOp::getSuccessorRegions( + RegionBranchPoint point, SmallVectorImpl ®ions) {} + +//===----------------------------------------------------------------------===// +// Test InferIntRangeInterface +//===----------------------------------------------------------------------===// + +OpFoldResult ManualCppOpWithFold::fold(ArrayRef attributes) { + // Just a simple fold for testing purposes that reads an operands constant + // value and returns it. + if (!attributes.empty()) + return attributes.front(); + return nullptr; +} + +//===----------------------------------------------------------------------===// +// Tensor/Buffer Ops +//===----------------------------------------------------------------------===// + +void ReadBufferOp::getEffects( + SmallVectorImpl> + &effects) { + // The buffer operand is read. + effects.emplace_back(MemoryEffects::Read::get(), getBuffer(), + SideEffects::DefaultResource::get()); + // The buffer contents are dumped. + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); +} + +//===----------------------------------------------------------------------===// +// Test Dataflow +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// TestCallAndStoreOp + +CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() { + return getCallee(); +} + +void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) { + setCalleeAttr(callee.get()); +} + +Operation::operand_range TestCallAndStoreOp::getArgOperands() { + return getCalleeOperands(); +} + +MutableOperandRange TestCallAndStoreOp::getArgOperandsMutable() { + return getCalleeOperandsMutable(); +} + +//===----------------------------------------------------------------------===// +// TestCallOnDeviceOp + +CallInterfaceCallable TestCallOnDeviceOp::getCallableForCallee() { + return getCallee(); +} + +void TestCallOnDeviceOp::setCalleeFromCallable(CallInterfaceCallable callee) { + setCalleeAttr(callee.get()); +} + +Operation::operand_range TestCallOnDeviceOp::getArgOperands() { + return getForwardedOperands(); +} + +MutableOperandRange TestCallOnDeviceOp::getArgOperandsMutable() { + return getForwardedOperandsMutable(); +} + +//===----------------------------------------------------------------------===// +// TestStoreWithARegion + +void TestStoreWithARegion::getSuccessorRegions( + RegionBranchPoint point, SmallVectorImpl ®ions) { + if (point.isParent()) + regions.emplace_back(&getBody(), getBody().front().getArguments()); + else + regions.emplace_back(); +} + +//===----------------------------------------------------------------------===// +// TestStoreWithALoopRegion + +void TestStoreWithALoopRegion::getSuccessorRegions( + RegionBranchPoint point, SmallVectorImpl ®ions) { + // Both the operation itself and the region may be branching into the body or + // back into the operation itself. It is possible for the operation not to + // enter the body. + regions.emplace_back( + RegionSuccessor(&getBody(), getBody().front().getArguments())); + regions.emplace_back(); +} + +//===----------------------------------------------------------------------===// +// TestVersionedOpA +//===----------------------------------------------------------------------===// + +LogicalResult +TestVersionedOpA::readProperties(mlir::DialectBytecodeReader &reader, + mlir::OperationState &state) { + auto &prop = state.getOrAddProperties(); + if (mlir::failed(reader.readAttribute(prop.dims))) + return mlir::failure(); + + // Check if we have a version. If not, assume we are parsing the current + // version. + auto maybeVersion = reader.getDialectVersion(); + if (succeeded(maybeVersion)) { + // If version is less than 2.0, there is no additional attribute to parse. + // We can materialize missing properties post parsing before verification. + const auto *version = + reinterpret_cast(*maybeVersion); + if ((version->major_ < 2)) { + return success(); + } + } + + if (mlir::failed(reader.readAttribute(prop.modifier))) + return mlir::failure(); + return mlir::success(); +} + +void TestVersionedOpA::writeProperties(mlir::DialectBytecodeWriter &writer) { + auto &prop = getProperties(); + writer.writeAttribute(prop.dims); + + auto maybeVersion = writer.getDialectVersion(); + if (succeeded(maybeVersion)) { + // If version is less than 2.0, there is no additional attribute to write. + const auto *version = + reinterpret_cast(*maybeVersion); + if ((version->major_ < 2)) { + llvm::outs() << "downgrading op properties...\n"; + return; + } + } + writer.writeAttribute(prop.modifier); +} + +//===----------------------------------------------------------------------===// +// TestOpWithVersionedProperties +//===----------------------------------------------------------------------===// + +mlir::LogicalResult TestOpWithVersionedProperties::readFromMlirBytecode( + mlir::DialectBytecodeReader &reader, test::VersionedProperties &prop) { + uint64_t value1, value2 = 0; + if (failed(reader.readVarInt(value1))) + return failure(); + + // Check if we have a version. If not, assume we are parsing the current + // version. + auto maybeVersion = reader.getDialectVersion(); + bool needToParseAnotherInt = true; + if (succeeded(maybeVersion)) { + // If version is less than 2.0, there is no additional attribute to parse. + // We can materialize missing properties post parsing before verification. + const auto *version = + reinterpret_cast(*maybeVersion); + if ((version->major_ < 2)) + needToParseAnotherInt = false; + } + if (needToParseAnotherInt && failed(reader.readVarInt(value2))) + return failure(); + + prop.value1 = value1; + prop.value2 = value2; + return success(); +} + +void TestOpWithVersionedProperties::writeToMlirBytecode( + mlir::DialectBytecodeWriter &writer, + const test::VersionedProperties &prop) { + writer.writeVarInt(prop.value1); + writer.writeVarInt(prop.value2); +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.cpp b/mlir/test/lib/Dialect/Test/TestOps.cpp new file mode 100644 index 0000000000000..ce7e476be74e6 --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestOps.cpp @@ -0,0 +1,18 @@ +//===- TestOps.cpp - MLIR Test Dialect Operations ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestOps.h" +#include "TestDialect.h" +#include "TestFormatUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" + +using namespace mlir; +using namespace test; + +#define GET_OP_CLASSES +#include "TestOps.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.h b/mlir/test/lib/Dialect/Test/TestOps.h new file mode 100644 index 0000000000000..f9925855bb9db --- /dev/null +++ b/mlir/test/lib/Dialect/Test/TestOps.h @@ -0,0 +1,149 @@ +//===- TestOps.h - MLIR Test Dialect Operations ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TESTOPS_H +#define MLIR_TESTOPS_H + +#include "TestAttributes.h" +#include "TestInterfaces.h" +#include "TestTypes.h" +#include "mlir/Bytecode/BytecodeImplementation.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/DLTI/Traits.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectResourceBlobManager.h" +#include "mlir/IR/ExtensibleDialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/RegionKindInterface.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ValueBoundsOpInterface.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "llvm/ADT/SetVector.h" + +namespace test { +class TestDialect; + +//===----------------------------------------------------------------------===// +// TestResource +//===----------------------------------------------------------------------===// + +/// A test resource for side effects. +struct TestResource : public mlir::SideEffects::Resource::Base { + llvm::StringRef getName() final { return ""; } +}; + +//===----------------------------------------------------------------------===// +// PropertiesWithCustomPrint +//===----------------------------------------------------------------------===// + +struct PropertiesWithCustomPrint { + /// A shared_ptr to a const object is safe: it is equivalent to a value-based + /// member. Here the label will be deallocated when the last operation + /// refering to it is destroyed. However there is no pool-allocation: this is + /// offloaded to the client. + std::shared_ptr label; + int value; + bool operator==(const PropertiesWithCustomPrint &rhs) const { + return value == rhs.value && *label == *rhs.label; + } +}; + +mlir::LogicalResult setPropertiesFromAttribute( + PropertiesWithCustomPrint &prop, mlir::Attribute attr, + llvm::function_ref emitError); +mlir::DictionaryAttr +getPropertiesAsAttribute(mlir::MLIRContext *ctx, + const PropertiesWithCustomPrint &prop); +llvm::hash_code computeHash(const PropertiesWithCustomPrint &prop); +void customPrintProperties(mlir::OpAsmPrinter &p, + const PropertiesWithCustomPrint &prop); +mlir::ParseResult customParseProperties(mlir::OpAsmParser &parser, + PropertiesWithCustomPrint &prop); + +//===----------------------------------------------------------------------===// +// MyPropStruct +//===----------------------------------------------------------------------===// + +class MyPropStruct { +public: + std::string content; + // These three methods are invoked through the `MyStructProperty` wrapper + // defined in TestOps.td + mlir::Attribute asAttribute(mlir::MLIRContext *ctx) const; + static mlir::LogicalResult + setFromAttr(MyPropStruct &prop, mlir::Attribute attr, + llvm::function_ref emitError); + llvm::hash_code hash() const; + bool operator==(const MyPropStruct &rhs) const { + return content == rhs.content; + } +}; + +mlir::LogicalResult readFromMlirBytecode(mlir::DialectBytecodeReader &reader, + MyPropStruct &prop); +void writeToMlirBytecode(mlir::DialectBytecodeWriter &writer, + MyPropStruct &prop); + +//===----------------------------------------------------------------------===// +// VersionedProperties +//===----------------------------------------------------------------------===// + +struct VersionedProperties { + // For the sake of testing, assume that this object was associated to version + // 1.2 of the test dialect when having only one int value. In the current + // version 2.0, the property has two values. We also assume that the class is + // upgrade-able if value2 = 0. + int value1; + int value2; + bool operator==(const VersionedProperties &rhs) const { + return value1 == rhs.value1 && value2 == rhs.value2; + } +}; + +mlir::LogicalResult setPropertiesFromAttribute( + VersionedProperties &prop, mlir::Attribute attr, + llvm::function_ref emitError); +mlir::DictionaryAttr getPropertiesAsAttribute(mlir::MLIRContext *ctx, + const VersionedProperties &prop); +llvm::hash_code computeHash(const VersionedProperties &prop); +void customPrintProperties(mlir::OpAsmPrinter &p, + const VersionedProperties &prop); +mlir::ParseResult customParseProperties(mlir::OpAsmParser &parser, + VersionedProperties &prop); + +//===----------------------------------------------------------------------===// +// Bytecode Support +//===----------------------------------------------------------------------===// + +mlir::LogicalResult readFromMlirBytecode(mlir::DialectBytecodeReader &reader, + llvm::MutableArrayRef prop); +void writeToMlirBytecode(mlir::DialectBytecodeWriter &writer, + llvm::ArrayRef prop); + +} // namespace test + +#define GET_OP_CLASSES +#include "TestOps.h.inc" + +#endif // MLIR_TESTOPS_H diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp index 84e6a43655cac..c376d6c73c645 100644 --- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.cpp @@ -8,6 +8,7 @@ #include "TestOpsSyntax.h" #include "TestDialect.h" +#include "TestOps.h" #include "mlir/IR/OpImplementation.h" #include "llvm/Support/Base64.h" diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 76dc825fe4451..0c1731ba5f07c 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestOps.h" #include "TestTypes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" diff --git a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp index fa093cafcb0dc..57e7d658fb501 100644 --- a/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp +++ b/mlir/test/lib/Dialect/Test/TestToLLVMIRTranslation.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp index d9b67ef95ace8..031e1062dac76 100644 --- a/mlir/test/lib/Dialect/Test/TestTraits.cpp +++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "TestDialect.h" +#include "TestOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index 7a195eb25a3ba..1593b6d7d7534 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -139,6 +139,7 @@ static void printBarString(AsmPrinter &printer, StringRef foo) { // Tablegen Generated Definitions //===----------------------------------------------------------------------===// +#include "TestTypeInterfaces.cpp.inc" #define GET_TYPEDEF_CLASSES #include "TestTypeDefs.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h index b1b5921d8fadd..da5604944d5a3 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -31,11 +31,11 @@ class TestAttrWithFormatAttr; /// FieldInfo represents a field in the StructType data type. It is used as a /// parameter in TestTypeDefs.td. struct FieldInfo { - ::llvm::StringRef name; - ::mlir::Type type; + llvm::StringRef name; + mlir::Type type; // Custom allocation called from generated constructor code - FieldInfo allocateInto(::mlir::TypeStorageAllocator &alloc) const { + FieldInfo allocateInto(mlir::TypeStorageAllocator &alloc) const { return FieldInfo{alloc.copyInto(name), type}; } }; diff --git a/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp b/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp index e668224d34323..4894ad5294990 100644 --- a/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp +++ b/mlir/test/lib/IR/TestBytecodeRoundtrip.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestOps.h" #include "mlir/Bytecode/BytecodeReader.h" #include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/IR/BuiltinOps.h" diff --git a/mlir/test/lib/IR/TestClone.cpp b/mlir/test/lib/IR/TestClone.cpp index 7b18f219b915f..b742b316c7712 100644 --- a/mlir/test/lib/IR/TestClone.cpp +++ b/mlir/test/lib/IR/TestClone.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "TestDialect.h" +#include "TestOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/test/lib/IR/TestSideEffects.cpp b/mlir/test/lib/IR/TestSideEffects.cpp index 09ad136322824..8e13dd9751398 100644 --- a/mlir/test/lib/IR/TestSideEffects.cpp +++ b/mlir/test/lib/IR/TestSideEffects.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "TestDialect.h" +#include "TestOps.h" #include "mlir/Pass/Pass.h" using namespace mlir; diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp index 0e1368f2e0eca..b470b15c533b5 100644 --- a/mlir/test/lib/IR/TestSymbolUses.cpp +++ b/mlir/test/lib/IR/TestSymbolUses.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "TestDialect.h" +#include "TestOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/test/lib/IR/TestTypes.cpp b/mlir/test/lib/IR/TestTypes.cpp index 2bd63a48f77d1..c6bce111d3ea7 100644 --- a/mlir/test/lib/IR/TestTypes.cpp +++ b/mlir/test/lib/IR/TestTypes.cpp @@ -8,6 +8,7 @@ #include "TestTypes.h" #include "TestDialect.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" using namespace mlir; diff --git a/mlir/test/lib/IR/TestVisitorsGeneric.cpp b/mlir/test/lib/IR/TestVisitorsGeneric.cpp index 00148df26e351..4556671df0ba0 100644 --- a/mlir/test/lib/IR/TestVisitorsGeneric.cpp +++ b/mlir/test/lib/IR/TestVisitorsGeneric.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "TestDialect.h" +#include "TestOps.h" #include "mlir/Pass/Pass.h" using namespace mlir; diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp index 477b75916f80c..2762e25490324 100644 --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/test/lib/Transforms/TestInlining.cpp b/mlir/test/lib/Transforms/TestInlining.cpp index 9821179d05e89..223cc78dd1e21 100644 --- a/mlir/test/lib/Transforms/TestInlining.cpp +++ b/mlir/test/lib/Transforms/TestInlining.cpp @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" diff --git a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp index 61e1fbcf3feaf..82fa6cdb68d23 100644 --- a/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp +++ b/mlir/test/lib/Transforms/TestMakeIsolatedFromAbove.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" diff --git a/mlir/unittests/IR/AdaptorTest.cpp b/mlir/unittests/IR/AdaptorTest.cpp index 66ce53bbbadec..0a5fa8d3c475c 100644 --- a/mlir/unittests/IR/AdaptorTest.cpp +++ b/mlir/unittests/IR/AdaptorTest.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "../../test/lib/Dialect/Test/TestDialect.h" +#include "../../test/lib/Dialect/Test/TestOps.h" #include "../../test/lib/Dialect/Test/TestOpsSyntax.h" #include "gmock/gmock.h" #include "gtest/gtest.h" diff --git a/mlir/unittests/IR/IRMapping.cpp b/mlir/unittests/IR/IRMapping.cpp index 83627975006ee..b88009d1e3c36 100644 --- a/mlir/unittests/IR/IRMapping.cpp +++ b/mlir/unittests/IR/IRMapping.cpp @@ -11,6 +11,7 @@ #include "gtest/gtest.h" #include "../../test/lib/Dialect/Test/TestDialect.h" +#include "../../test/lib/Dialect/Test/TestOps.h" using namespace mlir; diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp index 58049a9969e3a..b6066dd5685dc 100644 --- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp +++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp @@ -19,6 +19,7 @@ #include "../../test/lib/Dialect/Test/TestAttributes.h" #include "../../test/lib/Dialect/Test/TestDialect.h" +#include "../../test/lib/Dialect/Test/TestOps.h" #include "../../test/lib/Dialect/Test/TestTypes.h" #include "mlir/IR/OwningOpRef.h" diff --git a/mlir/unittests/IR/InterfaceTest.cpp b/mlir/unittests/IR/InterfaceTest.cpp index 5ab4d9a106231..42196b003e7da 100644 --- a/mlir/unittests/IR/InterfaceTest.cpp +++ b/mlir/unittests/IR/InterfaceTest.cpp @@ -15,6 +15,7 @@ #include "../../test/lib/Dialect/Test/TestAttributes.h" #include "../../test/lib/Dialect/Test/TestDialect.h" +#include "../../test/lib/Dialect/Test/TestOps.h" #include "../../test/lib/Dialect/Test/TestTypes.h" using namespace mlir; diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp index 9d75615b39c0c..f94dc78445807 100644 --- a/mlir/unittests/IR/OperationSupportTest.cpp +++ b/mlir/unittests/IR/OperationSupportTest.cpp @@ -8,6 +8,7 @@ #include "mlir/IR/OperationSupport.h" #include "../../test/lib/Dialect/Test/TestDialect.h" +#include "../../test/lib/Dialect/Test/TestOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/BitVector.h" diff --git a/mlir/unittests/IR/PatternMatchTest.cpp b/mlir/unittests/IR/PatternMatchTest.cpp index 30b72618e45f0..75d5228c82d99 100644 --- a/mlir/unittests/IR/PatternMatchTest.cpp +++ b/mlir/unittests/IR/PatternMatchTest.cpp @@ -10,6 +10,7 @@ #include "gtest/gtest.h" #include "../../test/lib/Dialect/Test/TestDialect.h" +#include "../../test/lib/Dialect/Test/TestOps.h" using namespace mlir; diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp index 52347dcabe038..c83ac9088114c 100644 --- a/mlir/unittests/TableGen/OpBuildGen.cpp +++ b/mlir/unittests/TableGen/OpBuildGen.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "TestOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h"