Skip to content

Commit

Permalink
[AArch64] Implement intrinsics for SME FP8 FMOPA
Browse files Browse the repository at this point in the history
  • Loading branch information
SpencerAbson committed Nov 29, 2024
1 parent 55dd475 commit 010eae4
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 13 deletions.
10 changes: 10 additions & 0 deletions clang/include/clang/Basic/arm_sme.td
Original file line number Diff line number Diff line change
Expand Up @@ -824,4 +824,14 @@ let SMETargetGuard = "sme-lutv2" in {
def SVLUTI4_ZT_X4 : SInst<"svluti4_zt_{d}_x4", "4i2.u", "cUc", MergeNone, "aarch64_sme_luti4_zt_x4", [IsStreaming, IsInZT0], [ImmCheck<0, ImmCheck0_0>]>;
}

let SMETargetGuard = "sme-f8f32" in {
def SVMOPA_FP8_ZA32 : Inst<"svmopa_za32[_mf8]_m_fpm", "viPPdd>", "m", MergeNone, "aarch64_sme_fp8_fmopa_za32",
[IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], [ImmCheck<0, ImmCheck0_3>]>;
}

let SMETargetGuard = "sme-f8f16" in {
def SVMOPA_FP8_ZA16 : Inst<"svmopa_za16[_mf8]_m_fpm", "viPPdd>", "m", MergeNone, "aarch64_sme_fp8_fmopa_za16",
[IsStreaming, IsInOutZA, SetsFPMR, IsOverloadNone], [ImmCheck<0, ImmCheck0_1>]>;
}

} // let SVETargetGuard = InvalidMode
6 changes: 6 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10183,6 +10183,8 @@ CodeGenFunction::getSVEType(const SVETypeFlags &TypeFlags) {
case SVETypeFlags::EltTyInt64:
return llvm::ScalableVectorType::get(Builder.getInt64Ty(), 2);

case SVETypeFlags::EltTyMFloat8:
return llvm::ScalableVectorType::get(Builder.getInt8Ty(), 16);
case SVETypeFlags::EltTyFloat16:
return llvm::ScalableVectorType::get(Builder.getHalfTy(), 8);
case SVETypeFlags::EltTyBFloat16:
Expand Down Expand Up @@ -11234,6 +11236,10 @@ Value *CodeGenFunction::EmitAArch64SMEBuiltinExpr(unsigned BuiltinID,
BuiltinID == SME::BI__builtin_sme_svstr_za)
return EmitSMELdrStr(TypeFlags, Ops, Builtin->LLVMIntrinsic);

// Emit set FPMR for intrinsics that require it
if (TypeFlags.setsFPMR())
Builder.CreateCall(CGM.getIntrinsic(Intrinsic::aarch64_set_fpmr),
Ops.pop_back_val());
// Handle builtins which require their multi-vector operands to be swapped
swapCommutativeSMEOperands(BuiltinID, Ops);

Expand Down
55 changes: 55 additions & 0 deletions clang/test/CodeGen/AArch64/fp8-intrinsics/acle_sme2_fp8_fmopa.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
// REQUIRES: aarch64-registered-target

// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - -x c++ %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s -check-prefix=CPP-CHECK
// RUN: %clang_cc1 -DSVE_OVERLOADED_FORMS -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s
// RUN: %clang_cc1 -DSVE_OVERLOADED_FORMS -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -disable-O0-optnone -Werror -Wall -emit-llvm -o - -x c++ %s | opt -S -passes=mem2reg,tailcallelim | FileCheck %s -check-prefix=CPP-CHECK
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme-f8f16 -target-feature +sme-f8f32 -S -disable-O0-optnone -Werror -Wall -o /dev/null %s

#include <arm_sme.h>

#ifdef SVE_OVERLOADED_FORMS
#define SVE_ACLE_FUNC(A1,A2_UNUSED,A3) A1##A3
#else
#define SVE_ACLE_FUNC(A1,A2,A3) A1##A2##A3
#endif


// CHECK-LABEL: define dso_local void @test_svmopa_za16_mf8_m(
// CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0:[0-9]+]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
// CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmopa.za16(i32 1, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
// CHECK-NEXT: ret void
//
// CPP-CHECK-LABEL: define dso_local void @_Z22test_svmopa_za16_mf8_mu10__SVBool_tS_u13__SVMfloat8_tS0_m(
// CPP-CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0:[0-9]+]] {
// CPP-CHECK-NEXT: [[ENTRY:.*:]]
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmopa.za16(i32 1, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
// CPP-CHECK-NEXT: ret void
//
void test_svmopa_za16_mf8_m(svbool_t pn, svbool_t pm, svmfloat8_t zn,
svmfloat8_t zm, fpm_t fpmr) __arm_streaming __arm_inout("za") {
SVE_ACLE_FUNC(svmopa_za16,_mf8,_m_fpm)(1, pn, pm, zn, zm, fpmr);
}

// CHECK-LABEL: define dso_local void @test_svmopa_za32_mf8_m(
// CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
// CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmopa.za32(i32 3, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
// CHECK-NEXT: ret void
//
// CPP-CHECK-LABEL: define dso_local void @_Z22test_svmopa_za32_mf8_mu10__SVBool_tS_u13__SVMfloat8_tS0_m(
// CPP-CHECK-SAME: <vscale x 16 x i1> [[PN:%.*]], <vscale x 16 x i1> [[PM:%.*]], <vscale x 16 x i8> [[ZN:%.*]], <vscale x 16 x i8> [[ZM:%.*]], i64 noundef [[FPMR:%.*]]) #[[ATTR0]] {
// CPP-CHECK-NEXT: [[ENTRY:.*:]]
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.set.fpmr(i64 [[FPMR]])
// CPP-CHECK-NEXT: tail call void @llvm.aarch64.sme.fp8.fmopa.za32(i32 3, <vscale x 16 x i1> [[PN]], <vscale x 16 x i1> [[PM]], <vscale x 16 x i8> [[ZN]], <vscale x 16 x i8> [[ZM]])
// CPP-CHECK-NEXT: ret void
//
void test_svmopa_za32_mf8_m(svbool_t pn, svbool_t pm, svmfloat8_t zn,
svmfloat8_t zm, fpm_t fpmr) __arm_streaming __arm_inout("za") {
SVE_ACLE_FUNC(svmopa_za32,_mf8,_m_fpm)(3, pn, pm, zn, zm, fpmr);
}
18 changes: 18 additions & 0 deletions clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_imm.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: %clang_cc1 -triple aarch64 -target-feature +sme -target-feature +sme2 -target-feature +sme-f8f16 -target-feature +sme-f8f32 -fsyntax-only -verify %s

// REQUIRES: aarch64-registered-target

#include <arm_sme.h>

void test_svmopa(svbool_t pn, svbool_t pm, svmfloat8_t zn, svmfloat8_t zm,
fpm_t fpmr) __arm_streaming __arm_inout("za") {
// expected-error@+1 {{argument value 18446744073709551615 is outside the valid range [0, 1]}}
svmopa_za16_mf8_m_fpm(-1, pn, pm, zn, zm, fpmr);
// expected-error@+1 {{argument value 2 is outside the valid range [0, 1]}}
svmopa_za16_mf8_m_fpm(2, pn, pm, zn, zm, fpmr);

// expected-error@+1 {{argument value 18446744073709551615 is outside the valid range [0, 3]}}
svmopa_za32_mf8_m_fpm(-1, pn, pm, zn, zm, fpmr);
// expected-error@+1 {{argument value 4 is outside the valid range [0, 3]}}
svmopa_za32_mf8_m_fpm(4, pn, pm, zn, zm, fpmr);
}
13 changes: 13 additions & 0 deletions clang/test/Sema/aarch64-fp8-intrinsics/acle_sme2_fp8_mopa.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: %clang_cc1 -triple aarch64 -target-feature +sme -verify -emit-llvm-only %s

// REQUIRES: aarch64-registered-target

#include <arm_sme.h>

void test_features(svbool_t pn, svbool_t pm, svmfloat8_t zn, svmfloat8_t zm,
fpm_t fpmr) __arm_streaming __arm_inout("za") {
// expected-error@+1 {{'svmopa_za16_mf8_m_fpm' needs target feature sme,sme-f8f16}}
svmopa_za16_mf8_m_fpm(0, pn, pm, zn, zm, fpmr);
// expected-error@+1 {{'svmopa_za32_mf8_m_fpm' needs target feature sme,sme-f8f32}}
svmopa_za32_mf8_m_fpm(0, pn, pm, zn, zm, fpmr);
}
15 changes: 14 additions & 1 deletion clang/utils/TableGen/SveEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,6 @@ void SVEType::applyTypespec(StringRef TS) {
ElementBitwidth = 16;
break;
case 'm':
Signed = false;
MFloat = true;
Float = false;
BFloat = false;
Expand Down Expand Up @@ -702,6 +701,7 @@ void SVEType::applyModifier(char Mod) {
Svcount = false;
Float = false;
BFloat = false;
MFloat = false;
ElementBitwidth = Bitwidth = 64;
NumVectors = 0;
Signed = false;
Expand All @@ -712,6 +712,7 @@ void SVEType::applyModifier(char Mod) {
Svcount = false;
Float = false;
BFloat = false;
MFloat = false;
ElementBitwidth = Bitwidth = 32;
NumVectors = 0;
Signed = true;
Expand All @@ -723,6 +724,7 @@ void SVEType::applyModifier(char Mod) {
Svcount = false;
Float = false;
BFloat = false;
MFloat = false;
ElementBitwidth = Bitwidth = 32;
NumVectors = 0;
Signed = true;
Expand All @@ -735,6 +737,7 @@ void SVEType::applyModifier(char Mod) {
Signed = true;
Float = false;
BFloat = false;
MFloat = false;
ElementBitwidth = Bitwidth = 32;
NumVectors = 0;
break;
Expand All @@ -744,6 +747,7 @@ void SVEType::applyModifier(char Mod) {
Signed = true;
Float = false;
BFloat = false;
MFloat = false;
ElementBitwidth = Bitwidth = 64;
NumVectors = 0;
break;
Expand All @@ -753,6 +757,7 @@ void SVEType::applyModifier(char Mod) {
Signed = false;
Float = false;
BFloat = false;
MFloat = false;
ElementBitwidth = Bitwidth = 32;
NumVectors = 0;
break;
Expand All @@ -765,6 +770,7 @@ void SVEType::applyModifier(char Mod) {
Signed = false;
Float = false;
BFloat = false;
MFloat = false;
ElementBitwidth = Bitwidth = 64;
NumVectors = 0;
break;
Expand All @@ -783,25 +789,29 @@ void SVEType::applyModifier(char Mod) {
case 'g':
Signed = false;
Float = false;
MFloat = false;
BFloat = false;
ElementBitwidth = 64;
break;
case '[':
Signed = false;
Float = false;
BFloat = false;
MFloat = false;
ElementBitwidth = 8;
break;
case 't':
Signed = true;
Float = false;
BFloat = false;
MFloat = false;
ElementBitwidth = 32;
break;
case 'z':
Signed = false;
Float = false;
BFloat = false;
MFloat = false;
ElementBitwidth = 32;
break;
case 'O':
Expand All @@ -815,6 +825,7 @@ void SVEType::applyModifier(char Mod) {
Svcount = false;
Float = true;
BFloat = false;
MFloat = false;
ElementBitwidth = 32;
break;
case 'N':
Expand Down Expand Up @@ -922,6 +933,7 @@ void SVEType::applyModifier(char Mod) {
Predicate = false;
Svcount = false;
Float = false;
MFloat = false;
BFloat = true;
ElementBitwidth = 16;
break;
Expand All @@ -932,6 +944,7 @@ void SVEType::applyModifier(char Mod) {
NumVectors = 0;
Float = false;
BFloat = false;
MFloat = false;
break;
case '~':
Float = false;
Expand Down
11 changes: 11 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsAArch64.td
Original file line number Diff line number Diff line change
Expand Up @@ -2983,6 +2983,13 @@ let TargetPrefix = "aarch64" in {
LLVMMatchType<0>,
llvm_anyvector_ty], [ImmArg<ArgIndex<0>>]>;

class SME_FP8_OuterProduct_Intrinsic
: DefaultAttrsIntrinsic<[],
[llvm_i32_ty,
llvm_nxv16i1_ty, llvm_nxv16i1_ty,
llvm_nxv16i8_ty, llvm_nxv16i8_ty],
[ImmArg<ArgIndex<0>>, IntrInaccessibleMemOnly, IntrHasSideEffects]>;

def int_aarch64_sme_mopa : SME_OuterProduct_Intrinsic;
def int_aarch64_sme_mops : SME_OuterProduct_Intrinsic;

Expand All @@ -2998,6 +3005,10 @@ let TargetPrefix = "aarch64" in {
def int_aarch64_sme_usmopa_wide : SME_OuterProduct_Intrinsic;
def int_aarch64_sme_usmops_wide : SME_OuterProduct_Intrinsic;

// FP8 outer product
def int_aarch64_sme_fp8_fmopa_za16 : SME_FP8_OuterProduct_Intrinsic;
def int_aarch64_sme_fp8_fmopa_za32 : SME_FP8_OuterProduct_Intrinsic;

class SME_AddVectorToTile_Intrinsic
: DefaultAttrsIntrinsic<[],
[llvm_i32_ty,
Expand Down
17 changes: 7 additions & 10 deletions llvm/lib/Target/AArch64/AArch64SMEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -990,31 +990,30 @@ defm FDOT_VG2_M2ZZI_BtoH : sme2p1_multi_vec_array_vg2_index_f8f16<"fdot", 0b11
defm FDOT_VG4_M4ZZI_BtoH : sme2p1_multi_vec_array_vg4_index_f8f16<"fdot", 0b100, ZZZZ_b_mul_r, ZPR4b8>;
defm FDOT_VG2_M2ZZ_BtoH : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0010001, MatrixOp16, ZZ_b, ZPR4b8>;
defm FDOT_VG4_M4ZZ_BtoH : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0110001, MatrixOp16, ZZZZ_b, ZPR4b8>;
// TODO: Replace nxv16i8 by nxv16f8

defm FDOT_VG2_M2Z2Z_BtoH : sme2_dot_mla_add_sub_array_vg2_multi<"fdot", 0b0100100, MatrixOp16, ZZ_b_mul_r, nxv16i8, null_frag>;
defm FDOT_VG4_M4Z4Z_BtoH : sme2_dot_mla_add_sub_array_vg4_multi<"fdot", 0b0100100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, null_frag>;

def FMLAL_MZZI_BtoH : sme2_mla_ll_array_index_16b<"fmlal", 0b11, 0b00>;
defm FMLAL_VG2_M2ZZI_BtoH : sme2_multi_vec_array_vg2_index_16b<"fmlal", 0b10, 0b111>;
defm FMLAL_VG4_M4ZZI_BtoH : sme2_multi_vec_array_vg4_index_16b<"fmlal", 0b10, 0b110>;
def FMLAL_VG2_MZZ_BtoH : sme2_mla_long_array_single_16b<"fmlal">;
// TODO: Replace nxv16i8 by nxv16f8

defm FMLAL_VG2_M2ZZ_BtoH : sme2_fp_mla_long_array_vg2_single<"fmlal", 0b001, MatrixOp16, ZZ_b, ZPR4b8, nxv16i8, null_frag>;
defm FMLAL_VG4_M4ZZ_BtoH : sme2_fp_mla_long_array_vg4_single<"fmlal", 0b001, MatrixOp16, ZZZZ_b, ZPR4b8, nxv16i8, null_frag>;
defm FMLAL_VG2_M2Z2Z_BtoH : sme2_fp_mla_long_array_vg2_multi<"fmlal", 0b100, MatrixOp16, ZZ_b_mul_r, nxv16i8, null_frag>;
defm FMLAL_VG4_M4Z4Z_BtoH : sme2_fp_mla_long_array_vg4_multi<"fmlal", 0b100, MatrixOp16, ZZZZ_b_mul_r, nxv16i8, null_frag>;

defm FMOPA_MPPZZ_BtoH : sme2p1_fmop_tile_f8f16<"fmopa", 0b1, 0b0, 0b01>;

defm FMOPA_MPPZZ_BtoH : sme2_fp8_fmopa_za16<"fmopa", int_aarch64_sme_fp8_fmopa_za16>;
} //[HasSMEF8F16]

let Predicates = [HasSMEF8F32] in {
// TODO : Replace nxv16i8 by nxv16f8

defm FDOT_VG2_M2ZZI_BtoS : sme2_multi_vec_array_vg2_index_32b<"fdot", 0b01, 0b0111, ZZ_b_mul_r, ZPR4b8, nxv16i8, null_frag>;
defm FDOT_VG4_M4ZZI_BtoS : sme2_multi_vec_array_vg4_index_32b<"fdot", 0b0001, ZZZZ_b_mul_r, ZPR4b8, nxv16i8, null_frag>;
defm FDOT_VG2_M2ZZ_BtoS : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0010011, MatrixOp32, ZZ_b, ZPR4b8>;
defm FDOT_VG4_M4ZZ_BtoS : sme2_dot_mla_add_sub_array_vg24_single<"fdot", 0b0110011, MatrixOp32, ZZZZ_b, ZPR4b8>;
// TODO : Replace nxv16i8 by nxv16f8

defm FDOT_VG2_M2Z2Z_BtoS : sme2_dot_mla_add_sub_array_vg2_multi<"fdot", 0b0100110, MatrixOp32, ZZ_b_mul_r, nxv16i8, null_frag>;
defm FDOT_VG4_M4Z4Z_BtoS : sme2_dot_mla_add_sub_array_vg4_multi<"fdot", 0b0100110, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, null_frag>;

Expand All @@ -1024,16 +1023,14 @@ def FVDOTT_VG4_M2ZZI_BtoS : sme2_fp8_multi_vec_array_vg4_index<"fvdott", 0b1>;
defm FMLALL_MZZI_BtoS : sme2_mla_ll_array_index_32b<"fmlall", 0b01, 0b000, null_frag>;
defm FMLALL_VG2_M2ZZI_BtoS : sme2_mla_ll_array_vg2_index_32b<"fmlall", 0b10, 0b100, null_frag>;
defm FMLALL_VG4_M4ZZI_BtoS : sme2_mla_ll_array_vg4_index_32b<"fmlall", 0b00, 0b1000, null_frag>;
// TODO: Replace nxv16i8 by nxv16f8

defm FMLALL_MZZ_BtoS : sme2_mla_ll_array_single<"fmlall", 0b01000, MatrixOp32, ZPR8, ZPR4b8, nxv16i8, null_frag>;
defm FMLALL_VG2_M2ZZ_BtoS : sme2_mla_ll_array_vg24_single<"fmlall", 0b000001, MatrixOp32, ZZ_b, ZPR4b8>;
defm FMLALL_VG4_M4ZZ_BtoS : sme2_mla_ll_array_vg24_single<"fmlall", 0b010001, MatrixOp32, ZZZZ_b, ZPR4b8>;
defm FMLALL_VG2_M2Z2Z_BtoS : sme2_mla_ll_array_vg2_multi<"fmlall", 0b01000, MatrixOp32, ZZ_b_mul_r, nxv16i8, null_frag>;
defm FMLALL_VG4_M4Z4Z_BtoS : sme2_mla_ll_array_vg4_multi<"fmlall", 0b01000, MatrixOp32, ZZZZ_b_mul_r, nxv16i8, null_frag>;


defm FMOPA_MPPZZ_BtoS : sme_outer_product_fp32<0b0, 0b01, ZPR8, "fmopa", null_frag>;

defm FMOPA_MPPZZ_BtoS : sme2_fp8_fmopa_za32<"fmopa", int_aarch64_sme_fp8_fmopa_za32>;
} //[HasSMEF8F32]

let Predicates = [HasSME2, HasSVEBFSCALE] in {
Expand Down
26 changes: 24 additions & 2 deletions llvm/lib/Target/AArch64/SMEInstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,21 @@ multiclass sme_outer_product_fp32<bit S, bits<2> sz, ZPRRegOp zpr_ty, string mne
def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, op, timm32_0_3, nxv4i1, nxv4f32>;
}

multiclass sme2_fp8_fmopa_za32<string mnemonic, SDPatternOperator intrinsic> {
def NAME : sme_fp_outer_product_inst<0, 0b01, 0b00, TileOp32, ZPR8, mnemonic>, SMEPseudo2Instr<NAME, 1> {
bits<2> ZAda;
let Inst{1-0} = ZAda;
let Inst{2} = 0b0;

let Uses = [FPMR, FPCR];
}

let mayStore = 1, mayLoad = 1 in
def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR8, SMEMatrixTileS>, SMEPseudo2Instr<NAME, 0>;

def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, intrinsic, timm32_0_3, nxv16i1, nxv16i8>;
}

multiclass sme_outer_product_fp64<bit S, string mnemonic, SDPatternOperator op> {
def NAME : sme_fp_outer_product_inst<S, 0b10, 0b00, TileOp64, ZPR64, mnemonic>, SMEPseudo2Instr<NAME, 1> {
bits<3> ZAda;
Expand All @@ -316,12 +331,19 @@ multiclass sme_outer_product_fp64<bit S, string mnemonic, SDPatternOperator op>
def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, op, timm32_0_7, nxv2i1, nxv2f64>;
}

multiclass sme2p1_fmop_tile_f8f16<string mnemonic, bit bf, bit s, bits<2> op> {
def NAME : sme_fp_outer_product_inst<s, {0,bf}, op, TileOp16, ZPR8, mnemonic> {
multiclass sme2_fp8_fmopa_za16<string mnemonic, SDPatternOperator intrinsic> {
def NAME : sme_fp_outer_product_inst<0, {0, 0b1}, 0b01, TileOp16, ZPR8, mnemonic>, SMEPseudo2Instr<NAME, 1> {
bits<1> ZAda;
let Inst{2-1} = 0b00;
let Inst{0} = ZAda;

let Uses = [FPMR, FPCR];
}

let mayStore = 1, mayLoad = 1 in
def NAME # _PSEUDO : sme_outer_product_pseudo<ZPR8, SMEMatrixTileH>, SMEPseudo2Instr<NAME, 0>;

def : SME_ZA_Tile_TwoPred_TwoVec_Pat<NAME, intrinsic, timm32_0_1, nxv16i1, nxv16i8>;
}

multiclass sme2p1_fmop_tile_fp16<string mnemonic, bit bf, bit s, ValueType vt, SDPatternOperator intrinsic = null_frag> {
Expand Down
22 changes: 22 additions & 0 deletions llvm/test/CodeGen/AArch64/sme2-fp8-intrinsics-fmopa.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sme-f8f16,+sme-f8f32 -force-streaming < %s | FileCheck %s

define void @test_fmopa_16(<vscale x 16 x i1> %pn, <vscale x 16 x i1> %pm, <vscale x 16 x i8> %vn, <vscale x 16 x i8> %vm) {
; CHECK-LABEL: test_fmopa_16:
; CHECK: // %bb.0:
; CHECK-NEXT: fmopa za1.h, p0/m, p1/m, z0.b, z1.b
; CHECK-NEXT: ret
call void @llvm.aarch64.sme.fp8.fmopa.za16(i32 1, <vscale x 16 x i1> %pn, <vscale x 16 x i1> %pm,
<vscale x 16 x i8> %vn, <vscale x 16 x i8> %vm)
ret void
}

define void @test_fmopa_32(<vscale x 16 x i1> %pn, <vscale x 16 x i1> %pm, <vscale x 16 x i8> %vn, <vscale x 16 x i8> %vm) #0 {
; CHECK-LABEL: test_fmopa_32:
; CHECK: // %bb.0:
; CHECK-NEXT: fmopa za3.s, p0/m, p1/m, z0.b, z1.b
; CHECK-NEXT: ret
call void @llvm.aarch64.sme.fp8.fmopa.za32(i32 3, <vscale x 16 x i1> %pn, <vscale x 16 x i1> %pm,
<vscale x 16 x i8> %vn, <vscale x 16 x i8> %vm)
ret void
}

0 comments on commit 010eae4

Please sign in to comment.