Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][IR] Turn FloatType into a type interface #118891

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Dec 5, 2024

This makes it possible to add new MLIR floating point types in downstream projects. (Adding new APFloat semantics in downstream projects is not possible yet, so parsing/printing/converting float literals of newly added types is not supported.)

Also removes two functions where we had to hard-code all existing floating point types (FloatType::classof). See discussion here: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361

No measurable compilation time changes for these lit tests:

Benchmark 1: mlir-opt ./mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir -split-input-file -convert-vector-to-llvm -o /dev/null
  BEFORE
  Time (mean ± σ):     248.4 ms ±   3.2 ms    [User: 237.0 ms, System: 20.1 ms]
  Range (min … max):   243.3 ms … 255.9 ms    30 runs

  AFTER
  Time (mean ± σ):     246.8 ms ±   3.2 ms    [User: 233.2 ms, System: 21.8 ms]
  Range (min … max):   240.2 ms … 252.1 ms    30 runs


Benchmark 2: mlir-opt- ./mlir/test/Dialect/Arith/canonicalize.mlir -split-input-file -canonicalize -o /dev/null
  BEFORE
  Time (mean ± σ):      37.3 ms ±   1.8 ms    [User: 31.6 ms, System: 30.4 ms]
  Range (min … max):    34.6 ms …  42.0 ms    200 runs

  AFTER
  Time (mean ± σ):      37.5 ms ±   2.0 ms    [User: 31.5 ms, System: 29.2 ms]
  Range (min … max):    34.5 ms …  43.0 ms    200 runs


Benchmark 3: mlir-opt ./mlir/test/Dialect/Tensor/canonicalize.mlir -split-input-file -canonicalize -allow-unregistered-dialect -o /dev/null
  BEFORE
  Time (mean ± σ):     152.2 ms ±   2.5 ms    [User: 140.1 ms, System: 12.2 ms]
  Range (min … max):   147.6 ms … 161.8 ms    200 runs

  AFTER
  Time (mean ± σ):     151.9 ms ±   2.7 ms    [User: 140.5 ms, System: 11.5 ms]
  Range (min … max):   147.2 ms … 159.1 ms    200 runs

A micro benchmark that parses + prints 32768 floats with random floating-point type shows a slowdown from 55.1 ms -> 48.3 ms.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/float_type_interface branch from dae827c to c996d3f Compare December 6, 2024 00:42
@matthias-springer matthias-springer marked this pull request as ready for review December 6, 2024 00:48
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:ods labels Dec 6, 2024
@llvmbot
Copy link
Member

llvmbot commented Dec 6, 2024

@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

This makes it possible to add new MLIR floating point types in downstream projects. (Adding new APFloat semantics in downstream projects is not possible yet, so parsing/printing/converting float literals of newly added types is not supported.)

Also removes one place where we had to hard-code all existing floating point types (FloatType::classof). See discussion here: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361

No measurable compilation time changes for these lit tests:

Benchmark 1: mlir-opt ./mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir -split-input-file -convert-vector-to-llvm -o /dev/null
  BEFORE
  Time (mean ± σ):     248.4 ms ±   3.2 ms    [User: 237.0 ms, System: 20.1 ms]
  Range (min … max):   243.3 ms … 255.9 ms    30 runs

  AFTER
  Time (mean ± σ):     246.8 ms ±   3.2 ms    [User: 233.2 ms, System: 21.8 ms]
  Range (min … max):   240.2 ms … 252.1 ms    30 runs


Benchmark 2: mlir-opt- ./mlir/test/Dialect/Arith/canonicalize.mlir -split-input-file -canonicalize -o /dev/null
  BEFORE
  Time (mean ± σ):      37.3 ms ±   1.8 ms    [User: 31.6 ms, System: 30.4 ms]
  Range (min … max):    34.6 ms …  42.0 ms    200 runs

  AFTER
  Time (mean ± σ):      37.5 ms ±   2.0 ms    [User: 31.5 ms, System: 29.2 ms]
  Range (min … max):    34.5 ms …  43.0 ms    200 runs


Benchmark 3: mlir-opt ./mlir/test/Dialect/Tensor/canonicalize.mlir -split-input-file -canonicalize -allow-unregistered-dialect -o /dev/null
  BEFORE
  Time (mean ± σ):     152.2 ms ±   2.5 ms    [User: 140.1 ms, System: 12.2 ms]
  Range (min … max):   147.6 ms … 161.8 ms    200 runs

  AFTER
  Time (mean ± σ):     151.9 ms ±   2.7 ms    [User: 140.5 ms, System: 11.5 ms]
  Range (min … max):   147.2 ms … 159.1 ms    200 runs

Full diff: https://github.com/llvm/llvm-project/pull/118891.diff

7 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.h (+9)
  • (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+46)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (-56)
  • (modified) mlir/include/mlir/IR/BuiltinTypes.td (+3-1)
  • (modified) mlir/lib/IR/BuiltinTypeInterfaces.cpp (+29)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+26-66)
  • (modified) mlir/unittests/IR/InterfaceAttachmentTest.cpp (+1-1)
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
index ed5e5ca22c5958..e8011b5488dc98 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h
@@ -11,6 +11,15 @@
 
 #include "mlir/IR/Types.h"
 
+namespace llvm {
+struct fltSemantics;
+} // namespace llvm
+
+namespace mlir {
+class FloatType;
+class MLIRContext;
+} // namespace mlir
+
 #include "mlir/IR/BuiltinTypeInterfaces.h.inc"
 
 #endif // MLIR_IR_BUILTINTYPEINTERFACES_H
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index c9dcd546cf67c2..8b0242672dfdb4 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -16,6 +16,52 @@
 
 include "mlir/IR/OpBase.td"
 
