Skip to content

Commit

Permalink
[mlir][spirv] Add conversions for Arith's maxnumf and minnumf (#6…
Browse files Browse the repository at this point in the history
…6696)

This patch is part of a larger initiative aimed at fixing floating-point
`max` and `min` operations in MLIR:
https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671.

In this commit, we add conversion patterns for the newly introduced
operations `arith.minnumf` and `arith.maxnumf`. When converting to
`spirv.CL`, there is no need to insert additional guards to propagate
non-NaN values when one of the arguments is NaN because `CL` ops do
exactly the same. However, `GL` ops have undefined behavior when one of
the arguments is NaN, so we should insert additional guards to enforce
the semantics of Arith's ops.

This patch addresses the 1.5 task of the mentioned RFC.
  • Loading branch information
unterumarmung authored Sep 19, 2023
1 parent 80c01dd commit 641124a
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 12 deletions.
60 changes: 60 additions & 0 deletions mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
Expand Down Expand Up @@ -1086,6 +1087,61 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
}
};

//===----------------------------------------------------------------------===//
// MinNumFOp, MaxNumFOp
//===----------------------------------------------------------------------===//

/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
/// spirv.CL.fmax/fmin.
template <typename Op, typename SPIRVOp>
class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
template <typename TargetOp>
constexpr bool shouldInsertNanGuards() const {
return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
}

public:
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
Type dstType = converter->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);

// arith.maxnumf/minnumf:
// "If one of the arguments is NaN, then the result is the other
// argument."
// spirv.GL.FMax/FMin
// "which operand is the result is undefined if one of the operands
// is a NaN."
// spirv.CL.fmax/fmin:
// "If one argument is a NaN, Fmin returns the other argument."

Location loc = op.getLoc();
Value spirvOp =
rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());

if (!shouldInsertNanGuards<SPIRVOp>() ||
converter->getOptions().enableFastMathMode) {
rewriter.replaceOp(op, spirvOp);
return success();
}

Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());

Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
adaptor.getRhs(), spirvOp);
Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
adaptor.getLhs(), select1);

rewriter.replaceOp(op, select2);
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1138,13 +1194,17 @@ void mlir::arith::populateArithToSPIRVPatterns(

MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>,

MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::CLSMaxOp>,
spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::CLUMaxOp>,
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::CLSMinOp>,
Expand Down
61 changes: 53 additions & 8 deletions mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1124,9 +1124,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
return
}

// CHECK-LABEL: @float32_minf_scalar
// CHECK-LABEL: @float32_minimumf_scalar
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
Expand All @@ -1137,9 +1137,18 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
return %0: f32
}

// CHECK-LABEL: @float32_maxf_scalar
// CHECK-LABEL: @float32_minnumf_scalar
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32
%0 = arith.minnumf %arg0, %arg1 : f32
// CHECK: return %[[MIN]]
return %0: f32
}

// CHECK-LABEL: @float32_maximumf_scalar
// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
// CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32>
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
Expand All @@ -1150,6 +1159,16 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) ->
return %0: vector<2xf32>
}

// CHECK-LABEL: @float32_maxnumf_scalar
// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
// CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32>
%0 = arith.maxnumf %arg0, %arg1 : vector<2xf32>
// CHECK: return %[[MAX]]
return %0: vector<2xf32>
}


// CHECK-LABEL: @scalar_srem
// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
func.func @scalar_srem(%lhs: i32, %rhs: i32) {
Expand Down Expand Up @@ -1270,9 +1289,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
return
}

// CHECK-LABEL: @float32_minf_scalar
// CHECK-LABEL: @float32_minimumf_scalar
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
Expand All @@ -1283,9 +1302,22 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
return %0: f32
}

// CHECK-LABEL: @float32_maxf_scalar
// CHECK-LABEL: @float32_minnumf_scalar
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MIN]]
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]]
%0 = arith.minnumf %arg0, %arg1 : f32
// CHECK: return %[[SELECT2]]
return %0: f32
}

// CHECK-LABEL: @float32_maximumf_scalar
// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
// CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32>
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
Expand All @@ -1296,6 +1328,19 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) ->
return %0: vector<2xf32>
}

// CHECK-LABEL: @float32_maxnumf_scalar
// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
// CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32>
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MAX]]
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]]
%0 = arith.maxnumf %arg0, %arg1 : vector<2xf32>
// CHECK: return %[[SELECT2]]
return %0: vector<2xf32>
}

// Check int vector types.
// CHECK-LABEL: @int_vector234
func.func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) {
Expand Down
26 changes: 22 additions & 4 deletions mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,40 @@ module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
} {

// CHECK-LABEL: @minf
// CHECK-LABEL: @minimumf
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
func.func @minf(%arg0 : f32, %arg1 : f32) -> f32 {
func.func @minimumf(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
%0 = arith.minimumf %arg0, %arg1 : f32
// CHECK: return %[[F]]
return %0: f32
}

// CHECK-LABEL: @maxf
// CHECK-LABEL: @maximumf
// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
func.func @maxf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
func.func @maximumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
// CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
%0 = arith.maximumf %arg0, %arg1 : vector<4xf32>
// CHECK: return %[[F]]
return %0: vector<4xf32>
}

// CHECK-LABEL: @minnumf
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
func.func @minnumf(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
%0 = arith.minnumf %arg0, %arg1 : f32
// CHECK: return %[[F]]
return %0: f32
}

// CHECK-LABEL: @maxnumf
// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
func.func @maxnumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
// CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
%0 = arith.maxnumf %arg0, %arg1 : vector<4xf32>
// CHECK: return %[[F]]
return %0: vector<4xf32>
}

} // end module

0 comments on commit 641124a

Please sign in to comment.