diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h index ed5e5ca22c5958..e8011b5488dc98 100644 --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h @@ -11,6 +11,15 @@ #include "mlir/IR/Types.h" +namespace llvm { +struct fltSemantics; +} // namespace llvm + +namespace mlir { +class FloatType; +class MLIRContext; +} // namespace mlir + #include "mlir/IR/BuiltinTypeInterfaces.h.inc" #endif // MLIR_IR_BUILTINTYPEINTERFACES_H diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td index c9dcd546cf67c2..c36b738e38f42a 100644 --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -16,6 +16,65 @@ include "mlir/IR/OpBase.td" +def FloatTypeInterface : TypeInterface<"FloatType"> { + let cppNamespace = "::mlir"; + let description = [{ + This type interface should be implemented by all floating-point types. It + defines the LLVM APFloat semantics and provides a few helper functions. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Returns the APFloat semantics for this floating-point type. + }], + /*retTy=*/"const ::llvm::fltSemantics &", + /*methodName=*/"getFloatSemantics", + /*args=*/(ins) + >, + InterfaceMethod< + /*desc=*/[{ + Returns a float type with bitwidth scaled by `scale`. Returns a "null" + float type if the scaled element type cannot be represented. + }], + /*retTy=*/"::mlir::FloatType", + /*methodName=*/"scaleElementBitwidth", + /*args=*/(ins "unsigned":$scale), + /*methodBody=*/"", + /*defaultImplementation=*/"return ::mlir::FloatType();" + > + ]; + + let extraClassDeclaration = [{ + // Convenience factories. + static FloatType getBF16(MLIRContext *ctx); + static FloatType getF16(MLIRContext *ctx); + static FloatType getF32(MLIRContext *ctx); + static FloatType getTF32(MLIRContext *ctx); + static FloatType getF64(MLIRContext *ctx); + static FloatType getF80(MLIRContext *ctx); + static FloatType getF128(MLIRContext *ctx); + static FloatType getFloat8E5M2(MLIRContext *ctx); + static FloatType getFloat8E4M3(MLIRContext *ctx); + static FloatType getFloat8E4M3FN(MLIRContext *ctx); + static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx); + static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx); + static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx); + static FloatType getFloat8E3M4(MLIRContext *ctx); + static FloatType getFloat4E2M1FN(MLIRContext *ctx); + static FloatType getFloat6E2M3FN(MLIRContext *ctx); + static FloatType getFloat6E3M2FN(MLIRContext *ctx); + static FloatType getFloat8E8M0FNU(MLIRContext *ctx); + + /// Return the bitwidth of this float type. + unsigned getWidth(); + + /// Return the width of the mantissa of this type. + /// The width includes the integer bit. + unsigned getFPMantissaWidth(); + }]; +} + //===----------------------------------------------------------------------===// // MemRefElementTypeInterface //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h index 7f9c470ffec304..2b3c2b6d1753dc 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -25,7 +25,6 @@ struct fltSemantics; namespace mlir { class AffineExpr; class AffineMap; -class FloatType; class IndexType; class IntegerType; class MemRefType; @@ -44,52 +43,6 @@ template class ValueSemantics : public TypeTrait::TraitBase {}; -//===----------------------------------------------------------------------===// -// FloatType -//===----------------------------------------------------------------------===// - -class FloatType : public Type { -public: - using Type::Type; - - // Convenience factories. - static FloatType getBF16(MLIRContext *ctx); - static FloatType getF16(MLIRContext *ctx); - static FloatType getF32(MLIRContext *ctx); - static FloatType getTF32(MLIRContext *ctx); - static FloatType getF64(MLIRContext *ctx); - static FloatType getF80(MLIRContext *ctx); - static FloatType getF128(MLIRContext *ctx); - static FloatType getFloat8E5M2(MLIRContext *ctx); - static FloatType getFloat8E4M3(MLIRContext *ctx); - static FloatType getFloat8E4M3FN(MLIRContext *ctx); - static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx); - static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx); - static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx); - static FloatType getFloat8E3M4(MLIRContext *ctx); - static FloatType getFloat4E2M1FN(MLIRContext *ctx); - static FloatType getFloat6E2M3FN(MLIRContext *ctx); - static FloatType getFloat6E3M2FN(MLIRContext *ctx); - static FloatType getFloat8E8M0FNU(MLIRContext *ctx); - - /// Methods for support type inquiry through isa, cast, and dyn_cast. - static bool classof(Type type); - - /// Return the bitwidth of this float type. - unsigned getWidth(); - - /// Return the width of the mantissa of this type. - /// The width includes the integer bit. - unsigned getFPMantissaWidth(); - - /// Get or create a new FloatType with bitwidth scaled by `scale`. - /// Return null if the scaled element type cannot be represented. - FloatType scaleElementBitwidth(unsigned scale); - - /// Return the floating semantics of this float type. - const llvm::fltSemantics &getFloatSemantics(); -}; - //===----------------------------------------------------------------------===// // TensorType //===----------------------------------------------------------------------===// @@ -448,15 +401,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) { llvm::isa(type); } -inline bool FloatType::classof(Type type) { - return llvm::isa(type); -} - inline FloatType FloatType::getFloat4E2M1FN(MLIRContext *ctx) { return Float4E2M1FNType::get(ctx); } diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index dca228097d782d..fc50b28c09e41c 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -79,8 +79,12 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> { //===----------------------------------------------------------------------===// // Base class for Builtin dialect float types. -class Builtin_FloatType - : Builtin_Type { +class Builtin_FloatType declaredInterfaceMethods = []> + : Builtin_Type]> { let extraClassDeclaration = [{ static }] # name # [{Type get(MLIRContext *context); }]; @@ -322,14 +326,16 @@ def Builtin_Float8E8M0FNU : Builtin_FloatType<"Float8E8M0FNU", "f8E8M0FNU"> { //===----------------------------------------------------------------------===// // BFloat16Type -def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16"> { +def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16", + /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> { let summary = "bfloat16 floating-point type"; } //===----------------------------------------------------------------------===// // Float16Type -def Builtin_Float16 : Builtin_FloatType<"Float16", "f16"> { +def Builtin_Float16 : Builtin_FloatType<"Float16", "f16", + /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> { let summary = "16-bit floating-point type"; } @@ -343,7 +349,8 @@ def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32", "tf32"> { //===----------------------------------------------------------------------===// // Float32Type -def Builtin_Float32 : Builtin_FloatType<"Float32", "f32"> { +def Builtin_Float32 : Builtin_FloatType<"Float32", "f32", + /*declaredInterfaceMethods=*/["scaleElementBitwidth"]> { let summary = "32-bit floating-point type"; } diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp index ab9e65b5edfed3..c663f6c9094604 100644 --- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp @@ -8,6 +8,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/Sequence.h" using namespace mlir; @@ -19,6 +20,18 @@ using namespace mlir::detail; #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc" +//===----------------------------------------------------------------------===// +// FloatType +//===----------------------------------------------------------------------===// + +unsigned FloatType::getWidth() { + return APFloat::semanticsSizeInBits(getFloatSemantics()); +} + +unsigned FloatType::getFPMantissaWidth() { + return APFloat::semanticsPrecision(getFloatSemantics()); +} + //===----------------------------------------------------------------------===// // ShapedType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 6546234429c8cb..41b794bc0aec59 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -87,72 +87,54 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { } //===----------------------------------------------------------------------===// -// Float Type -//===----------------------------------------------------------------------===// - -unsigned FloatType::getWidth() { - return APFloat::semanticsSizeInBits(getFloatSemantics()); -} - -/// Returns the floating semantics for the given type. -const llvm::fltSemantics &FloatType::getFloatSemantics() { - if (llvm::isa(*this)) - return APFloat::Float4E2M1FN(); - if (llvm::isa(*this)) - return APFloat::Float6E2M3FN(); - if (llvm::isa(*this)) - return APFloat::Float6E3M2FN(); - if (llvm::isa(*this)) - return APFloat::Float8E5M2(); - if (llvm::isa(*this)) - return APFloat::Float8E4M3(); - if (llvm::isa(*this)) - return APFloat::Float8E4M3FN(); - if (llvm::isa(*this)) - return APFloat::Float8E5M2FNUZ(); - if (llvm::isa(*this)) - return APFloat::Float8E4M3FNUZ(); - if (llvm::isa(*this)) - return APFloat::Float8E4M3B11FNUZ(); - if (llvm::isa(*this)) - return APFloat::Float8E3M4(); - if (llvm::isa(*this)) - return APFloat::Float8E8M0FNU(); - if (llvm::isa(*this)) - return APFloat::BFloat(); - if (llvm::isa(*this)) - return APFloat::IEEEhalf(); - if (llvm::isa(*this)) - return APFloat::FloatTF32(); - if (llvm::isa(*this)) - return APFloat::IEEEsingle(); - if (llvm::isa(*this)) - return APFloat::IEEEdouble(); - if (llvm::isa(*this)) - return APFloat::x87DoubleExtended(); - if (llvm::isa(*this)) - return APFloat::IEEEquad(); - llvm_unreachable("non-floating point type used"); -} - -FloatType FloatType::scaleElementBitwidth(unsigned scale) { - if (!scale) - return FloatType(); - MLIRContext *ctx = getContext(); - if (isF16() || isBF16()) { - if (scale == 2) - return FloatType::getF32(ctx); - if (scale == 4) - return FloatType::getF64(ctx); +// Float Types +//===----------------------------------------------------------------------===// + +// Mapping from MLIR FloatType to APFloat semantics. +#define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \ + const llvm::fltSemantics &TYPE::getFloatSemantics() const { \ + return APFloat::SEM(); \ } - if (isF32()) - if (scale == 2) - return FloatType::getF64(ctx); +FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN) +FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN) +FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN) +FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2) +FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3) +FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN) +FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ) +FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ) +FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ) +FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4) +FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU) +FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat) +FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf) +FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32) +FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle) +FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble) +FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended) +FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad) +#undef FLOAT_TYPE_SEMANTICS + +FloatType Float16Type::scaleElementBitwidth(unsigned scale) const { + if (scale == 2) + return FloatType::getF32(getContext()); + if (scale == 4) + return FloatType::getF64(getContext()); return FloatType(); } -unsigned FloatType::getFPMantissaWidth() { - return APFloat::semanticsPrecision(getFloatSemantics()); +FloatType BFloat16Type::scaleElementBitwidth(unsigned scale) const { + if (scale == 2) + return FloatType::getF32(getContext()); + if (scale == 4) + return FloatType::getF64(getContext()); + return FloatType(); +} + +FloatType Float32Type::scaleElementBitwidth(unsigned scale) const { + if (scale == 2) + return FloatType::getF64(getContext()); + return FloatType(); } //===----------------------------------------------------------------------===// diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp index b6066dd5685dc6..1b5d3b8c31bd22 100644 --- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp +++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp @@ -43,7 +43,7 @@ struct Model /// overrides default methods. struct OverridingModel : public TestExternalTypeInterface::ExternalModel { + Float32Type> { unsigned getBitwidthPlusArg(Type type, unsigned arg) const { return type.getIntOrFloatBitWidth() + arg; }