-
Notifications
You must be signed in to change notification settings - Fork 12.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64] Implement intrinsics for SME FP8 FMOPA #118115
Conversation
@llvm/pr-subscribers-clang-codegen @llvm/pr-subscribers-clang Author: None (SpencerAbson) ChangesThis patch implements the following intrinsics: 8-bit floating-point sum of outer products and accumulate. // Only if __ARM_FEATURE_SME_F8F16 != 0
void svmopa_za16[_mf8]_m_fpm(uint64_t tile, svbool_t pn, svbool_t pm,
svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm)
__arm_streaming __arm_inout("za");
// Only if __ARM_FEATURE_SME_F8F32 != 0
void svmopa_za32[_mf8]_m_fpm(uint64_t tile, svbool_t pn, svbool_t pm,
svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm)
__arm_streaming __arm_inout("za"); In accordance with: ARM-software/acle#323 Co-authored-by: Momchil Velikov momchil.velikov@arm.com Patch is 20.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118115.diff 10 Files Affected:
diff --git a/clang/include/clang/Basic/arm_sme.td b/clang/include/clang/Basic/arm_sme.td
index 0f689e82bdb742..71b2c7cdd04f93 100644
--- a/clang/include/clang/Basic/arm_sme.td
+++ b/clang/include/clang/Basic/arm_sme.td
@@ -824,4 +824,14 @@ let SMETargetGuard = "sme-lutv2" in {
def SVLUTI4_ZT_X4 : SInst<"svluti4_zt_{d}_x4", "4i2.u", "cUc", MergeNone, "aarch64_sme_luti4_zt_x4", [IsStreaming, IsInZT0], [ImmCheck<0, ImmCheck0_0>]>;
}
+let SMETargetGuard = "sme-f8f32" in {
+ def SVMOPA_FP8_ZA32 : Inst<"svmopa_za32[_mf8]_m_fpm", "viPPdd>", "m", MergeNone, "aarch64_sme_fp8_fmopa_za32",
+ [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], [ImmCheck<0, ImmCheck0_3>]>;
+}
+
+let SMETargetGuard = "sme-f8f16" in {
+ def SVMOPA_FP8_ZA16 : Inst<"svmopa_za16[_mf8]_m_fpm", "viPPdd>", "m", MergeNone, "aarch64_sme_fp8_fmopa_za16",
+ [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], [ImmCheck<0, ImmCheck0_1>]>;
+}
+
} // let SVETargetGuard = InvalidMode
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index cb9c23b8e0a0d0..56595bb4704e74 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -10183,6 +10183,8 @@ CodeGenFunction::getSVEType(const SVETypeFlags &TypeFlags) {
case SVETypeFlags::EltTyInt64:
return llvm::ScalableVectorType::get(Builder.getInt64Ty(), 2);
+ case SVETypeFlags::EltTyMFloat8:
+ return llvm::ScalableVectorType::get(Builder.getInt8Ty(), 16);
case SVETypeFlags::EltTyFloat16:
return llvm::ScalableVectorType::get(Builder.getHalfTy(), 8);
case SVETypeFlags::EltTyBFloat16:
@@ -11234,6 +11236,10 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID,
BuiltinID == SME::BI__builtin_sme_svstr_za)
return EmitSMELdrStr(TypeFlags, Ops, Builtin->LLVMIntrinsic);
+ // Emit set FPMR for intrinsics that require it
+ if (TypeFlags.setsFPMR())
+ Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_set_fpmr),
+ Ops.pop_back_val());
// Handle builtins which require their multi-vector operands to be swapped
swapCommutativeSMEOperands(BuiltinID, Ops);
diff --git a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_fmopa.c b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_fmopa.c
new file mode 100644
index 00000000000000..95d6383ab30efe
--- /dev/null
+++ b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_fmopa.c
@@ -0,0 +1,55 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+// REQUIRES: aarch64-registered-target
+
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - -x c++ %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s -check-prefix=CPP-CHECK
+// RUN: %clang_cc1 -DSVE_OVERLOADED_FORMS -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s
+// RUN: %clang_cc1 -DSVE_OVERLOADED_FORMS -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - -x c++ %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s -check-prefix=CPP-CHECK
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -S -disable-O0-optnone -Werror -Wall -o /dev/null %s
+
+#include <arm_sme.h>
+
+#ifdef SVE_OVERLOADED_FORMS
+#define SVE_ACLE_FUNC(A1,A2_UNUSED,A3) A1##A3
+#else
+#define SVE_ACLE_FUNC(A1,A2,A3) A1##A2##A3
+#endif
+
+
+// CHECK-LABEL: define dso_local void @test_svmopa_za16_mf8_m(
+// CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmopa.za16(i32 1, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
+// CHECK-NEXT: ret void
+//
+// CPP-CHECK-LABEL: define dso_local void @_Z22test_svmopa_za16_mf8_mu10__SVBool_tS_u13__SVMfloat8_tS0_m(
+// CPP-CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0:[0-9]+]] {
+// CPP-CHECK-NEXT: [[ENTRY:.*:]]
+// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmopa.za16(i32 1, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
+// CPP-CHECK-NEXT: ret void
+//
+void test_svmopa_za16_mf8_m(svbool_t pn, svbool_t pm, svmfloat8_t zn,
+ svmfloat8_t zm, fpm_t fpmr) __arm_streaming __arm_inout("za") {
+ SVE_ACLE_FUNC(svmopa_za16,_mf8,_m_fpm)(1, pn, pm, zn, zm, fpmr);
+}
+
+// CHECK-LABEL: define dso_local void @test_svmopa_za32_mf8_m(
+// CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT: [[ENTRY:.*:]]
+// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmopa.za32(i32 3, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
+// CHECK-NEXT: ret void
+//
+// CPP-CHECK-LABEL: define dso_local void @_Z22test_svmopa_za32_mf8_mu10__SVBool_tS_u13__SVMfloat8_tS0_m(
+// CPP-CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CPP-CHECK-NEXT: [[ENTRY:.*:]]
+// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmopa.za32(i32 3, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
+// CPP-CHECK-NEXT: ret void
+//
+void test_svmopa_za32_mf8_m(svbool_t pn, svbool_t pm, svmfloat8_t zn,
+ svmfloat8_t zm, fpm_t fpmr) __arm_streaming __arm_inout("za") {
+ SVE_ACLE_FUNC(svmopa_za32,_mf8,_m_fpm)(3, pn, pm, zn, zm, fpmr);
+}
diff --git a/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_imm.c b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_imm.c
new file mode 100644
index 00000000000000..62cad9cfa4c8fd
--- /dev/null
+++ b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_imm.c
@@ -0,0 +1,18 @@
+// RUN: %clang_cc1 -triple aarch64 -target-feature +sme -target-feature +sme2 -target-feature +sme-f8f16 -target-feature +sme-f8f32 -fsyntax-only -verify %s
+
+// REQUIRES: aarch64-registered-target
+
+#include <arm_sme.h>
+
+void test_svmopa(svbool_t pn, svbool_t pm, svmfloat8_t zn, svmfloat8_t zm,
+ fpm_t fpmr) __arm_streaming __arm_inout("za") {
+ // expected-error@+1 {{argument value 18446744073709551615 is outside the valid range [0, 1]}}
+ svmopa_za16_mf8_m_fpm(-1, pn, pm, zn, zm, fpmr);
+ // expected-error@+1 {{argument value 2 is outside the valid range [0, 1]}}
+ svmopa_za16_mf8_m_fpm(2, pn, pm, zn, zm, fpmr);
+
+ // expected-error@+1 {{argument value 18446744073709551615 is outside the valid range [0, 3]}}
+ svmopa_za32_mf8_m_fpm(-1, pn, pm, zn, zm, fpmr);
+ // expected-error@+1 {{argument value 4 is outside the valid range [0, 3]}}
+ svmopa_za32_mf8_m_fpm(4, pn, pm, zn, zm, fpmr);
+}
diff --git a/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mopa.c b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mopa.c
new file mode 100644
index 00000000000000..86426abcd43291
--- /dev/null
+++ b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mopa.c
@@ -0,0 +1,13 @@
+// RUN: %clang_cc1 -triple aarch64 -target-feature +sme -verify -emit-llvm-only %s
+
+// REQUIRES: aarch64-registered-target
+
+#include <arm_sme.h>
+
+void test_features(svbool_t pn, svbool_t pm, svmfloat8_t zn, svmfloat8_t zm,
+ fpm_t fpmr) __arm_streaming __arm_inout("za") {
+ // expected-error@+1 {{'svmopa_za16_mf8_m_fpm' needs target feature sme,sme-f8f16}}
+ svmopa_za16_mf8_m_fpm(0, pn, pm, zn, zm, fpmr);
+ // expected-error@+1 {{'svmopa_za32_mf8_m_fpm' needs target feature sme,sme-f8f32}}
+ svmopa_za32_mf8_m_fpm(0, pn, pm, zn, zm, fpmr);
+}
diff --git a/clang/utils/TableGen/SveEmitter.cpp b/clang/utils/TableGen/SveEmitter.cpp
index e9fa01ea98dced..e24e93e8f29d8f 100644
--- a/clang/utils/TableGen/SveEmitter.cpp
+++ b/clang/utils/TableGen/SveEmitter.cpp
@@ -587,7 +587,6 @@ void SVEType::applyTypespec(StringRef TS) {
ElementBitwidth = 16;
break;
case 'm':
- Signed = false;
MFloat = true;
Float = false;
BFloat = false;
@@ -702,6 +701,7 @@ void SVEType::applyModifier(char Mod) {
Svcount = false;
Float = false;
BFloat = false;
+ MFloat = false;
ElementBitwidth = Bitwidth = 64;
NumVectors = 0;
Signed = false;
@@ -712,6 +712,7 @@ void SVEType::applyModifier(char Mod) {
Svcount = false;
Float = false;
BFloat = false;
+ MFloat = false;
ElementBitwidth = Bitwidth = 32;
NumVectors = 0;
Signed = true;
@@ -723,6 +724,7 @@ void SVEType::applyModifier(char Mod) {
Svcount = false;
Float = false;
BFloat = false;
+ MFloat = false;
ElementBitwidth = Bitwidth = 32;
NumVectors = 0;
Signed = true;
@@ -735,6 +737,7 @@ void SVEType::applyModifier(char Mod) {
Signed = true;
Float = false;
BFloat = false;
+ MFloat = false;
ElementBitwidth = Bitwidth = 32;
NumVectors = 0;
break;
@@ -744,6 +747,7 @@ void SVEType::applyModifier(char Mod) {
Signed = true;
Float = false;
BFloat = false;
+ MFloat = false;
ElementBitwidth = Bitwidth = 64;
NumVectors = 0;
break;
@@ -753,6 +757,7 @@ void SVEType::applyModifier(char Mod) {
Signed = false;
Float = false;
BFloat = false;
+ MFloat = false;
ElementBitwidth = Bitwidth = 32;
NumVectors = 0;
break;
@@ -765,6 +770,7 @@ void SVEType::applyModifier(char Mod) {
Signed = false;
Float = false;
BFloat = false;
+ MFloat = false;
ElementBitwidth = Bitwidth = 64;
NumVectors = 0;
break;
@@ -783,6 +789,7 @@ void SVEType::applyModifier(char Mod) {
case 'g':
Signed = false;
Float = false;
+ MFloat = false;
BFloat = false;
ElementBitwidth = 64;
break;
@@ -790,18 +797,21 @@ void SVEType::applyModifier(char Mod) {
Signed = false;
Float = false;
BFloat = false;
+ MFloat = false;
ElementBitwidth = 8;
break;
case 't':
Signed = true;
Float = false;
BFloat = false;
+ MFloat = false;
ElementBitwidth = 32;
break;
case 'z':
Signed = false;
Float = false;
BFloat = false;
+ MFloat = false;
ElementBitwidth = 32;
break;
case 'O':
@@ -815,6 +825,7 @@ void SVEType::applyModifier(char Mod) {
Svcount = false;
Float = true;
BFloat = false;
+ MFloat = false;
ElementBitwidth = 32;
break;
case 'N':
@@ -922,6 +933,7 @@ void SVEType::applyModifier(char Mod) {
Predicate = false;
Svcount = false;
Float = false;
+ MFloat = false;
BFloat = true;
ElementBitwidth = 16;
break;
@@ -932,6 +944,7 @@ void SVEType::applyModifier(char Mod) {
NumVectors = 0;
Float = false;
BFloat = false;
+ MFloat = false;
break;
case '~':
Float = false;
diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td
index a91616b9556828..0fde957ecbba6e 100644
--- a/llvm/include/llvm/IR/IntrinsicsAArch64.td
+++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td
@@ -2983,6 +2983,13 @@ let TargetPrefix = "aarch64" in {
LLVMMatchType<0>,
llvm_anyvector_ty], [ImmArg<ArgIndex<0>>]>;
+ class SME_FP8_OuterProduct_Intrinsic
+ : DefaultAttrsIntrinsic<[],
+ [llvm_i32_ty,
+ llvm_nxv16i1_ty, llvm_nxv16i1_ty,
+ llvm_nxv16i8_ty, llvm_nxv16i8_ty],
+ [ImmArg<ArgIndex<0>>, IntrInaccessibleMemOnly, IntrHasSideEffects]>;
+
def int_aarch64_sme_mopa : SME_OuterProduct_Intrinsic;
def int_aarch64_sme_mops : SME_OuterProduct_Intrinsic;
@@ -2998,6 +3005,10 @@ let TargetPrefix = "aarch64" in {
def int_aarch64_sme_usmopa_wide : SME_OuterProduct_Intrinsic;
def int_aarch64_sme_usmops_wide : SME_OuterProduct_Intrinsic;
+ // FP8 outer product
+ def int_aarch64_sme_fp8_fmopa_za16 : SME_FP8_OuterProduct_Intrinsic;
+ def int_aarch64_sme_fp8_fmopa_za32 : SME_FP8_OuterProduct_Intrinsic;
+
class SME_AddVectorToTile_Intrinsic
: DefaultAttrsIntrinsic<[],
[llvm_i32_ty,
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 37ac915d1d8808..9c657787d3492b 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -990,7 +990,7 @@ defm FDOT_VG2_M2ZZI_BtoH : sme2p1_multi_vec_array_vg2_index_f8f16<"fdot", 0b11
defm FDOT_VG4_M4ZZI_BtoH : sme2p1_multi_vec_array_vg4_index_f8f16<"fdot", 0b100, ZZZZ_b_mul_r, ZPR4b8>;
defm FDOT_VG2_M2ZZ_BtoH : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0010001, MatrixOp16, ZZ_b, ZPR4b8>;
defm FDOT_VG4_M4ZZ_BtoH : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0110001, MatrixOp16, ZZZZ_b, ZPR4b8>;
-// TODO: Replace nxv16i8 by nxv16f8
+
defm FDOT_VG2_M2Z2Z_BtoH : sme2_dot_mla_add_sub_array_vg2_multi<"fdot", 0b0100100, MatrixOp16, ZZ_b_mul_r, nxv16i8, null_frag>;
defm FDOT_VG4_M4Z4Z_BtoH : sme2_dot_mla_add_sub_array_vg4_multi<"fdot", 0b0100100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, null_frag>;
@@ -998,23 +998,22 @@ def FMLAL_MZZI_BtoH : sme2_mla_ll_array_index_16b<"fmlal", 0b11, 0b00>;
defm FMLAL_VG2_M2ZZI_BtoH : sme2_multi_vec_array_vg2_index_16b<"fmlal", 0b10, 0b111>;
defm FMLAL_VG4_M4ZZI_BtoH : sme2_multi_vec_array_vg4_index_16b<"fmlal", 0b10, 0b110>;
def FMLAL_VG2_MZZ_BtoH : sme2_mla_long_array_single_16b<"fmlal">;
-// TODO: Replace nxv16i8 by nxv16f8
+
defm FMLAL_VG2_M2ZZ_BtoH : sme2_fp_mla_long_array_vg2_single<"fmlal", 0b001, MatrixOp16, ZZ_b, ZPR4b8, nxv16i8, null_frag>;
defm FMLAL_VG4_M4ZZ_BtoH : sme2_fp_mla_long_array_vg4_single<"fmlal", 0b001, MatrixOp16, ZZZZ_b, ZPR4b8, nxv16i8, null_frag>;
defm FMLAL_VG2_M2Z2Z_BtoH : sme2_fp_mla_long_array_vg2_multi<"fmlal", 0b100, MatrixOp16, ZZ_b_mul_r, nxv16i8, null_frag>;
defm FMLAL_VG4_M4Z4Z_BtoH : sme2_fp_mla_long_array_vg4_multi<"fmlal", 0b100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, null_frag>;
-defm FMOPA_MPPZZ_BtoH : sme2p1_fmop_tile_f8f16<"fmopa", 0b1, 0b0, 0b01>;
-
+defm FMOPA_MPPZZ_BtoH : sme2_fp8_fmopa_za16<"fmopa", int_aarch64_sme_fp8_fmopa_za16>;
} //[HasSMEF8F16]
let Predicates = [HasSMEF8F32] in {
-// TODO : Replace nxv16i8 by nxv16f8
+
defm FDOT_VG2_M2ZZI_BtoS : sme2_multi_vec_array_vg2_index_32b<"fdot", 0b01, 0b0111, ZZ_b_mul_r, ZPR4b8, nxv16i8, null_frag>;
defm FDOT_VG4_M4ZZI_BtoS : sme2_multi_vec_array_vg4_index_32b<"fdot", 0b0001, ZZZZ_b_mul_r, ZPR4b8, nxv16i8, null_frag>;
defm FDOT_VG2_M2ZZ_BtoS : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0010011, MatrixOp32, ZZ_b, ZPR4b8>;
defm FDOT_VG4_M4ZZ_BtoS : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0110011, MatrixOp32, ZZZZ_b, ZPR4b8>;
-// TODO : Replace nxv16i8 by nxv16f8
+
defm FDOT_VG2_M2Z2Z_BtoS : sme2_dot_mla_add_sub_array_vg2_multi<"fdot", 0b0100110, MatrixOp32, ZZ_b_mul_r, nxv16i8, null_frag>;
defm FDOT_VG4_M4Z4Z_BtoS : sme2_dot_mla_add_sub_array_vg4_multi<"fdot", 0b0100110, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, null_frag>;
@@ -1024,16 +1023,14 @@ def FVDOTT_VG4_M2ZZI_BtoS : sme2_fp8_multi_vec_array_vg4_index<"fvdott", 0b1>;
defm FMLALL_MZZI_BtoS : sme2_mla_ll_array_index_32b<"fmlall", 0b01, 0b000, null_frag>;
defm FMLALL_VG2_M2ZZI_BtoS : sme2_mla_ll_array_vg2_index_32b<"fmlall", 0b10, 0b100, null_frag>;
defm FMLALL_VG4_M4ZZI_BtoS : sme2_mla_ll_array_vg4_index_32b<"fmlall", 0b00, 0b1000, null_frag>;
-// TODO: Replace nxv16i8 by nxv16f8
+
defm FMLALL_MZZ_BtoS : sme2_mla_ll_array_single<"fmlall", 0b01000, MatrixOp32, ZPR8, ZPR4b8, nxv16i8, null_frag>;
defm FMLALL_VG2_M2ZZ_BtoS : sme2_mla_ll_array_vg24_single<"fmlall", 0b000001, MatrixOp32, ZZ_b, ZPR4b8>;
defm FMLALL_VG4_M4ZZ_BtoS : sme2_mla_ll_array_vg24_single<"fmlall", 0b010001, MatrixOp32, ZZZZ_b, ZPR4b8>;
defm FMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"fmlall", 0b01000, MatrixOp32, ZZ_b_mul_r, nxv16i8, null_frag>;
defm FMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"fmlall", 0b01000, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, null_frag>;
-
-defm FMOPA_MPPZZ_BtoS : sme_outer_product_fp32<0b0, 0b01, ZPR8, "fmopa", null_frag>;
-
+defm FMOPA_MPPZZ_BtoS : sme2_fp8_fmopa_za32<"fmopa", int_aarch64_sme_fp8_fmopa_za32>;
} //[HasSMEF8F32]
let Predicates = [HasSME2, HasSVEBFSCALE] in {
diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td
index 776472e72af05a..e6535f957e2024 100644
--- a/llvm/lib/Target/AArch64/SMEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td
@@ -305,6 +305,21 @@ multiclass sme_outer_product_fp32<bit S, bits<2> sz, ZPRRegOp zpr_ty, string mne
def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, op, timm32_0_3, nxv4i1, nxv4f32>;
}
+multiclass sme2_fp8_fmopa_za32<string mnemonic, SDPatternOperator intrinsic> {
+ def NAME : sme_fp_outer_product_inst<0, 0b01, 0b00, TileOp32, ZPR8, mnemonic>, SMEPseudo2Instr<NAME, 1> {
+ bits<2> ZAda;
+ let Inst{1-0} = ZAda;
+ let Inst{2} = 0b0;
+
+ let Uses = [FPMR, FPCR];
+ }
+
+ let mayStore = 1, mayLoad = 1 in
+ def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR8, SMEMatrixTileS>, SMEPseudo2Instr<NAME, 0>;
+
+ def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, intrinsic, timm32_0_3, nxv16i1, nxv16i8>;
+}
+
multiclass sme_outer_product_fp64<bit S, string mnemonic, SDPatternOperator op> {
def NAME : sme_fp_outer_product_inst<S, 0b10, 0b00, TileOp64, ZPR64, mnemonic>, SMEPseudo2Instr<NAME, 1> {
bits<3> ZAda;
@@ -316,12 +331,19 @@ multiclass sme_outer_product_fp64<bit S, string mnemonic, SDPatternOperator op>
def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, op, timm32_0_7, nxv2i1, nxv2f64>;
}
-multiclass sme2p1_fmop_tile_f8f16<string mnemonic, bit bf, bit s, bits<2> op> {
- def NAME : sme_fp_outer_product_inst<s, {0,bf}, op, TileOp16, ZPR8, mnemonic> {
+multiclass sme2_fp8_fmopa_za16<string mnemonic, SDPatternOperator intrinsic> {
+ def NAME : sme_fp_outer_product_inst<0, {0, 0b1}, 0b01, TileOp16, ZPR8, mnemonic>, SMEPseudo2Instr<NAME, 1> {
bits<1> ZAda;
let Inst{2-1} = 0b00;
let Inst{0} = ZAda;
+
+ let Uses = [FPMR, FPCR];
}
+
+ let mayStore = 1, mayLoad = 1 in
+ def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR8, SMEMatrixTileH>, SMEPseudo2Instr<NAME, 0>;
+
+ def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, intrinsic, timm32_0_1, nxv16i1, nxv16i8>;
}
multiclass sme2p1_fmop_tile_fp16<string mnemonic, bit bf, bit s, ValueType vt, SDPatternOperator intrinsic = null_frag> {
diff --git a/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-fmopa.ll b/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-fmopa.ll
new file mode 100644
index 00000000000000..6e88cdf4e7fec3
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-fmopa.ll
@@ -0,0 +1,22 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme-f8f16,+sme-f8f32 -force-streaming < %s | FileCheck %s
+
+define void @test_fmopa_16(<vscale x 16 x i1> %pn, <vscale x 16 x i1> %pm, <vscale x 16 x i8> %vn, <vscale x 16 x i8> %vm) {
+; CHECK-LABEL: test_fmopa_16:
+; CHEC...
[truncated]
|
010eae4
to
a92b7a9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the patch Spencer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
a92b7a9
to
9edab2a
Compare
This patch implements the following intrinsics: 8-bit floating-point sum of outer products and accumulate. ``` c // Only if __ARM_FEATURE_SME_F8F16 != 0 void svmopa_za16[_mf8]_m_fpm(uint64_t tile, svbool_t pn, svbool_t pm, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm) __arm_streaming __arm_inout("za"); // Only if __ARM_FEATURE_SME_F8F32 != 0 void svmopa_za32[_mf8]_m_fpm(uint64_t tile, svbool_t pn, svbool_t pm, svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm) __arm_streaming __arm_inout("za"); ``` In accordance with: ARM-software/acle#323 Co-authored-by: Momchil Velikov momchil.velikov@arm.com Co-authored-by: Marian Lukac marian.lukac@arm.com
This patch implements the following intrinsics:
8-bit floating-point sum of outer products and accumulate.
In accordance with: ARM-software/acle#323
Co-authored-by: Momchil Velikov momchil.velikov@arm.com
Co-authored-by: Marian Lukac marian.lukac@arm.com