Skip to content

Commit

Permalink
[mlir][spirv] Add folding for [I|Logical][Not]Equal (#74194)
Browse files Browse the repository at this point in the history
  • Loading branch information
inbelic authored Dec 20, 2023
1 parent cf048e1 commit 4c83c27
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 11 deletions.
9 changes: 8 additions & 1 deletion mlir/include/mlir/Dialect/SPIRV/IR/SPIRVLogicalOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,8 @@ def SPIRV_IEqualOp : SPIRV_LogicalBinaryOp<"IEqual",
%5 = spirv.IEqual %2, %3 : vector<4xi32>
```
}];

let hasFolder = 1;
}

// -----
Expand All @@ -395,6 +397,8 @@ def SPIRV_INotEqualOp : SPIRV_LogicalBinaryOp<"INotEqual",

```
}];

let hasFolder = 1;
}

// -----
Expand Down Expand Up @@ -501,6 +505,8 @@ def SPIRV_LogicalEqualOp : SPIRV_LogicalBinaryOp<"LogicalEqual",
%2 = spirv.LogicalEqual %0, %1 : vector<4xi1>
```
}];

let hasFolder = 1;
}

// -----
Expand Down Expand Up @@ -557,7 +563,8 @@ def SPIRV_LogicalNotEqualOp : SPIRV_LogicalBinaryOp<"LogicalNotEqual",
%2 = spirv.LogicalNotEqual %0, %1 : vector<4xi1>
```
}];
let hasFolder = true;

let hasFolder = 1;
}

// -----
Expand Down
77 changes: 75 additions & 2 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,19 +662,52 @@ OpFoldResult spirv::LogicalAndOp::fold(FoldAdaptor adaptor) {
return Attribute();
}

//===----------------------------------------------------------------------===//
// spirv.LogicalEqualOp
//===----------------------------------------------------------------------===//

OpFoldResult
spirv::LogicalEqualOp::fold(spirv::LogicalEqualOp::FoldAdaptor adaptor) {
// x == x -> true
if (getOperand1() == getOperand2()) {
auto trueAttr = BoolAttr::get(getContext(), true);
if (isa<IntegerType>(getType()))
return trueAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, trueAttr);
}

return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
});
}

//===----------------------------------------------------------------------===//
// spirv.LogicalNotEqualOp
//===----------------------------------------------------------------------===//

OpFoldResult spirv::LogicalNotEqualOp::fold(FoldAdaptor adaptor) {
if (std::optional<bool> rhs =
getScalarOrSplatBoolAttr(adaptor.getOperand2())) {
// x && false = x
// x != false -> x
if (!rhs.value())
return getOperand1();
}

return Attribute();
// x == x -> false
if (getOperand1() == getOperand2()) {
auto falseAttr = BoolAttr::get(getContext(), false);
if (isa<IntegerType>(getType()))
return falseAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, falseAttr);
}

return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), [](const APInt &a, const APInt &b) {
return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
});
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -709,6 +742,46 @@ OpFoldResult spirv::LogicalOrOp::fold(FoldAdaptor adaptor) {
return Attribute();
}

//===----------------------------------------------------------------------===//
// spirv.IEqualOp
//===----------------------------------------------------------------------===//

OpFoldResult spirv::IEqualOp::fold(spirv::IEqualOp::FoldAdaptor adaptor) {
// x == x -> true
if (getOperand1() == getOperand2()) {
auto trueAttr = BoolAttr::get(getContext(), true);
if (isa<IntegerType>(getType()))
return trueAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, trueAttr);
}

return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
return a == b ? APInt::getAllOnes(1) : APInt::getZero(1);
});
}

//===----------------------------------------------------------------------===//
// spirv.INotEqualOp
//===----------------------------------------------------------------------===//

OpFoldResult spirv::INotEqualOp::fold(spirv::INotEqualOp::FoldAdaptor adaptor) {
// x == x -> false
if (getOperand1() == getOperand2()) {
auto falseAttr = BoolAttr::get(getContext(), false);
if (isa<IntegerType>(getType()))
return falseAttr;
if (auto vecTy = dyn_cast<VectorType>(getType()))
return SplatElementsAttr::get(vecTy, falseAttr);
}

return constFoldBinaryOp<IntegerAttr>(
adaptor.getOperands(), getType(), [](const APInt &a, const APInt &b) {
return a == b ? APInt::getZero(1) : APInt::getAllOnes(1);
});
}

//===----------------------------------------------------------------------===//
// spirv.ShiftLeftLogical
//===----------------------------------------------------------------------===//
Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Conversion/SPIRVToLLVM/logical-ops-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
// CHECK-LABEL: @logical_equal_scalar
spirv.func @logical_equal_scalar(%arg0: i1, %arg1: i1) "None" {
// CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : i1
%0 = spirv.LogicalEqual %arg0, %arg0 : i1
%0 = spirv.LogicalEqual %arg0, %arg1 : i1
spirv.Return
}

// CHECK-LABEL: @logical_equal_vector
spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
// CHECK: llvm.icmp "eq" %{{.*}}, %{{.*}} : vector<4xi1>
%0 = spirv.LogicalEqual %arg0, %arg0 : vector<4xi1>
%0 = spirv.LogicalEqual %arg0, %arg1 : vector<4xi1>
spirv.Return
}

Expand All @@ -25,14 +25,14 @@ spirv.func @logical_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None
// CHECK-LABEL: @logical_not_equal_scalar
spirv.func @logical_not_equal_scalar(%arg0: i1, %arg1: i1) "None" {
// CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : i1
%0 = spirv.LogicalNotEqual %arg0, %arg0 : i1
%0 = spirv.LogicalNotEqual %arg0, %arg1 : i1
spirv.Return
}

// CHECK-LABEL: @logical_not_equal_vector
spirv.func @logical_not_equal_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
// CHECK: llvm.icmp "ne" %{{.*}}, %{{.*}} : vector<4xi1>
%0 = spirv.LogicalNotEqual %arg0, %arg0 : vector<4xi1>
%0 = spirv.LogicalNotEqual %arg0, %arg1 : vector<4xi1>
spirv.Return
}

Expand Down Expand Up @@ -63,14 +63,14 @@ spirv.func @logical_not_vector(%arg0: vector<4xi1>) "None" {
// CHECK-LABEL: @logical_and_scalar
spirv.func @logical_and_scalar(%arg0: i1, %arg1: i1) "None" {
// CHECK: llvm.and %{{.*}}, %{{.*}} : i1
%0 = spirv.LogicalAnd %arg0, %arg0 : i1
%0 = spirv.LogicalAnd %arg0, %arg1 : i1
spirv.Return
}

// CHECK-LABEL: @logical_and_vector
spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
// CHECK: llvm.and %{{.*}}, %{{.*}} : vector<4xi1>
%0 = spirv.LogicalAnd %arg0, %arg0 : vector<4xi1>
%0 = spirv.LogicalAnd %arg0, %arg1 : vector<4xi1>
spirv.Return
}

Expand All @@ -81,13 +81,13 @@ spirv.func @logical_and_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None"
// CHECK-LABEL: @logical_or_scalar
spirv.func @logical_or_scalar(%arg0: i1, %arg1: i1) "None" {
// CHECK: llvm.or %{{.*}}, %{{.*}} : i1
%0 = spirv.LogicalOr %arg0, %arg0 : i1
%0 = spirv.LogicalOr %arg0, %arg1 : i1
spirv.Return
}

// CHECK-LABEL: @logical_or_vector
spirv.func @logical_or_vector(%arg0: vector<4xi1>, %arg1: vector<4xi1>) "None" {
// CHECK: llvm.or %{{.*}}, %{{.*}} : vector<4xi1>
%0 = spirv.LogicalOr %arg0, %arg0 : vector<4xi1>
%0 = spirv.LogicalOr %arg0, %arg1 : vector<4xi1>
spirv.Return
}
165 changes: 165 additions & 0 deletions mlir/test/Dialect/SPIRV/Transforms/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,48 @@ func.func @convert_logical_not_to_not_equal(%arg0: vector<3xi64>, %arg1: vector<
spirv.ReturnValue %3 : vector<3xi1>
}

// -----

//===----------------------------------------------------------------------===//
// spirv.LogicalEqual
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @logical_equal_same
func.func @logical_equal_same(%arg0 : i1, %arg1 : vector<3xi1>) -> (i1, vector<3xi1>) {
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
// CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>

%0 = spirv.LogicalEqual %arg0, %arg0 : i1
%1 = spirv.LogicalEqual %arg1, %arg1 : vector<3xi1>
// CHECK: return %[[CTRUE]], %[[CVTRUE]]
return %0, %1 : i1, vector<3xi1>
}

// CHECK-LABEL: @const_fold_scalar_logical_equal
func.func @const_fold_scalar_logical_equal() -> (i1, i1) {
%true = spirv.Constant true
%false = spirv.Constant false

// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
%0 = spirv.LogicalEqual %true, %false : i1
%1 = spirv.LogicalEqual %false, %false : i1

// CHECK: return %[[CFALSE]], %[[CTRUE]]
return %0, %1 : i1, i1
}

// CHECK-LABEL: @const_fold_vector_logical_equal
func.func @const_fold_vector_logical_equal() -> vector<3xi1> {
%cv0 = spirv.Constant dense<[true, false, true]> : vector<3xi1>
%cv1 = spirv.Constant dense<[true, false, false]> : vector<3xi1>

// CHECK: %[[RET:.*]] = spirv.Constant dense<[true, true, false]>
%0 = spirv.LogicalEqual %cv0, %cv1 : vector<3xi1>

// CHECK: return %[[RET]]
return %0 : vector<3xi1>
}

// -----

Expand All @@ -1064,6 +1106,43 @@ func.func @convert_logical_not_equal_false(%arg: vector<4xi1>) -> vector<4xi1> {
spirv.ReturnValue %0 : vector<4xi1>
}

