diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 2df2fe4c5ce8e..917b27a40f26f 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -215,6 +215,8 @@ def Vector_ReductionOp : Vector_Op<"reduction", [Pure, PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, + OptionalTypesMatchWith<"dest and acc have the same type", + "dest", "acc", "::llvm::cast($_self)">, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods @@ -263,9 +265,8 @@ def Vector_ReductionOp : "::mlir::arith::FastMathFlags::none">:$fastMathFlags)> ]; - // TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional - // operands. - let hasCustomAssemblyFormat = 1; + let assemblyFormat = "$kind `,` $vector (`,` $acc^)? (`fastmath` `` $fastmath^)?" + " attr-dict `:` type($vector) `into` type($dest)"; let hasCanonicalizer = 1; let hasVerifier = 1; } diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 236dd74839dfb..7866ac24c1ccb 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -568,6 +568,14 @@ class TypesMatchWith + : TypesMatchWith.ret # "()" + # " || !get" # snakeCaseToCamelCase.ret # "() || " # comparator>; + // Special variant of `TypesMatchWith` that provides a comparator suitable for // ranged arguments. class RangedTypesMatchWith { string defaultValue = value; } +// Helper which makes the first letter of a string uppercase. +// e.g. cat -> Cat +class firstCharToUpper +{ + string ret = !if(!gt(!size(str), 0), + !toupper(!substr(str, 0, 1)) # !substr(str, 1), + ""); +} + +class _snakeCaseHelper { + int idx = !find(str, "_"); + string ret = !if(!ge(idx, 0), + !substr(str, 0, idx) # firstCharToUpper.ret, + str); +} + +// Converts a snake_case string to CamelCase. +// TODO: Replace with a !tocamelcase bang operator. +class snakeCaseToCamelCase +{ + string ret = !foldl(firstCharToUpper.ret, + !range(0, !size(str)), acc, idx, _snakeCaseHelper.ret); +} + #endif // UTILS_TD diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 044b6cc07d3d6..b63018dbd5d6a 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -485,47 +485,6 @@ LogicalResult ReductionOp::verify() { return success(); } -ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) { - SmallVector operandsInfo; - Type redType; - Type resType; - CombiningKindAttr kindAttr; - arith::FastMathFlagsAttr fastMathAttr; - if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind", - result.attributes) || - parser.parseComma() || parser.parseOperandList(operandsInfo) || - (succeeded(parser.parseOptionalKeyword("fastmath")) && - parser.parseCustomAttributeWithFallback(fastMathAttr, Type{}, "fastmath", - result.attributes)) || - parser.parseColonType(redType) || - parser.parseKeywordType("into", resType) || - (!operandsInfo.empty() && - parser.resolveOperand(operandsInfo[0], redType, result.operands)) || - (operandsInfo.size() > 1 && - parser.resolveOperand(operandsInfo[1], resType, result.operands)) || - parser.addTypeToList(resType, result.types)) - return failure(); - if (operandsInfo.empty() || operandsInfo.size() > 2) - return parser.emitError(parser.getNameLoc(), - "unsupported number of operands"); - return success(); -} - -void ReductionOp::print(OpAsmPrinter &p) { - p << " "; - getKindAttr().print(p); - p << ", " << getVector(); - if (getAcc()) - p << ", " << getAcc(); - - if (getFastmathAttr() && - getFastmathAttr().getValue() != arith::FastMathFlags::none) { - p << ' ' << getFastmathAttrName().getValue(); - p.printStrippedAttrOrType(getFastmathAttr()); - } - p << " : " << getVector().getType() << " into " << getDest().getType(); -} - // MaskableOpInterface methods. /// Returns the mask type expected by this operation. diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 5967a8d69bbfc..504ac89659fdb 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1169,7 +1169,7 @@ func.func @reduce_unsupported_attr(%arg0: vector<16xf32>) -> i32 { // ----- func.func @reduce_unsupported_third_argument(%arg0: vector<16xf32>, %arg1: f32) -> f32 { - // expected-error@+1 {{'vector.reduction' unsupported number of operands}} + // expected-error@+1 {{expected ':'}} %0 = vector.reduction , %arg0, %arg1, %arg1 : vector<16xf32> into f32 } diff --git a/mlir/test/mlir-tblgen/utils.td b/mlir/test/mlir-tblgen/utils.td new file mode 100644 index 0000000000000..28e0fecb2881b --- /dev/null +++ b/mlir/test/mlir-tblgen/utils.td @@ -0,0 +1,23 @@ +// RUN: mlir-tblgen -I %S/../../include %s | FileCheck %s + +include "mlir/IR/Utils.td" + +// CHECK-DAG: string value = "CamelCaseTest" +class already_camel_case { + string value = snakeCaseToCamelCase<"CamelCaseTest">.ret; +} + +// CHECK-DAG: string value = "Foo" +class single_word { + string value = snakeCaseToCamelCase<"foo">.ret; +} + +// CHECK-DAG: string value = "ThisIsATest" +class snake_case { + string value = snakeCaseToCamelCase<"this_is_a_test">.ret; +} + +// CHECK-DAG: string value = "ThisIsATestAgain" +class extra_underscores { + string value = snakeCaseToCamelCase<"__this__is_a_test__again__">.ret; +}