From 64c8b66cc9972123c5f4aefe692c275898221aeb Mon Sep 17 00:00:00 2001 From: Kerry McLaughlin Date: Thu, 6 Jun 2024 11:02:19 +0100 Subject: [PATCH] [AArch64][SME] Add calling convention for __arm_get_current_vg (#93963) Adds a calling convention for calls to the `__arm_get_current_vg` support routine, which preserves X1-X15, X19-X29, SP, Z0-Z31 & P0-P15. See https://github.com/ARM-software/abi-aa/pull/263 --- llvm/include/llvm/AsmParser/LLToken.h | 1 + llvm/include/llvm/IR/CallingConv.h | 3 ++ llvm/lib/AsmParser/LLLexer.cpp | 1 + llvm/lib/AsmParser/LLParser.cpp | 4 +++ llvm/lib/IR/AsmWriter.cpp | 3 ++ .../AArch64/AArch64CallingConvention.td | 8 +++++ .../Target/AArch64/AArch64RegisterInfo.cpp | 33 ++++++++++++++++--- ...sme-support-routines-calling-convention.ll | 20 +++++++++++ 8 files changed, 68 insertions(+), 5 deletions(-) diff --git a/llvm/include/llvm/AsmParser/LLToken.h b/llvm/include/llvm/AsmParser/LLToken.h index 69821c22dcd619..db6780b70ca5aa 100644 --- a/llvm/include/llvm/AsmParser/LLToken.h +++ b/llvm/include/llvm/AsmParser/LLToken.h @@ -147,6 +147,7 @@ enum Kind { kw_aarch64_vector_pcs, kw_aarch64_sve_vector_pcs, kw_aarch64_sme_preservemost_from_x0, + kw_aarch64_sme_preservemost_from_x1, kw_aarch64_sme_preservemost_from_x2, kw_msp430_intrcc, kw_avr_intrcc, diff --git a/llvm/include/llvm/IR/CallingConv.h b/llvm/include/llvm/IR/CallingConv.h index a05d1a4d587845..55e32028e3ed08 100644 --- a/llvm/include/llvm/IR/CallingConv.h +++ b/llvm/include/llvm/IR/CallingConv.h @@ -267,6 +267,9 @@ namespace CallingConv { /// Calling convention used for RISC-V V-extension. RISCV_VectorCall = 110, + /// Preserve X1-X15, X19-X29, SP, Z0-Z31, P0-P15. + AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1 = 111, + /// The highest possible ID. Must be some 2^k - 1. MaxID = 1023 }; diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp index d3ab306904da12..7d7fe19568e8a6 100644 --- a/llvm/lib/AsmParser/LLLexer.cpp +++ b/llvm/lib/AsmParser/LLLexer.cpp @@ -604,6 +604,7 @@ lltok::Kind LLLexer::LexIdentifier() { KEYWORD(aarch64_vector_pcs); KEYWORD(aarch64_sve_vector_pcs); KEYWORD(aarch64_sme_preservemost_from_x0); + KEYWORD(aarch64_sme_preservemost_from_x1); KEYWORD(aarch64_sme_preservemost_from_x2); KEYWORD(msp430_intrcc); KEYWORD(avr_intrcc); diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp index 07c8aa23fc5e21..f0fde9ae4df5c3 100644 --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -2153,6 +2153,7 @@ void LLParser::parseOptionalDLLStorageClass(unsigned &Res) { /// ::= 'aarch64_vector_pcs' /// ::= 'aarch64_sve_vector_pcs' /// ::= 'aarch64_sme_preservemost_from_x0' +/// ::= 'aarch64_sme_preservemost_from_x1' /// ::= 'aarch64_sme_preservemost_from_x2' /// ::= 'msp430_intrcc' /// ::= 'avr_intrcc' @@ -2212,6 +2213,9 @@ bool LLParser::parseOptionalCallingConv(unsigned &CC) { case lltok::kw_aarch64_sme_preservemost_from_x0: CC = CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0; break; + case lltok::kw_aarch64_sme_preservemost_from_x1: + CC = CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1; + break; case lltok::kw_aarch64_sme_preservemost_from_x2: CC = CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2; break; diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp index 7a5f18fe2cbd52..0bf8be9ac55f9d 100644 --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -326,6 +326,9 @@ static void PrintCallingConv(unsigned cc, raw_ostream &Out) { case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0: Out << "aarch64_sme_preservemost_from_x0"; break; + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1: + Out << "aarch64_sme_preservemost_from_x1"; + break; case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2: Out << "aarch64_sme_preservemost_from_x2"; break; diff --git a/llvm/lib/Target/AArch64/AArch64CallingConvention.td b/llvm/lib/Target/AArch64/AArch64CallingConvention.td index 32646c6ee68913..941990c53c4a7f 100644 --- a/llvm/lib/Target/AArch64/AArch64CallingConvention.td +++ b/llvm/lib/Target/AArch64/AArch64CallingConvention.td @@ -589,6 +589,14 @@ def CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 (sequence "X%u",19, 28), LR, FP)>; +// SME ABI support routines such as __arm_get_current_vg preserve most registers. +def CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1 + : CalleeSavedRegs<(add (sequence "Z%u", 0, 31), + (sequence "P%u", 0, 15), + (sequence "X%u", 1, 15), + (sequence "X%u",19, 28), + LR, FP)>; + // SME ABI support routines __arm_sme_state preserves most registers. def CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 : CalleeSavedRegs<(add (sequence "Z%u", 0, 31), diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp index e97d7e3b6ed81e..cc50b59dd8d7e1 100644 --- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp @@ -107,13 +107,22 @@ AArch64RegisterInfo::getCalleeSavedRegs(const MachineFunction *MF) const { if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) report_fatal_error( - "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is " - "only supported to improve calls to SME ACLE save/restore/disable-za " + "Calling convention " + "AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is only " + "supported to improve calls to SME ACLE save/restore/disable-za " "functions, and is not intended to be used beyond that scope."); + if (MF->getFunction().getCallingConv() == + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1) + report_fatal_error( + "Calling convention " + "AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1 is " + "only supported to improve calls to SME ACLE __arm_get_current_vg " + "function, and is not intended to be used beyond that scope."); if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) report_fatal_error( - "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is " + "Calling convention " + "AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is " "only supported to improve calls to SME ACLE __arm_sme_state " "and is not intended to be used beyond that scope."); if (MF->getSubtarget().getTargetLowering() @@ -153,13 +162,22 @@ AArch64RegisterInfo::getDarwinCalleeSavedRegs(const MachineFunction *MF) const { if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) report_fatal_error( - "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is " + "Calling convention " + "AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0 is " "only supported to improve calls to SME ACLE save/restore/disable-za " "functions, and is not intended to be used beyond that scope."); + if (MF->getFunction().getCallingConv() == + CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1) + report_fatal_error( + "Calling convention " + "AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1 is " + "only supported to improve calls to SME ACLE __arm_get_current_vg " + "function, and is not intended to be used beyond that scope."); if (MF->getFunction().getCallingConv() == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) report_fatal_error( - "Calling convention AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is " + "Calling convention " + "AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2 is " "only supported to improve calls to SME ACLE __arm_sme_state " "and is not intended to be used beyond that scope."); if (MF->getFunction().getCallingConv() == CallingConv::CXX_FAST_TLS) @@ -236,6 +254,8 @@ AArch64RegisterInfo::getDarwinCallPreservedMask(const MachineFunction &MF, "Calling convention SVE_VectorCall is unsupported on Darwin."); if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0_RegMask; + if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1) + return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1_RegMask; if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2_RegMask; if (CC == CallingConv::CFGuard_Check) @@ -282,6 +302,8 @@ AArch64RegisterInfo::getCallPreservedMask(const MachineFunction &MF, : CSR_AArch64_SVE_AAPCS_RegMask; if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0) return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0_RegMask; + if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1) + return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1_RegMask; if (CC == CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2) return CSR_AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2_RegMask; if (CC == CallingConv::CFGuard_Check) @@ -643,6 +665,7 @@ bool AArch64RegisterInfo::isArgumentRegister(const MachineFunction &MF, case CallingConv::AArch64_VectorCall: case CallingConv::AArch64_SVE_VectorCall: case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0: + case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X1: case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2: if (STI.isTargetWindows()) return HasReg(CC_AArch64_Win64PCS_ArgRegs, Reg); diff --git a/llvm/test/CodeGen/AArch64/sme-support-routines-calling-convention.ll b/llvm/test/CodeGen/AArch64/sme-support-routines-calling-convention.ll index 7535638137caa8..63c65334afe119 100644 --- a/llvm/test/CodeGen/AArch64/sme-support-routines-calling-convention.ll +++ b/llvm/test/CodeGen/AArch64/sme-support-routines-calling-convention.ll @@ -25,6 +25,25 @@ define void @test_sme_calling_convention_x0() nounwind { ret void } +define i64 @test_sme_calling_convention_x1() nounwind { +; CHECK-LABEL: test_sme_calling_convention_x1: +; CHECK: // %bb.0: +; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill +; CHECK-NEXT: bl __arm_get_current_vg +; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload +; CHECK-NEXT: ret +; DARWIN-LABEL: test_sme_calling_convention_x1: +; DARWIN: stp x29, x30, [sp, #-16]! +; DARWIN: bl ___arm_get_current_vg +; DARWIN: ldp x29, x30, [sp], #16 +; DARWIN: ret +; +; CHECK-CSRMASK-LABEL: name: test_sme_calling_convention_x1 +; CHECK-CSRMASK: BL @__arm_get_current_vg, csr_aarch64_sme_abi_support_routines_preservemost_from_x1 + %vg = call aarch64_sme_preservemost_from_x1 i64 @__arm_get_current_vg() + ret i64 %vg +} + define i64 @test_sme_calling_convention_x2() nounwind { ; CHECK-LABEL: test_sme_calling_convention_x2: ; CHECK: // %bb.0: @@ -46,4 +65,5 @@ define i64 @test_sme_calling_convention_x2() nounwind { } declare void @__arm_tpidr2_save() +declare i64 @__arm_get_current_vg() declare {i64, i64} @__arm_sme_state()