Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64] Implement intrinsics for SME FP8 FMOPA #118115

Merged
merged 2 commits into from
Dec 9, 2024

Conversation

SpencerAbson
Copy link
Contributor

This patch implements the following intrinsics:

8-bit floating-point sum of outer products and accumulate.

  // Only if __ARM_FEATURE_SME_F8F16 != 0
    void svmopa_za16[_mf8]_m_fpm(uint64_t tile, svbool_t pn, svbool_t pm,
                                 svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm)
                                 __arm_streaming __arm_inout("za");

  // Only if __ARM_FEATURE_SME_F8F32 != 0
    void svmopa_za32[_mf8]_m_fpm(uint64_t tile, svbool_t pn, svbool_t pm,
                                 svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm)
                                 __arm_streaming __arm_inout("za");

In accordance with: ARM-software/acle#323

Co-authored-by: Momchil Velikov momchil.velikov@arm.com
Co-authored-by: Marian Lukac marian.lukac@arm.com

@llvmbot llvmbot added clang Clang issues not falling into any other category backend:AArch64 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen llvm:ir labels Nov 29, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 29, 2024

@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-llvm-ir

@llvm/pr-subscribers-clang

Author: None (SpencerAbson)

Changes

This patch implements the following intrinsics:

8-bit floating-point sum of outer products and accumulate.

  // Only if __ARM_FEATURE_SME_F8F16 != 0
    void svmopa_za16[_mf8]_m_fpm(uint64_t tile, svbool_t pn, svbool_t pm,
                                 svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm)
                                 __arm_streaming __arm_inout("za");

  // Only if __ARM_FEATURE_SME_F8F32 != 0
    void svmopa_za32[_mf8]_m_fpm(uint64_t tile, svbool_t pn, svbool_t pm,
                                 svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm)
                                 __arm_streaming __arm_inout("za");

In accordance with: ARM-software/acle#323

Co-authored-by: Momchil Velikov momchil.velikov@arm.com
Co-authored-by: Marian Lukac marian.lukac@arm.com


Patch is 20.77 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118115.diff

10 Files Affected:

  • (modified) clang/include/clang/Basic/arm_sme.td (+10)
  • (modified) clang/lib/CodeGen/CGBuiltin.cpp (+6)
  • (added) clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_fmopa.c (+55)
  • (added) clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_imm.c (+18)
  • (added) clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mopa.c (+13)
  • (modified) clang/utils/TableGen/SveEmitter.cpp (+14-1)
  • (modified) llvm/include/llvm/IR/IntrinsicsAArch64.td (+11)
  • (modified) llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td (+7-10)
  • (modified) llvm/lib/Target/AArch64/SMEInstrFormats.td (+24-2)
  • (added) llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-fmopa.ll (+22)
diff --git a/clang/include/clang/Basic/arm_sme.td b/clang/include/clang/Basic/arm_sme.td
index 0f689e82bdb742..71b2c7cdd04f93 100644
--- a/clang/include/clang/Basic/arm_sme.td
+++ b/clang/include/clang/Basic/arm_sme.td
@@ -824,4 +824,14 @@ let SMETargetGuard = "sme-lutv2" in {
   def SVLUTI4_ZT_X4 : SInst<"svluti4_zt_{d}_x4", "4i2.u", "cUc", MergeNone, "aarch64_sme_luti4_zt_x4", [IsStreaming, IsInZT0], [ImmCheck<0, ImmCheck0_0>]>;
 }
 
