Skip to content

Commit

Permalink
[mlir][IR] Turn FloatType into a type interface
Browse files Browse the repository at this point in the history
This makes it possible to add new floating point types in downstream projects. Also removes one place where we had to hard-code all existing floating point types (`FloatType::classof`).
  • Loading branch information
matthias-springer committed Dec 6, 2024
1 parent e7412a5 commit c996d3f
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 124 deletions.
9 changes: 9 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypeInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@

#include "mlir/IR/Types.h"

namespace llvm {
struct fltSemantics;
} // namespace llvm

namespace mlir {
class FloatType;
class MLIRContext;
} // namespace mlir

#include "mlir/IR/BuiltinTypeInterfaces.h.inc"

#endif // MLIR_IR_BUILTINTYPEINTERFACES_H
46 changes: 46 additions & 0 deletions mlir/include/mlir/IR/BuiltinTypeInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,52 @@

include "mlir/IR/OpBase.td"

def FloatTypeInterface : TypeInterface<"FloatType"> {
let cppNamespace = "::mlir";
let description = [{
This type interface should be implemented by all floating-point types. It
defines the LLVM APFloat semantics and provides a few helper functions.
}];
let methods = [
InterfaceMethod<[{
Returns the APFloat semantics for this floating-point type.
}], "const llvm::fltSemantics &", "getFloatSemantics", (ins)>,
];

let extraClassDeclaration = [{
// Convenience factories.
static FloatType getBF16(MLIRContext *ctx);
static FloatType getF16(MLIRContext *ctx);
static FloatType getF32(MLIRContext *ctx);
static FloatType getTF32(MLIRContext *ctx);
static FloatType getF64(MLIRContext *ctx);
static FloatType getF80(MLIRContext *ctx);
static FloatType getF128(MLIRContext *ctx);
static FloatType getFloat8E5M2(MLIRContext *ctx);
static FloatType getFloat8E4M3(MLIRContext *ctx);
static FloatType getFloat8E4M3FN(MLIRContext *ctx);
static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
static FloatType getFloat8E3M4(MLIRContext *ctx);
static FloatType getFloat4E2M1FN(MLIRContext *ctx);
static FloatType getFloat6E2M3FN(MLIRContext *ctx);
static FloatType getFloat6E3M2FN(MLIRContext *ctx);
static FloatType getFloat8E8M0FNU(MLIRContext *ctx);

/// Return the bitwidth of this float type.
unsigned getWidth();

/// Return the width of the mantissa of this type.
/// The width includes the integer bit.
unsigned getFPMantissaWidth();

/// Get or create a new FloatType with bitwidth scaled by `scale`.
/// Return null if the scaled element type cannot be represented.
FloatType scaleElementBitwidth(unsigned scale);
}];
}

//===----------------------------------------------------------------------===//
// MemRefElementTypeInterface
//===----------------------------------------------------------------------===//
Expand Down
56 changes: 0 additions & 56 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ struct fltSemantics;
namespace mlir {
class AffineExpr;
class AffineMap;
class FloatType;
class IndexType;
class IntegerType;
class MemRefType;
Expand All @@ -44,52 +43,6 @@ template <typename ConcreteType>
class ValueSemantics
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};

//===----------------------------------------------------------------------===//
// FloatType
//===----------------------------------------------------------------------===//

class FloatType : public Type {
public:
using Type::Type;

// Convenience factories.
static FloatType getBF16(MLIRContext *ctx);
static FloatType getF16(MLIRContext *ctx);
static FloatType getF32(MLIRContext *ctx);
static FloatType getTF32(MLIRContext *ctx);
static FloatType getF64(MLIRContext *ctx);
static FloatType getF80(MLIRContext *ctx);
static FloatType getF128(MLIRContext *ctx);
static FloatType getFloat8E5M2(MLIRContext *ctx);
static FloatType getFloat8E4M3(MLIRContext *ctx);
static FloatType getFloat8E4M3FN(MLIRContext *ctx);
static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
static FloatType getFloat8E3M4(MLIRContext *ctx);
static FloatType getFloat4E2M1FN(MLIRContext *ctx);
static FloatType getFloat6E2M3FN(MLIRContext *ctx);
static FloatType getFloat6E3M2FN(MLIRContext *ctx);
static FloatType getFloat8E8M0FNU(MLIRContext *ctx);

/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type);

/// Return the bitwidth of this float type.
unsigned getWidth();

/// Return the width of the mantissa of this type.
/// The width includes the integer bit.
unsigned getFPMantissaWidth();

/// Get or create a new FloatType with bitwidth scaled by `scale`.
/// Return null if the scaled element type cannot be represented.
FloatType scaleElementBitwidth(unsigned scale);

/// Return the floating semantics of this float type.
const llvm::fltSemantics &getFloatSemantics();
};

//===----------------------------------------------------------------------===//
// TensorType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -448,15 +401,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
llvm::isa<MemRefElementTypeInterface>(type);
}

inline bool FloatType::classof(Type type) {
return llvm::isa<Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
Float8E5M2FNUZType, Float8E4M3FNUZType,
Float8E4M3B11FNUZType, Float8E3M4Type, Float8E8M0FNUType,
BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
Float64Type, Float80Type, Float128Type>(type);
}