// CHECK-LABEL: @logical_not_equal_same
func.func @logical_not_equal_same(%arg0 : i1, %arg1 : vector<3xi1>) -> (i1, vector<3xi1>) {
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
// CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
%0 = spirv.LogicalNotEqual %arg0, %arg0 : i1
%1 = spirv.LogicalNotEqual %arg1, %arg1 : vector<3xi1>

// CHECK: return %[[CFALSE]], %[[CVFALSE]]
return %0, %1 : i1, vector<3xi1>
}

// CHECK-LABEL: @const_fold_scalar_logical_not_equal
func.func @const_fold_scalar_logical_not_equal() -> (i1, i1) {
%true = spirv.Constant true
%false = spirv.Constant false

// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
%0 = spirv.LogicalNotEqual %true, %false : i1
%1 = spirv.LogicalNotEqual %false, %false : i1

// CHECK: return %[[CTRUE]], %[[CFALSE]]
return %0, %1 : i1, i1
}

// CHECK-LABEL: @const_fold_vector_logical_not_equal
func.func @const_fold_vector_logical_not_equal() -> vector<3xi1> {
%cv0 = spirv.Constant dense<[true, false, true]> : vector<3xi1>
%cv1 = spirv.Constant dense<[true, false, false]> : vector<3xi1>

// CHECK: %[[RET:.*]] = spirv.Constant dense<[false, false, true]>
%0 = spirv.LogicalNotEqual %cv0, %cv1 : vector<3xi1>

// CHECK: return %[[RET]]
return %0 : vector<3xi1>
}

// -----

func.func @convert_logical_not_to_equal(%arg0: vector<3xi64>, %arg1: vector<3xi64>) -> vector<3xi1> {
Expand Down Expand Up @@ -1139,6 +1218,92 @@ func.func @convert_logical_or_true_false_vector(%arg: vector<3xi1>) -> (vector<3

// -----

//===----------------------------------------------------------------------===//
// spirv.IEqual
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @iequal_same
func.func @iequal_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
// CHECK-DAG: %[[CVTRUE:.*]] = spirv.Constant dense<true>
%0 = spirv.IEqual %arg0, %arg0 : i32
%1 = spirv.IEqual %arg1, %arg1 : vector<3xi32>

// CHECK: return %[[CTRUE]], %[[CVTRUE]]
return %0, %1 : i1, vector<3xi1>
}

// CHECK-LABEL: @const_fold_scalar_iequal
func.func @const_fold_scalar_iequal() -> (i1, i1) {
%c5 = spirv.Constant 5 : i32
%c6 = spirv.Constant 6 : i32

// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
%0 = spirv.IEqual %c5, %c6 : i32
%1 = spirv.IEqual %c5, %c5 : i32

// CHECK: return %[[CFALSE]], %[[CTRUE]]
return %0, %1 : i1, i1
}

// CHECK-LABEL: @const_fold_vector_iequal
func.func @const_fold_vector_iequal() -> vector<3xi1> {
%cv0 = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
%cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>

// CHECK: %[[RET:.*]] = spirv.Constant dense<[true, false, true]>
%0 = spirv.IEqual %cv0, %cv1 : vector<3xi32>

// CHECK: return %[[RET]]
return %0 : vector<3xi1>
}

// -----

//===----------------------------------------------------------------------===//
// spirv.INotEqual
//===----------------------------------------------------------------------===//

// CHECK-LABEL: @inotequal_same
func.func @inotequal_same(%arg0 : i32, %arg1 : vector<3xi32>) -> (i1, vector<3xi1>) {
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
// CHECK-DAG: %[[CVFALSE:.*]] = spirv.Constant dense<false>
%0 = spirv.INotEqual %arg0, %arg0 : i32
%1 = spirv.INotEqual %arg1, %arg1 : vector<3xi32>

// CHECK: return %[[CFALSE]], %[[CVFALSE]]
return %0, %1 : i1, vector<3xi1>
}

// CHECK-LABEL: @const_fold_scalar_inotequal
func.func @const_fold_scalar_inotequal() -> (i1, i1) {
%c5 = spirv.Constant 5 : i32
%c6 = spirv.Constant 6 : i32

// CHECK-DAG: %[[CTRUE:.*]] = spirv.Constant true
// CHECK-DAG: %[[CFALSE:.*]] = spirv.Constant false
%0 = spirv.INotEqual %c5, %c6 : i32
%1 = spirv.INotEqual %c5, %c5 : i32

// CHECK: return %[[CTRUE]], %[[CFALSE]]
return %0, %1 : i1, i1
}

// CHECK-LABEL: @const_fold_vector_inotequal
func.func @const_fold_vector_inotequal() -> vector<3xi1> {
%cv0 = spirv.Constant dense<[-1, -4, 2]> : vector<3xi32>
%cv1 = spirv.Constant dense<[-1, -3, 2]> : vector<3xi32>

// CHECK: %[[RET:.*]] = spirv.Constant dense<[false, true, false]>
%0 = spirv.INotEqual %cv0, %cv1 : vector<3xi32>

// CHECK: return %[[RET]]
return %0 : vector<3xi1>
}

// -----

//===----------------------------------------------------------------------===//
// spirv.LeftShiftLogical
//===----------------------------------------------------------------------===//
Expand Down

0 comments on commit 4c83c27

Please sign in to comment.