Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][spirv] Add conversions for Arith's maxnumf and minnumf #66696

Merged
merged 1 commit into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
Expand Down Expand Up @@ -1086,6 +1087,61 @@ class MinimumMaximumFOpPattern final : public OpConversionPattern<Op> {
}
};

//===----------------------------------------------------------------------===//
// MinNumFOp, MaxNumFOp
//===----------------------------------------------------------------------===//

/// Converts arith.maxnumf/minnumf to spirv.GL.FMax/FMin or
/// spirv.CL.fmax/fmin.
template <typename Op, typename SPIRVOp>
class MinNumMaxNumFOpPattern final : public OpConversionPattern<Op> {
template <typename TargetOp>
constexpr bool shouldInsertNanGuards() const {
return llvm::is_one_of<TargetOp, spirv::GLFMaxOp, spirv::GLFMinOp>::value;
}

public:
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *converter = this->template getTypeConverter<SPIRVTypeConverter>();
Type dstType = converter->convertType(op.getType());
if (!dstType)
return getTypeConversionFailure(rewriter, op);

// arith.maxnumf/minnumf:
// "If one of the arguments is NaN, then the result is the other
// argument."
// spirv.GL.FMax/FMin
// "which operand is the result is undefined if one of the operands
// is a NaN."
// spirv.CL.fmax/fmin:
// "If one argument is a NaN, Fmin returns the other argument."

Location loc = op.getLoc();
Value spirvOp =
rewriter.create<SPIRVOp>(loc, dstType, adaptor.getOperands());

if (!shouldInsertNanGuards<SPIRVOp>() ||
converter->getOptions().enableFastMathMode) {
rewriter.replaceOp(op, spirvOp);
return success();
}

Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());

Value select1 = rewriter.create<spirv::SelectOp>(loc, dstType, lhsIsNan,
adaptor.getRhs(), spirvOp);
Comment on lines +1132 to +1136
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we lower it to something like:

%lhsIsNan = spirv.IsNan %lhs
%x = spirv.Select %lhsIsNan, %rhs, %lhs
%rhsIsNan = spirv.IsNan %rhs
%y = spirv.Select %rhsIsNan, %lhs, %rhs
%res = spirv.GL.FMax %x, %y

I don't know which of the lowerings should be preferred, just wanted to explore the alternatives.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah no, the difference is that arith.*numf guarantees a NaN output when both inputs are NaN, and in spirv.GL we would get an undefined result. So the current lowering is the way to go.

Value select2 = rewriter.create<spirv::SelectOp>(loc, dstType, rhsIsNan,
adaptor.getLhs(), select1);

rewriter.replaceOp(op, select2);
return success();
}
};

} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1138,13 +1194,17 @@ void mlir::arith::populateArithToSPIRVPatterns(

MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::GLFMaxOp>,
MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::GLFMinOp>,
MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::GLFMaxOp>,
MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::GLFMinOp>,
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>,

MinimumMaximumFOpPattern<arith::MaximumFOp, spirv::CLFMaxOp>,
MinimumMaximumFOpPattern<arith::MinimumFOp, spirv::CLFMinOp>,
MinNumMaxNumFOpPattern<arith::MaxNumFOp, spirv::CLFMaxOp>,
MinNumMaxNumFOpPattern<arith::MinNumFOp, spirv::CLFMinOp>,
spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::CLSMaxOp>,
spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::CLUMaxOp>,
spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::CLSMinOp>,
Expand Down
61 changes: 53 additions & 8 deletions mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1124,9 +1124,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
return
}

// CHECK-LABEL: @float32_minf_scalar
// CHECK-LABEL: @float32_minimumf_scalar
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
Expand All @@ -1137,9 +1137,18 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
return %0: f32
}

// CHECK-LABEL: @float32_maxf_scalar
// CHECK-LABEL: @float32_minnumf_scalar
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[MIN:.+]] = spirv.CL.fmin %arg0, %arg1 : f32
%0 = arith.minnumf %arg0, %arg1 : f32
// CHECK: return %[[MIN]]
return %0: f32
}

// CHECK-LABEL: @float32_maximumf_scalar
// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
// CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32>
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
Expand All @@ -1150,6 +1159,16 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) ->
return %0: vector<2xf32>
}