inline FloatType FloatType::getFloat4E2M1FN(MLIRContext *ctx) {
return Float4E2M1FNType::get(ctx);
}
Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {

// Base class for Builtin dialect float types.
class Builtin_FloatType<string name, string mnemonic>
: Builtin_Type<name, mnemonic, /*traits=*/[], "::mlir::FloatType"> {
: Builtin_Type<name, mnemonic, /*traits=*/[
DeclareTypeInterfaceMethods<FloatTypeInterface,
["getFloatSemantics"]>]> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
}];
Expand Down
29 changes: 29 additions & 0 deletions mlir/lib/IR/BuiltinTypeInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Sequence.h"

using namespace mlir;
Expand All @@ -19,6 +20,34 @@ using namespace mlir::detail;

#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"

//===----------------------------------------------------------------------===//
// FloatType
//===----------------------------------------------------------------------===//

unsigned FloatType::getWidth() {
return APFloat::semanticsSizeInBits(getFloatSemantics());
}

FloatType FloatType::scaleElementBitwidth(unsigned scale) {
if (!scale)
return FloatType();
MLIRContext *ctx = getContext();
if (isF16() || isBF16()) {
if (scale == 2)
return FloatType::getF32(ctx);
if (scale == 4)
return FloatType::getF64(ctx);
}
if (isF32())
if (scale == 2)
return FloatType::getF64(ctx);
return FloatType();
}

unsigned FloatType::getFPMantissaWidth() {
return APFloat::semanticsPrecision(getFloatSemantics());
}

//===----------------------------------------------------------------------===//
// ShapedType
//===----------------------------------------------------------------------===//
Expand Down
92 changes: 26 additions & 66 deletions mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,73 +87,33 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
}

//===----------------------------------------------------------------------===//
// Float Type
//===----------------------------------------------------------------------===//

unsigned FloatType::getWidth() {
return APFloat::semanticsSizeInBits(getFloatSemantics());
}

/// Returns the floating semantics for the given type.
const llvm::fltSemantics &FloatType::getFloatSemantics() {
if (llvm::isa<Float4E2M1FNType>(*this))
return APFloat::Float4E2M1FN();
if (llvm::isa<Float6E2M3FNType>(*this))
return APFloat::Float6E2M3FN();
if (llvm::isa<Float6E3M2FNType>(*this))
return APFloat::Float6E3M2FN();
if (llvm::isa<Float8E5M2Type>(*this))
return APFloat::Float8E5M2();
if (llvm::isa<Float8E4M3Type>(*this))
return APFloat::Float8E4M3();
if (llvm::isa<Float8E4M3FNType>(*this))
return APFloat::Float8E4M3FN();
if (llvm::isa<Float8E5M2FNUZType>(*this))
return APFloat::Float8E5M2FNUZ();
if (llvm::isa<Float8E4M3FNUZType>(*this))
return APFloat::Float8E4M3FNUZ();
if (llvm::isa<Float8E4M3B11FNUZType>(*this))
return APFloat::Float8E4M3B11FNUZ();
if (llvm::isa<Float8E3M4Type>(*this))
return APFloat::Float8E3M4();
if (llvm::isa<Float8E8M0FNUType>(*this))
return APFloat::Float8E8M0FNU();
if (llvm::isa<BFloat16Type>(*this))
return APFloat::BFloat();
if (llvm::isa<Float16Type>(*this))
return APFloat::IEEEhalf();
if (llvm::isa<FloatTF32Type>(*this))
return APFloat::FloatTF32();
if (llvm::isa<Float32Type>(*this))
return APFloat::IEEEsingle();
if (llvm::isa<Float64Type>(*this))
return APFloat::IEEEdouble();
if (llvm::isa<Float80Type>(*this))
return APFloat::x87DoubleExtended();
if (llvm::isa<Float128Type>(*this))
return APFloat::IEEEquad();
llvm_unreachable("non-floating point type used");
}

FloatType FloatType::scaleElementBitwidth(unsigned scale) {
if (!scale)
return FloatType();
MLIRContext *ctx = getContext();
if (isF16() || isBF16()) {
if (scale == 2)
return FloatType::getF32(ctx);
if (scale == 4)
return FloatType::getF64(ctx);
}
if (isF32())
if (scale == 2)
return FloatType::getF64(ctx);
return FloatType();
}
// Float Types
//===----------------------------------------------------------------------===//

unsigned FloatType::getFPMantissaWidth() {
return APFloat::semanticsPrecision(getFloatSemantics());
}
// Mapping from MLIR FloatType to APFloat semantics.
#define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
return APFloat::SEM(); \
}
FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
#undef FLOAT_TYPE_SEMANTICS

//===----------------------------------------------------------------------===//
// FunctionType
Expand Down
2 changes: 1 addition & 1 deletion mlir/unittests/IR/InterfaceAttachmentTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ struct Model
/// overrides default methods.
struct OverridingModel
: public TestExternalTypeInterface::ExternalModel<OverridingModel,
FloatType> {
Float32Type> {
unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
return type.getIntOrFloatBitWidth() + arg;
}
Expand Down

0 comments on commit c996d3f

Please sign in to comment.