From bd2d055940bf6d4e6c83f3dbb0803607e5befb6e Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 20 Sep 2023 12:20:40 +0200 Subject: [PATCH 1/3] [mlir][Vector] Add fastmath flags to vector.reduction This revision pipes the fastmath attribute support through the vector.reduction op. This seemingly simple first step already requires quite some genuflexions, file and builder reorganization. In the process, retire the boolean reassoc flag deep in the LLVM dialect builders and just use the fastmath attribute. During conversions, templated builders for predicated intrinsics are partially cleaned up. In the future, to finalize the cleanups, one should consider adding fastmath to the VPIntrinsic ops. --- .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 13 +- .../mlir/Dialect/Vector/IR/CMakeLists.txt | 24 ++- mlir/include/mlir/Dialect/Vector/IR/Vector.td | 31 ++++ .../Dialect/Vector/IR/VectorAttributes.td | 85 ++++++++++ .../mlir/Dialect/Vector/IR/VectorOps.h | 7 +- .../mlir/Dialect/Vector/IR/VectorOps.td | 107 +++--------- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 156 +++++++++--------- mlir/lib/Dialect/Vector/IR/CMakeLists.txt | 2 +- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 28 +++- mlir/python/mlir/dialects/VectorOps.td | 10 +- 10 files changed, 264 insertions(+), 199 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Vector/IR/Vector.td create mode 100644 mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index 51017b5e050ff..5af84c9e8646f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -656,8 +656,9 @@ class LLVM_VecReductionI class LLVM_VecReductionAccBase : LLVM_OneResultIntrOp<"vector.reduce." # mnem, [], [0], [Pure, SameOperandsAndResultElementType]>, - Arguments<(ins element:$start_value, LLVM_VectorOf:$input, - DefaultValuedAttr:$reassoc)> { + Arguments<(ins element:$start_value, + LLVM_VectorOf:$input, + DefaultValuedAttr:$fastmathFlags)> { let llvmBuilder = [{ llvm::Module *module = builder.GetInsertBlock()->getModule(); llvm::Function *fn = llvm::Intrinsic::getDeclaration( @@ -667,17 +668,11 @@ class LLVM_VecReductionAccBase ", ") # [{ }); auto operands = moduleTranslation.lookupValues(opInst.getOperands()); - llvm::FastMathFlags origFM = builder.getFastMathFlags(); - llvm::FastMathFlags tempFM = origFM; - tempFM.setAllowReassoc($reassoc); - builder.setFastMathFlags(tempFM); // set fastmath flag $res = builder.CreateCall(fn, operands); - builder.setFastMathFlags(origFM); // restore fastmath flag }]; let mlirBuilder = [{ - bool allowReassoc = inst->getFastMathFlags().allowReassoc(); $res = $_builder.create<$_qualCppClassName>($_location, - $_resultType, $start_value, $input, allowReassoc); + $_resultType, $start_value, $input, inst->getFastMathFlags()); }]; } diff --git a/mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt index 2e56afe727ac0..23bed7e0f447e 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt @@ -1,10 +1,18 @@ -add_mlir_dialect(VectorOps vector) -add_mlir_doc(VectorOps VectorOps Dialects/ -gen-op-doc) +add_mlir_dialect(Vector vector) +add_mlir_doc(Vector Vector Dialects/ -gen-op-doc -dialect=vector) +# Add Vector operations set(LLVM_TARGET_DEFINITIONS VectorOps.td) -mlir_tablegen(VectorOpsEnums.h.inc -gen-enum-decls) -mlir_tablegen(VectorOpsEnums.cpp.inc -gen-enum-defs) -mlir_tablegen(VectorOpsAttrDefs.h.inc -gen-attrdef-decls) -mlir_tablegen(VectorOpsAttrDefs.cpp.inc -gen-attrdef-defs) -add_public_tablegen_target(MLIRVectorOpsEnumsIncGen) -add_dependencies(mlir-headers MLIRVectorOpsEnumsIncGen) +mlir_tablegen(VectorOps.h.inc -gen-op-decls) +mlir_tablegen(VectorOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRVectorOpsIncGen) +add_dependencies(mlir-generic-headers MLIRVectorOpsIncGen) + +# Add Vector attributes +set(LLVM_TARGET_DEFINITIONS VectorAttributes.td) +mlir_tablegen(VectorEnums.h.inc -gen-enum-decls) +mlir_tablegen(VectorEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(VectorAttributes.h.inc -gen-attrdef-decls) +mlir_tablegen(VectorAttributes.cpp.inc -gen-attrdef-defs) +add_public_tablegen_target(MLIRVectorAttributesIncGen) +add_dependencies(mlir-generic-headers MLIRVectorAttributesIncGen) diff --git a/mlir/include/mlir/Dialect/Vector/IR/Vector.td b/mlir/include/mlir/Dialect/Vector/IR/Vector.td new file mode 100644 index 0000000000000..c439ca083e2e0 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/IR/Vector.td @@ -0,0 +1,31 @@ +//===- Vector.td - Vector Dialect --------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Vector dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR +#define MLIR_DIALECT_VECTOR_IR_VECTOR + +include "mlir/IR/OpBase.td" + +def Vector_Dialect : Dialect { + let name = "vector"; + let cppNamespace = "::mlir::vector"; + + let useDefaultAttributePrinterParser = 1; + let hasConstantMaterializer = 1; + let dependentDialects = ["arith::ArithDialect"]; +} + +// Base class for Vector dialect ops. +class Vector_Op traits = []> : + Op; + +#endif // MLIR_DIALECT_VECTOR_IR_VECTOR diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td new file mode 100644 index 0000000000000..2db944b4ceaf1 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td @@ -0,0 +1,85 @@ +//===- VectorAttributes.td - Vector Dialect ----------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the attributes used in the Vector dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES +#define MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES + +include "Vector.td" +include "mlir/IR/EnumAttr.td" + +// The "kind" of combining function for contractions and reductions. +def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">; +def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">; +def COMBINING_KIND_MINUI : I32BitEnumAttrCaseBit<"MINUI", 2, "minui">; +def COMBINING_KIND_MINSI : I32BitEnumAttrCaseBit<"MINSI", 3, "minsi">; +def COMBINING_KIND_MINF : I32BitEnumAttrCaseBit<"MINF", 4, "minf">; +def COMBINING_KIND_MAXUI : I32BitEnumAttrCaseBit<"MAXUI", 5, "maxui">; +def COMBINING_KIND_MAXSI : I32BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">; +def COMBINING_KIND_MAXF : I32BitEnumAttrCaseBit<"MAXF", 7, "maxf">; +def COMBINING_KIND_AND : I32BitEnumAttrCaseBit<"AND", 8, "and">; +def COMBINING_KIND_OR : I32BitEnumAttrCaseBit<"OR", 9, "or">; +def COMBINING_KIND_XOR : I32BitEnumAttrCaseBit<"XOR", 10, "xor">; +def COMBINING_KIND_MINIMUMF : I32BitEnumAttrCaseBit<"MINIMUMF", 11, "minimumf">; +def COMBINING_KIND_MAXIMUMF : I32BitEnumAttrCaseBit<"MAXIMUMF", 12, "maximumf">; + +def CombiningKind : I32BitEnumAttr< + "CombiningKind", + "Kind of combining function for contractions and reductions", + [COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI, + COMBINING_KIND_MINSI, COMBINING_KIND_MINF, COMBINING_KIND_MAXUI, + COMBINING_KIND_MAXSI, COMBINING_KIND_MAXF, COMBINING_KIND_AND, + COMBINING_KIND_OR, COMBINING_KIND_XOR, + COMBINING_KIND_MAXIMUMF, COMBINING_KIND_MINIMUMF]> { + let cppNamespace = "::mlir::vector"; + let genSpecializedAttr = 0; +} + +/// An attribute that specifies the combining function for `vector.contract`, +/// and `vector.reduction`. +def Vector_CombiningKindAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def Vector_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [ + I32EnumAttrCase<"parallel", 0>, + I32EnumAttrCase<"reduction", 1> +]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::vector"; +} + +def Vector_IteratorTypeEnum + : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def Vector_IteratorTypeArrayAttr + : TypedArrayAttrBase; + +def PrintPunctuation : I32EnumAttr<"PrintPunctuation", + "Punctuation for separating vectors or vector elements", [ + I32EnumAttrCase<"NoPunctuation", 0, "no_punctuation">, + I32EnumAttrCase<"NewLine", 1, "newline">, + I32EnumAttrCase<"Comma", 2, "comma">, + I32EnumAttrCase<"Open", 3, "open">, + I32EnumAttrCase<"Close", 4, "close"> +]> { + let cppNamespace = "::mlir::vector"; + let genSpecializedAttr = 0; +} + +def Vector_PrintPunctuation : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +#endif // MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index 4a624bd5f1ccd..fcf7eb4a616b0 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" #include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h" #include "mlir/IR/AffineMap.h" @@ -31,10 +32,10 @@ #include "llvm/ADT/StringExtras.h" // Pull in all enum type definitions and utility function declarations. -#include "mlir/Dialect/Vector/IR/VectorOpsEnums.h.inc" +#include "mlir/Dialect/Vector/IR/VectorEnums.h.inc" #define GET_ATTRDEF_CLASSES -#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.h.inc" +#include "mlir/Dialect/Vector/IR/VectorAttributes.h.inc" namespace mlir { class MLIRContext; @@ -157,7 +158,7 @@ Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, } // namespace mlir #define GET_OP_CLASSES +#include "mlir/Dialect/Vector/IR/VectorDialect.h.inc" #include "mlir/Dialect/Vector/IR/VectorOps.h.inc" -#include "mlir/Dialect/Vector/IR/VectorOpsDialect.h.inc" #endif // MLIR_DIALECT_VECTOR_IR_VECTOROPS_H diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 28b5864914f69..ab77c31b418e5 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -10,9 +10,13 @@ // //===----------------------------------------------------------------------===// -#ifndef VECTOR_OPS -#define VECTOR_OPS +#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_OPS +#define MLIR_DIALECT_VECTOR_IR_VECTOR_OPS +include "Vector.td" +include "VectorAttributes.td" +include "mlir/Dialect/Arith/IR/ArithBase.td" +include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td" include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td" include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td" include "mlir/IR/EnumAttr.td" @@ -23,69 +27,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" -def Vector_Dialect : Dialect { - let name = "vector"; - let cppNamespace = "::mlir::vector"; - - let useDefaultAttributePrinterParser = 1; - let hasConstantMaterializer = 1; - let dependentDialects = ["arith::ArithDialect"]; -} - -// Base class for Vector dialect ops. -class Vector_Op traits = []> : - Op; - -// The "kind" of combining function for contractions and reductions. -def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">; -def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">; -def COMBINING_KIND_MINUI : I32BitEnumAttrCaseBit<"MINUI", 2, "minui">; -def COMBINING_KIND_MINSI : I32BitEnumAttrCaseBit<"MINSI", 3, "minsi">; -def COMBINING_KIND_MINF : I32BitEnumAttrCaseBit<"MINF", 4, "minf">; -def COMBINING_KIND_MAXUI : I32BitEnumAttrCaseBit<"MAXUI", 5, "maxui">; -def COMBINING_KIND_MAXSI : I32BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">; -def COMBINING_KIND_MAXF : I32BitEnumAttrCaseBit<"MAXF", 7, "maxf">; -def COMBINING_KIND_AND : I32BitEnumAttrCaseBit<"AND", 8, "and">; -def COMBINING_KIND_OR : I32BitEnumAttrCaseBit<"OR", 9, "or">; -def COMBINING_KIND_XOR : I32BitEnumAttrCaseBit<"XOR", 10, "xor">; -def COMBINING_KIND_MINIMUMF : I32BitEnumAttrCaseBit<"MINIMUMF", 11, "minimumf">; -def COMBINING_KIND_MAXIMUMF : I32BitEnumAttrCaseBit<"MAXIMUMF", 12, "maximumf">; - -def CombiningKind : I32BitEnumAttr< - "CombiningKind", - "Kind of combining function for contractions and reductions", - [COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI, - COMBINING_KIND_MINSI, COMBINING_KIND_MINF, COMBINING_KIND_MAXUI, - COMBINING_KIND_MAXSI, COMBINING_KIND_MAXF, COMBINING_KIND_AND, - COMBINING_KIND_OR, COMBINING_KIND_XOR, - COMBINING_KIND_MAXIMUMF, COMBINING_KIND_MINIMUMF]> { - let cppNamespace = "::mlir::vector"; - let genSpecializedAttr = 0; -} - -/// An attribute that specifies the combining function for `vector.contract`, -/// and `vector.reduction`. -def Vector_CombiningKindAttr : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -def Vector_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [ - I32EnumAttrCase<"parallel", 0>, - I32EnumAttrCase<"reduction", 1> -]> { - let genSpecializedAttr = 0; - let cppNamespace = "::mlir::vector"; -} - -def Vector_IteratorTypeEnum - : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - -def Vector_IteratorTypeArrayAttr - : TypedArrayAttrBase; - // TODO: Add an attribute to specify a different algebra with operators other // than the current set: {*, +}. def Vector_ContractionOp : @@ -274,12 +215,16 @@ def Vector_ReductionOp : Vector_Op<"reduction", [Pure, PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - DeclareOpInterfaceMethods]>, + DeclareOpInterfaceMethods + ]>, Arguments<(ins Vector_CombiningKindAttr:$kind, AnyVectorOfAnyRank:$vector, - Optional:$acc)>, + Optional:$acc, + DefaultValuedAttr< + Arith_FastMathAttr, + "::mlir::arith::FastMathFlags::none">:$fastmath)>, Results<(outs AnyType:$dest)> { let summary = "reduction operation"; let description = [{ @@ -309,9 +254,13 @@ def Vector_ReductionOp : }]; let builders = [ // Builder that infers the type of `dest`. - OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, "Value":$acc)>, + OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, "Value":$acc, + CArg<"::mlir::arith::FastMathFlags", + "::mlir::arith::FastMathFlags::none">:$fastMathFlags)>, // Builder that infers the type of `dest` and has no accumulator. - OpBuilder<(ins "CombiningKind":$kind, "Value":$vector)> + OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, + CArg<"::mlir::arith::FastMathFlags", + "::mlir::arith::FastMathFlags::none">:$fastMathFlags)> ]; // TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional @@ -2466,22 +2415,6 @@ def Vector_TransposeOp : let hasVerifier = 1; } -def PrintPunctuation : I32EnumAttr<"PrintPunctuation", - "Punctuation for separating vectors or vector elements", [ - I32EnumAttrCase<"NoPunctuation", 0, "no_punctuation">, - I32EnumAttrCase<"NewLine", 1, "newline">, - I32EnumAttrCase<"Comma", 2, "comma">, - I32EnumAttrCase<"Open", 3, "open">, - I32EnumAttrCase<"Close", 4, "close"> -]> { - let cppNamespace = "::mlir::vector"; - let genSpecializedAttr = 0; -} - -def Vector_PrintPunctuation : EnumAttr { - let assemblyFormat = "`<` $value `>`"; -} - def Vector_PrintOp : Vector_Op<"print", []>, Arguments<(ins Optional { } // namespace template -static Value -createFPReductionComparisonOpLowering(ConversionPatternRewriter &rewriter, - Location loc, Type llvmType, - Value vectorOperand, Value accumulator) { - Value result = rewriter.create(loc, llvmType, vectorOperand); +static Value createFPReductionComparisonOpLowering( + ConversionPatternRewriter &rewriter, Location loc, Type llvmType, + Value vectorOperand, Value accumulator, LLVM::FastmathFlagsAttr fmf) { + Value result = + rewriter.create(loc, llvmType, vectorOperand, fmf); if (accumulator) { result = @@ -641,87 +642,72 @@ static Value createMaskNeutralValue(ConversionPatternRewriter &rewriter, /// `fmaximum`/`fminimum`. /// More information: https://github.com/llvm/llvm-project/issues/64940 template -static Value lowerMaskedReductionWithRegular( - ConversionPatternRewriter &rewriter, Location loc, Type llvmType, - Value vectorOperand, Value accumulator, Value mask) { +static Value +lowerMaskedReductionWithRegular(ConversionPatternRewriter &rewriter, + Location loc, Type llvmType, + Value vectorOperand, Value accumulator, + Value mask, LLVM::FastmathFlagsAttr fmf) { const Value vectorMaskNeutral = createMaskNeutralValue( rewriter, loc, llvmType, vectorOperand.getType()); const Value selectedVectorByMask = rewriter.create( loc, mask, vectorOperand, vectorMaskNeutral); return createFPReductionComparisonOpLowering( - rewriter, loc, llvmType, selectedVectorByMask, accumulator); + rewriter, loc, llvmType, selectedVectorByMask, accumulator, fmf); } -/// Overloaded methods to lower a reduction to an llvm instrinsic that requires -/// a start value. This start value format spans across fp reductions without -/// mask and all the masked reduction intrinsics. -template -static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, - Location loc, Type llvmType, - Value vectorOperand, - Value accumulator) { - accumulator = getOrCreateAccumulator(rewriter, loc, - llvmType, accumulator); - return rewriter.create(loc, llvmType, - /*startValue=*/accumulator, - vectorOperand); -} - -template +template static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, Type llvmType, Value vectorOperand, - Value accumulator, bool reassociateFPReds) { + Value accumulator, LLVM::FastmathFlagsAttr fmf) { accumulator = getOrCreateAccumulator(rewriter, loc, llvmType, accumulator); - return rewriter.create(loc, llvmType, - /*startValue=*/accumulator, - vectorOperand, reassociateFPReds); + return rewriter.create(loc, llvmType, + /*startValue=*/accumulator, + vectorOperand, fmf); } +/// Overloaded methods to lower a *predicated* reduction to an llvm instrinsic +/// that requires a start value. This start value format spans across fp +/// reductions without mask and all the masked reduction intrinsics. template -static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, - Location loc, Type llvmType, - Value vectorOperand, - Value accumulator, Value mask) { +static Value +lowerPredicatedReductionWithStartValue(ConversionPatternRewriter &rewriter, + Location loc, Type llvmType, + Value vectorOperand, Value accumulator) { accumulator = getOrCreateAccumulator(rewriter, loc, llvmType, accumulator); - Value vectorLength = - createVectorLengthValue(rewriter, loc, vectorOperand.getType()); return rewriter.create(loc, llvmType, /*startValue=*/accumulator, - vectorOperand, mask, vectorLength); + vectorOperand); } template -static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, - Location loc, Type llvmType, - Value vectorOperand, - Value accumulator, Value mask, - bool reassociateFPReds) { +static Value lowerPredicatedReductionWithStartValue( + ConversionPatternRewriter &rewriter, Location loc, Type llvmType, + Value vectorOperand, Value accumulator, Value mask) { accumulator = getOrCreateAccumulator(rewriter, loc, llvmType, accumulator); Value vectorLength = createVectorLengthValue(rewriter, loc, vectorOperand.getType()); return rewriter.create(loc, llvmType, /*startValue=*/accumulator, - vectorOperand, mask, vectorLength, - reassociateFPReds); + vectorOperand, mask, vectorLength); } template -static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, - Location loc, Type llvmType, - Value vectorOperand, - Value accumulator, Value mask) { +static Value lowerPredicatedReductionWithStartValue( + ConversionPatternRewriter &rewriter, Location loc, Type llvmType, + Value vectorOperand, Value accumulator, Value mask) { if (llvmType.isIntOrIndex()) - return lowerReductionWithStartValue( + return lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, vectorOperand, accumulator, mask); // FP dispatch. - return lowerReductionWithStartValue( + return lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, vectorOperand, accumulator, mask); } @@ -809,30 +795,39 @@ class VectorReductionOpConversion if (!isa(eltType)) return failure(); + arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); + LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( + reductionOp.getContext(), + convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); + fmf = LLVM::FastmathFlagsAttr::get( + reductionOp.getContext(), + fmf.getValue() | (reassociateFPReductions ? LLVM::FastmathFlags::reassoc + : LLVM::FastmathFlags::none)); + // Floating-point reductions: add/mul/min/max Value result; if (kind == vector::CombiningKind::ADD) { result = lowerReductionWithStartValue( - rewriter, loc, llvmType, operand, acc, reassociateFPReductions); + rewriter, loc, llvmType, operand, acc, fmf); } else if (kind == vector::CombiningKind::MUL) { result = lowerReductionWithStartValue( - rewriter, loc, llvmType, operand, acc, reassociateFPReductions); + rewriter, loc, llvmType, operand, acc, fmf); } else if (kind == vector::CombiningKind::MINIMUMF) { result = createFPReductionComparisonOpLowering( - rewriter, loc, llvmType, operand, acc); + rewriter, loc, llvmType, operand, acc, fmf); } else if (kind == vector::CombiningKind::MAXIMUMF) { result = createFPReductionComparisonOpLowering( - rewriter, loc, llvmType, operand, acc); + rewriter, loc, llvmType, operand, acc, fmf); } else if (kind == vector::CombiningKind::MINF) { result = createFPReductionComparisonOpLowering( - rewriter, loc, llvmType, operand, acc); + rewriter, loc, llvmType, operand, acc, fmf); } else if (kind == vector::CombiningKind::MAXF) { result = createFPReductionComparisonOpLowering( - rewriter, loc, llvmType, operand, acc); + rewriter, loc, llvmType, operand, acc, fmf); } else return failure(); @@ -893,74 +888,79 @@ class MaskedReductionOpConversion Value acc = reductionOp.getAcc(); Location loc = reductionOp.getLoc(); + arith::FastMathFlagsAttr fMFAttr = reductionOp.getFastMathFlagsAttr(); + LLVM::FastmathFlagsAttr fmf = LLVM::FastmathFlagsAttr::get( + reductionOp.getContext(), + convertArithFastMathFlagsToLLVM(fMFAttr.getValue())); + Value result; switch (kind) { case vector::CombiningKind::ADD: - result = lowerReductionWithStartValue< + result = lowerPredicatedReductionWithStartValue< LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp, ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::MUL: - result = lowerReductionWithStartValue< + result = lowerPredicatedReductionWithStartValue< LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp, ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::MINUI: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::MINSI: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::MAXUI: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::MAXSI: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::AND: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::OR: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::XOR: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::MINF: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case vector::CombiningKind::MAXF: - result = lowerReductionWithStartValue( + result = lowerPredicatedReductionWithStartValue( rewriter, loc, llvmType, operand, acc, maskOp.getMask()); break; case CombiningKind::MAXIMUMF: result = lowerMaskedReductionWithRegular( - rewriter, loc, llvmType, operand, acc, maskOp.getMask()); + rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); break; case CombiningKind::MINIMUMF: result = lowerMaskedReductionWithRegular( - rewriter, loc, llvmType, operand, acc, maskOp.getMask()); + rewriter, loc, llvmType, operand, acc, maskOp.getMask(), fmf); break; } diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt index 596f6422807cc..9ec919423b342 100644 --- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt @@ -8,7 +8,7 @@ add_mlir_dialect_library(MLIRVectorDialect MLIRMaskableOpInterfaceIncGen MLIRMaskingOpInterfaceIncGen MLIRVectorOpsIncGen - MLIRVectorOpsEnumsIncGen + MLIRVectorAttributesIncGen LINK_LIBS PUBLIC MLIRArithDialect diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index a8ad05f7bc1ca..54a3de6608505 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -42,9 +42,9 @@ #include #include -#include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc" +#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc" // Pull in all enum type and utility function definitions. -#include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc" +#include "mlir/Dialect/Vector/IR/VectorEnums.cpp.inc" using namespace mlir; using namespace mlir::vector; @@ -256,7 +256,7 @@ struct BitmaskEnumStorage : public AttributeStorage { void VectorDialect::initialize() { addAttributes< #define GET_ATTRDEF_LIST -#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc" +#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc" >(); addOperations< @@ -415,15 +415,17 @@ void MultiDimReductionOp::getCanonicalizationPatterns( //===----------------------------------------------------------------------===// void vector::ReductionOp::build(OpBuilder &builder, OperationState &result, - CombiningKind kind, Value vector) { - build(builder, result, kind, vector, /*acc=*/Value()); + CombiningKind kind, Value vector, + arith::FastMathFlags fastMathFlags) { + build(builder, result, kind, vector, /*acc=*/Value(), fastMathFlags); } void vector::ReductionOp::build(OpBuilder &builder, OperationState &result, - CombiningKind kind, Value vector, Value acc) { + CombiningKind kind, Value vector, Value acc, + arith::FastMathFlags fastMathFlags) { build(builder, result, llvm::cast(vector.getType()).getElementType(), kind, vector, - acc); + acc, fastMathFlags); } LogicalResult ReductionOp::verify() { @@ -447,9 +449,13 @@ ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) { Type redType; Type resType; CombiningKindAttr kindAttr; + arith::FastMathFlagsAttr fastMathAttr; if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind", result.attributes) || parser.parseComma() || parser.parseOperandList(operandsInfo) || + (succeeded(parser.parseOptionalKeyword("fastmath")) && + parser.parseCustomAttributeWithFallback(fastMathAttr, Type{}, "fastmath", + result.attributes)) || parser.parseColonType(redType) || parser.parseKeywordType("into", resType) || (!operandsInfo.empty() && @@ -470,6 +476,12 @@ void ReductionOp::print(OpAsmPrinter &p) { p << ", " << getVector(); if (getAcc()) p << ", " << getAcc(); + + if (getFastmathAttr() && + getFastmathAttr().getValue() != arith::FastMathFlags::none) { + p << ' ' << getFastmathAttrName().getValue(); + p.printStrippedAttrOrType(getFastmathAttr()); + } p << " : " << getVector().getType() << " into " << getDest().getType(); } @@ -6052,7 +6064,7 @@ Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask, //===----------------------------------------------------------------------===// #define GET_ATTRDEF_CLASSES -#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc" +#include "mlir/Dialect/Vector/IR/VectorAttributes.cpp.inc" #define GET_OP_CLASSES #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc" diff --git a/mlir/python/mlir/dialects/VectorOps.td b/mlir/python/mlir/dialects/VectorOps.td index 69a1028c9be61..f659f754b66a7 100644 --- a/mlir/python/mlir/dialects/VectorOps.td +++ b/mlir/python/mlir/dialects/VectorOps.td @@ -1,4 +1,4 @@ -//===-- VectorOps.td - Entry point for VectorOps bind ------*- tablegen -*-===// +//===-- Vector.td - Entry point for Vector bindings --------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,9 +6,9 @@ // //===----------------------------------------------------------------------===// -#ifndef PYTHON_BINDINGS_VECTOR_OPS -#define PYTHON_BINDINGS_VECTOR_OPS +#ifndef PYTHON_BINDINGS_VECTOR +#define PYTHON_BINDINGS_VECTOR -include "mlir/Dialect/Vector/IR/VectorOps.td" +include "mlir/Dialect/Vector/IR/Vector.td" -#endif +#endif // PYTHON_BINDINGS_VECTOR From efe2f1faab7a3925bf9c1d38394e812fb9f8c256 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 20 Sep 2023 16:35:49 +0200 Subject: [PATCH 2/3] Address comment and fix tests --- .../mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td | 25 ++++---------- .../vector-reduction-to-llvm.mlir | 33 ++++++++++++++++--- .../VectorToLLVM/vector-to-llvm.mlir | 8 ++--- mlir/test/Dialect/Vector/ops.mlir | 7 ++++ mlir/test/Target/LLVMIR/Import/intrinsic.ll | 8 ++--- .../test/Target/LLVMIR/llvmir-intrinsics.mlir | 4 +-- 6 files changed, 51 insertions(+), 34 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td index 5af84c9e8646f..040f9895ad0db 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -654,27 +654,14 @@ class LLVM_VecReductionI // LLVM vector reduction over a single vector, with an initial value, // and with permission to reassociate the reduction operations. class LLVM_VecReductionAccBase - : LLVM_OneResultIntrOp<"vector.reduce." # mnem, [], [0], - [Pure, SameOperandsAndResultElementType]>, + : LLVM_OneResultIntrOp, Arguments<(ins element:$start_value, LLVM_VectorOf:$input, - DefaultValuedAttr:$fastmathFlags)> { - let llvmBuilder = [{ - llvm::Module *module = builder.GetInsertBlock()->getModule(); - llvm::Function *fn = llvm::Intrinsic::getDeclaration( - module, - llvm::Intrinsic::vector_reduce_}] # mnem # [{, - { }] # !interleave(ListIntSubst.lst, - ", ") # [{ - }); - auto operands = moduleTranslation.lookupValues(opInst.getOperands()); - $res = builder.CreateCall(fn, operands); - }]; - let mlirBuilder = [{ - $res = $_builder.create<$_qualCppClassName>($_location, - $_resultType, $start_value, $input, inst->getFastMathFlags()); - }]; -} + DefaultValuedAttr:$fastmathFlags)>; class LLVM_VecReductionAccF : LLVM_VecReductionAccBase; diff --git a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir index fd2d6ae5a472f..13b7faed4790d 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir @@ -5,14 +5,14 @@ // CHECK-SAME: %[[A:.*]]: vector<16xf32>) // CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 // CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]]) -// CHECK-SAME: <{reassoc = false}> : (f32, vector<16xf32>) -> f32 +// CHECK-SAME: <{fastmathFlags = #llvm.fastmath}> : (f32, vector<16xf32>) -> f32 // CHECK: return %[[V]] : f32 // // REASSOC-LABEL: @reduce_add_f32( // REASSOC-SAME: %[[A:.*]]: vector<16xf32>) // REASSOC: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 // REASSOC: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]]) -// REASSOC-SAME: <{reassoc = true}> : (f32, vector<16xf32>) -> f32 +// REASSOC-SAME: <{fastmathFlags = #llvm.fastmath}> : (f32, vector<16xf32>) -> f32 // REASSOC: return %[[V]] : f32 // func.func @reduce_add_f32(%arg0: vector<16xf32>) -> f32 { @@ -22,22 +22,45 @@ func.func @reduce_add_f32(%arg0: vector<16xf32>) -> f32 { // ----- +// CHECK-LABEL: @reduce_add_f32_always_reassoc( +// CHECK-SAME: %[[A:.*]]: vector<16xf32>) +// CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 +// CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]]) +/// Note: the reassoc flag remains even though the pass sets reassociate-fp-reduction to false. +/// Ponder whether this flag really is a property of the pass / pattern.. +// CHECK-SAME: <{fastmathFlags = #llvm.fastmath}> : (f32, vector<16xf32>) -> f32 +// CHECK: return %[[V]] : f32 +// +// REASSOC-LABEL: @reduce_add_f32_always_reassoc( +// REASSOC-SAME: %[[A:.*]]: vector<16xf32>) +// REASSOC: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 +// REASSOC: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]]) +// REASSOC-SAME: <{fastmathFlags = #llvm.fastmath}> : (f32, vector<16xf32>) -> f32 +// REASSOC: return %[[V]] : f32 +// +func.func @reduce_add_f32_always_reassoc(%arg0: vector<16xf32>) -> f32 { + %0 = vector.reduction , %arg0 fastmath : vector<16xf32> into f32 + return %0 : f32 +} + +// ----- + // CHECK-LABEL: @reduce_mul_f32( // CHECK-SAME: %[[A:.*]]: vector<16xf32>) // CHECK: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 // CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fmul"(%[[C]], %[[A]]) -// CHECK-SAME: <{reassoc = false}> : (f32, vector<16xf32>) -> f32 +// CHECK-SAME: <{fastmathFlags = #llvm.fastmath}> : (f32, vector<16xf32>) -> f32 // CHECK: return %[[V]] : f32 // // REASSOC-LABEL: @reduce_mul_f32( // REASSOC-SAME: %[[A:.*]]: vector<16xf32>) // REASSOC: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 // REASSOC: %[[V:.*]] = "llvm.intr.vector.reduce.fmul"(%[[C]], %[[A]]) -// REASSOC-SAME: <{reassoc = true}> : (f32, vector<16xf32>) -> f32 +// REASSOC-SAME: <{fastmathFlags = #llvm.fastmath}> : (f32, vector<16xf32>) -> f32 // REASSOC: return %[[V]] : f32 // func.func @reduce_mul_f32(%arg0: vector<16xf32>) -> f32 { - %0 = vector.reduction , %arg0 : vector<16xf32> into f32 + %0 = vector.reduction , %arg0 fastmath : vector<16xf32> into f32 return %0 : f32 } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 7b29ef44c1f2f..53b87cfcce42a 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -1216,7 +1216,7 @@ func.func @reduce_0d_f32(%arg0: vector) -> f32 { // CHECK: %[[CA:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector to vector<1xf32> // CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 // CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[CA]]) -// CHECK-SAME: <{reassoc = false}> : (f32, vector<1xf32>) -> f32 +// CHECK-SAME: <{fastmathFlags = #llvm.fastmath}> : (f32, vector<1xf32>) -> f32 // CHECK: return %[[V]] : f32 // ----- @@ -1229,7 +1229,7 @@ func.func @reduce_f16(%arg0: vector<16xf16>) -> f16 { // CHECK-SAME: %[[A:.*]]: vector<16xf16>) // CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f16) : f16 // CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]]) -// CHECK-SAME: <{reassoc = false}> : (f16, vector<16xf16>) -> f16 +// CHECK-SAME: <{fastmathFlags = #llvm.fastmath}> : (f16, vector<16xf16>) -> f16 // CHECK: return %[[V]] : f16 // ----- @@ -1242,7 +1242,7 @@ func.func @reduce_f32(%arg0: vector<16xf32>) -> f32 { // CHECK-SAME: %[[A:.*]]: vector<16xf32>) // CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 // CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]]) -// CHECK-SAME: <{reassoc = false}> : (f32, vector<16xf32>) -> f32 +// CHECK-SAME: <{fastmathFlags = #llvm.fastmath}> : (f32, vector<16xf32>) -> f32 // CHECK: return %[[V]] : f32 // ----- @@ -1255,7 +1255,7 @@ func.func @reduce_f64(%arg0: vector<16xf64>) -> f64 { // CHECK-SAME: %[[A:.*]]: vector<16xf64>) // CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f64) : f64 // CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.fadd"(%[[C]], %[[A]]) -// CHECK-SAME: <{reassoc = false}> : (f64, vector<16xf64>) -> f64 +// CHECK-SAME: <{fastmathFlags = #llvm.fastmath}> : (f64, vector<16xf64>) -> f64 // CHECK: return %[[V]] : f64 // ----- diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 4ea4379372e83..409951d31eedf 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -1007,3 +1007,10 @@ func.func @contraction_masked_scalable(%A: vector<3x4xf32>, : vector<3x[8]x4xi1> -> vector<3x[8]xf32> return %0 : vector<3x[8]xf32> } + +// CHECK-LABEL: func.func @fastmath( +func.func @fastmath(%x: vector<42xf32>) -> f32 { + // CHECK: vector.reduction , %{{.*}} fastmath + %min = vector.reduction , %x fastmath : vector<42xf32> into f32 + return %min: f32 +} diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll index 56c3cfbb5c7c2..8ce16fe5705cb 100644 --- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll +++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll @@ -354,13 +354,13 @@ define void @vector_reductions(float %0, <8 x float> %1, <8 x i32> %2) { %12 = call i32 @llvm.vector.reduce.umax.v8i32(<8 x i32> %2) ; CHECK: "llvm.intr.vector.reduce.umin"(%{{.*}}) : (vector<8xi32>) -> i32 %13 = call i32 @llvm.vector.reduce.umin.v8i32(<8 x i32> %2) - ; CHECK: "llvm.intr.vector.reduce.fadd"(%{{.*}}, %{{.*}}) <{reassoc = false}> : (f32, vector<8xf32>) -> f32 + ; CHECK: "llvm.intr.vector.reduce.fadd"(%{{.*}}, %{{.*}}) <{fastmathFlags = #llvm.fastmath}> : (f32, vector<8xf32>) -> f32 %14 = call float @llvm.vector.reduce.fadd.v8f32(float %0, <8 x float> %1) - ; CHECK: "llvm.intr.vector.reduce.fmul"(%{{.*}}, %{{.*}}) <{reassoc = false}> : (f32, vector<8xf32>) -> f32 + ; CHECK: "llvm.intr.vector.reduce.fmul"(%{{.*}}, %{{.*}}) <{fastmathFlags = #llvm.fastmath}> : (f32, vector<8xf32>) -> f32 %15 = call float @llvm.vector.reduce.fmul.v8f32(float %0, <8 x float> %1) - ; CHECK: "llvm.intr.vector.reduce.fadd"(%{{.*}}, %{{.*}}) <{reassoc = true}> : (f32, vector<8xf32>) -> f32 + ; CHECK: "llvm.intr.vector.reduce.fadd"(%{{.*}}, %{{.*}}) <{fastmathFlags = #llvm.fastmath}> : (f32, vector<8xf32>) -> f32 %16 = call reassoc float @llvm.vector.reduce.fadd.v8f32(float %0, <8 x float> %1) - ; CHECK: "llvm.intr.vector.reduce.fmul"(%{{.*}}, %{{.*}}) <{reassoc = true}> : (f32, vector<8xf32>) -> f32 + ; CHECK: "llvm.intr.vector.reduce.fmul"(%{{.*}}, %{{.*}}) <{fastmathFlags = #llvm.fastmath}> : (f32, vector<8xf32>) -> f32 %17 = call reassoc float @llvm.vector.reduce.fmul.v8f32(float %0, <8 x float> %1) ; CHECK: "llvm.intr.vector.reduce.xor"(%{{.*}}) : (vector<8xi32>) -> i32 %18 = call i32 @llvm.vector.reduce.xor.v8i32(<8 x i32> %2) diff --git a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir index 427c09976ef14..6bbd761b6e613 100644 --- a/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir @@ -375,9 +375,9 @@ llvm.func @vector_reductions(%arg0: f32, %arg1: vector<8xf32>, %arg2: vector<8xi // CHECK: call float @llvm.vector.reduce.fmul.v8f32 "llvm.intr.vector.reduce.fmul"(%arg0, %arg1) : (f32, vector<8xf32>) -> f32 // CHECK: call reassoc float @llvm.vector.reduce.fadd.v8f32 - "llvm.intr.vector.reduce.fadd"(%arg0, %arg1) {reassoc = true} : (f32, vector<8xf32>) -> f32 + "llvm.intr.vector.reduce.fadd"(%arg0, %arg1) <{fastmathFlags = #llvm.fastmath}> : (f32, vector<8xf32>) -> f32 // CHECK: call reassoc float @llvm.vector.reduce.fmul.v8f32 - "llvm.intr.vector.reduce.fmul"(%arg0, %arg1) {reassoc = true} : (f32, vector<8xf32>) -> f32 + "llvm.intr.vector.reduce.fmul"(%arg0, %arg1) <{fastmathFlags = #llvm.fastmath}> : (f32, vector<8xf32>) -> f32 // CHECK: call i32 @llvm.vector.reduce.xor.v8i32 "llvm.intr.vector.reduce.xor"(%arg2) : (vector<8xi32>) -> i32 llvm.return From 5c890ce6e81949288c30bf4e7d43a9db5f1fab02 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 20 Sep 2023 16:53:30 +0200 Subject: [PATCH 3/3] Address comments and fix python --- .../mlir/Dialect/Vector/IR/VectorAttributes.td | 2 +- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 4 ++-- mlir/python/mlir/dialects/Vector.td | 14 ++++++++++++++ mlir/python/mlir/dialects/VectorOps.td | 10 +++++----- 4 files changed, 22 insertions(+), 8 deletions(-) create mode 100644 mlir/python/mlir/dialects/Vector.td diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td index 2db944b4ceaf1..f8f85b0d09d90 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td @@ -13,7 +13,7 @@ #ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES #define MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES -include "Vector.td" +include "mlir/Dialect/Vector/IR/Vector.td" include "mlir/IR/EnumAttr.td" // The "kind" of combining function for contractions and reductions. diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index ab77c31b418e5..4aff10c14fb81 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -13,8 +13,8 @@ #ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_OPS #define MLIR_DIALECT_VECTOR_IR_VECTOR_OPS -include "Vector.td" -include "VectorAttributes.td" +include "mlir/Dialect/Vector/IR/Vector.td" +include "mlir/Dialect/Vector/IR/VectorAttributes.td" include "mlir/Dialect/Arith/IR/ArithBase.td" include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td" include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td" diff --git a/mlir/python/mlir/dialects/Vector.td b/mlir/python/mlir/dialects/Vector.td new file mode 100644 index 0000000000000..f659f754b66a7 --- /dev/null +++ b/mlir/python/mlir/dialects/Vector.td @@ -0,0 +1,14 @@ +//===-- Vector.td - Entry point for Vector bindings --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_VECTOR +#define PYTHON_BINDINGS_VECTOR + +include "mlir/Dialect/Vector/IR/Vector.td" + +#endif // PYTHON_BINDINGS_VECTOR diff --git a/mlir/python/mlir/dialects/VectorOps.td b/mlir/python/mlir/dialects/VectorOps.td index f659f754b66a7..69a1028c9be61 100644 --- a/mlir/python/mlir/dialects/VectorOps.td +++ b/mlir/python/mlir/dialects/VectorOps.td @@ -1,4 +1,4 @@ -//===-- Vector.td - Entry point for Vector bindings --------*- tablegen -*-===// +//===-- VectorOps.td - Entry point for VectorOps bind ------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,9 +6,9 @@ // //===----------------------------------------------------------------------===// -#ifndef PYTHON_BINDINGS_VECTOR -#define PYTHON_BINDINGS_VECTOR +#ifndef PYTHON_BINDINGS_VECTOR_OPS +#define PYTHON_BINDINGS_VECTOR_OPS -include "mlir/Dialect/Vector/IR/Vector.td" +include "mlir/Dialect/Vector/IR/VectorOps.td" -#endif // PYTHON_BINDINGS_VECTOR +#endif