-
Notifications
You must be signed in to change notification settings - Fork 12.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ArmSVE] Add convert.from/to.svbool intrinsics #68418
Conversation
These will be used in future pass to ensure that loads/stores of masks are legal (as the LLVM backend does not suppor this for any type smaller than an svbool, which is vector<[16]xi1>).
fac5375
to
2ee3ec5
Compare
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm ChangesThese will be used in future pass to ensure that loads/stores of masks Depends on #68399 Full diff: https://github.com/llvm/llvm-project/pull/68418.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index 58dec6091f27f6e..d4294b4dd9fd4e8 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -30,6 +30,16 @@ def ArmSVE_Dialect : Dialect {
}];
}
+//===----------------------------------------------------------------------===//
+// ArmSVE type definitions
+//===----------------------------------------------------------------------===//
+
+def SVBool : ScalableVectorOfRankAndLengthAndType<
+ [1], [16], [I1]>;
+
+def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
+ [1], [16, 8, 4, 2, 1], [I1]>;
+
//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
@@ -302,4 +312,18 @@ def ScalableMaskedDivFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;
+def ConvertFromSvboolIntrOp :
+ ArmSVE_IntrOp<"convert.from.svbool",
+ [TypeIs<"res", SVEPredicate>],
+ /*overloadedOperands=*/[],
+ /*overloadedResults=*/[0]>,
+ Arguments<(ins SVBool:$svbool)>;
+
+def ConvertToSvboolIntrOp :
+ ArmSVE_IntrOp<"convert.to.svbool",
+ [TypeIs<"res", SVBool>],
+ /*overloadedOperands=*/[0],
+ /*overloadedResults=*/[]>,
+ Arguments<(ins SVEPredicate:$mask)>;
+
#endif // ARMSVE_OPS
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 999df8079e0727a..172a2f7d12d440e 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -272,3 +272,47 @@ llvm.func @get_vector_scale() -> i64 {
%0 = "llvm.intr.vscale"() : () -> i64
llvm.return %0 : i64
}
+
+// CHECK-LABEL: @arm_sve_convert_from_svbool(
+// CHECK-SAME: <vscale x 16 x i1> %[[SVBOOL:[0-9]+]])
+llvm.func @arm_sve_convert_from_svbool(%nxv16i1 : vector<[16]xi1>) {
+ // CHECK: %[[RES0:.*]] = call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %[[SVBOOL]])
+ %res0 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+ : (vector<[16]xi1>) -> vector<[8]xi1>
+ // CHECK: %[[RES1:.*]] = call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> %[[SVBOOL]])
+ %res1 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+ : (vector<[16]xi1>) -> vector<[4]xi1>
+ // CHECK: %[[RES2:.*]] = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> %[[SVBOOL]])
+ %res2 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+ : (vector<[16]xi1>) -> vector<[2]xi1>
+ // CHECK: %[[RES3:.*]] = call <vscale x 1 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv1i1(<vscale x 16 x i1> %[[SVBOOL]])
+ %res3 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
+ : (vector<[16]xi1>) -> vector<[1]xi1>
+ llvm.return
+}
+
+// CHECK-LABEL: arm_sve_convert_to_svbool(
+// CHECK-SAME: <vscale x 8 x i1> %[[P8:[0-9]+]],
+// CHECK-SAME: <vscale x 4 x i1> %[[P4:[0-9]+]],
+// CHECK-SAME: <vscale x 2 x i1> %[[P2:[0-9]+]],
+// CHECK-SAME: <vscale x 1 x i1> %[[P1:[0-9]+]])
+llvm.func @arm_sve_convert_to_svbool(
+ %nxv8i1 : vector<[8]xi1>,
+ %nxv4i1 : vector<[4]xi1>,
+ %nxv2i1 : vector<[2]xi1>,
+ %nxv1i1 : vector<[1]xi1>
+) {
+ // CHECK-NEXT: %[[RES0:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %[[P8]])
+ %res0 = "arm_sve.intr.convert.to.svbool"(%nxv8i1)
+ : (vector<[8]xi1>) -> vector<[16]xi1>
+ // CHECK-NEXT: %[[RES1:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %[[P4]])
+ %res1 = "arm_sve.intr.convert.to.svbool"(%nxv4i1)
+ : (vector<[4]xi1>) -> vector<[16]xi1>
+ // CHECK-NEXT: %[[RES2:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1> %[[P2]])
+ %res2 = "arm_sve.intr.convert.to.svbool"(%nxv2i1)
+ : (vector<[2]xi1>) -> vector<[16]xi1>
+ // CHECK-NEXT: %[[RES3:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv1i1(<vscale x 1 x i1> %[[P1]])
+ %res3 = "arm_sve.intr.convert.to.svbool"(%nxv1i1)
+ : (vector<[1]xi1>) -> vector<[16]xi1>
+ llvm.return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, ta!
This adds slightly higher-level ops for converting masks between svbool and SVE predicate types. The main reason to use these over the intrinsics is these ops support vectors of masks (via unrolling). E.g. ``` // Convert a svbool mask to a mask of SVE predicates: %svbool = vector.load %memref[%c0, %c0] : memref<2x?xi1>, vector<2x[16]xi1> %mask = arm_sve.convert_from_svbool %svbool : vector<2x[8]xi1> // => Results in vector<2x[8]xi1> ``` Or: ``` // Convert a mask of SVE predicates to a svbool mask: %mask = vector.create_mask %c2, %dim_size : vector<2x[2]xi1> %svbool = arm_sve.convert_to_svbool %mask : vector<2x[2]xi1> // => Results in vector<2x[16]xi1> ``` Depends on #68418
These will be used in future pass to ensure that loads/stores of masks
are legal (as the LLVM backend does not support this for any type smaller
than an svbool, which is vector<[16]xi1>).
Depends on #68399