diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index a589fb8050f34d..aba6a21deccb0c 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" #include @@ -1086,6 +1087,61 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern { } }; +//===----------------------------------------------------------------------===// +// MinNumFOp, MaxNumFOp +//===----------------------------------------------------------------------===// + +/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or +/// spirv.CL.fmax/fmin. +template +class MinNumMaxNumFOpPattern final : public OpConversionPattern { + template + constexpr bool shouldInsertNanGuards() const { + return llvm::is_one_of::value; + } + +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(Op op, typename Op::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *converter = this->template getTypeConverter(); + Type dstType = converter->convertType(op.getType()); + if (!dstType) + return getTypeConversionFailure(rewriter, op); + + // arith.maxnumf/minnumf: + // "If one of the arguments is NaN, then the result is the other + // argument." + // spirv.GL.FMax/FMin + // "which operand is the result is undefined if one of the operands + // is a NaN." + // spirv.CL.fmax/fmin: + // "If one argument is a NaN, Fmin returns the other argument." + + Location loc = op.getLoc(); + Value spirvOp = + rewriter.create(loc, dstType, adaptor.getOperands()); + + if (!shouldInsertNanGuards() || + converter->getOptions().enableFastMathMode) { + rewriter.replaceOp(op, spirvOp); + return success(); + } + + Value lhsIsNan = rewriter.create(loc, adaptor.getLhs()); + Value rhsIsNan = rewriter.create(loc, adaptor.getRhs()); + + Value select1 = rewriter.create(loc, dstType, lhsIsNan, + adaptor.getRhs(), spirvOp); + Value select2 = rewriter.create(loc, dstType, rhsIsNan, + adaptor.getLhs(), select1); + + rewriter.replaceOp(op, select2); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -1138,6 +1194,8 @@ void mlir::arith::populateArithToSPIRVPatterns( MinimumMaximumFOpPattern, MinimumMaximumFOpPattern, + MinNumMaxNumFOpPattern, + MinNumMaxNumFOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, @@ -1145,6 +1203,8 @@ void mlir::arith::populateArithToSPIRVPatterns( MinimumMaximumFOpPattern, MinimumMaximumFOpPattern, + MinNumMaxNumFOpPattern, + MinNumMaxNumFOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, spirv::ElementwiseOpPattern, diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir index 165877eb554e24..0221e4815a9397 100644 --- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir @@ -1124,9 +1124,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) { return } -// CHECK-LABEL: @float32_minf_scalar +// CHECK-LABEL: @float32_minimumf_scalar // CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32 -func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 { +func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 { // CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32 // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32 // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32 @@ -1137,9 +1137,18 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 { return %0: f32 } -// CHECK-LABEL: @float32_maxf_scalar +// CHECK-LABEL: @float32_minnumf_scalar +// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32 +func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32 + %0 = arith.minnumf %arg0, %arg1 : f32 + // CHECK: return %[[MIN]] + return %0: f32 +} + +// CHECK-LABEL: @float32_maximumf_scalar // CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32> -func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> { +func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> { // CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32> // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32> // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32> @@ -1150,6 +1159,16 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> return %0: vector<2xf32> } +// CHECK-LABEL: @float32_maxnumf_scalar +// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32> +func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> { + // CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32> + %0 = arith.maxnumf %arg0, %arg1 : vector<2xf32> + // CHECK: return %[[MAX]] + return %0: vector<2xf32> +} + + // CHECK-LABEL: @scalar_srem // CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) func.func @scalar_srem(%lhs: i32, %rhs: i32) { @@ -1270,9 +1289,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) { return } -// CHECK-LABEL: @float32_minf_scalar +// CHECK-LABEL: @float32_minimumf_scalar // CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32 -func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 { +func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 { // CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32 // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32 // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32 @@ -1283,9 +1302,22 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 { return %0: f32 } -// CHECK-LABEL: @float32_maxf_scalar +// CHECK-LABEL: @float32_minnumf_scalar +// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32 +func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32 + // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32 + // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32 + // CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MIN]] + // CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]] + %0 = arith.minnumf %arg0, %arg1 : f32 + // CHECK: return %[[SELECT2]] + return %0: f32 +} + +// CHECK-LABEL: @float32_maximumf_scalar // CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32> -func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> { +func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> { // CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32> // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32> // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32> @@ -1296,6 +1328,19 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> return %0: vector<2xf32> } +// CHECK-LABEL: @float32_maxnumf_scalar +// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32> +func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> { + // CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32> + // CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32> + // CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32> + // CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MAX]] + // CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]] + %0 = arith.maxnumf %arg0, %arg1 : vector<2xf32> + // CHECK: return %[[SELECT2]] + return %0: vector<2xf32> +} + // Check int vector types. // CHECK-LABEL: @int_vector234 func.func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) { diff --git a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir index 9dea7d6623885e..dbf0361c2ab35b 100644 --- a/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir +++ b/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir @@ -30,22 +30,40 @@ module attributes { spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits<>> } { -// CHECK-LABEL: @minf +// CHECK-LABEL: @minimumf // CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32 -func.func @minf(%arg0 : f32, %arg1 : f32) -> f32 { +func.func @minimumf(%arg0 : f32, %arg1 : f32) -> f32 { // CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]] %0 = arith.minimumf %arg0, %arg1 : f32 // CHECK: return %[[F]] return %0: f32 } -// CHECK-LABEL: @maxf +// CHECK-LABEL: @maximumf // CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32> -func.func @maxf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> { +func.func @maximumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> { // CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]] %0 = arith.maximumf %arg0, %arg1 : vector<4xf32> // CHECK: return %[[F]] return %0: vector<4xf32> } +// CHECK-LABEL: @minnumf +// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32 +func.func @minnumf(%arg0 : f32, %arg1 : f32) -> f32 { + // CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]] + %0 = arith.minnumf %arg0, %arg1 : f32 + // CHECK: return %[[F]] + return %0: f32 +} + +// CHECK-LABEL: @maxnumf +// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32> +func.func @maxnumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> { + // CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]] + %0 = arith.maxnumf %arg0, %arg1 : vector<4xf32> + // CHECK: return %[[F]] + return %0: vector<4xf32> +} + } // end module