// CHECK-LABEL: @float32_maxnumf_scalar
// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
// CHECK: %[[MAX:.+]] = spirv.CL.fmax %arg0, %arg1 : vector<2xf32>
%0 = arith.maxnumf %arg0, %arg1 : vector<2xf32>
// CHECK: return %[[MAX]]
return %0: vector<2xf32>
}


// CHECK-LABEL: @scalar_srem
// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32)
func.func @scalar_srem(%lhs: i32, %rhs: i32) {
Expand Down Expand Up @@ -1270,9 +1289,9 @@ func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
return
}

// CHECK-LABEL: @float32_minf_scalar
// CHECK-LABEL: @float32_minimumf_scalar
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
func.func @float32_minimumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
Expand All @@ -1283,9 +1302,22 @@ func.func @float32_minf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
return %0: f32
}

// CHECK-LABEL: @float32_maxf_scalar
// CHECK-LABEL: @float32_minnumf_scalar
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
func.func @float32_minnumf_scalar(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[MIN:.+]] = spirv.GL.FMin %arg0, %arg1 : f32
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : f32
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : f32
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MIN]]
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]]
%0 = arith.minnumf %arg0, %arg1 : f32
// CHECK: return %[[SELECT2]]
return %0: f32
}

// CHECK-LABEL: @float32_maximumf_scalar
// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
func.func @float32_maximumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
// CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32>
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
Expand All @@ -1296,6 +1328,19 @@ func.func @float32_maxf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) ->
return %0: vector<2xf32>
}

// CHECK-LABEL: @float32_maxnumf_scalar
// CHECK-SAME: %[[LHS:.+]]: vector<2xf32>, %[[RHS:.+]]: vector<2xf32>
func.func @float32_maxnumf_scalar(%arg0 : vector<2xf32>, %arg1 : vector<2xf32>) -> vector<2xf32> {
// CHECK: %[[MAX:.+]] = spirv.GL.FMax %arg0, %arg1 : vector<2xf32>
// CHECK: %[[LHS_NAN:.+]] = spirv.IsNan %[[LHS]] : vector<2xf32>
// CHECK: %[[RHS_NAN:.+]] = spirv.IsNan %[[RHS]] : vector<2xf32>
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[LHS_NAN]], %[[RHS]], %[[MAX]]
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[RHS_NAN]], %[[LHS]], %[[SELECT1]]
%0 = arith.maxnumf %arg0, %arg1 : vector<2xf32>
// CHECK: return %[[SELECT2]]
return %0: vector<2xf32>
}

// Check int vector types.
// CHECK-LABEL: @int_vector234
func.func @int_vector234(%arg0: vector<2xi8>, %arg1: vector<4xi64>) {
Expand Down
26 changes: 22 additions & 4 deletions mlir/test/Conversion/ArithToSPIRV/fast-math.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,40 @@ module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader], []>, #spirv.resource_limits<>>
} {

// CHECK-LABEL: @minf
// CHECK-LABEL: @minimumf
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
func.func @minf(%arg0 : f32, %arg1 : f32) -> f32 {
func.func @minimumf(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
%0 = arith.minimumf %arg0, %arg1 : f32
// CHECK: return %[[F]]
return %0: f32
}

// CHECK-LABEL: @maxf
// CHECK-LABEL: @maximumf
// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
func.func @maxf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
func.func @maximumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
// CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
%0 = arith.maximumf %arg0, %arg1 : vector<4xf32>
// CHECK: return %[[F]]
return %0: vector<4xf32>
}

// CHECK-LABEL: @minnumf
// CHECK-SAME: %[[LHS:.+]]: f32, %[[RHS:.+]]: f32
func.func @minnumf(%arg0 : f32, %arg1 : f32) -> f32 {
// CHECK: %[[F:.+]] = spirv.GL.FMin %[[LHS]], %[[RHS]]
%0 = arith.minnumf %arg0, %arg1 : f32
// CHECK: return %[[F]]
return %0: f32
}

// CHECK-LABEL: @maxnumf
// CHECK-SAME: %[[LHS:.+]]: vector<4xf32>, %[[RHS:.+]]: vector<4xf32>
func.func @maxnumf(%arg0 : vector<4xf32>, %arg1 : vector<4xf32>) -> vector<4xf32> {
// CHECK: %[[F:.+]] = spirv.GL.FMax %[[LHS]], %[[RHS]]
%0 = arith.maxnumf %arg0, %arg1 : vector<4xf32>
// CHECK: return %[[F]]
return %0: vector<4xf32>
}

} // end module