Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][spirv] Add conversions for Arith's maxnumf and minnumf #66696

Merged
merged 1 commit into from
Sep 19, 2023

Conversation

unterumarmung
Copy link
Contributor

This patch is part of a larger initiative aimed at fixing floating-point max and min operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

In this commit, we add conversion patterns for the newly introduced operations arith.minnumf and arith.maxnumf. When converting to spirv.CL, there is no need to insert additional guards to propagate non-NaN values when one of the arguments is NaN because CL ops do exactly the same. However, GL ops have undefined behavior when one of the arguments is NaN, so we should insert additional guards to enforce the semantics of Arith's ops.

This patch addresses the 1.5 task of the mentioned RFC.

@llvmbot
Copy link
Member

llvmbot commented Sep 18, 2023

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Changes

This patch is part of a larger initiative aimed at fixing floating-point max and min operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

In this commit, we add conversion patterns for the newly introduced operations arith.minnumf and arith.maxnumf. When converting to spirv.CL, there is no need to insert additional guards to propagate non-NaN values when one of the arguments is NaN because CL ops do exactly the same. However, GL ops have undefined behavior when one of the arguments is NaN, so we should insert additional guards to enforce the semantics of Arith's ops.

This patch addresses the 1.5 task of the mentioned RFC.


Full diff: https://github.com/llvm/llvm-project/pull/66696.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+60)
  • (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir (+53-8)
  • (modified) mlir/test/Conversion/ArithToSPIRV/fast-math.mlir (+22-4)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index a589fb8050f34db..aba6a21deccb0cf 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -19,6 +19,7 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/MathExtras.h"
 #include <cassert>
@@ -1086,6 +1087,61 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// MinNumFOp, MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
+/// spirv.CL.fmax/fmin.
+template <typename Op, typename SPIRVOp>
+class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
+  template <typename TargetOp>
+  constexpr bool shouldInsertNanGuards() const {
+    return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
+  }
+
+public:
+  using OpConversionPattern<Op>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type dstType = converter->convertType(op.getType());
+    if (!dstType)
+      return getTypeConversionFailure(rewriter, op);
+
+    // arith.maxnumf/minnumf:
+    //   "If one of the arguments is NaN, then the result is the other
+    //   argument."
+    // spirv.GL.FMax/FMin
+    //   "which operand is the result is undefined if one of the operands
+    //   is a NaN."
+    // spirv.CL.fmax/fmin:
+    //   "If one argument is a NaN, Fmin returns the other argument."
+
+    Location loc = op.getLoc();
+    Value spirvOp =
+        rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+
+    if (!shouldInsertNanGuards<SPIRVOp>() ||
+        converter->getOptions().enableFastMathMode) {
+      rewriter.replaceOp(op, spirvOp);
+      return success();
+    }
+
+    Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
+    Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+
+    Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
+                                                     adaptor.getRhs(), spirvOp);
+    Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
+                                                     adaptor.getLhs(), select1);
+
+    rewriter.replaceOp(op, select2);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1138,6 +1194,8 @@ void mlir::arith::populateArithToSPIRVPatterns(
 
     MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
     MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
+    MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
+    MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
     spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
     spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
     spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
@@ -1145,6 +1203,8 @@ void mlir::arith::populateArithToSPIRVPatterns(
 
     MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
     MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
+    MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
+    MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
     spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::CLSMaxOp>,
     spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::CLUMaxOp>,
     spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::CLSMinOp>,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 165877eb554e245..0221e4815a9397d 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -1124,9 +1124,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
   return
 }
 
-// CHECK-LABEL: @float32_minf_scalar
+// CHECK-LABEL: @float32_minimumf_scalar
 // CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
-func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
   // CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32
   // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
   // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
@@ -1137,9 +1137,18 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
   return %0: f32
 }
 
-// CHECK-LABEL: @float32_maxf_scalar
+// CHECK-LABEL: @float32_minnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32
+  %0 = arith.minnumf %arg0, %arg1 : f32
+  // CHECK: return %[[MIN]]
+  return %0: f32
+}
+
+// CHECK-LABEL: @float32_maximumf_scalar
 // CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
-func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
   // CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32>
   // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
   // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
@@ -1150,6 +1159,16 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) ->
   return %0: vector<2xf32>
 }
 
