Skip to content

Commit

Permalink
[mlir][Vector] Add fastmath flags to vector.reduction (llvm#66905)
Browse files Browse the repository at this point in the history
This revision pipes the fastmath attribute support through the
vector.reduction op. This seemingly simple first step already requires
quite some genuflexions, file and builder reorganization. In the
process, retire the boolean reassoc flag deep in the LLVM dialect
builders and just use the fastmath attribute.

During conversions, templated builders for predicated intrinsics are
partially cleaned up. In the future, to finalize the cleanups, one
should consider adding fastmath to the VPIntrinsic ops.
  • Loading branch information
nicolasvasilache authored Sep 20, 2023
1 parent ebefe83 commit 1b8b556
Show file tree
Hide file tree
Showing 15 changed files with 322 additions and 226 deletions.
34 changes: 8 additions & 26 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -654,32 +654,14 @@ class LLVM_VecReductionI<string mnem>
// LLVM vector reduction over a single vector, with an initial value,
// and with permission to reassociate the reduction operations.
class LLVM_VecReductionAccBase<string mnem, Type element>
: LLVM_OneResultIntrOp<"vector.reduce." # mnem, [], [0],
[Pure, SameOperandsAndResultElementType]>,
Arguments<(ins element:$start_value, LLVM_VectorOf<element>:$input,
DefaultValuedAttr<BoolAttr, "false">:$reassoc)> {
let llvmBuilder = [{
llvm::Module *module = builder.GetInsertBlock()->getModule();
llvm::Function *fn = llvm::Intrinsic::getDeclaration(
module,
llvm::Intrinsic::vector_reduce_}] # mnem # [{,
{ }] # !interleave(ListIntSubst<LLVM_IntrPatterns.operand, [1]>.lst,
", ") # [{
});
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
llvm::FastMathFlags origFM = builder.getFastMathFlags();
llvm::FastMathFlags tempFM = origFM;
tempFM.setAllowReassoc($reassoc);
builder.setFastMathFlags(tempFM); // set fastmath flag
$res = builder.CreateCall(fn, operands);
builder.setFastMathFlags(origFM); // restore fastmath flag
}];
let mlirBuilder = [{
bool allowReassoc = inst->getFastMathFlags().allowReassoc();
$res = $_builder.create<$_qualCppClassName>($_location,
$_resultType, $start_value, $input, allowReassoc);
}];
}
: LLVM_OneResultIntrOp</*mnem=*/"vector.reduce." # mnem,
/*overloadedResults=*/[],
/*overloadedOperands=*/[1],
/*traits=*/[Pure, SameOperandsAndResultElementType],
/*equiresFastmath=*/1>,
Arguments<(ins element:$start_value,
LLVM_VectorOf<element>:$input,
DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags)>;

class LLVM_VecReductionAccF<string mnem>
: LLVM_VecReductionAccBase<mnem, AnyFloat>;
Expand Down
24 changes: 16 additions & 8 deletions mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
add_mlir_dialect(VectorOps vector)
add_mlir_doc(VectorOps VectorOps Dialects/ -gen-op-doc)
add_mlir_dialect(Vector vector)
add_mlir_doc(Vector Vector Dialects/ -gen-op-doc -dialect=vector)