+let SMETargetGuard = "sme-f8f32" in {
+  def SVMOPA_FP8_ZA32 : Inst<"svmopa_za32[_mf8]_m_fpm", "viPPdd>", "m", MergeNone, "aarch64_sme_fp8_fmopa_za32",
+                             [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], [ImmCheck<0, ImmCheck0_3>]>;
+}
+
+let SMETargetGuard = "sme-f8f16" in {
+  def SVMOPA_FP8_ZA16 : Inst<"svmopa_za16[_mf8]_m_fpm", "viPPdd>", "m", MergeNone, "aarch64_sme_fp8_fmopa_za16",
+                             [IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], [ImmCheck<0, ImmCheck0_1>]>;
+}
+
 } // let SVETargetGuard = InvalidMode
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index cb9c23b8e0a0d0..56595bb4704e74 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -10183,6 +10183,8 @@ CodeGenFunction::getSVEType(const SVETypeFlags &TypeFlags) {
   case SVETypeFlags::EltTyInt64:
     return llvm::ScalableVectorType::get(Builder.getInt64Ty(), 2);
 
+  case SVETypeFlags::EltTyMFloat8:
+    return llvm::ScalableVectorType::get(Builder.getInt8Ty(), 16);
   case SVETypeFlags::EltTyFloat16:
     return llvm::ScalableVectorType::get(Builder.getHalfTy(), 8);
   case SVETypeFlags::EltTyBFloat16:
@@ -11234,6 +11236,10 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID,
            BuiltinID == SME::BI__builtin_sme_svstr_za)
     return EmitSMELdrStr(TypeFlags, Ops, Builtin->LLVMIntrinsic);
 
+  // Emit set FPMR for intrinsics that require it
+  if (TypeFlags.setsFPMR())
+    Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_set_fpmr),
+                       Ops.pop_back_val());
   // Handle builtins which require their multi-vector operands to be swapped
   swapCommutativeSMEOperands(BuiltinID, Ops);
 
