-
Notifications
You must be signed in to change notification settings - Fork 12.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][spirv] Add conversions for Arith's maxnumf
and minnumf
#66696
Conversation
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir ChangesThis patch is part of a larger initiative aimed at fixing floating-point In this commit, we add conversion patterns for the newly introduced operations This patch addresses the 1.5 task of the mentioned RFC. Full diff: https://github.com/llvm/llvm-project/pull/66696.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index a589fb8050f34db..aba6a21deccb0cf 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
@@ -1086,6 +1087,61 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
}
};
+//===----------------------------------------------------------------------===//
+// MinNumFOp, MaxNumFOp
+//===----------------------------------------------------------------------===//
+
+/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
+/// spirv.CL.fmax/fmin.
+template <typename Op, typename SPIRVOp>
+class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
+ template <typename TargetOp>
+ constexpr bool shouldInsertNanGuards() const {
+ return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
+ }
+
+public:
+ using OpConversionPattern<Op>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(Op op, typename Op::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
+ Type dstType = converter->convertType(op.getType());
+ if (!dstType)
+ return getTypeConversionFailure(rewriter, op);
+
+ // arith.maxnumf/minnumf:
+ // "If one of the arguments is NaN, then the result is the other
+ // argument."
+ // spirv.GL.FMax/FMin
+ // "which operand is the result is undefined if one of the operands
+ // is a NaN."
+ // spirv.CL.fmax/fmin:
+ // "If one argument is a NaN, Fmin returns the other argument."
+
+ Location loc = op.getLoc();
+ Value spirvOp =
+ rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());
+
+ if (!shouldInsertNanGuards<SPIRVOp>() ||
+ converter->getOptions().enableFastMathMode) {
+ rewriter.replaceOp(op, spirvOp);
+ return success();
+ }
+
+ Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
+ Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
+
+ Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
+ adaptor.getRhs(), spirvOp);
+ Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
+ adaptor.getLhs(), select1);
+
+ rewriter.replaceOp(op, select2);
+ return success();
+ }
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -1138,6 +1194,8 @@ void mlir::arith::populateArithToSPIRVPatterns(
MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
+ MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
+ MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
@@ -1145,6 +1203,8 @@ void mlir::arith::populateArithToSPIRVPatterns(
MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
+ MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
+ MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::CLSMaxOp>,
spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::CLUMaxOp>,
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::CLSMinOp>,
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 165877eb554e245..0221e4815a9397d 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -1124,9 +1124,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
return
}
-// CHECK-LABEL: @float32_minf_scalar
+// CHECK-LABEL: @float32_minimumf_scalar
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
-func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
@@ -1137,9 +1137,18 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
return %0: f32
}
-// CHECK-LABEL: @float32_maxf_scalar
+// CHECK-LABEL: @float32_minnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+ // CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32
+ %0 = arith.minnumf %arg0, %arg1 : f32
+ // CHECK: return %[[MIN]]
+ return %0: f32
+}
+
+// CHECK-LABEL: @float32_maximumf_scalar
// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
-func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
// CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32>
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
@@ -1150,6 +1159,16 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) ->
return %0: vector<2xf32>
}
+// CHECK-LABEL: @float32_maxnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
+func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+ // CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32>
+ %0 = arith.maxnumf %arg0, %arg1 : vector<2xf32>
+ // CHECK: return %[[MAX]]
+ return %0: vector<2xf32>
+}
+
+
// CHECK-LABEL: @scalar_srem
// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
func.func @scalar_srem(%lhs: i32, %rhs: i32) {
@@ -1270,9 +1289,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
return
}
-// CHECK-LABEL: @float32_minf_scalar
+// CHECK-LABEL: @float32_minimumf_scalar
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
-func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
@@ -1283,9 +1302,22 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
return %0: f32
}
-// CHECK-LABEL: @float32_maxf_scalar
+// CHECK-LABEL: @float32_minnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
+ // CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32
+ // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
+ // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
+ // CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MIN]]
+ // CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]]
+ %0 = arith.minnumf %arg0, %arg1 : f32
+ // CHECK: return %[[SELECT2]]
+ return %0: f32
+}
+
+// CHECK-LABEL: @float32_maximumf_scalar
// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
-func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
// CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32>
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
@@ -1296,6 +1328,19 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) ->
return %0: vector<2xf32>
}
+// CHECK-LABEL: @float32_maxnumf_scalar
+// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
+func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
+ // CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32>
+ // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
+ // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
+ // CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MAX]]
+ // CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]]
+ %0 = arith.maxnumf %arg0, %arg1 : vector<2xf32>
+ // CHECK: return %[[SELECT2]]
+ return %0: vector<2xf32>
+}
+
// Check int vector types.
// CHECK-LABEL: @int_vector234
func.func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) {
diff --git a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
index 9dea7d6623885e4..dbf0361c2ab35bb 100644
--- a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
@@ -30,22 +30,40 @@ module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
} {
-// CHECK-LABEL: @minf
+// CHECK-LABEL: @minimumf
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
-func.func @minf(%arg0 : f32, %arg1 : f32) -> f32 {
+func.func @minimumf(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
%0 = arith.minimumf %arg0, %arg1 : f32
// CHECK: return %[[F]]
return %0: f32
}
-// CHECK-LABEL: @maxf
+// CHECK-LABEL: @maximumf
// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
-func.func @maxf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
+func.func @maximumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
// CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
%0 = arith.maximumf %arg0, %arg1 : vector<4xf32>
// CHECK: return %[[F]]
return %0: vector<4xf32>
}
+// CHECK-LABEL: @minnumf
+// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
+func.func @minnumf(%arg0 : f32, %arg1 : f32) -> f32 {
+ // CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
+ %0 = arith.minnumf %arg0, %arg1 : f32
+ // CHECK: return %[[F]]
+ return %0: f32
+}
+
+// CHECK-LABEL: @maxnumf
+// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
+func.func @maxnumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
+ // CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
+ %0 = arith.maxnumf %arg0, %arg1 : vector<4xf32>
+ // CHECK: return %[[F]]
+ return %0: vector<4xf32>
+}
+
} // end module
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for plumbing this through.
Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs()); | ||
Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs()); | ||
|
||
Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan, | ||
adaptor.getRhs(), spirvOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we lower it to something like:
%lhsIsNan = spirv.IsNan %lhs
%x = spirv.Select %lhsIsNan, %rhs, %lhs
%rhsIsNan = spirv.IsNan %rhs
%y = spirv.Select %rhsIsNan, %lhs, %rhs
%res = spirv.GL.FMax %x, %y
I don't know which of the lowerings should be preferred, just wanted to explore the alternatives.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah no, the difference is that arith.*numf
guarantees a NaN output when both inputs are NaN, and in spirv.GL
we would get an undefined result. So the current lowering is the way to go.
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. In this commit, we add conversion patterns for the newly introduced operations `arith.minnumf` and `arith.maxnumf`. When converting to `spirv.CL`, there is no need to insert additional guards to propagate non-NaN values when one of the arguments is NaN because `CL` ops do exactly the same. However, `GL` ops have undefined behavior when one of the arguments is NaN, so we should insert additional guards to enforce the semantics of Arith's ops. This patch addresses the 1.5 task of the mentioned RFC.
1cbe674
to
5732946
Compare
This patch is part of a larger initiative aimed at fixing floating-point
max
andmin
operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.In this commit, we add conversion patterns for the newly introduced operations
arith.minnumf
andarith.maxnumf
. When converting tospirv.CL
, there is no need to insert additional guards to propagate non-NaN values when one of the arguments is NaN becauseCL
ops do exactly the same. However,GL
ops have undefined behavior when one of the arguments is NaN, so we should insert additional guards to enforce the semantics of Arith's ops.This patch addresses the 1.5 task of the mentioned RFC.