diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index 51017b5e050ff..040f9895ad0db 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -654,32 +654,14 @@ class LLVM_VecReductionI // LLVM vector reduction over a single vector, with an initial value, // and with permission to reassociate the reduction operations. class LLVM_VecReductionAccBase - : LLVM_OneResultIntrOp<"vector.reduce." # mnem, [], [0], - [Pure, SameOperandsAndResultElementType]>, - Arguments<(ins element:$start_value, LLVM_VectorOf:$input, - DefaultValuedAttr:$reassoc)> { - let llvmBuilder = [{ - llvm::Module *module = builder.GetInsertBlock()->getModule(); - llvm::Function *fn = llvm::Intrinsic::getDeclaration( - module, - llvm::Intrinsic::vector_reduce_}] # mnem # [{, - { }] # !interleave(ListIntSubst.lst, - ", ") # [{ - }); - auto operands = moduleTranslation.lookupValues(opInst.getOperands()); - llvm::FastMathFlags origFM = builder.getFastMathFlags(); - llvm::FastMathFlags tempFM = origFM; - tempFM.setAllowReassoc($reassoc); - builder.setFastMathFlags(tempFM); // set fastmath flag - $res = builder.CreateCall(fn, operands); - builder.setFastMathFlags(origFM); // restore fastmath flag - }]; - let mlirBuilder = [{ - bool allowReassoc = inst->getFastMathFlags().allowReassoc(); - $res = $_builder.create<$_qualCppClassName>($_location, - $_resultType, $start_value, $input, allowReassoc); - }]; -} + : LLVM_OneResultIntrOp, + Arguments<(ins element:$start_value, + LLVM_VectorOf:$input, + DefaultValuedAttr:$fastmathFlags)>; class LLVM_VecReductionAccF : LLVM_VecReductionAccBase; diff --git a/mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt index 2e56afe727ac0..23bed7e0f447e 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt @@ -1,10 +1,18 @@ -add_mlir_dialect(VectorOps vector) -add_mlir_doc(VectorOps VectorOps Dialects/ -gen-op-doc) +add_mlir_dialect(Vector vector) +add_mlir_doc(Vector Vector Dialects/ -gen-op-doc -dialect=vector) +# Add Vector operations set(LLVM_TARGET_DEFINITIONS VectorOps.td) -mlir_tablegen(VectorOpsEnums.h.inc -gen-enum-decls) -mlir_tablegen(VectorOpsEnums.cpp.inc -gen-enum-defs) -mlir_tablegen(VectorOpsAttrDefs.h.inc -gen-attrdef-decls) -mlir_tablegen(VectorOpsAttrDefs.cpp.inc -gen-attrdef-defs) -add_public_tablegen_target(MLIRVectorOpsEnumsIncGen) -add_dependencies(mlir-headers MLIRVectorOpsEnumsIncGen) +mlir_tablegen(VectorOps.h.inc -gen-op-decls) +mlir_tablegen(VectorOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRVectorOpsIncGen) +add_dependencies(mlir-generic-headers MLIRVectorOpsIncGen) + +# Add Vector attributes +set(LLVM_TARGET_DEFINITIONS VectorAttributes.td) +mlir_tablegen(VectorEnums.h.inc -gen-enum-decls) +mlir_tablegen(VectorEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(VectorAttributes.h.inc -gen-attrdef-decls) +mlir_tablegen(VectorAttributes.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRVectorAttributesIncGen) +add_dependencies(mlir-generic-headers MLIRVectorAttributesIncGen) diff --git a/mlir/include/mlir/Dialect/Vector/IR/Vector.td b/mlir/include/mlir/Dialect/Vector/IR/Vector.td new file mode 100644 index 0000000000000..c439ca083e2e0 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/IR/Vector.td @@ -0,0 +1,31 @@ +//===- Vector.td - Vector Dialect --------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Vector dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR +#define MLIR_DIALECT_VECTOR_IR_VECTOR + +include "mlir/IR/OpBase.td" + +def Vector_Dialect : Dialect { + let name = "vector"; + let cppNamespace = "::mlir::vector"; + + let useDefaultAttributePrinterParser = 1; + let hasConstantMaterializer = 1; + let dependentDialects = ["arith::ArithDialect"]; +} + +// Base class for Vector dialect ops. +class Vector_Op traits = []> : + Op; + +#endif // MLIR_DIALECT_VECTOR_IR_VECTOR diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td new file mode 100644 index 0000000000000..f8f85b0d09d90 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td @@ -0,0 +1,85 @@ +//===- VectorAttributes.td - Vector Dialect ----------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the attributes used in the Vector dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES +#define MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES + +include "mlir/Dialect/Vector/IR/Vector.td" +include "mlir/IR/EnumAttr.td" + +// The "kind" of combining function for contractions and reductions. +def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">; +def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">; +def COMBINING_KIND_MINUI : I32BitEnumAttrCaseBit<"MINUI", 2, "minui">; +def COMBINING_KIND_MINSI : I32BitEnumAttrCaseBit<"MINSI", 3, "minsi">; +def COMBINING_KIND_MINF : I32BitEnumAttrCaseBit<"MINF", 4, "minf">; +def COMBINING_KIND_MAXUI : I32BitEnumAttrCaseBit<"MAXUI", 5, "maxui">; +def COMBINING_KIND_MAXSI : I32BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">; +def COMBINING_KIND_MAXF : I32BitEnumAttrCaseBit<"MAXF", 7, "maxf">; +def COMBINING_KIND_AND : I32BitEnumAttrCaseBit<"AND", 8, "and">; +def COMBINING_KIND_OR : I32BitEnumAttrCaseBit<"OR", 9, "or">; +def COMBINING_KIND_XOR : I32BitEnumAttrCaseBit<"XOR", 10, "xor">; +def COMBINING_KIND_MINIMUMF : I32BitEnumAttrCaseBit<"MINIMUMF", 11, "minimumf">; +def COMBINING_KIND_MAXIMUMF : I32BitEnumAttrCaseBit<"MAXIMUMF", 12, "maximumf">; + +def CombiningKind : I32BitEnumAttr< + "CombiningKind", + "Kind of combining function for contractions and reductions", + [COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI, + COMBINING_KIND_MINSI, COMBINING_KIND_MINF, COMBINING_KIND_MAXUI, + COMBINING_KIND_MAXSI, COMBINING_KIND_MAXF, COMBINING_KIND_AND, + COMBINING_KIND_OR, COMBINING_KIND_XOR, + COMBINING_KIND_MAXIMUMF, COMBINING_KIND_MINIMUMF]> { + let cppNamespace = "::mlir::vector"; + let genSpecializedAttr = 0; +} + +/// An attribute that specifies the combining function for `vector.contract`, +/// and `vector.reduction`. +def Vector_CombiningKindAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def Vector_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [ + I32EnumAttrCase<"parallel", 0>, + I32EnumAttrCase<"reduction", 1> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::vector"; +} + +def Vector_IteratorTypeEnum + : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def Vector_IteratorTypeArrayAttr + : TypedArrayAttrBase; + +def PrintPunctuation : I32EnumAttr<"PrintPunctuation", + "Punctuation for separating vectors or vector elements", [ + I32EnumAttrCase<"NoPunctuation", 0, "no_punctuation">, + I32EnumAttrCase<"NewLine", 1, "newline">, + I32EnumAttrCase<"Comma", 2, "comma">, + I32EnumAttrCase<"Open", 3, "open">, + I32EnumAttrCase<"Close", 4, "close"> +]> { + let cppNamespace = "::mlir::vector"; + let genSpecializedAttr = 0; +} + +def Vector_PrintPunctuation : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +#endif // MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index 4a624bd5f1ccd..fcf7eb4a616b0 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" #include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h" #include "mlir/IR/AffineMap.h" @@ -31,10 +32,10 @@ #include "llvm/ADT/StringExtras.h" // Pull in all enum type definitions and utility function declarations. -#include "mlir/Dialect/Vector/IR/VectorOpsEnums.h.inc" +#include "mlir/Dialect/Vector/IR/VectorEnums.h.inc" #define GET_ATTRDEF_CLASSES -#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.h.inc" +#include "mlir/Dialect/Vector/IR/VectorAttributes.h.inc" namespace mlir { class MLIRContext; @@ -157,7 +158,7 @@ Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, } // namespace mlir #define GET_OP_CLASSES +#include "mlir/Dialect/Vector/IR/VectorDialect.h.inc" #include "mlir/Dialect/Vector/IR/VectorOps.h.inc" -#include "mlir/Dialect/Vector/IR/VectorOpsDialect.h.inc" #endif // MLIR_DIALECT_VECTOR_IR_VECTOROPS_H diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 28b5864914f69..4aff10c14fb81 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -10,9 +10,13 @@ // //===----------------------------------------------------------------------===// -#ifndef VECTOR_OPS -#define VECTOR_OPS +#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_OPS +#define MLIR_DIALECT_VECTOR_IR_VECTOR_OPS +include "mlir/Dialect/Vector/IR/Vector.td" +include "mlir/Dialect/Vector/IR/VectorAttributes.td" +include "mlir/Dialect/Arith/IR/ArithBase.td" +include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td" include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td" include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td" include "mlir/IR/EnumAttr.td" @@ -23,69 +27,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" -def Vector_Dialect : Dialect { - let name = "vector"; - let cppNamespace = "::mlir::vector"; - - let useDefaultAttributePrinterParser = 1; - let hasConstantMaterializer = 1; - let dependentDialects = ["arith::ArithDialect"]; -} - -// Base class for Vector dialect ops. -class Vector_Op traits = []> : - Op; - -// The "kind" of combining function for contractions and reductions. -def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">; -def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">; -def COMBINING_KIND_MINUI : I32BitEnumAttrCaseBit<"MINUI", 2, "minui">; -def COMBINING_KIND_MINSI : I32BitEnumAttrCaseBit<"MINSI", 3, "minsi">; -def COMBINING_KIND_MINF : I32BitEnumAttrCaseBit<"MINF", 4, "minf">; -def COMBINING_KIND_MAXUI : I32BitEnumAttrCaseBit<"MAXUI", 5, "maxui">; -def COMBINING_KIND_MAXSI : I32BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">; -def COMBINING_KIND_MAXF : I32BitEnumAttrCaseBit<"MAXF", 7, "maxf">; -def COMBINING_KIND_AND : I32BitEnumAttrCaseBit<"AND", 8, "and">; -def COMBINING_KIND_OR : I32BitEnumAttrCaseBit<"OR", 9, "or">; -def COMBINING_KIND_XOR : I32BitEnumAttrCaseBit<"XOR", 10, "xor">; -def COMBINING_KIND_MINIMUMF : I32BitEnumAttrCaseBit<"MINIMUMF", 11, "minimumf">; -def COMBINING_KIND_MAXIMUMF : I32BitEnumAttrCaseBit<"MAXIMUMF", 12, "maximumf">; - -def CombiningKind : I32BitEnumAttr< - "CombiningKind", - "Kind of combining function for contractions and reductions", - [COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI, - COMBINING_KIND_MINSI, COMBINING_KIND_MINF, COMBINING_KIND_MAXUI, - COMBINING_KIND_MAXSI, COMBINING_KIND_MAXF, COMBINING_KIND_AND, - COMBINING_KIND_OR, COMBINING_KIND_XOR, - COMBINING_KIND_MAXIMUMF, COMBINING_KIND_MINIMUMF]> { - let cppNamespace = "::mlir::vector"; - let genSpecializedAttr = 0; -} - -/// An attribute that specifies the combining function for `vector.contract`, -/// and `vector.reduction`. -def Vector_CombiningKindAttr : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -def Vector_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [ - I32EnumAttrCase<"parallel", 0>, - I32EnumAttrCase<"reduction", 1> -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::vector"; -} - -def Vector_IteratorTypeEnum - : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -def Vector_IteratorTypeArrayAttr - : TypedArrayAttrBase; - // TODO: Add an attribute to specify a different algebra with operators other // than the current set: {*, +}. def Vector_ContractionOp : @@ -274,12 +215,16 @@ def Vector_ReductionOp : Vector_Op<"reduction", [Pure, PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]>, + DeclareOpInterfaceMethods + ]>, Arguments<(ins Vector_CombiningKindAttr:$kind, AnyVectorOfAnyRank:$vector, - Optional:$acc)>, + Optional:$acc, + DefaultValuedAttr< + Arith_FastMathAttr, + "::mlir::arith::FastMathFlags::none">:$fastmath)>, Results<(outs AnyType:$dest)> { let summary = "reduction operation"; let description = [{ @@ -309,9 +254,13 @@ def Vector_ReductionOp : }]; let builders = [ // Builder that infers the type of `dest`. - OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, "Value":$acc)>, + OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, "Value":$acc, + CArg<"::mlir::arith::FastMathFlags", + "::mlir::arith::FastMathFlags::none">:$fastMathFlags)>, // Builder that infers the type of `dest` and has no accumulator. - OpBuilder<(ins "CombiningKind":$kind, "Value":$vector)> + OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, + CArg<"::mlir::arith::FastMathFlags", + "::mlir::arith::FastMathFlags::none">:$fastMathFlags)> ]; // TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional @@ -2466,22 +2415,6 @@ def Vector_TransposeOp : let hasVerifier = 1; } -def PrintPunctuation : I32EnumAttr<"PrintPunctuation", - "Punctuation for separating vectors or vector elements", [ - I32EnumAttrCase<"NoPunctuation", 0, "no_punctuation">, - I32EnumAttrCase<"NewLine", 1, "newline">, - I32EnumAttrCase<"Comma", 2, "comma">, - I32EnumAttrCase<"Open", 3, "open">, - I32EnumAttrCase<"Close", 4, "close"> -]> { - let cppNamespace = "::mlir::vector"; - let genSpecializedAttr = 0; -} - -def Vector_PrintPunctuation : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - def Vector_PrintOp : Vector_Op<"print", []>, Arguments<(ins Optional { } // namespace template -static Value -createFPReductionComparisonOpLowering(ConversionPatternRewriter &rewriter, - Location loc, Type llvmType, - Value vectorOperand, Value accumulator) { - Value result = rewriter.create(loc, llvmType, vectorOperand); +static Value createFPReductionComparisonOpLowering( + ConversionPatternRewriter &rewriter, Location loc, Type llvmType, + Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) { + Value result = + rewriter.create(loc, llvmType, vectorOperand, fmf); if (accumulator) { result = @@ -641,87 +642,72 @@ static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter, /// `fmaximum`/`fminimum`. /// More information: https://github.com/llvm/llvm-project/issues/64940 template -static Value lowerMaskedReductionWithRegular( - ConversionPatternRewriter &rewriter, Location loc, Type llvmType, - Value vectorOperand, Value accumulator, Value mask) { +static Value +lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter, + Location loc, Type llvmType, + Value vectorOperand, Value accumulator, + Value mask, LLVM::FastmathFlagsAttr fmf) { const Value vectorMaskNeutral = createMaskNeutralValue( rewriter, loc, llvmType, vectorOperand.getType()); const Value selectedVectorByMask = rewriter.create( loc, mask, vectorOperand, vectorMaskNeutral); return createFPReductionComparisonOpLowering( - rewriter, loc, llvmType, selectedVectorByMask, accumulator); + rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf); } -/// Overloaded methods to lower a reduction to an llvm instrinsic that requires -/// a start value. This start value format spans across fp reductions without -/// mask and all the masked reduction intrinsics. -template -static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, - Location loc, Type llvmType, - Value vectorOperand, - Value accumulator) { - accumulator = getOrCreateAccumulator(rewriter, loc, - llvmType, accumulator); - return rewriter.create(loc, llvmType, - /*startValue=*/accumulator, - vectorOperand); -} - -template +template static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, Type llvmType, Value vectorOperand, - Value accumulator, bool reassociateFPReds) { + Value accumulator, LLVM::FastmathFlagsAttr fmf) { accumulator = getOrCreateAccumulator(rewriter, loc, llvmType, accumulator); - return rewriter.create(loc, llvmType, - /*startValue=*/accumulator, - vectorOperand, reassociateFPReds); + return rewriter.create(loc, llvmType, + /*startValue=*/accumulator, + vectorOperand, fmf); } +/// Overloaded methods to lower a *predicated* reduction to an llvm instrinsic +/// that requires a start value. This start value format spans across fp +/// reductions without mask and all the masked reduction intrinsics. template -static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, - Location loc, Type llvmType, - Value vectorOperand, - Value accumulator, Value mask) { +static Value +lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter, + Location loc, Type llvmType, + Value vectorOperand, Value accumulator) { accumulator = getOrCreateAccumulator(rewriter, loc, llvmType, accumulator); - Value vectorLength = - createVectorLengthValue(rewriter, loc, vectorOperand.getType()); return rewriter.create(loc, llvmType, /*startValue=*/accumulator, - vectorOperand, mask, vectorLength); + vectorOperand); } template -static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, - Location loc, Type llvmType, - Value vectorOperand, - Value accumulator, Value mask, - bool reassociateFPReds) { +static Value lowerPredicatedReductionWithStartValue( + ConversionPatternRewriter &rewriter, Location loc, Type llvmType, + Value vectorOperand, Value accumulator, Value mask) { accumulator = getOrCreateAccumulator(rewriter, loc, llvmType, accumulator); Value vectorLength = createVectorLengthValue(rewriter, loc, vectorOperand.getType()); return rewriter.create(loc, llvmType, /*startValue=*/accumulator, - vectorOperand, mask, vectorLength, - reassociateFPReds); + vectorOperand, mask, vectorLength); } template -static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, - Location loc, Type llvmType, - Value vectorOperand, - Value accumulator, Value mask) { +static Value lowerPredicatedReductionWithStartValue( + ConversionPatternRewriter &rewriter, Location loc, Type llvmType, + Value vectorOperand, Value accumulator, Value mask) { if (llvmType.isIntOrIndex()) - return lowerReductionWithStartValue( + return lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, vectorOperand, accumulator, mask); // FP dispatch. - return lowerReductionWithStartValue( + return lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, vectorOperand, accumulator, mask); } @@ -809,30 +795,39 @@ class VectorReductionOpConversion if (!isa(eltType)) return failure(); + arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); + LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( + reductionOp.getContext(), + convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); + fmf = LLVM::FastmathFlagsAttr::get( + reductionOp.getContext(), + fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc + : LLVM::FastmathFlags::none)); + // Floating-point reductions: add/mul/min/max Value result; if (kind == vector::CombiningKind::ADD) { result = lowerReductionWithStartValue( - rewriter, loc, llvmType, operand, acc, reassociateFPReductions); + rewriter, loc, llvmType, operand, acc, fmf); } else if (kind == vector::CombiningKind::MUL) { result = lowerReductionWithStartValue( - rewriter, loc, llvmType, operand, acc, reassociateFPReductions); + rewriter, loc, llvmType, operand, acc, fmf); } else if (kind == vector::CombiningKind::MINIMUMF) { result = createFPReductionComparisonOpLowering( - rewriter, loc, llvmType, operand, acc); + rewriter, loc, llvmType, operand, acc, fmf); } else if (kind == vector::CombiningKind::MAXIMUMF) { result = createFPReductionComparisonOpLowering( - rewriter, loc, llvmType, operand, acc); + rewriter, loc, llvmType, operand, acc, fmf); } else if (kind == vector::CombiningKind::MINF) { result = createFPReductionComparisonOpLowering( - rewriter, loc, llvmType, operand, acc); + rewriter, loc, llvmType, operand, acc, fmf); } else if (kind == vector::CombiningKind::MAXF) { result = createFPReductionComparisonOpLowering( - rewriter, loc, llvmType, operand, acc); + rewriter, loc, llvmType, operand, acc, fmf); } else return failure(); @@ -893,74 +888,79 @@ class MaskedReductionOpConversion Value acc = reductionOp.getAcc(); Location loc = reductionOp.getLoc(); + arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); + LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( + reductionOp.getContext(), + convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); + Value result; switch (kind) { case vector::CombiningKind::ADD: - result = lowerReductionWithStartValue< + result = lowerPredicatedReductionWithStartValue< LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp, ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::MUL: - result = lowerReductionWithStartValue< + result = lowerPredicatedReductionWithStartValue< LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp, ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::MINUI: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::MINSI: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::MAXUI: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::MAXSI: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::AND: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::OR: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::XOR: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::MINF: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::MAXF: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case CombiningKind::MAXIMUMF: result = lowerMaskedReductionWithRegular( - rewriter, loc, llvmType, operand, acc, maskOp.getMask()); + rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); break; case CombiningKind::MINIMUMF: result = lowerMaskedReductionWithRegular( - rewriter, loc, llvmType, operand, acc, maskOp.getMask()); + rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); break; } diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt index 596f6422807cc..9ec919423b342 100644 --- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt @@ -8,7 +8,7 @@ add_mlir_dialect_library(MLIRVectorDialect MLIRMaskableOpInterfaceIncGen MLIRMaskingOpInterfaceIncGen MLIRVectorOpsIncGen - MLIRVectorOpsEnumsIncGen + MLIRVectorAttributesIncGen LINK_LIBS PUBLIC MLIRArithDialect diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a8ad05f7bc1ca..54a3de6608505 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -42,9 +42,9 @@ #include #include -#include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc" +#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc" // Pull in all enum type and utility function definitions. -#include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc" +#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc" using namespace mlir; using namespace mlir::vector; @@ -256,7 +256,7 @@ struct BitmaskEnumStorage : public AttributeStorage { void VectorDialect::initialize() { addAttributes< #define GET_ATTRDEF_LIST -#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc" +#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc" >(); addOperations< @@ -415,15 +415,17 @@ void MultiDimReductionOp::getCanonicalizationPatterns( //===----------------------------------------------------------------------===// void vector::ReductionOp::build(OpBuilder &builder, OperationState &result, - CombiningKind kind, Value vector) { - build(builder, result, kind, vector, /*acc=*/Value()); + CombiningKind kind, Value vector, + arith::FastMathFlags fastMathFlags) { + build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags); } void vector::ReductionOp::build(OpBuilder &builder, OperationState &result, - CombiningKind kind, Value vector, Value acc) { + CombiningKind kind, Value vector, Value acc, + arith::FastMathFlags fastMathFlags) { build(builder, result, llvm::cast(vector.getType()).getElementType(), kind, vector, - acc); + acc, fastMathFlags); } LogicalResult ReductionOp::verify() { @@ -447,9 +449,13 @@ ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) { Type redType; Type resType; CombiningKindAttr kindAttr; + arith::FastMathFlagsAttr fastMathAttr; if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind", result.attributes) || parser.parseComma() || parser.parseOperandList(operandsInfo) || + (succeeded(parser.parseOptionalKeyword("fastmath")) && + parser.parseCustomAttributeWithFallback(fastMathAttr, Type{}, "fastmath", + result.attributes)) || parser.parseColonType(redType) || parser.parseKeywordType("into", resType) || (!operandsInfo.empty() && @@ -470,6 +476,12 @@ void ReductionOp::print(OpAsmPrinter &p) { p << ", " << getVector(); if (getAcc()) p << ", " << getAcc(); + + if (getFastmathAttr() && + getFastmathAttr().getValue() != arith::FastMathFlags::none) { + p << ' ' << getFastmathAttrName().getValue(); + p.printStrippedAttrOrType(getFastmathAttr()); + } p << " : " << getVector().getType() << " into " << getDest().getType(); } @@ -6052,7 +6064,7 @@ Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask, //===----------------------------------------------------------------------===// #define GET_ATTRDEF_CLASSES -#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc" +#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc" diff --git a/mlir/python/mlir/dialects/Vector.td b/mlir/python/mlir/dialects/Vector.td new file mode 100644 index 0000000000000..f659f754b66a7 --- /dev/null +++ b/mlir/python/mlir/dialects/Vector.td @@ -0,0 +1,14 @@ +//===-- Vector.td - Entry point for Vector bindings --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_VECTOR +#define PYTHON_BINDINGS_VECTOR + +include "mlir/Dialect/Vector/IR/Vector.td" + +#endif // PYTHON_BINDINGS_VECTOR diff --git a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir index fd2d6ae5a472f..13b7faed4790d 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir @@ -5,14 +5,14 @@ // CHECK-SAME: %[[A:.*]]: vector<16xf32>) // CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 // CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]]) -// CHECK-SAME: <{reassoc = false}> : (f32, vector<16xf32>) -> f32 +// CHECK-SAME: <{fastmathFlags = #llvm.fastmath}> : (f32, vector<16xf32>) -> f32 // CHECK: return %[[V]] : f32 // // REASSOC-LABEL: @reduce_add_f32( // REASSOC-SAME: %[[A:.*]]: vector<16xf32>) // REASSOC: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 // REASSOC: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]]) -// REASSOC-SAME: <{reassoc = true}> : (f32, vector<16xf32>) -> f32 +// REASSOC-SAME: <{fastmathFlags = #llvm.fastmath}> : (f32, vector<16xf32>) -> f32 // REASSOC: return %[[V]] : f32 // func.func @reduce_add_f32(%arg0: vector<16xf32>) -> f32 { @@ -22,22 +22,45 @@ func.func @reduce_add_f32(%arg0: vector<16xf32>) -> f32 { // ----- +// CHECK-LABEL: @reduce_add_f32_always_reassoc( +// CHECK-SAME: %[[A:.*]]: vector<16xf32>) +// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 +// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]]) +/// Note: the reassoc flag remains even though the pass sets reassociate-fp-reduction to false. +/// Ponder whether this flag really is a property of the pass / pattern.. +// CHECK-SAME: <{fastmathFlags = #llvm.fastmath}> : (f32, vector<16xf32>) -> f32 +// CHECK: return %[[V]] : f32 +// +// REASSOC-LABEL: @reduce_add_f32_always_reassoc( +// REASSOC-SAME: %[[A:.*]]: vector<16xf32>) +// REASSOC: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 +// REASSOC: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]]) +// REASSOC-SAME: <{fastmathFlags = #llvm.fastmath}> : (f32, vector<16xf32>) -> f32 +// REASSOC: return %[[V]] : f32 +// +func.func @reduce_add_f32_always_reassoc(%arg0: vector<16xf32>) -> f32 { + %0 = vector.reduction , %arg0 fastmath : vector<16xf32> into f32 + return %0 : f32 +} + +// ----- + // CHECK-LABEL: @reduce_mul_f32( // CHECK-SAME: %[[A:.*]]: vector<16xf32>) // CHECK: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 // CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fmul"(%[[C]], %[[A]]) -// CHECK-SAME: <{reassoc = false}> : (f32, vector<16xf32>) -> f32 +// CHECK-SAME: <{fastmathFlags = #llvm.fastmath}> : (f32, vector<16xf32>) -> f32 // CHECK: return %[[V]] : f32 // // REASSOC-LABEL: @reduce_mul_f32( // REASSOC-SAME: %[[A:.*]]: vector<16xf32>) // REASSOC: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 // REASSOC: %[[V:.*]] = "llvm.intr.vector.reduce.fmul"(%[[C]], %[[A]]) -// REASSOC-SAME: <{reassoc = true}> : (f32, vector<16xf32>) -> f32 +// REASSOC-SAME: <{fastmathFlags = #llvm.fastmath}> : (f32, vector<16xf32>) -> f32 // REASSOC: return %[[V]] : f32 // func.func @reduce_mul_f32(%arg0: vector<16xf32>) -> f32 { - %0 = vector.reduction , %arg0 : vector<16xf32> into f32 + %0 = vector.reduction , %arg0 fastmath : vector<16xf32> into f32 return %0 : f32 } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 7b29ef44c1f2f..53b87cfcce42a 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1216,7 +1216,7 @@ func.func @reduce_0d_f32(%arg0: vector) -> f32 { // CHECK: %[[CA:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector to vector<1xf32> // CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 // CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[CA]]) -// CHECK-SAME: <{reassoc = false}> : (f32, vector<1xf32>) -> f32 +// CHECK-SAME: <{fastmathFlags = #llvm.fastmath}> : (f32, vector<1xf32>) -> f32 // CHECK: return %[[V]] : f32 // ----- @@ -1229,7 +1229,7 @@ func.func @reduce_f16(%arg0: vector<16xf16>) -> f16 { // CHECK-SAME: %[[A:.*]]: vector<16xf16>) // CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f16) : f16 // CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]]) -// CHECK-SAME: <{reassoc = false}> : (f16, vector<16xf16>) -> f16 +// CHECK-SAME: <{fastmathFlags = #llvm.fastmath}> : (f16, vector<16xf16>) -> f16 // CHECK: return %[[V]] : f16 // ----- @@ -1242,7 +1242,7 @@ func.func @reduce_f32(%arg0: vector<16xf32>) -> f32 { // CHECK-SAME: %[[A:.*]]: vector<16xf32>) // CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 // CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]]) -// CHECK-SAME: <{reassoc = false}> : (f32, vector<16xf32>) -> f32 +// CHECK-SAME: <{fastmathFlags = #llvm.fastmath}> : (f32, vector<16xf32>) -> f32 // CHECK: return %[[V]] : f32 // ----- @@ -1255,7 +1255,7 @@ func.func @reduce_f64(%arg0: vector<16xf64>) -> f64 { // CHECK-SAME: %[[A:.*]]: vector<16xf64>) // CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f64) : f64 // CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]]) -// CHECK-SAME: <{reassoc = false}> : (f64, vector<16xf64>) -> f64 +// CHECK-SAME: <{fastmathFlags = #llvm.fastmath}> : (f64, vector<16xf64>) -> f64 // CHECK: return %[[V]] : f64 // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 4ea4379372e83..409951d31eedf 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -1007,3 +1007,10 @@ func.func @contraction_masked_scalable(%A: vector<3x4xf32>, : vector<3x[8]x4xi1> -> vector<3x[8]xf32> return %0 : vector<3x[8]xf32> } + +// CHECK-LABEL: func.func @fastmath( +func.func @fastmath(%x: vector<42xf32>) -> f32 { + // CHECK: vector.reduction , %{{.*}} fastmath + %min = vector.reduction , %x fastmath : vector<42xf32> into f32 + return %min: f32 +} diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll index 56c3cfbb5c7c2..8ce16fe5705cb 100644 --- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll +++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll @@ -354,13 +354,13 @@ define void @vector_reductions(float %0, <8 x float> %1, <8 x i32> %2) { %12 = call i32 @llvm.vector.reduce.umax.v8i32(<8 x i32> %2) ; CHECK: "llvm.intr.vector.reduce.umin"(%{{.*}}) : (vector<8xi32>) -> i32 %13 = call i32 @llvm.vector.reduce.umin.v8i32(<8 x i32> %2) - ; CHECK: "llvm.intr.vector.reduce.fadd"(%{{.*}}, %{{.*}}) <{reassoc = false}> : (f32, vector<8xf32>) -> f32 + ; CHECK: "llvm.intr.vector.reduce.fadd"(%{{.*}}, %{{.*}}) <{fastmathFlags = #llvm.fastmath}> : (f32, vector<8xf32>) -> f32 %14 = call float @llvm.vector.reduce.fadd.v8f32(float %0, <8 x float> %1) - ; CHECK: "llvm.intr.vector.reduce.fmul"(%{{.*}}, %{{.*}}) <{reassoc = false}> : (f32, vector<8xf32>) -> f32 + ; CHECK: "llvm.intr.vector.reduce.fmul"(%{{.*}}, %{{.*}}) <{fastmathFlags = #llvm.fastmath}> : (f32, vector<8xf32>) -> f32 %15 = call float @llvm.vector.reduce.fmul.v8f32(float %0, <8 x float> %1) - ; CHECK: "llvm.intr.vector.reduce.fadd"(%{{.*}}, %{{.*}}) <{reassoc = true}> : (f32, vector<8xf32>) -> f32 + ; CHECK: "llvm.intr.vector.reduce.fadd"(%{{.*}}, %{{.*}}) <{fastmathFlags = #llvm.fastmath}> : (f32, vector<8xf32>) -> f32 %16 = call reassoc float @llvm.vector.reduce.fadd.v8f32(float %0, <8 x float> %1) - ; CHECK: "llvm.intr.vector.reduce.fmul"(%{{.*}}, %{{.*}}) <{reassoc = true}> : (f32, vector<8xf32>) -> f32 + ; CHECK: "llvm.intr.vector.reduce.fmul"(%{{.*}}, %{{.*}}) <{fastmathFlags = #llvm.fastmath}> : (f32, vector<8xf32>) -> f32 %17 = call reassoc float @llvm.vector.reduce.fmul.v8f32(float %0, <8 x float> %1) ; CHECK: "llvm.intr.vector.reduce.xor"(%{{.*}}) : (vector<8xi32>) -> i32 %18 = call i32 @llvm.vector.reduce.xor.v8i32(<8 x i32> %2) diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir index 427c09976ef14..6bbd761b6e613 100644 --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -375,9 +375,9 @@ llvm.func @vector_reductions(%arg0: f32, %arg1: vector<8xf32>, %arg2: vector<8xi // CHECK: call float @llvm.vector.reduce.fmul.v8f32 "llvm.intr.vector.reduce.fmul"(%arg0, %arg1) : (f32, vector<8xf32>) -> f32 // CHECK: call reassoc float @llvm.vector.reduce.fadd.v8f32 - "llvm.intr.vector.reduce.fadd"(%arg0, %arg1) {reassoc = true} : (f32, vector<8xf32>) -> f32 + "llvm.intr.vector.reduce.fadd"(%arg0, %arg1) <{fastmathFlags = #llvm.fastmath}> : (f32, vector<8xf32>) -> f32 // CHECK: call reassoc float @llvm.vector.reduce.fmul.v8f32 - "llvm.intr.vector.reduce.fmul"(%arg0, %arg1) {reassoc = true} : (f32, vector<8xf32>) -> f32 + "llvm.intr.vector.reduce.fmul"(%arg0, %arg1) <{fastmathFlags = #llvm.fastmath}> : (f32, vector<8xf32>) -> f32 // CHECK: call i32 @llvm.vector.reduce.xor.v8i32 "llvm.intr.vector.reduce.xor"(%arg2) : (vector<8xi32>) -> i32 llvm.return