+// CHECK-LABEL: @float32_maxnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
+func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+  // CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32>
+  %0 = arith.maxnumf %arg0, %arg1 : vector<2xf32>
+  // CHECK: return %[[MAX]]
+  return %0: vector<2xf32>
+}
+
+
 // CHECK-LABEL: @scalar_srem
 // CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
 func.func @scalar_srem(%lhs: i32, %rhs: i32) {
@@ -1270,9 +1289,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
   return
 }
 
-// CHECK-LABEL: @float32_minf_scalar
+// CHECK-LABEL: @float32_minimumf_scalar
 // CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
-func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
   // CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32
   // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
   // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
@@ -1283,9 +1302,22 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
   return %0: f32
 }
 
-// CHECK-LABEL: @float32_maxf_scalar
+// CHECK-LABEL: @float32_minnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32
+  // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
+  // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
+  // CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MIN]]
+  // CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]]
+  %0 = arith.minnumf %arg0, %arg1 : f32
+  // CHECK: return %[[SELECT2]]
+  return %0: f32
+}
+
+// CHECK-LABEL: @float32_maximumf_scalar
 // CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
-func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
   // CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32>
   // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
   // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
@@ -1296,6 +1328,19 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) ->
   return %0: vector<2xf32>
 }
 
+// CHECK-LABEL: @float32_maxnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
+func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+  // CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32>
+  // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
+  // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
+  // CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MAX]]
+  // CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]]
+  %0 = arith.maxnumf %arg0, %arg1 : vector<2xf32>
+  // CHECK: return %[[SELECT2]]
+  return %0: vector<2xf32>
+}
+
 // Check int vector types.
 // CHECK-LABEL: @int_vector234
 func.func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) {
diff --git a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
index 9dea7d6623885e4..dbf0361c2ab35bb 100644
--- a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
@@ -30,22 +30,40 @@ module attributes {
   spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
 } {
 
-// CHECK-LABEL: @minf
+// CHECK-LABEL: @minimumf
 // CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
-func.func @minf(%arg0 : f32, %arg1 : f32) -> f32 {
+func.func @minimumf(%arg0 : f32, %arg1 : f32) -> f32 {
   // CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
   %0 = arith.minimumf %arg0, %arg1 : f32
   // CHECK: return %[[F]]
   return %0: f32
 }
 
-// CHECK-LABEL: @maxf
+// CHECK-LABEL: @maximumf
 // CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
-func.func @maxf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
+func.func @maximumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
   // CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
   %0 = arith.maximumf %arg0, %arg1 : vector<4xf32>
   // CHECK: return %[[F]]
   return %0: vector<4xf32>
 }
 
+// CHECK-LABEL: @minnumf
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @minnumf(%arg0 : f32, %arg1 : f32) -> f32 {
+  // CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
+  %0 = arith.minnumf %arg0, %arg1 : f32
+  // CHECK: return %[[F]]
+  return %0: f32
+}
+
+// CHECK-LABEL: @maxnumf
+// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
+func.func @maxnumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
+  // CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
+  %0 = arith.maxnumf %arg0, %arg1 : vector<4xf32>
+  // CHECK: return %[[F]]
+  return %0: vector<4xf32>
+}
+
 } // end module

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for plumbing this through.

Comment on lines +1132 to +1136
Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());

Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
adaptor.getRhs(), spirvOp);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we lower it to something like:

%lhsIsNan = spirv.IsNan %lhs
%x = spirv.Select %lhsIsNan, %rhs, %lhs
%rhsIsNan = spirv.IsNan %rhs
%y = spirv.Select %rhsIsNan, %lhs, %rhs
%res = spirv.GL.FMax %x, %y

I don't know which of the lowerings should be preferred, just wanted to explore the alternatives.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah no, the difference is that arith.*numf guarantees a NaN output when both inputs are NaN, and in spirv.GL we would get an undefined result. So the current lowering is the way to go.

This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

In this commit, we add conversion patterns for the newly introduced operations `arith.minnumf` and `arith.maxnumf`. When converting to `spirv.CL`, there is no need to insert additional guards to propagate non-NaN values when one of the arguments is NaN because `CL` ops do exactly the same. However, `GL` ops have undefined behavior when one of the arguments is NaN, so we should insert additional guards to enforce the semantics of Arith's ops.

This patch addresses the 1.5 task of the mentioned RFC.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants