Skip to content

Commit

Permalink
[mlir][ArmSVE] Add convert.from/to.svbool intrinsics (#68418)
Browse files Browse the repository at this point in the history
These will be used in future pass to ensure that loads/stores of masks
are legal (as the LLVM backend does not support this for any type
smaller
than an svbool, which is vector<[16]xi1>).

Depends on #68399
  • Loading branch information
MacDue authored Oct 10, 2023
1 parent 962a049 commit 3d70ba6
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
24 changes: 24 additions & 0 deletions mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ def ArmSVE_Dialect : Dialect {
}];
}

//===----------------------------------------------------------------------===//
// ArmSVE type definitions
//===----------------------------------------------------------------------===//

def SVBool : ScalableVectorOfRankAndLengthAndType<
[1], [16], [I1]>;

def SVEPredicate : ScalableVectorOfRankAndLengthAndType<
[1], [16, 8, 4, 2, 1], [I1]>;

//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -302,4 +312,18 @@ def ScalableMaskedDivFIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"fdiv">,
Arguments<(ins AnyScalableVector, AnyScalableVector, AnyScalableVector)>;

def ConvertFromSvboolIntrOp :
ArmSVE_IntrOp<"convert.from.svbool",
[TypeIs<"res", SVEPredicate>],
/*overloadedOperands=*/[],
/*overloadedResults=*/[0]>,
Arguments<(ins SVBool:$svbool)>;

def ConvertToSvboolIntrOp :
ArmSVE_IntrOp<"convert.to.svbool",
[TypeIs<"res", SVBool>],
/*overloadedOperands=*/[0],
/*overloadedResults=*/[]>,
Arguments<(ins SVEPredicate:$mask)>;

#endif // ARMSVE_OPS
44 changes: 44 additions & 0 deletions mlir/test/Target/LLVMIR/arm-sve.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,47 @@ llvm.func @get_vector_scale() -> i64 {
%0 = "llvm.intr.vscale"() : () -> i64
llvm.return %0 : i64
}

// CHECK-LABEL: @arm_sve_convert_from_svbool(
// CHECK-SAME: <vscale x 16 x i1> %[[SVBOOL:[0-9]+]])
llvm.func @arm_sve_convert_from_svbool(%nxv16i1 : vector<[16]xi1>) {
// CHECK: %[[RES0:.*]] = call <vscale x 8 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv8i1(<vscale x 16 x i1> %[[SVBOOL]])
%res0 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
: (vector<[16]xi1>) -> vector<[8]xi1>
// CHECK: %[[RES1:.*]] = call <vscale x 4 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv4i1(<vscale x 16 x i1> %[[SVBOOL]])
%res1 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
: (vector<[16]xi1>) -> vector<[4]xi1>
// CHECK: %[[RES2:.*]] = call <vscale x 2 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv2i1(<vscale x 16 x i1> %[[SVBOOL]])
%res2 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
: (vector<[16]xi1>) -> vector<[2]xi1>
// CHECK: %[[RES3:.*]] = call <vscale x 1 x i1> @llvm.aarch64.sve.convert.from.svbool.nxv1i1(<vscale x 16 x i1> %[[SVBOOL]])
%res3 = "arm_sve.intr.convert.from.svbool"(%nxv16i1)
: (vector<[16]xi1>) -> vector<[1]xi1>
llvm.return
}

// CHECK-LABEL: arm_sve_convert_to_svbool(
// CHECK-SAME: <vscale x 8 x i1> %[[P8:[0-9]+]],
// CHECK-SAME: <vscale x 4 x i1> %[[P4:[0-9]+]],
// CHECK-SAME: <vscale x 2 x i1> %[[P2:[0-9]+]],
// CHECK-SAME: <vscale x 1 x i1> %[[P1:[0-9]+]])
llvm.func @arm_sve_convert_to_svbool(
%nxv8i1 : vector<[8]xi1>,
%nxv4i1 : vector<[4]xi1>,
%nxv2i1 : vector<[2]xi1>,
%nxv1i1 : vector<[1]xi1>
) {
// CHECK-NEXT: %[[RES0:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv8i1(<vscale x 8 x i1> %[[P8]])
%res0 = "arm_sve.intr.convert.to.svbool"(%nxv8i1)
: (vector<[8]xi1>) -> vector<[16]xi1>
// CHECK-NEXT: %[[RES1:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv4i1(<vscale x 4 x i1> %[[P4]])
%res1 = "arm_sve.intr.convert.to.svbool"(%nxv4i1)
: (vector<[4]xi1>) -> vector<[16]xi1>
// CHECK-NEXT: %[[RES2:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv2i1(<vscale x 2 x i1> %[[P2]])
%res2 = "arm_sve.intr.convert.to.svbool"(%nxv2i1)
: (vector<[2]xi1>) -> vector<[16]xi1>
// CHECK-NEXT: %[[RES3:.*]] = call <vscale x 16 x i1> @llvm.aarch64.sve.convert.to.svbool.nxv1i1(<vscale x 1 x i1> %[[P1]])
%res3 = "arm_sve.intr.convert.to.svbool"(%nxv1i1)
: (vector<[1]xi1>) -> vector<[16]xi1>
llvm.return
}

0 comments on commit 3d70ba6

Please sign in to comment.