Skip to content

Commit

Permalink
[mlir][IR] Turn FloatType into a type interface (#118891)
Browse files Browse the repository at this point in the history
This makes it possible to add new MLIR floating point types in
downstream projects. (Adding new APFloat semantics in downstream
projects is not possible yet, so parsing/printing/converting float
literals of newly added types is not supported.)

Also removes two functions where we had to hard-code all existing
floating point types (`FloatType::classof`). See discussion here:
https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361

No measurable compilation time changes for these lit tests:
```
Benchmark 1: mlir-opt ./mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir -split-input-file -convert-vector-to-llvm -o /dev/null
  BEFORE
  Time (mean ± σ):     248.4 ms ±   3.2 ms    [User: 237.0 ms, System: 20.1 ms]
  Range (min … max):   243.3 ms … 255.9 ms    30 runs

  AFTER
  Time (mean ± σ):     246.8 ms ±   3.2 ms    [User: 233.2 ms, System: 21.8 ms]
  Range (min … max):   240.2 ms … 252.1 ms    30 runs


Benchmark 2: mlir-opt- ./mlir/test/Dialect/Arith/canonicalize.mlir -split-input-file -canonicalize -o /dev/null
  BEFORE
  Time (mean ± σ):      37.3 ms ±   1.8 ms    [User: 31.6 ms, System: 30.4 ms]
  Range (min … max):    34.6 ms …  42.0 ms    200 runs

  AFTER
  Time (mean ± σ):      37.5 ms ±   2.0 ms    [User: 31.5 ms, System: 29.2 ms]
  Range (min … max):    34.5 ms …  43.0 ms    200 runs


Benchmark 3: mlir-opt ./mlir/test/Dialect/Tensor/canonicalize.mlir -split-input-file -canonicalize -allow-unregistered-dialect -o /dev/null
  BEFORE
  Time (mean ± σ):     152.2 ms ±   2.5 ms    [User: 140.1 ms, System: 12.2 ms]
  Range (min … max):   147.6 ms … 161.8 ms    200 runs

  AFTER
  Time (mean ± σ):     151.9 ms ±   2.7 ms    [User: 140.5 ms, System: 11.5 ms]
  Range (min … max):   147.2 ms … 159.1 ms    200 runs
```

A micro benchmark that parses + prints 32768 floats with random
floating-point type shows a slowdown from 55.1 ms -> 48.3 ms.
  • Loading branch information
matthias-springer authored Jan 15, 2025
1 parent 4c2e4ea commit c24ce32
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 124 deletions.
9 changes: 9 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypeInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@

#include "mlir/IR/Types.h"

namespace llvm {
struct fltSemantics;
} // namespace llvm

namespace mlir {
class FloatType;
class MLIRContext;
} // namespace mlir

#include "mlir/IR/BuiltinTypeInterfaces.h.inc"

#endif // MLIR_IR_BUILTINTYPEINTERFACES_H
59 changes: 59 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypeInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,65 @@

include "mlir/IR/OpBase.td"

def FloatTypeInterface : TypeInterface<"FloatType"> {
let cppNamespace = "::mlir";
let description = [{
This type interface should be implemented by all floating-point types. It
defines the LLVM APFloat semantics and provides a few helper functions.
}];

let methods = [
InterfaceMethod<
/*desc=*/[{
Returns the APFloat semantics for this floating-point type.
}],
/*retTy=*/"const ::llvm::fltSemantics &",
/*methodName=*/"getFloatSemantics",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
Returns a float type with bitwidth scaled by `scale`. Returns a "null"
float type if the scaled element type cannot be represented.
}],
/*retTy=*/"::mlir::FloatType",
/*methodName=*/"scaleElementBitwidth",
/*args=*/(ins "unsigned":$scale),
/*methodBody=*/"",
/*defaultImplementation=*/"return ::mlir::FloatType();"
>
];

let extraClassDeclaration = [{
// Convenience factories.
static FloatType getBF16(MLIRContext *ctx);
static FloatType getF16(MLIRContext *ctx);
static FloatType getF32(MLIRContext *ctx);
static FloatType getTF32(MLIRContext *ctx);
static FloatType getF64(MLIRContext *ctx);
static FloatType getF80(MLIRContext *ctx);
static FloatType getF128(MLIRContext *ctx);
static FloatType getFloat8E5M2(MLIRContext *ctx);
static FloatType getFloat8E4M3(MLIRContext *ctx);
static FloatType getFloat8E4M3FN(MLIRContext *ctx);
static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
static FloatType getFloat8E3M4(MLIRContext *ctx);
static FloatType getFloat4E2M1FN(MLIRContext *ctx);
static FloatType getFloat6E2M3FN(MLIRContext *ctx);
static FloatType getFloat6E3M2FN(MLIRContext *ctx);
static FloatType getFloat8E8M0FNU(MLIRContext *ctx);

/// Return the bitwidth of this float type.
unsigned getWidth();

/// Return the width of the mantissa of this type.
/// The width includes the integer bit.
unsigned getFPMantissaWidth();
}];
}

//===----------------------------------------------------------------------===//
// MemRefElementTypeInterface
//===----------------------------------------------------------------------===//
Expand Down
56 changes: 0 additions & 56 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ struct fltSemantics;
namespace mlir {
class AffineExpr;
class AffineMap;
class FloatType;
class IndexType;
class IntegerType;
class MemRefType;
Expand All @@ -44,52 +43,6 @@ template <typename ConcreteType>
class ValueSemantics
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};

//===----------------------------------------------------------------------===//
// FloatType
//===----------------------------------------------------------------------===//

class FloatType : public Type {
public:
using Type::Type;

// Convenience factories.
static FloatType getBF16(MLIRContext *ctx);
static FloatType getF16(MLIRContext *ctx);
static FloatType getF32(MLIRContext *ctx);
static FloatType getTF32(MLIRContext *ctx);
static FloatType getF64(MLIRContext *ctx);
static FloatType getF80(MLIRContext *ctx);
static FloatType getF128(MLIRContext *ctx);
static FloatType getFloat8E5M2(MLIRContext *ctx);
static FloatType getFloat8E4M3(MLIRContext *ctx);
static FloatType getFloat8E4M3FN(MLIRContext *ctx);
static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
static FloatType getFloat8E3M4(MLIRContext *ctx);
static FloatType getFloat4E2M1FN(MLIRContext *ctx);
static FloatType getFloat6E2M3FN(MLIRContext *ctx);
static FloatType getFloat6E3M2FN(MLIRContext *ctx);
static FloatType getFloat8E8M0FNU(MLIRContext *ctx);

/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);

/// Return the bitwidth of this float type.
unsigned getWidth();

/// Return the width of the mantissa of this type.
/// The width includes the integer bit.
unsigned getFPMantissaWidth();

/// Get or create a new FloatType with bitwidth scaled by `scale`.
/// Return null if the scaled element type cannot be represented.
FloatType scaleElementBitwidth(unsigned scale);

/// Return the floating semantics of this float type.
const llvm::fltSemantics &getFloatSemantics();
};

//===----------------------------------------------------------------------===//
// TensorType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -448,15 +401,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
llvm::isa<MemRefElementTypeInterface>(type);
}

inline bool FloatType::classof(Type type) {
return llvm::isa<Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
Float8E5M2FNUZType, Float8E4M3FNUZType,
Float8E4M3B11FNUZType, Float8E3M4Type, Float8E8M0FNUType,
BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
Float64Type, Float80Type, Float128Type>(type);
}

inline FloatType FloatType::getFloat4E2M1FN(MLIRContext *ctx) {
return Float4E2M1FNType::get(ctx);
}
Expand Down
17 changes: 12 additions & 5 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,12 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
//===----------------------------------------------------------------------===//

// Base class for Builtin dialect float types.
class Builtin_FloatType<string name, string mnemonic>
: Builtin_Type<name, mnemonic, /*traits=*/[], "::mlir::FloatType"> {
class Builtin_FloatType<string name, string mnemonic,
list<string> declaredInterfaceMethods = []>
: Builtin_Type<name, mnemonic, /*traits=*/[
DeclareTypeInterfaceMethods<
FloatTypeInterface,
["getFloatSemantics"] # declaredInterfaceMethods>]> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
}];
Expand Down Expand Up @@ -322,14 +326,16 @@ def Builtin_Float8E8M0FNU : Builtin_FloatType<"Float8E8M0FNU", "f8E8M0FNU"> {
//===----------------------------------------------------------------------===//
// BFloat16Type

def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16"> {
def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "bfloat16 floating-point type";
}

//===----------------------------------------------------------------------===//
// Float16Type

def Builtin_Float16 : Builtin_FloatType<"Float16", "f16"> {
def Builtin_Float16 : Builtin_FloatType<"Float16", "f16",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "16-bit floating-point type";
}

Expand All @@ -343,7 +349,8 @@ def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32", "tf32"> {
//===----------------------------------------------------------------------===//
// Float32Type

def Builtin_Float32 : Builtin_FloatType<"Float32", "f32"> {
def Builtin_Float32 : Builtin_FloatType<"Float32", "f32",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "32-bit floating-point type";
}

Expand Down
13 changes: 13 additions & 0 deletions mlir/lib/IR/BuiltinTypeInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Sequence.h"

using namespace mlir;
Expand All @@ -19,6 +20,18 @@ using namespace mlir::detail;

#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"

//===----------------------------------------------------------------------===//
// FloatType
//===----------------------------------------------------------------------===//

unsigned FloatType::getWidth() {
return APFloat::semanticsSizeInBits(getFloatSemantics());
}

unsigned FloatType::getFPMantissaWidth() {
return APFloat::semanticsPrecision(getFloatSemantics());
}

//===----------------------------------------------------------------------===//
// ShapedType
//===----------------------------------------------------------------------===//
Expand Down
106 changes: 44 additions & 62 deletions mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,72 +87,54 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
}

//===----------------------------------------------------------------------===//
// Float Type
//===----------------------------------------------------------------------===//

unsigned FloatType::getWidth() {
return APFloat::semanticsSizeInBits(getFloatSemantics());
}

/// Returns the floating semantics for the given type.
const llvm::fltSemantics &FloatType::getFloatSemantics() {
if (llvm::isa<Float4E2M1FNType>(*this))
return APFloat::Float4E2M1FN();
if (llvm::isa<Float6E2M3FNType>(*this))
return APFloat::Float6E2M3FN();
if (llvm::isa<Float6E3M2FNType>(*this))
return APFloat::Float6E3M2FN();
if (llvm::isa<Float8E5M2Type>(*this))
return APFloat::Float8E5M2();
if (llvm::isa<Float8E4M3Type>(*this))
return APFloat::Float8E4M3();
if (llvm::isa<Float8E4M3FNType>(*this))
return APFloat::Float8E4M3FN();
if (llvm::isa<Float8E5M2FNUZType>(*this))
return APFloat::Float8E5M2FNUZ();
if (llvm::isa<Float8E4M3FNUZType>(*this))
return APFloat::Float8E4M3FNUZ();
if (llvm::isa<Float8E4M3B11FNUZType>(*this))
return APFloat::Float8E4M3B11FNUZ();
if (llvm::isa<Float8E3M4Type>(*this))
return APFloat::Float8E3M4();
if (llvm::isa<Float8E8M0FNUType>(*this))
return APFloat::Float8E8M0FNU();
if (llvm::isa<BFloat16Type>(*this))
return APFloat::BFloat();
if (llvm::isa<Float16Type>(*this))
return APFloat::IEEEhalf();
if (llvm::isa<FloatTF32Type>(*this))
return APFloat::FloatTF32();
if (llvm::isa<Float32Type>(*this))
return APFloat::IEEEsingle();
if (llvm::isa<Float64Type>(*this))
return APFloat::IEEEdouble();
if (llvm::isa<Float80Type>(*this))
return APFloat::x87DoubleExtended();
if (llvm::isa<Float128Type>(*this))
return APFloat::IEEEquad();
llvm_unreachable("non-floating point type used");
}

FloatType FloatType::scaleElementBitwidth(unsigned scale) {
if (!scale)
return FloatType();
MLIRContext *ctx = getContext();
if (isF16() || isBF16()) {
if (scale == 2)
return FloatType::getF32(ctx);
if (scale == 4)
return FloatType::getF64(ctx);
// Float Types
//===----------------------------------------------------------------------===//

// Mapping from MLIR FloatType to APFloat semantics.
#define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
return APFloat::SEM(); \
}
if (isF32())
if (scale == 2)
return FloatType::getF64(ctx);
FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
#undef FLOAT_TYPE_SEMANTICS

FloatType Float16Type::scaleElementBitwidth(unsigned scale) const {
if (scale == 2)
return FloatType::getF32(getContext());
if (scale == 4)
return FloatType::getF64(getContext());
return FloatType();
}

unsigned FloatType::getFPMantissaWidth() {
return APFloat::semanticsPrecision(getFloatSemantics());
FloatType BFloat16Type::scaleElementBitwidth(unsigned scale) const {
if (scale == 2)
return FloatType::getF32(getContext());
if (scale == 4)
return FloatType::getF64(getContext());
return FloatType();
}

FloatType Float32Type::scaleElementBitwidth(unsigned scale) const {
if (scale == 2)
return FloatType::getF64(getContext());
return FloatType();
}

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/unittests/IR/InterfaceAttachmentTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ struct Model
/// overrides default methods.
struct OverridingModel
: public TestExternalTypeInterface::ExternalModel<OverridingModel,
FloatType> {
Float32Type> {
unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
return type.getIntOrFloatBitWidth() + arg;
}
Expand Down

0 comments on commit c24ce32

Please sign in to comment.