# Add Vector operations
set(LLVM_TARGET_DEFINITIONS VectorOps.td)
mlir_tablegen(VectorOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(VectorOpsEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(VectorOpsAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(VectorOpsAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRVectorOpsEnumsIncGen)
add_dependencies(mlir-headers MLIRVectorOpsEnumsIncGen)
mlir_tablegen(VectorOps.h.inc -gen-op-decls)
mlir_tablegen(VectorOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRVectorOpsIncGen)
add_dependencies(mlir-generic-headers MLIRVectorOpsIncGen)

# Add Vector attributes
set(LLVM_TARGET_DEFINITIONS VectorAttributes.td)
mlir_tablegen(VectorEnums.h.inc -gen-enum-decls)
mlir_tablegen(VectorEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(VectorAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(VectorAttributes.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRVectorAttributesIncGen)
add_dependencies(mlir-generic-headers MLIRVectorAttributesIncGen)
31 changes: 31 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/Vector.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===- Vector.td - Vector Dialect --------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the Vector dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR
#define MLIR_DIALECT_VECTOR_IR_VECTOR

include "mlir/IR/OpBase.td"

def Vector_Dialect : Dialect {
let name = "vector";
let cppNamespace = "::mlir::vector";

let useDefaultAttributePrinterParser = 1;
let hasConstantMaterializer = 1;
let dependentDialects = ["arith::ArithDialect"];
}

// Base class for Vector dialect ops.
class Vector_Op<string mnemonic, list<Trait> traits = []> :
Op<Vector_Dialect, mnemonic, traits>;

#endif // MLIR_DIALECT_VECTOR_IR_VECTOR
85 changes: 85 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
//===- VectorAttributes.td - Vector Dialect ----------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the attributes used in the Vector dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES
#define MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES

include "mlir/Dialect/Vector/IR/Vector.td"
include "mlir/IR/EnumAttr.td"

// The "kind" of combining function for contractions and reductions.
def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
def COMBINING_KIND_MINUI : I32BitEnumAttrCaseBit<"MINUI", 2, "minui">;
def COMBINING_KIND_MINSI : I32BitEnumAttrCaseBit<"MINSI", 3, "minsi">;
def COMBINING_KIND_MINF : I32BitEnumAttrCaseBit<"MINF", 4, "minf">;
def COMBINING_KIND_MAXUI : I32BitEnumAttrCaseBit<"MAXUI", 5, "maxui">;
def COMBINING_KIND_MAXSI : I32BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">;
def COMBINING_KIND_MAXF : I32BitEnumAttrCaseBit<"MAXF", 7, "maxf">;
def COMBINING_KIND_AND : I32BitEnumAttrCaseBit<"AND", 8, "and">;
def COMBINING_KIND_OR : I32BitEnumAttrCaseBit<"OR", 9, "or">;
def COMBINING_KIND_XOR : I32BitEnumAttrCaseBit<"XOR", 10, "xor">;
def COMBINING_KIND_MINIMUMF : I32BitEnumAttrCaseBit<"MINIMUMF", 11, "minimumf">;
def COMBINING_KIND_MAXIMUMF : I32BitEnumAttrCaseBit<"MAXIMUMF", 12, "maximumf">;

def CombiningKind : I32BitEnumAttr<
"CombiningKind",
"Kind of combining function for contractions and reductions",
[COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI,
COMBINING_KIND_MINSI, COMBINING_KIND_MINF, COMBINING_KIND_MAXUI,
COMBINING_KIND_MAXSI, COMBINING_KIND_MAXF, COMBINING_KIND_AND,
COMBINING_KIND_OR, COMBINING_KIND_XOR,
COMBINING_KIND_MAXIMUMF, COMBINING_KIND_MINIMUMF]> {
let cppNamespace = "::mlir::vector";
let genSpecializedAttr = 0;
}

/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
let assemblyFormat = "`<` $value `>`";
}

def Vector_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
I32EnumAttrCase<"parallel", 0>,
I32EnumAttrCase<"reduction", 1>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::vector";
}

def Vector_IteratorTypeEnum
: EnumAttr<Vector_Dialect, Vector_IteratorType, "iterator_type"> {
let assemblyFormat = "`<` $value `>`";
}

def Vector_IteratorTypeArrayAttr
: TypedArrayAttrBase<Vector_IteratorTypeEnum,
"Iterator type should be an enum.">;

def PrintPunctuation : I32EnumAttr<"PrintPunctuation",
"Punctuation for separating vectors or vector elements", [
I32EnumAttrCase<"NoPunctuation", 0, "no_punctuation">,
I32EnumAttrCase<"NewLine", 1, "newline">,
I32EnumAttrCase<"Comma", 2, "comma">,
I32EnumAttrCase<"Open", 3, "open">,
I32EnumAttrCase<"Close", 4, "close">
]> {
let cppNamespace = "::mlir::vector";
let genSpecializedAttr = 0;
}

def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctuation"> {
let assemblyFormat = "`<` $value `>`";
}

#endif // MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES
7 changes: 4 additions & 3 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h"
#include "mlir/IR/AffineMap.h"
Expand All @@ -31,10 +32,10 @@
#include "llvm/ADT/StringExtras.h"

// Pull in all enum type definitions and utility function declarations.
#include "mlir/Dialect/Vector/IR/VectorOpsEnums.h.inc"
#include "mlir/Dialect/Vector/IR/VectorEnums.h.inc"

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.h.inc"
#include "mlir/Dialect/Vector/IR/VectorAttributes.h.inc"

namespace mlir {
class MLIRContext;
Expand Down Expand Up @@ -157,7 +158,7 @@ Value selectPassthru(OpBuilder &builder, Value mask, Value newValue,
} // namespace mlir

#define GET_OP_CLASSES
#include "mlir/Dialect/Vector/IR/VectorDialect.h.inc"
#include "mlir/Dialect/Vector/IR/VectorOps.h.inc"
#include "mlir/Dialect/Vector/IR/VectorOpsDialect.h.inc"

#endif // MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
107 changes: 20 additions & 87 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
//
//===----------------------------------------------------------------------===//

#ifndef VECTOR_OPS
#define VECTOR_OPS
#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
#define MLIR_DIALECT_VECTOR_IR_VECTOR_OPS

include "mlir/Dialect/Vector/IR/Vector.td"
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
include "mlir/IR/EnumAttr.td"
Expand All @@ -23,69 +27,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"

def Vector_Dialect : Dialect {
let name = "vector";
let cppNamespace = "::mlir::vector";

let useDefaultAttributePrinterParser = 1;
let hasConstantMaterializer = 1;
let dependentDialects = ["arith::ArithDialect"];
}

// Base class for Vector dialect ops.
class Vector_Op<string mnemonic, list<Trait> traits = []> :
Op<Vector_Dialect, mnemonic, traits>;

// The "kind" of combining function for contractions and reductions.
def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
def COMBINING_KIND_MINUI : I32BitEnumAttrCaseBit<"MINUI", 2, "minui">;
def COMBINING_KIND_MINSI : I32BitEnumAttrCaseBit<"MINSI", 3, "minsi">;
def COMBINING_KIND_MINF : I32BitEnumAttrCaseBit<"MINF", 4, "minf">;
def COMBINING_KIND_MAXUI : I32BitEnumAttrCaseBit<"MAXUI", 5, "maxui">;
def COMBINING_KIND_MAXSI : I32BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">;
def COMBINING_KIND_MAXF : I32BitEnumAttrCaseBit<"MAXF", 7, "maxf">;
def COMBINING_KIND_AND : I32BitEnumAttrCaseBit<"AND", 8, "and">;
def COMBINING_KIND_OR : I32BitEnumAttrCaseBit<"OR", 9, "or">;
def COMBINING_KIND_XOR : I32BitEnumAttrCaseBit<"XOR", 10, "xor">;
def COMBINING_KIND_MINIMUMF : I32BitEnumAttrCaseBit<"MINIMUMF", 11, "minimumf">;
def COMBINING_KIND_MAXIMUMF : I32BitEnumAttrCaseBit<"MAXIMUMF", 12, "maximumf">;

def CombiningKind : I32BitEnumAttr<
"CombiningKind",
"Kind of combining function for contractions and reductions",
[COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI,
COMBINING_KIND_MINSI, COMBINING_KIND_MINF, COMBINING_KIND_MAXUI,
COMBINING_KIND_MAXSI, COMBINING_KIND_MAXF, COMBINING_KIND_AND,
COMBINING_KIND_OR, COMBINING_KIND_XOR,
COMBINING_KIND_MAXIMUMF, COMBINING_KIND_MINIMUMF]> {
let cppNamespace = "::mlir::vector";
let genSpecializedAttr = 0;
}

/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
let assemblyFormat = "`<` $value `>`";
}

def Vector_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
I32EnumAttrCase<"parallel", 0>,
I32EnumAttrCase<"reduction", 1>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::vector";
}

def Vector_IteratorTypeEnum
: EnumAttr<Vector_Dialect, Vector_IteratorType, "iterator_type"> {
let assemblyFormat = "`<` $value `>`";
}

def Vector_IteratorTypeArrayAttr
: TypedArrayAttrBase<Vector_IteratorTypeEnum,
"Iterator type should be an enum.">;

// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
def Vector_ContractionOp :
Expand Down Expand Up @@ -274,12 +215,16 @@ def Vector_ReductionOp :
Vector_Op<"reduction", [Pure,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface,
["getShapeForUnroll"]>]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins Vector_CombiningKindAttr:$kind,
AnyVectorOfAnyRank:$vector,
Optional<AnyType>:$acc)>,
Optional<AnyType>:$acc,
DefaultValuedAttr<
Arith_FastMathAttr,
"::mlir::arith::FastMathFlags::none">:$fastmath)>,
Results<(outs AnyType:$dest)> {
let summary = "reduction operation";
let description = [{
Expand Down Expand Up @@ -309,9 +254,13 @@ def Vector_ReductionOp :
}];
let builders = [
// Builder that infers the type of `dest`.
OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, "Value":$acc)>,
OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, "Value":$acc,
CArg<"::mlir::arith::FastMathFlags",
"::mlir::arith::FastMathFlags::none">:$fastMathFlags)>,
// Builder that infers the type of `dest` and has no accumulator.
OpBuilder<(ins "CombiningKind":$kind, "Value":$vector)>
OpBuilder<(ins "CombiningKind":$kind, "Value":$vector,
CArg<"::mlir::arith::FastMathFlags",
"::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
];

// TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional
Expand Down Expand Up @@ -2469,22 +2418,6 @@ def Vector_TransposeOp :
let hasVerifier = 1;
}

def PrintPunctuation : I32EnumAttr<"PrintPunctuation",
"Punctuation for separating vectors or vector elements", [
I32EnumAttrCase<"NoPunctuation", 0, "no_punctuation">,
I32EnumAttrCase<"NewLine", 1, "newline">,
I32EnumAttrCase<"Comma", 2, "comma">,
I32EnumAttrCase<"Open", 3, "open">,
I32EnumAttrCase<"Close", 4, "close">
]> {
let cppNamespace = "::mlir::vector";
let genSpecializedAttr = 0;
}

def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctuation"> {
let assemblyFormat = "`<` $value `>`";
}

def Vector_PrintOp :
Vector_Op<"print", []>,
Arguments<(ins Optional<Type<Or<[
Expand Down Expand Up @@ -2939,4 +2872,4 @@ def Vector_WarpExecuteOnLane0Op : Vector_Op<"warp_execute_on_lane_0",
}];
}

#endif // VECTOR_OPS
#endif // MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
Loading

0 comments on commit 1b8b556

Please sign in to comment.