+def FloatTypeInterface : TypeInterface<"FloatType"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    This type interface should be implemented by all floating-point types. It
+    defines the LLVM APFloat semantics and provides a few helper functions.
+  }];
+  let methods = [
+    InterfaceMethod<[{
+      Returns the APFloat semantics for this floating-point type.
+    }], "const llvm::fltSemantics &", "getFloatSemantics", (ins)>,
+  ];
+
+  let extraClassDeclaration = [{
+    // Convenience factories.
+    static FloatType getBF16(MLIRContext *ctx);
+    static FloatType getF16(MLIRContext *ctx);
+    static FloatType getF32(MLIRContext *ctx);
+    static FloatType getTF32(MLIRContext *ctx);
+    static FloatType getF64(MLIRContext *ctx);
+    static FloatType getF80(MLIRContext *ctx);
+    static FloatType getF128(MLIRContext *ctx);
+    static FloatType getFloat8E5M2(MLIRContext *ctx);
+    static FloatType getFloat8E4M3(MLIRContext *ctx);
+    static FloatType getFloat8E4M3FN(MLIRContext *ctx);
+    static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
+    static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
+    static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
+    static FloatType getFloat8E3M4(MLIRContext *ctx);
+    static FloatType getFloat4E2M1FN(MLIRContext *ctx);
+    static FloatType getFloat6E2M3FN(MLIRContext *ctx);
+    static FloatType getFloat6E3M2FN(MLIRContext *ctx);
+    static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
+
+    /// Return the bitwidth of this float type.
+    unsigned getWidth();
+
+    /// Return the width of the mantissa of this type.
+    /// The width includes the integer bit.
+    unsigned getFPMantissaWidth();
+
+    /// Get or create a new FloatType with bitwidth scaled by `scale`.
+    /// Return null if the scaled element type cannot be represented.
+    FloatType scaleElementBitwidth(unsigned scale);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // MemRefElementTypeInterface
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 7f9c470ffec304..2b3c2b6d1753dc 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -25,7 +25,6 @@ struct fltSemantics;
 namespace mlir {
 class AffineExpr;
 class AffineMap;
-class FloatType;
 class IndexType;
 class IntegerType;
 class MemRefType;
@@ -44,52 +43,6 @@ template <typename ConcreteType>
 class ValueSemantics
     : public TypeTrait::TraitBase<ConcreteType, ValueSemantics> {};
 
-//===----------------------------------------------------------------------===//
-// FloatType
-//===----------------------------------------------------------------------===//
-
-class FloatType : public Type {
-public:
-  using Type::Type;
-
-  // Convenience factories.
-  static FloatType getBF16(MLIRContext *ctx);
-  static FloatType getF16(MLIRContext *ctx);
-  static FloatType getF32(MLIRContext *ctx);
-  static FloatType getTF32(MLIRContext *ctx);
-  static FloatType getF64(MLIRContext *ctx);
-  static FloatType getF80(MLIRContext *ctx);
-  static FloatType getF128(MLIRContext *ctx);
-  static FloatType getFloat8E5M2(MLIRContext *ctx);
-  static FloatType getFloat8E4M3(MLIRContext *ctx);
-  static FloatType getFloat8E4M3FN(MLIRContext *ctx);
-  static FloatType getFloat8E5M2FNUZ(MLIRContext *ctx);
-  static FloatType getFloat8E4M3FNUZ(MLIRContext *ctx);
-  static FloatType getFloat8E4M3B11FNUZ(MLIRContext *ctx);
-  static FloatType getFloat8E3M4(MLIRContext *ctx);
-  static FloatType getFloat4E2M1FN(MLIRContext *ctx);
-  static FloatType getFloat6E2M3FN(MLIRContext *ctx);
-  static FloatType getFloat6E3M2FN(MLIRContext *ctx);
-  static FloatType getFloat8E8M0FNU(MLIRContext *ctx);
-
-  /// Methods for support type inquiry through isa, cast, and dyn_cast.
-  static bool classof(Type type);
-
-  /// Return the bitwidth of this float type.
-  unsigned getWidth();
-
-  /// Return the width of the mantissa of this type.
-  /// The width includes the integer bit.
-  unsigned getFPMantissaWidth();
-
-  /// Get or create a new FloatType with bitwidth scaled by `scale`.
-  /// Return null if the scaled element type cannot be represented.
-  FloatType scaleElementBitwidth(unsigned scale);
-
-  /// Return the floating semantics of this float type.
-  const llvm::fltSemantics &getFloatSemantics();
-};
-
 //===----------------------------------------------------------------------===//
 // TensorType
 //===----------------------------------------------------------------------===//
@@ -448,15 +401,6 @@ inline bool BaseMemRefType::isValidElementType(Type type) {
          llvm::isa<MemRefElementTypeInterface>(type);
 }
 
-inline bool FloatType::classof(Type type) {
-  return llvm::isa<Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType,
-                   Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType,
-                   Float8E5M2FNUZType, Float8E4M3FNUZType,
-                   Float8E4M3B11FNUZType, Float8E3M4Type, Float8E8M0FNUType,
-                   BFloat16Type, Float16Type, FloatTF32Type, Float32Type,
-                   Float64Type, Float80Type, Float128Type>(type);
-}
-
 inline FloatType FloatType::getFloat4E2M1FN(MLIRContext *ctx) {
   return Float4E2M1FNType::get(ctx);
 }
diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td
index dca228097d782d..a0afda4e3b465e 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.td
+++ b/mlir/include/mlir/IR/BuiltinTypes.td
@@ -80,7 +80,9 @@ def Builtin_Complex : Builtin_Type<"Complex", "complex"> {
 
 // Base class for Builtin dialect float types.
 class Builtin_FloatType<string name, string mnemonic>
-    : Builtin_Type<name, mnemonic, /*traits=*/[], "::mlir::FloatType"> {
+    : Builtin_Type<name, mnemonic, /*traits=*/[
+        DeclareTypeInterfaceMethods<FloatTypeInterface,
+                                    ["getFloatSemantics"]>]> {
   let extraClassDeclaration = [{
     static }] # name # [{Type get(MLIRContext *context);
   }];
diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
index ab9e65b5edfed3..1374e889833283 100644
--- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp
+++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Diagnostics.h"
+#include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/Sequence.h"
 
 using namespace mlir;
@@ -19,6 +20,34 @@ using namespace mlir::detail;
 
 #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// FloatType
+//===----------------------------------------------------------------------===//
+
+unsigned FloatType::getWidth() {
+  return APFloat::semanticsSizeInBits(getFloatSemantics());
+}
+
+FloatType FloatType::scaleElementBitwidth(unsigned scale) {
+  if (!scale)
+    return FloatType();
+  MLIRContext *ctx = getContext();
+  if (isF16() || isBF16()) {
+    if (scale == 2)
+      return FloatType::getF32(ctx);
+    if (scale == 4)
+      return FloatType::getF64(ctx);
+  }
+  if (isF32())
+    if (scale == 2)
+      return FloatType::getF64(ctx);
+  return FloatType();
+}
+
+unsigned FloatType::getFPMantissaWidth() {
+  return APFloat::semanticsPrecision(getFloatSemantics());
+}
+
 //===----------------------------------------------------------------------===//
 // ShapedType
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 6546234429c8cb..81e154328a4a2e 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -87,73 +87,33 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
 }
 
 //===----------------------------------------------------------------------===//
-// Float Type
-//===----------------------------------------------------------------------===//
-
-unsigned FloatType::getWidth() {
-  return APFloat::semanticsSizeInBits(getFloatSemantics());
-}
-
-/// Returns the floating semantics for the given type.
-const llvm::fltSemantics &FloatType::getFloatSemantics() {
-  if (llvm::isa<Float4E2M1FNType>(*this))
-    return APFloat::Float4E2M1FN();
-  if (llvm::isa<Float6E2M3FNType>(*this))
-    return APFloat::Float6E2M3FN();
-  if (llvm::isa<Float6E3M2FNType>(*this))
-    return APFloat::Float6E3M2FN();
-  if (llvm::isa<Float8E5M2Type>(*this))
-    return APFloat::Float8E5M2();
-  if (llvm::isa<Float8E4M3Type>(*this))
-    return APFloat::Float8E4M3();
-  if (llvm::isa<Float8E4M3FNType>(*this))
-    return APFloat::Float8E4M3FN();
-  if (llvm::isa<Float8E5M2FNUZType>(*this))
-    return APFloat::Float8E5M2FNUZ();
-  if (llvm::isa<Float8E4M3FNUZType>(*this))
-    return APFloat::Float8E4M3FNUZ();
-  if (llvm::isa<Float8E4M3B11FNUZType>(*this))
-    return APFloat::Float8E4M3B11FNUZ();
-  if (llvm::isa<Float8E3M4Type>(*this))
-    return APFloat::Float8E3M4();
-  if (llvm::isa<Float8E8M0FNUType>(*this))
-    return APFloat::Float8E8M0FNU();
-  if (llvm::isa<BFloat16Type>(*this))
-    return APFloat::BFloat();
-  if (llvm::isa<Float16Type>(*this))
-    return APFloat::IEEEhalf();
-  if (llvm::isa<FloatTF32Type>(*this))
-    return APFloat::FloatTF32();
-  if (llvm::isa<Float32Type>(*this))
-    return APFloat::IEEEsingle();
-  if (llvm::isa<Float64Type>(*this))
-    return APFloat::IEEEdouble();
-  if (llvm::isa<Float80Type>(*this))
-    return APFloat::x87DoubleExtended();
-  if (llvm::isa<Float128Type>(*this))
-    return APFloat::IEEEquad();
-  llvm_unreachable("non-floating point type used");
-}
-
-FloatType FloatType::scaleElementBitwidth(unsigned scale) {
-  if (!scale)
-    return FloatType();
-  MLIRContext *ctx = getContext();
-  if (isF16() || isBF16()) {
-    if (scale == 2)
-      return FloatType::getF32(ctx);
-    if (scale == 4)
-      return FloatType::getF64(ctx);
-  }
-  if (isF32())
-    if (scale == 2)
-      return FloatType::getF64(ctx);
-  return FloatType();
-}
+// Float Types
+//===----------------------------------------------------------------------===//
 
-unsigned FloatType::getFPMantissaWidth() {
-  return APFloat::semanticsPrecision(getFloatSemantics());
-}
+// Mapping from MLIR FloatType to APFloat semantics.
+#define FLOAT_TYPE_SEMANTICS(TYPE, SEM)                                        \
+  const llvm::fltSemantics &TYPE::getFloatSemantics() const {                  \
+    return APFloat::SEM();                                                     \
+  }
+FLOAT_TYPE_SEMANTICS(Float4E2M1FNType, Float4E2M1FN)
+FLOAT_TYPE_SEMANTICS(Float6E2M3FNType, Float6E2M3FN)
+FLOAT_TYPE_SEMANTICS(Float6E3M2FNType, Float6E3M2FN)
+FLOAT_TYPE_SEMANTICS(Float8E5M2Type, Float8E5M2)
+FLOAT_TYPE_SEMANTICS(Float8E4M3Type, Float8E4M3)
+FLOAT_TYPE_SEMANTICS(Float8E4M3FNType, Float8E4M3FN)
+FLOAT_TYPE_SEMANTICS(Float8E5M2FNUZType, Float8E5M2FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E4M3FNUZType, Float8E4M3FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E4M3B11FNUZType, Float8E4M3B11FNUZ)
+FLOAT_TYPE_SEMANTICS(Float8E3M4Type, Float8E3M4)
+FLOAT_TYPE_SEMANTICS(Float8E8M0FNUType, Float8E8M0FNU)
+FLOAT_TYPE_SEMANTICS(BFloat16Type, BFloat)
+FLOAT_TYPE_SEMANTICS(Float16Type, IEEEhalf)
+FLOAT_TYPE_SEMANTICS(FloatTF32Type, FloatTF32)
+FLOAT_TYPE_SEMANTICS(Float32Type, IEEEsingle)
+FLOAT_TYPE_SEMANTICS(Float64Type, IEEEdouble)
+FLOAT_TYPE_SEMANTICS(Float80Type, x87DoubleExtended)
+FLOAT_TYPE_SEMANTICS(Float128Type, IEEEquad)
+#undef FLOAT_TYPE_SEMANTICS
 
 //===----------------------------------------------------------------------===//
 // FunctionType
diff --git a/mlir/unittests/IR/InterfaceAttachmentTest.cpp b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
index b6066dd5685dc6..1b5d3b8c31bd22 100644
--- a/mlir/unittests/IR/InterfaceAttachmentTest.cpp
+++ b/mlir/unittests/IR/InterfaceAttachmentTest.cpp
@@ -43,7 +43,7 @@ struct Model
 /// overrides default methods.
 struct OverridingModel
     : public TestExternalTypeInterface::ExternalModel<OverridingModel,
-                                                      FloatType> {
+                                                      Float32Type> {
   unsigned getBitwidthPlusArg(Type type, unsigned arg) const {
     return type.getIntOrFloatBitWidth() + arg;
   }

@qcolombet
Copy link
Collaborator

How does this (will) connect to the LLVM IR dialect/pure LLVM IR?
I am wondering how say FP8 (whatever variant) gets lowered to LLVM given we don't have a Type instance that can represent that IIRC.

@matthias-springer
Copy link
Member Author

matthias-springer commented Dec 6, 2024

This PR does not remove the MLIR type classes (e.g., Float16Type). They are still there. The only difference is that FloatType (the super class) is no longer a hand-written class but a type interface. From a functional perspective this is almost NFC; the only difference is that you can no longer attach an interface to FloatType (because it's an interface now), that's why one test changed.

I am wondering how say FP8 (whatever variant) gets lowered to LLVM given we don't have a Type instance that can represent that IIRC.

LLVM does not seem to support "fancy" FP types in the type system. We use integer types instead. See here.

The one place that must extended in LLVM to support new floating-point types are the floating-point semantics. That's not possible yet, but I plan to make that part extensible as a follow-up. That's a bigger change.

@matthias-springer
Copy link
Member Author

I wrote a micro benchmark that parses + prints 32768 floats with random floating-point type: https://gist.github.com/matthias-springer/d22cc02b6097553b78bbba39fed8ca1e

Going through the type interface causes a slowdown compared to the hard-coded sequence of if checks in getFloatSemantics.

Benchmark 1: mlir-opt test.mlir -allow-unregistered-dialect
  BEFORE
  Time (mean ± σ):      43.3 ms ±   1.8 ms    [User: 31.9 ms, System: 11.4 ms]
  Range (min … max):    39.8 ms …  48.3 ms    200 runs

  AFTER
  Time (mean ± σ):      50.3 ms ±   1.8 ms    [User: 38.8 ms, System: 11.5 ms]
  Range (min … max):    47.3 ms …  55.1 ms    200 runs

@River707
Copy link
Contributor

The immediate changes here look reasonable. Have you done a sweep of the current users of FloatType to see if any of them should be tightened? Opening up to an interface suddenly expands the contract of FloatType.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/float_type_interface branch from c996d3f to 1ebf66d Compare December 16, 2024 09:43
@matthias-springer
Copy link
Member Author

The immediate changes here look reasonable. Have you done a sweep of the current users of FloatType to see if any of them should be tightened? Opening up to an interface suddenly expands the contract of FloatType.

From a user perspective this commit is mostly NFC. It does not matter if you have a concrete type or an interface. The API surface is almost identical. There is only one difference that I found: users can no longer attach an interface to FloatType, they have to attach it to a concrete float type (see changed test).

];

let extraClassDeclaration = [{
// Convenience factories.
Copy link
Member Author

@matthias-springer matthias-springer Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I think these should also be removed, but that's a larger change because existing code uses them a lot. (And it's unclear what to use instead. Maybe builder.getF32Type() etc...) So I'm just keeping them here for now.

Copy link
Contributor

@River707 River707 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

This makes it possible to add new floating point types in downstream projects. Also removes one place where we had to hard-code all existing floating point types (`FloatType::classof`).
@matthias-springer matthias-springer force-pushed the users/matthias-springer/float_type_interface branch from 1ebf66d to 8613fb3 Compare January 15, 2025 07:34
@matthias-springer matthias-springer merged commit c24ce32 into main Jan 15, 2025
8 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/float_type_interface branch January 15, 2025 08:47
paulhuggett pushed a commit to paulhuggett/llvm-project that referenced this pull request Jan 16, 2025
This makes it possible to add new MLIR floating point types in
downstream projects. (Adding new APFloat semantics in downstream
projects is not possible yet, so parsing/printing/converting float
literals of newly added types is not supported.)

Also removes two functions where we had to hard-code all existing
floating point types (`FloatType::classof`). See discussion here:
https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361

No measurable compilation time changes for these lit tests:
```
Benchmark 1: mlir-opt ./mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir -split-input-file -convert-vector-to-llvm -o /dev/null
  BEFORE
  Time (mean ± σ):     248.4 ms ±   3.2 ms    [User: 237.0 ms, System: 20.1 ms]
  Range (min … max):   243.3 ms … 255.9 ms    30 runs

  AFTER
  Time (mean ± σ):     246.8 ms ±   3.2 ms    [User: 233.2 ms, System: 21.8 ms]
  Range (min … max):   240.2 ms … 252.1 ms    30 runs


Benchmark 2: mlir-opt- ./mlir/test/Dialect/Arith/canonicalize.mlir -split-input-file -canonicalize -o /dev/null
  BEFORE
  Time (mean ± σ):      37.3 ms ±   1.8 ms    [User: 31.6 ms, System: 30.4 ms]
  Range (min … max):    34.6 ms …  42.0 ms    200 runs

  AFTER
  Time (mean ± σ):      37.5 ms ±   2.0 ms    [User: 31.5 ms, System: 29.2 ms]
  Range (min … max):    34.5 ms …  43.0 ms    200 runs


Benchmark 3: mlir-opt ./mlir/test/Dialect/Tensor/canonicalize.mlir -split-input-file -canonicalize -allow-unregistered-dialect -o /dev/null
  BEFORE
  Time (mean ± σ):     152.2 ms ±   2.5 ms    [User: 140.1 ms, System: 12.2 ms]
  Range (min … max):   147.6 ms … 161.8 ms    200 runs

  AFTER
  Time (mean ± σ):     151.9 ms ±   2.7 ms    [User: 140.5 ms, System: 11.5 ms]
  Range (min … max):   147.2 ms … 159.1 ms    200 runs
```

A micro benchmark that parses + prints 32768 floats with random
floating-point type shows a slowdown from 55.1 ms -> 48.3 ms.
DKLoehr pushed a commit to DKLoehr/llvm-project that referenced this pull request Jan 17, 2025
This makes it possible to add new MLIR floating point types in
downstream projects. (Adding new APFloat semantics in downstream
projects is not possible yet, so parsing/printing/converting float
literals of newly added types is not supported.)

Also removes two functions where we had to hard-code all existing
floating point types (`FloatType::classof`). See discussion here:
https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361

No measurable compilation time changes for these lit tests:
```
Benchmark 1: mlir-opt ./mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir -split-input-file -convert-vector-to-llvm -o /dev/null
  BEFORE
  Time (mean ± σ):     248.4 ms ±   3.2 ms    [User: 237.0 ms, System: 20.1 ms]
  Range (min … max):   243.3 ms … 255.9 ms    30 runs

  AFTER
  Time (mean ± σ):     246.8 ms ±   3.2 ms    [User: 233.2 ms, System: 21.8 ms]
  Range (min … max):   240.2 ms … 252.1 ms    30 runs


Benchmark 2: mlir-opt- ./mlir/test/Dialect/Arith/canonicalize.mlir -split-input-file -canonicalize -o /dev/null
  BEFORE
  Time (mean ± σ):      37.3 ms ±   1.8 ms    [User: 31.6 ms, System: 30.4 ms]
  Range (min … max):    34.6 ms …  42.0 ms    200 runs

  AFTER
  Time (mean ± σ):      37.5 ms ±   2.0 ms    [User: 31.5 ms, System: 29.2 ms]
  Range (min … max):    34.5 ms …  43.0 ms    200 runs


Benchmark 3: mlir-opt ./mlir/test/Dialect/Tensor/canonicalize.mlir -split-input-file -canonicalize -allow-unregistered-dialect -o /dev/null
  BEFORE
  Time (mean ± σ):     152.2 ms ±   2.5 ms    [User: 140.1 ms, System: 12.2 ms]
  Range (min … max):   147.6 ms … 161.8 ms    200 runs

  AFTER
  Time (mean ± σ):     151.9 ms ±   2.7 ms    [User: 140.5 ms, System: 11.5 ms]
  Range (min … max):   147.2 ms … 159.1 ms    200 runs
```

A micro benchmark that parses + prints 32768 floats with random
floating-point type shows a slowdown from 55.1 ms -> 48.3 ms.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:ods mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants