-
Notifications
You must be signed in to change notification settings - Fork 12.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][IR] Turn FloatType
into a type interface
#118891
[mlir][IR] Turn FloatType
into a type interface
#118891
Conversation
dae827c
to
c996d3f
Compare
@llvm/pr-subscribers-mlir-ods @llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThis makes it possible to add new MLIR floating point types in downstream projects. (Adding new APFloat semantics in downstream projects is not possible yet, so parsing/printing/converting float literals of newly added types is not supported.) Also removes one place where we had to hard-code all existing floating point types ( No measurable compilation time changes for these lit tests:
Full diff: https://github.com/llvm/llvm-project/pull/118891.diff 7 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
index ed5e5ca22c5958..e8011b5488dc98 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
@@ -11,6 +11,15 @@
#include "mlir/IR/Types.h"
+namespace llvm {
+struct fltSemantics;
+} // namespace llvm
+
+namespace mlir {
+class FloatType;
+class MLIRContext;
+} // namespace mlir
+
#include "mlir/IR/BuiltinTypeInterfaces.h.inc"
#endif // MLIR_IR_BUILTINTYPEINTERFACES_H
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index c9dcd546cf67c2..8b0242672dfdb4 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,6 +16,52 @@
include "mlir/IR/OpBase.td"
+def FloatTypeInterface : TypeInterface<"FloatType"> {
+ let cppNamespace = "::mlir";
+ let description = [{
+ This type interface should be implemented by all floating-point types. It
+ defines the LLVM APFloat semantics and provides a few helper functions.
+ }];
+ let methods = [
+ InterfaceMethod<[{
+ Returns the APFloat semantics for this floating-point type.
+ }], "const llvm::fltSemantics &", "getFloatSemantics", (ins)>,
+ ];
+
+ let extraClassDeclaration = [{
+ // Convenience factories.
+ static FloatType getBF16(MLIRContext *ctx);
+ static FloatType getF16(MLIRContext *ctx);
+ static FloatType getF32(MLIRContext *ctx);
+ static FloatType getTF32(MLIRContext *ctx);
+ static FloatType getF64(MLIRContext *ctx);
+ static FloatType getF80(MLIRContext *ctx);
+ static FloatType getF128(MLIRContext *ctx);
+ static FloatType getFloat8E5M2(MLIRContext *ctx);
+ static FloatType getFloat8E4M3(MLIRContext *ctx);
+ static FloatType getFloat8E4M3FN(MLIRContext *ctx);
+ static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
+ static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
+ static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
+ static FloatType getFloat8E3M4(MLIRContext *ctx);
+ static FloatType getFloat4E2M1FN(MLIRContext *ctx);
+ static FloatType getFloat6E2M3FN(MLIRContext *ctx);
+ static FloatType getFloat6E3M2FN(MLIRContext *ctx);
+ static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
+
+ /// Return the bitwidth of this float type.
+ unsigned getWidth();
+
+ /// Return the width of the mantissa of this type.
+ /// The width includes the integer bit.
+ unsigned getFPMantissaWidth();
+
+ /// Get or create a new FloatType with bitwidth scaled by `scale`.
+ /// Return null if the scaled element type cannot be represented.
+ FloatType scaleElementBitwidth(unsigned scale);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// MemRefElementTypeInterface
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 7f9c470ffec304..2b3c2b6d1753dc 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -25,7 +25,6 @@ struct fltSemantics;
namespace mlir {
class AffineExpr;
class AffineMap;
-class FloatType;
class IndexType;
class IntegerType;
class MemRefType;
@@ -44,52 +43,6 @@ template <typename ConcreteType>
class ValueSemantics
: public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
-//===----------------------------------------------------------------------===//
-// FloatType
-//===----------------------------------------------------------------------===//
-
-class FloatType : public Type {
-public:
- using Type::Type;
-
- // Convenience factories.
- static FloatType getBF16(MLIRContext *ctx);
- static FloatType getF16(MLIRContext *ctx);
- static FloatType getF32(MLIRContext *ctx);
- static FloatType getTF32(MLIRContext *ctx);
- static FloatType getF64(MLIRContext *ctx);
- static FloatType getF80(MLIRContext *ctx);
- static FloatType getF128(MLIRContext *ctx);
- static FloatType getFloat8E5M2(MLIRContext *ctx);
- static FloatType getFloat8E4M3(MLIRContext *ctx);
- static FloatType getFloat8E4M3FN(MLIRContext *ctx);
- static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
- static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
- static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
- static FloatType getFloat8E3M4(MLIRContext *ctx);
- static FloatType getFloat4E2M1FN(MLIRContext *ctx);
- static FloatType getFloat6E2M3FN(MLIRContext *ctx);
- static FloatType getFloat6E3M2FN(MLIRContext *ctx);
- static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
-
- /// Methods for support type inquiry through isa, cast, and dyn_cast.
- static bool classof(Type type);
-
- /// Return the bitwidth of this float type.
- unsigned getWidth();
-
- /// Return the width of the mantissa of this type.
- /// The width includes the integer bit.
- unsigned getFPMantissaWidth();
-
- /// Get or create a new FloatType with bitwidth scaled by `scale`.
- /// Return null if the scaled element type cannot be represented.
- FloatType scaleElementBitwidth(unsigned scale);
-
- /// Return the floating semantics of this float type.
- const llvm::fltSemantics &getFloatSemantics();
-};
-
//===----------------------------------------------------------------------===//
// TensorType
//===----------------------------------------------------------------------===//
@@ -448,15 +401,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
llvm::isa<MemRefElementTypeInterface>(type);
}
-inline bool FloatType::classof(Type type) {
- return llvm::isa<Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
- Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
- Float8E5M2FNUZType, Float8E4M3FNUZType,
- Float8E4M3B11FNUZType, Float8E3M4Type, Float8E8M0FNUType,
- BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
- Float64Type, Float80Type, Float128Type>(type);
-}
-
inline FloatType FloatType::getFloat4E2M1FN(MLIRContext *ctx) {
return Float4E2M1FNType::get(ctx);
}
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index dca228097d782d..a0afda4e3b465e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -80,7 +80,9 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
// Base class for Builtin dialect float types.
class Builtin_FloatType<string name, string mnemonic>
- : Builtin_Type<name, mnemonic, /*traits=*/[], "::mlir::FloatType"> {
+ : Builtin_Type<name, mnemonic, /*traits=*/[
+ DeclareTypeInterfaceMethods<FloatTypeInterface,
+ ["getFloatSemantics"]>]> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
}];
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index ab9e65b5edfed3..1374e889833283 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -8,6 +8,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Sequence.h"
using namespace mlir;
@@ -19,6 +20,34 @@ using namespace mlir::detail;
#include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
+//===----------------------------------------------------------------------===//
+// FloatType
+//===----------------------------------------------------------------------===//
+
+unsigned FloatType::getWidth() {
+ return APFloat::semanticsSizeInBits(getFloatSemantics());
+}
+
+FloatType FloatType::scaleElementBitwidth(unsigned scale) {
+ if (!scale)
+ return FloatType();
+ MLIRContext *ctx = getContext();
+ if (isF16() || isBF16()) {
+ if (scale == 2)
+ return FloatType::getF32(ctx);
+ if (scale == 4)
+ return FloatType::getF64(ctx);
+ }
+ if (isF32())
+ if (scale == 2)
+ return FloatType::getF64(ctx);
+ return FloatType();
+}
+
+unsigned FloatType::getFPMantissaWidth() {
+ return APFloat::semanticsPrecision(getFloatSemantics());
+}
+
//===----------------------------------------------------------------------===//
// ShapedType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 6546234429c8cb..81e154328a4a2e 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -87,73 +87,33 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
}
//===----------------------------------------------------------------------===//
-// Float Type
-//===----------------------------------------------------------------------===//
-
-unsigned FloatType::getWidth() {
- return APFloat::semanticsSizeInBits(getFloatSemantics());
-}
-
-/// Returns the floating semantics for the given type.
-const llvm::fltSemantics &FloatType::getFloatSemantics() {
- if (llvm::isa<Float4E2M1FNType>(*this))
- return APFloat::Float4E2M1FN();
- if (llvm::isa<Float6E2M3FNType>(*this))
- return APFloat::Float6E2M3FN();
- if (llvm::isa<Float6E3M2FNType>(*this))
- return APFloat::Float6E3M2FN();
- if (llvm::isa<Float8E5M2Type>(*this))
- return APFloat::Float8E5M2();
- if (llvm::isa<Float8E4M3Type>(*this))
- return APFloat::Float8E4M3();
- if (llvm::isa<Float8E4M3FNType>(*this))
- return APFloat::Float8E4M3FN();
- if (llvm::isa<Float8E5M2FNUZType>(*this))
- return APFloat::Float8E5M2FNUZ();
- if (llvm::isa<Float8E4M3FNUZType>(*this))
- return APFloat::Float8E4M3FNUZ();
- if (llvm::isa<Float8E4M3B11FNUZType>(*this))
- return APFloat::Float8E4M3B11FNUZ();
- if (llvm::isa<Float8E3M4Type>(*this))
- return APFloat::Float8E3M4();
- if (llvm::isa<Float8E8M0FNUType>(*this))
- return APFloat::Float8E8M0FNU();
- if (llvm::isa<BFloat16Type>(*this))
- return APFloat::BFloat();
- if (llvm::isa<Float16Type>(*this))
- return APFloat::IEEEhalf();
- if (llvm::isa<FloatTF32Type>(*this))
- return APFloat::FloatTF32();
- if (llvm::isa<Float32Type>(*this))
- return APFloat::IEEEsingle();
- if (llvm::isa<Float64Type>(*this))
- return APFloat::IEEEdouble();
- if (llvm::isa<Float80Type>(*this))
- return APFloat::x87DoubleExtended();
- if (llvm::isa<Float128Type>(*this))
- return APFloat::IEEEquad();
- llvm_unreachable("non-floating point type used");
-}
-
-FloatType FloatType::scaleElementBitwidth(unsigned scale) {
- if (!scale)
- return FloatType();
- MLIRContext *ctx = getContext();
- if (isF16() || isBF16()) {
- if (scale == 2)
- return FloatType::getF32(ctx);
- if (scale == 4)
- return FloatType::getF64(ctx);
- }
- if (isF32())
- if (scale == 2)
- return FloatType::getF64(ctx);
- return FloatType();
-}
+// Float Types
+//===----------------------------------------------------------------------===//
-unsigned FloatType::getFPMantissaWidth() {
- return APFloat::semanticsPrecision(getFloatSemantics());
-}
+// Mapping from MLIR FloatType to APFloat semantics.
+#define FLOAT_TYPE_SEMANTICS(TYPE, SEM) \
+ const llvm::fltSemantics &TYPE::getFloatSemantics() const { \
+ return APFloat::SEM(); \
+ }
+FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
+FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
+FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
+FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
+FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
+FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
+FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
+FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
+FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
+FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
+FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
+FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
+FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
+FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
+FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
+#undef FLOAT_TYPE_SEMANTICS
//===----------------------------------------------------------------------===//
// FunctionType
diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
index b6066dd5685dc6..1b5d3b8c31bd22 100644
--- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -43,7 +43,7 @@ struct Model
/// overrides default methods.
struct OverridingModel
: public TestExternalTypeInterface::ExternalModel<OverridingModel,
- FloatType> {
+ Float32Type> {
unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
return type.getIntOrFloatBitWidth() + arg;
}
|
How does this (will) connect to the LLVM IR dialect/pure LLVM IR? |
This PR does not remove the MLIR type classes (e.g.,
LLVM does not seem to support "fancy" FP types in the type system. We use integer types instead. See here. The one place that must extended in LLVM to support new floating-point types are the floating-point semantics. That's not possible yet, but I plan to make that part extensible as a follow-up. That's a bigger change. |
I wrote a micro benchmark that parses + prints 32768 floats with random floating-point type: https://gist.github.com/matthias-springer/d22cc02b6097553b78bbba39fed8ca1e Going through the type interface causes a slowdown compared to the hard-coded sequence of
|
The immediate changes here look reasonable. Have you done a sweep of the current users of FloatType to see if any of them should be tightened? Opening up to an interface suddenly expands the contract of FloatType. |
c996d3f
to
1ebf66d
Compare
From a user perspective this commit is mostly NFC. It does not matter if you have a concrete type or an interface. The API surface is almost identical. There is only one difference that I found: users can no longer attach an interface to |
]; | ||
|
||
let extraClassDeclaration = [{ | ||
// Convenience factories. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: I think these should also be removed, but that's a larger change because existing code uses them a lot. (And it's unclear what to use instead. Maybe builder.getF32Type()
etc...) So I'm just keeping them here for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
This makes it possible to add new floating point types in downstream projects. Also removes one place where we had to hard-code all existing floating point types (`FloatType::classof`).
1ebf66d
to
8613fb3
Compare
This makes it possible to add new MLIR floating point types in downstream projects. (Adding new APFloat semantics in downstream projects is not possible yet, so parsing/printing/converting float literals of newly added types is not supported.) Also removes two functions where we had to hard-code all existing floating point types (`FloatType::classof`). See discussion here: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361 No measurable compilation time changes for these lit tests: ``` Benchmark 1: mlir-opt ./mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir -split-input-file -convert-vector-to-llvm -o /dev/null BEFORE Time (mean ± σ): 248.4 ms ± 3.2 ms [User: 237.0 ms, System: 20.1 ms] Range (min … max): 243.3 ms … 255.9 ms 30 runs AFTER Time (mean ± σ): 246.8 ms ± 3.2 ms [User: 233.2 ms, System: 21.8 ms] Range (min … max): 240.2 ms … 252.1 ms 30 runs Benchmark 2: mlir-opt- ./mlir/test/Dialect/Arith/canonicalize.mlir -split-input-file -canonicalize -o /dev/null BEFORE Time (mean ± σ): 37.3 ms ± 1.8 ms [User: 31.6 ms, System: 30.4 ms] Range (min … max): 34.6 ms … 42.0 ms 200 runs AFTER Time (mean ± σ): 37.5 ms ± 2.0 ms [User: 31.5 ms, System: 29.2 ms] Range (min … max): 34.5 ms … 43.0 ms 200 runs Benchmark 3: mlir-opt ./mlir/test/Dialect/Tensor/canonicalize.mlir -split-input-file -canonicalize -allow-unregistered-dialect -o /dev/null BEFORE Time (mean ± σ): 152.2 ms ± 2.5 ms [User: 140.1 ms, System: 12.2 ms] Range (min … max): 147.6 ms … 161.8 ms 200 runs AFTER Time (mean ± σ): 151.9 ms ± 2.7 ms [User: 140.5 ms, System: 11.5 ms] Range (min … max): 147.2 ms … 159.1 ms 200 runs ``` A micro benchmark that parses + prints 32768 floats with random floating-point type shows a slowdown from 55.1 ms -> 48.3 ms.
This makes it possible to add new MLIR floating point types in downstream projects. (Adding new APFloat semantics in downstream projects is not possible yet, so parsing/printing/converting float literals of newly added types is not supported.) Also removes two functions where we had to hard-code all existing floating point types (`FloatType::classof`). See discussion here: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361 No measurable compilation time changes for these lit tests: ``` Benchmark 1: mlir-opt ./mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir -split-input-file -convert-vector-to-llvm -o /dev/null BEFORE Time (mean ± σ): 248.4 ms ± 3.2 ms [User: 237.0 ms, System: 20.1 ms] Range (min … max): 243.3 ms … 255.9 ms 30 runs AFTER Time (mean ± σ): 246.8 ms ± 3.2 ms [User: 233.2 ms, System: 21.8 ms] Range (min … max): 240.2 ms … 252.1 ms 30 runs Benchmark 2: mlir-opt- ./mlir/test/Dialect/Arith/canonicalize.mlir -split-input-file -canonicalize -o /dev/null BEFORE Time (mean ± σ): 37.3 ms ± 1.8 ms [User: 31.6 ms, System: 30.4 ms] Range (min … max): 34.6 ms … 42.0 ms 200 runs AFTER Time (mean ± σ): 37.5 ms ± 2.0 ms [User: 31.5 ms, System: 29.2 ms] Range (min … max): 34.5 ms … 43.0 ms 200 runs Benchmark 3: mlir-opt ./mlir/test/Dialect/Tensor/canonicalize.mlir -split-input-file -canonicalize -allow-unregistered-dialect -o /dev/null BEFORE Time (mean ± σ): 152.2 ms ± 2.5 ms [User: 140.1 ms, System: 12.2 ms] Range (min … max): 147.6 ms … 161.8 ms 200 runs AFTER Time (mean ± σ): 151.9 ms ± 2.7 ms [User: 140.5 ms, System: 11.5 ms] Range (min … max): 147.2 ms … 159.1 ms 200 runs ``` A micro benchmark that parses + prints 32768 floats with random floating-point type shows a slowdown from 55.1 ms -> 48.3 ms.
This makes it possible to add new MLIR floating point types in downstream projects. (Adding new APFloat semantics in downstream projects is not possible yet, so parsing/printing/converting float literals of newly added types is not supported.)
Also removes two functions where we had to hard-code all existing floating point types (
FloatType::classof
). See discussion here: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361No measurable compilation time changes for these lit tests:
A micro benchmark that parses + prints 32768 floats with random floating-point type shows a slowdown from 55.1 ms -> 48.3 ms.