diff --git a/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_fmopa.c b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_fmopa.c
new file mode 100644
index 00000000000000..95d6383ab30efe
--- /dev/null
+++ b/clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_fmopa.c
@@ -0,0 +1,55 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
+// REQUIRES: aarch64-registered-target
+
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - -x c++ %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s -check-prefix=CPP-CHECK
+// RUN: %clang_cc1 -DSVE_OVERLOADED_FORMS -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s
+// RUN: %clang_cc1 -DSVE_OVERLOADED_FORMS -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - -x c++ %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s -check-prefix=CPP-CHECK
+// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -S -disable-O0-optnone -Werror -Wall -o /dev/null %s
+
+#include <arm_sme.h>
+
+#ifdef SVE_OVERLOADED_FORMS
+#define SVE_ACLE_FUNC(A1,A2_UNUSED,A3) A1##A3
+#else
+#define SVE_ACLE_FUNC(A1,A2,A3) A1##A2##A3
+#endif
+
+
+// CHECK-LABEL: define dso_local void @test_svmopa_za16_mf8_m(
+// CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.fp8.fmopa.za16(i32 1, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
+// CHECK-NEXT:    ret void
+//
+// CPP-CHECK-LABEL: define dso_local void @_Z22test_svmopa_za16_mf8_mu10__SVBool_tS_u13__SVMfloat8_tS0_m(
+// CPP-CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0:[0-9]+]] {
+// CPP-CHECK-NEXT:  [[ENTRY:.*:]]
+// CPP-CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CPP-CHECK-NEXT:    tail call void @llvm.aarch64.sme.fp8.fmopa.za16(i32 1, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
+// CPP-CHECK-NEXT:    ret void
+//
+void test_svmopa_za16_mf8_m(svbool_t pn, svbool_t pm, svmfloat8_t zn,
+                            svmfloat8_t zm, fpm_t fpmr) __arm_streaming __arm_inout("za") {
+  SVE_ACLE_FUNC(svmopa_za16,_mf8,_m_fpm)(1, pn, pm, zn, zm, fpmr);
+}
+
+// CHECK-LABEL: define dso_local void @test_svmopa_za32_mf8_m(
+// CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CHECK-NEXT:    tail call void @llvm.aarch64.sme.fp8.fmopa.za32(i32 3, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
+// CHECK-NEXT:    ret void
+//
+// CPP-CHECK-LABEL: define dso_local void @_Z22test_svmopa_za32_mf8_mu10__SVBool_tS_u13__SVMfloat8_tS0_m(
+// CPP-CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
+// CPP-CHECK-NEXT:  [[ENTRY:.*:]]
+// CPP-CHECK-NEXT:    tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
+// CPP-CHECK-NEXT:    tail call void @llvm.aarch64.sme.fp8.fmopa.za32(i32 3, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
+// CPP-CHECK-NEXT:    ret void
+//
+void test_svmopa_za32_mf8_m(svbool_t pn, svbool_t pm, svmfloat8_t zn,
+                            svmfloat8_t zm, fpm_t fpmr) __arm_streaming __arm_inout("za") {
+  SVE_ACLE_FUNC(svmopa_za32,_mf8,_m_fpm)(3, pn, pm, zn, zm, fpmr);
+}
diff --git a/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_imm.c b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_imm.c
new file mode 100644
index 00000000000000..62cad9cfa4c8fd
--- /dev/null
+++ b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_imm.c
@@ -0,0 +1,18 @@
+// RUN: %clang_cc1 -triple aarch64 -target-feature +sme -target-feature +sme2 -target-feature +sme-f8f16 -target-feature +sme-f8f32 -fsyntax-only -verify  %s
+
+// REQUIRES: aarch64-registered-target
+
+#include <arm_sme.h>
+
+void test_svmopa(svbool_t pn, svbool_t pm, svmfloat8_t zn, svmfloat8_t zm,
+                 fpm_t fpmr) __arm_streaming __arm_inout("za") {
+    // expected-error@+1 {{argument value 18446744073709551615 is outside the valid range [0, 1]}}
+    svmopa_za16_mf8_m_fpm(-1, pn, pm, zn, zm, fpmr);
+    // expected-error@+1 {{argument value 2 is outside the valid range [0, 1]}}
+    svmopa_za16_mf8_m_fpm(2, pn, pm, zn, zm, fpmr);
+
+    // expected-error@+1 {{argument value 18446744073709551615 is outside the valid range [0, 3]}}
+    svmopa_za32_mf8_m_fpm(-1, pn, pm, zn, zm, fpmr);
+    // expected-error@+1 {{argument value 4 is outside the valid range [0, 3]}}
+    svmopa_za32_mf8_m_fpm(4, pn, pm, zn, zm, fpmr);
+}
diff --git a/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mopa.c b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mopa.c
new file mode 100644
index 00000000000000..86426abcd43291
--- /dev/null
+++ b/clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mopa.c
@@ -0,0 +1,13 @@
+// RUN: %clang_cc1 -triple aarch64 -target-feature +sme -verify -emit-llvm-only %s
+
+// REQUIRES: aarch64-registered-target
+
+#include <arm_sme.h>
+
+void test_features(svbool_t pn, svbool_t pm, svmfloat8_t zn, svmfloat8_t zm,
+                   fpm_t fpmr) __arm_streaming __arm_inout("za") {
+    // expected-error@+1 {{'svmopa_za16_mf8_m_fpm' needs target feature sme,sme-f8f16}}
+    svmopa_za16_mf8_m_fpm(0, pn, pm, zn, zm, fpmr);
+    // expected-error@+1 {{'svmopa_za32_mf8_m_fpm' needs target feature sme,sme-f8f32}}
+    svmopa_za32_mf8_m_fpm(0, pn, pm, zn, zm, fpmr);
+}
diff --git a/clang/utils/TableGen/SveEmitter.cpp b/clang/utils/TableGen/SveEmitter.cpp
index e9fa01ea98dced..e24e93e8f29d8f 100644
--- a/clang/utils/TableGen/SveEmitter.cpp
+++ b/clang/utils/TableGen/SveEmitter.cpp
@@ -587,7 +587,6 @@ void SVEType::applyTypespec(StringRef TS) {
       ElementBitwidth = 16;
       break;
     case 'm':
-      Signed = false;
       MFloat = true;
       Float = false;
       BFloat = false;
@@ -702,6 +701,7 @@ void SVEType::applyModifier(char Mod) {
     Svcount = false;
     Float = false;
     BFloat = false;
+    MFloat = false;
     ElementBitwidth = Bitwidth = 64;
     NumVectors = 0;
     Signed = false;
@@ -712,6 +712,7 @@ void SVEType::applyModifier(char Mod) {
     Svcount = false;
     Float = false;
     BFloat = false;
+    MFloat = false;
     ElementBitwidth = Bitwidth = 32;
     NumVectors = 0;
     Signed = true;
@@ -723,6 +724,7 @@ void SVEType::applyModifier(char Mod) {
     Svcount = false;
     Float = false;
     BFloat = false;
+    MFloat = false;
     ElementBitwidth = Bitwidth = 32;
     NumVectors = 0;
     Signed = true;
@@ -735,6 +737,7 @@ void SVEType::applyModifier(char Mod) {
     Signed = true;
     Float = false;
     BFloat = false;
+    MFloat = false;
     ElementBitwidth = Bitwidth = 32;
     NumVectors = 0;
     break;
@@ -744,6 +747,7 @@ void SVEType::applyModifier(char Mod) {
     Signed = true;
     Float = false;
     BFloat = false;
+    MFloat = false;
     ElementBitwidth = Bitwidth = 64;
     NumVectors = 0;
     break;
@@ -753,6 +757,7 @@ void SVEType::applyModifier(char Mod) {
     Signed = false;
     Float = false;
     BFloat = false;
+    MFloat = false;
     ElementBitwidth = Bitwidth = 32;
     NumVectors = 0;
     break;
@@ -765,6 +770,7 @@ void SVEType::applyModifier(char Mod) {
     Signed = false;
     Float = false;
     BFloat = false;
+    MFloat = false;
     ElementBitwidth = Bitwidth = 64;
     NumVectors = 0;
     break;
@@ -783,6 +789,7 @@ void SVEType::applyModifier(char Mod) {
   case 'g':
     Signed = false;
     Float = false;
+    MFloat = false;
     BFloat = false;
     ElementBitwidth = 64;
     break;
@@ -790,18 +797,21 @@ void SVEType::applyModifier(char Mod) {
     Signed = false;
     Float = false;
     BFloat = false;
+    MFloat = false;
     ElementBitwidth = 8;
     break;
   case 't':
     Signed = true;
     Float = false;
     BFloat = false;
+    MFloat = false;
     ElementBitwidth = 32;
     break;
   case 'z':
     Signed = false;
     Float = false;
     BFloat = false;
+    MFloat = false;
     ElementBitwidth = 32;
     break;
   case 'O':
@@ -815,6 +825,7 @@ void SVEType::applyModifier(char Mod) {
     Svcount = false;
     Float = true;
     BFloat = false;
+    MFloat = false;
     ElementBitwidth = 32;
     break;
   case 'N':
@@ -922,6 +933,7 @@ void SVEType::applyModifier(char Mod) {
     Predicate = false;
     Svcount = false;
     Float = false;
+    MFloat = false;
     BFloat = true;
     ElementBitwidth = 16;
     break;
@@ -932,6 +944,7 @@ void SVEType::applyModifier(char Mod) {
     NumVectors = 0;
     Float = false;
     BFloat = false;
+    MFloat = false;
     break;
   case '~':
     Float = false;
diff --git a/llvm/include/llvm/IR/IntrinsicsAArch64.td b/llvm/include/llvm/IR/IntrinsicsAArch64.td
index a91616b9556828..0fde957ecbba6e 100644
--- a/llvm/include/llvm/IR/IntrinsicsAArch64.td
+++ b/llvm/include/llvm/IR/IntrinsicsAArch64.td
@@ -2983,6 +2983,13 @@ let TargetPrefix = "aarch64" in {
            LLVMMatchType<0>,
            llvm_anyvector_ty], [ImmArg<ArgIndex<0>>]>;
 
+  class SME_FP8_OuterProduct_Intrinsic
+      : DefaultAttrsIntrinsic<[],
+          [llvm_i32_ty,
+           llvm_nxv16i1_ty, llvm_nxv16i1_ty,
+           llvm_nxv16i8_ty, llvm_nxv16i8_ty],
+          [ImmArg<ArgIndex<0>>, IntrInaccessibleMemOnly, IntrHasSideEffects]>;
+
   def int_aarch64_sme_mopa : SME_OuterProduct_Intrinsic;
   def int_aarch64_sme_mops : SME_OuterProduct_Intrinsic;
 
@@ -2998,6 +3005,10 @@ let TargetPrefix = "aarch64" in {
   def int_aarch64_sme_usmopa_wide : SME_OuterProduct_Intrinsic;
   def int_aarch64_sme_usmops_wide : SME_OuterProduct_Intrinsic;
 
+  // FP8 outer product
+  def int_aarch64_sme_fp8_fmopa_za16 : SME_FP8_OuterProduct_Intrinsic;
+  def int_aarch64_sme_fp8_fmopa_za32 : SME_FP8_OuterProduct_Intrinsic;
+
   class SME_AddVectorToTile_Intrinsic
       : DefaultAttrsIntrinsic<[],
           [llvm_i32_ty,
diff --git a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
index 37ac915d1d8808..9c657787d3492b 100644
--- a/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
@@ -990,7 +990,7 @@ defm FDOT_VG2_M2ZZI_BtoH  : sme2p1_multi_vec_array_vg2_index_f8f16<"fdot",  0b11
 defm FDOT_VG4_M4ZZI_BtoH  : sme2p1_multi_vec_array_vg4_index_f8f16<"fdot",    0b100, ZZZZ_b_mul_r, ZPR4b8>;
 defm FDOT_VG2_M2ZZ_BtoH   :  sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0010001, MatrixOp16, ZZ_b, ZPR4b8>;
 defm FDOT_VG4_M4ZZ_BtoH   :  sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0110001, MatrixOp16, ZZZZ_b, ZPR4b8>;
-// TODO: Replace nxv16i8 by nxv16f8
+
 defm FDOT_VG2_M2Z2Z_BtoH  : sme2_dot_mla_add_sub_array_vg2_multi<"fdot",    0b0100100, MatrixOp16, ZZ_b_mul_r, nxv16i8, null_frag>;
 defm FDOT_VG4_M4Z4Z_BtoH  : sme2_dot_mla_add_sub_array_vg4_multi<"fdot",    0b0100100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, null_frag>;
 
@@ -998,23 +998,22 @@ def  FMLAL_MZZI_BtoH      : sme2_mla_ll_array_index_16b<"fmlal", 0b11, 0b00>;
 defm FMLAL_VG2_M2ZZI_BtoH : sme2_multi_vec_array_vg2_index_16b<"fmlal", 0b10, 0b111>;
 defm FMLAL_VG4_M4ZZI_BtoH : sme2_multi_vec_array_vg4_index_16b<"fmlal", 0b10, 0b110>;
 def  FMLAL_VG2_MZZ_BtoH   : sme2_mla_long_array_single_16b<"fmlal">;
-// TODO: Replace nxv16i8 by nxv16f8
+
 defm FMLAL_VG2_M2ZZ_BtoH  : sme2_fp_mla_long_array_vg2_single<"fmlal",  0b001, MatrixOp16, ZZ_b, ZPR4b8, nxv16i8, null_frag>;
 defm FMLAL_VG4_M4ZZ_BtoH  :  sme2_fp_mla_long_array_vg4_single<"fmlal", 0b001, MatrixOp16, ZZZZ_b, ZPR4b8, nxv16i8, null_frag>;
 defm FMLAL_VG2_M2Z2Z_BtoH : sme2_fp_mla_long_array_vg2_multi<"fmlal",   0b100, MatrixOp16, ZZ_b_mul_r, nxv16i8, null_frag>;
 defm FMLAL_VG4_M4Z4Z_BtoH : sme2_fp_mla_long_array_vg4_multi<"fmlal",   0b100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, null_frag>;
 
-defm FMOPA_MPPZZ_BtoH     : sme2p1_fmop_tile_f8f16<"fmopa", 0b1, 0b0, 0b01>;
-
+defm FMOPA_MPPZZ_BtoH : sme2_fp8_fmopa_za16<"fmopa", int_aarch64_sme_fp8_fmopa_za16>;
 } //[HasSMEF8F16]
 
 let Predicates = [HasSMEF8F32] in {
-// TODO : Replace nxv16i8 by nxv16f8
+
 defm FDOT_VG2_M2ZZI_BtoS : sme2_multi_vec_array_vg2_index_32b<"fdot", 0b01, 0b0111, ZZ_b_mul_r, ZPR4b8, nxv16i8, null_frag>;
 defm FDOT_VG4_M4ZZI_BtoS : sme2_multi_vec_array_vg4_index_32b<"fdot", 0b0001, ZZZZ_b_mul_r, ZPR4b8, nxv16i8, null_frag>;
 defm FDOT_VG2_M2ZZ_BtoS  : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0010011, MatrixOp32, ZZ_b, ZPR4b8>;
 defm FDOT_VG4_M4ZZ_BtoS  : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0110011, MatrixOp32, ZZZZ_b, ZPR4b8>;
-// TODO : Replace nxv16i8 by nxv16f8
+
 defm FDOT_VG2_M2Z2Z_BtoS : sme2_dot_mla_add_sub_array_vg2_multi<"fdot",   0b0100110, MatrixOp32, ZZ_b_mul_r, nxv16i8, null_frag>;
 defm FDOT_VG4_M4Z4Z_BtoS : sme2_dot_mla_add_sub_array_vg4_multi<"fdot",   0b0100110, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, null_frag>;
 
@@ -1024,16 +1023,14 @@ def FVDOTT_VG4_M2ZZI_BtoS : sme2_fp8_multi_vec_array_vg4_index<"fvdott", 0b1>;
 defm FMLALL_MZZI_BtoS      : sme2_mla_ll_array_index_32b<"fmlall",     0b01, 0b000, null_frag>;
 defm FMLALL_VG2_M2ZZI_BtoS : sme2_mla_ll_array_vg2_index_32b<"fmlall", 0b10, 0b100, null_frag>;
 defm FMLALL_VG4_M4ZZI_BtoS : sme2_mla_ll_array_vg4_index_32b<"fmlall", 0b00, 0b1000, null_frag>;
-// TODO: Replace nxv16i8 by nxv16f8
+
 defm FMLALL_MZZ_BtoS       : sme2_mla_ll_array_single<"fmlall",      0b01000, MatrixOp32, ZPR8, ZPR4b8, nxv16i8, null_frag>;
 defm FMLALL_VG2_M2ZZ_BtoS  : sme2_mla_ll_array_vg24_single<"fmlall", 0b000001, MatrixOp32, ZZ_b, ZPR4b8>;
 defm FMLALL_VG4_M4ZZ_BtoS  : sme2_mla_ll_array_vg24_single<"fmlall", 0b010001, MatrixOp32, ZZZZ_b, ZPR4b8>;
 defm FMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"fmlall",   0b01000, MatrixOp32, ZZ_b_mul_r, nxv16i8, null_frag>;
 defm FMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"fmlall",   0b01000, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, null_frag>;
 
-
-defm FMOPA_MPPZZ_BtoS : sme_outer_product_fp32<0b0, 0b01, ZPR8, "fmopa", null_frag>;
-
+defm FMOPA_MPPZZ_BtoS : sme2_fp8_fmopa_za32<"fmopa", int_aarch64_sme_fp8_fmopa_za32>;
 } //[HasSMEF8F32]
 
 let Predicates = [HasSME2, HasSVEBFSCALE] in {
diff --git a/llvm/lib/Target/AArch64/SMEInstrFormats.td b/llvm/lib/Target/AArch64/SMEInstrFormats.td
index 776472e72af05a..e6535f957e2024 100644
--- a/llvm/lib/Target/AArch64/SMEInstrFormats.td
+++ b/llvm/lib/Target/AArch64/SMEInstrFormats.td
@@ -305,6 +305,21 @@ multiclass sme_outer_product_fp32<bit S, bits<2> sz, ZPRRegOp zpr_ty, string mne
   def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, op, timm32_0_3, nxv4i1, nxv4f32>;
 }
 
+multiclass sme2_fp8_fmopa_za32<string mnemonic, SDPatternOperator intrinsic> {
+    def NAME : sme_fp_outer_product_inst<0, 0b01, 0b00, TileOp32, ZPR8, mnemonic>, SMEPseudo2Instr<NAME, 1> {
+      bits<2> ZAda;
+      let Inst{1-0} = ZAda;
+      let Inst{2}   = 0b0;
+
+      let Uses = [FPMR, FPCR];
+    }
+
+    let mayStore = 1, mayLoad = 1 in
+    def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR8, SMEMatrixTileS>, SMEPseudo2Instr<NAME, 0>;
+
+    def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, intrinsic, timm32_0_3, nxv16i1, nxv16i8>;
+}
+
 multiclass sme_outer_product_fp64<bit S, string mnemonic, SDPatternOperator op> {
   def NAME : sme_fp_outer_product_inst<S, 0b10, 0b00, TileOp64, ZPR64, mnemonic>, SMEPseudo2Instr<NAME, 1> {
     bits<3> ZAda;
@@ -316,12 +331,19 @@ multiclass sme_outer_product_fp64<bit S, string mnemonic, SDPatternOperator op>
   def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, op, timm32_0_7, nxv2i1, nxv2f64>;
 }
 
-multiclass sme2p1_fmop_tile_f8f16<string mnemonic, bit bf, bit s, bits<2> op> {
-  def NAME : sme_fp_outer_product_inst<s, {0,bf}, op, TileOp16, ZPR8, mnemonic> {
+multiclass sme2_fp8_fmopa_za16<string mnemonic, SDPatternOperator intrinsic> {
+  def NAME : sme_fp_outer_product_inst<0, {0, 0b1}, 0b01, TileOp16, ZPR8, mnemonic>, SMEPseudo2Instr<NAME, 1> {
     bits<1> ZAda;
     let Inst{2-1} = 0b00;
     let Inst{0}   = ZAda;
+
+    let Uses = [FPMR, FPCR];
   }
+
+  let mayStore = 1, mayLoad = 1 in
+  def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR8, SMEMatrixTileH>, SMEPseudo2Instr<NAME, 0>;
+
+  def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, intrinsic, timm32_0_1, nxv16i1, nxv16i8>;
 }
 
 multiclass sme2p1_fmop_tile_fp16<string mnemonic, bit bf, bit s, ValueType vt, SDPatternOperator intrinsic = null_frag> {
diff --git a/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-fmopa.ll b/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-fmopa.ll
new file mode 100644
index 00000000000000..6e88cdf4e7fec3
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-fmopa.ll
@@ -0,0 +1,22 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
+; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme-f8f16,+sme-f8f32 -force-streaming < %s | FileCheck %s
+
+define void @test_fmopa_16(<vscale x 16 x i1> %pn, <vscale x 16 x i1> %pm, <vscale x 16 x i8> %vn, <vscale x 16 x i8> %vm) {
+; CHECK-LABEL: test_fmopa_16:
+; CHEC...
[truncated]

Copy link
Contributor

@CarolineConcatto CarolineConcatto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the patch Spencer.

llvm/lib/Target/AArch64/SMEInstrFormats.td Show resolved Hide resolved
Copy link
Contributor

@CarolineConcatto CarolineConcatto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@SpencerAbson SpencerAbson merged commit 99f6ca9 into llvm:main Dec 9, 2024
8 checks passed
broxigarchen pushed a commit to broxigarchen/llvm-project that referenced this pull request Dec 10, 2024
This patch implements the following intrinsics:

8-bit floating-point sum of outer products and accumulate.
``` c
  // Only if __ARM_FEATURE_SME_F8F16 != 0
    void svmopa_za16[_mf8]_m_fpm(uint64_t tile, svbool_t pn, svbool_t pm,
                                 svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm)
                                 __arm_streaming __arm_inout("za");

  // Only if __ARM_FEATURE_SME_F8F32 != 0
    void svmopa_za32[_mf8]_m_fpm(uint64_t tile, svbool_t pn, svbool_t pm,
                                 svmfloat8_t zn, svmfloat8_t zm, fpm_t fpm)
                                 __arm_streaming __arm_inout("za");
```

In accordance with: ARM-software/acle#323

Co-authored-by: Momchil Velikov momchil.velikov@arm.com
Co-authored-by: Marian Lukac marian.lukac@arm.com
@jthackray jthackray requested review from jthackray and removed request for momchil-velikov December 13, 2024 15:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 clang:codegen clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category llvm:ir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants