Skip to content

Commit 44076c9

Browse files
authored
[AArch64][PAC] Move emission of LR checks in tail calls to AsmPrinter (#110705)
Move the emission of the checks performed on the authenticated LR value during tail calls to AArch64AsmPrinter class, so that different checker sequences can be reused by pseudo instructions expanded there. This adds one more option to AuthCheckMethod enumeration, the generic XPAC variant which is not restricted to checking the LR register.
1 parent 469520e commit 44076c9

12 files changed

+367
-321
lines changed

Diff for: llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp

+120-31
Original file line numberDiff line numberDiff line change
@@ -152,9 +152,13 @@ class AArch64AsmPrinter : public AsmPrinter {
152152
void emitPtrauthCheckAuthenticatedValue(Register TestedReg,
153153
Register ScratchReg,
154154
AArch64PACKey::ID Key,
155+
AArch64PAuth::AuthCheckMethod Method,
155156
bool ShouldTrap,
156157
const MCSymbol *OnFailure);
157158

159+
// Check authenticated LR before tail calling.
160+
void emitPtrauthTailCallHardening(const MachineInstr *TC);
161+
158162
// Emit the sequence for AUT or AUTPAC.
159163
void emitPtrauthAuthResign(const MachineInstr *MI);
160164

@@ -1751,7 +1755,8 @@ unsigned AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
17511755
/// of proceeding to the next instruction (only if ShouldTrap is false).
17521756
void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
17531757
Register TestedReg, Register ScratchReg, AArch64PACKey::ID Key,
1754-
bool ShouldTrap, const MCSymbol *OnFailure) {
1758+
AArch64PAuth::AuthCheckMethod Method, bool ShouldTrap,
1759+
const MCSymbol *OnFailure) {
17551760
// Insert a sequence to check if authentication of TestedReg succeeded,
17561761
// such as:
17571762
//
@@ -1777,38 +1782,70 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
17771782
// Lsuccess:
17781783
// ...
17791784
//
1780-
// This sequence is expensive, but we need more information to be able to
1781-
// do better.
1782-
//
1783-
// We can't TBZ the poison bit because EnhancedPAC2 XORs the PAC bits
1784-
// on failure.
1785-
// We can't TST the PAC bits because we don't always know how the address
1786-
// space is setup for the target environment (and the bottom PAC bit is
1787-
// based on that).
1788-
// Either way, we also don't always know whether TBI is enabled or not for
1789-
// the specific target environment.
1785+
// See the documentation on AuthCheckMethod enumeration constants for
1786+
// the specific code sequences that can be used to perform the check.
1787+
using AArch64PAuth::AuthCheckMethod;
17901788

1791-
unsigned XPACOpc = getXPACOpcodeForKey(Key);
1789+
if (Method == AuthCheckMethod::None)
1790+
return;
1791+
if (Method == AuthCheckMethod::DummyLoad) {
1792+
EmitToStreamer(MCInstBuilder(AArch64::LDRWui)
1793+
.addReg(getWRegFromXReg(ScratchReg))
1794+
.addReg(TestedReg)
1795+
.addImm(0));
1796+
assert(ShouldTrap && !OnFailure && "DummyLoad always traps on error");
1797+
return;
1798+
}
17921799

17931800
MCSymbol *SuccessSym = createTempSymbol("auth_success_");
1801+
if (Method == AuthCheckMethod::XPAC || Method == AuthCheckMethod::XPACHint) {
1802+
// mov Xscratch, Xtested
1803+
emitMovXReg(ScratchReg, TestedReg);
17941804

1795-
// mov Xscratch, Xtested
1796-
emitMovXReg(ScratchReg, TestedReg);
1797-
1798-
// xpac(i|d) Xscratch
1799-
EmitToStreamer(MCInstBuilder(XPACOpc).addReg(ScratchReg).addReg(ScratchReg));
1805+
if (Method == AuthCheckMethod::XPAC) {
1806+
// xpac(i|d) Xscratch
1807+
unsigned XPACOpc = getXPACOpcodeForKey(Key);
1808+
EmitToStreamer(
1809+
MCInstBuilder(XPACOpc).addReg(ScratchReg).addReg(ScratchReg));
1810+
} else {
1811+
// xpaclri
1812+
1813+
// Note that this method applies XPAC to TestedReg instead of ScratchReg.
1814+
assert(TestedReg == AArch64::LR &&
1815+
"XPACHint mode is only compatible with checking the LR register");
1816+
assert((Key == AArch64PACKey::IA || Key == AArch64PACKey::IB) &&
1817+
"XPACHint mode is only compatible with I-keys");
1818+
EmitToStreamer(MCInstBuilder(AArch64::XPACLRI));
1819+
}
18001820

1801-
// cmp Xtested, Xscratch
1802-
EmitToStreamer(MCInstBuilder(AArch64::SUBSXrs)
1803-
.addReg(AArch64::XZR)
1804-
.addReg(TestedReg)
1805-
.addReg(ScratchReg)
1806-
.addImm(0));
1821+
// cmp Xtested, Xscratch
1822+
EmitToStreamer(MCInstBuilder(AArch64::SUBSXrs)
1823+
.addReg(AArch64::XZR)
1824+
.addReg(TestedReg)
1825+
.addReg(ScratchReg)
1826+
.addImm(0));
18071827

1808-
// b.eq Lsuccess
1809-
EmitToStreamer(MCInstBuilder(AArch64::Bcc)
1810-
.addImm(AArch64CC::EQ)
1811-
.addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
1828+
// b.eq Lsuccess
1829+
EmitToStreamer(
1830+
MCInstBuilder(AArch64::Bcc)
1831+
.addImm(AArch64CC::EQ)
1832+
.addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
1833+
} else if (Method == AuthCheckMethod::HighBitsNoTBI) {
1834+
// eor Xscratch, Xtested, Xtested, lsl #1
1835+
EmitToStreamer(MCInstBuilder(AArch64::EORXrs)
1836+
.addReg(ScratchReg)
1837+
.addReg(TestedReg)
1838+
.addReg(TestedReg)
1839+
.addImm(1));
1840+
// tbz Xscratch, #62, Lsuccess
1841+
EmitToStreamer(
1842+
MCInstBuilder(AArch64::TBZX)
1843+
.addReg(ScratchReg)
1844+
.addImm(62)
1845+
.addExpr(MCSymbolRefExpr::create(SuccessSym, OutContext)));
1846+
} else {
1847+
llvm_unreachable("Unsupported check method");
1848+
}
18121849

18131850
if (ShouldTrap) {
18141851
assert(!OnFailure && "Cannot specify OnFailure with ShouldTrap");
@@ -1822,9 +1859,26 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
18221859
// Note that this can introduce an authentication oracle (such as based on
18231860
// the high bits of the re-signed value).
18241861

1825-
// FIXME: Can we simply return the AUT result, already in TestedReg?
1826-
// mov Xtested, Xscratch
1827-
emitMovXReg(TestedReg, ScratchReg);
1862+
// FIXME: The XPAC method can be optimized by applying XPAC to TestedReg
1863+
// instead of ScratchReg, thus eliminating one `mov` instruction.
1864+
// Both XPAC and XPACHint can be further optimized by not using a
1865+
// conditional branch jumping over an unconditional one.
1866+
1867+
switch (Method) {
1868+
case AuthCheckMethod::XPACHint:
1869+
// LR is already XPAC-ed at this point.
1870+
break;
1871+
case AuthCheckMethod::XPAC:
1872+
// mov Xtested, Xscratch
1873+
emitMovXReg(TestedReg, ScratchReg);
1874+
break;
1875+
default:
1876+
// If Xtested was not XPAC-ed so far, emit XPAC here.
1877+
// xpac(i|d) Xtested
1878+
unsigned XPACOpc = getXPACOpcodeForKey(Key);
1879+
EmitToStreamer(
1880+
MCInstBuilder(XPACOpc).addReg(TestedReg).addReg(TestedReg));
1881+
}
18281882

18291883
if (OnFailure) {
18301884
// b Lend
@@ -1839,6 +1893,30 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
18391893
OutStreamer->emitLabel(SuccessSym);
18401894
}
18411895

1896+
// With Pointer Authentication, it may be needed to explicitly check the
1897+
// authenticated value in LR before performing a tail call.
1898+
// Otherwise, the callee may re-sign the invalid return address,
1899+
// introducing a signing oracle.
1900+
void AArch64AsmPrinter::emitPtrauthTailCallHardening(const MachineInstr *TC) {
1901+
if (!AArch64FI->shouldSignReturnAddress(*MF))
1902+
return;
1903+
1904+
auto LRCheckMethod = STI->getAuthenticatedLRCheckMethod(*MF);
1905+
if (LRCheckMethod == AArch64PAuth::AuthCheckMethod::None)
1906+
return;
1907+
1908+
const AArch64RegisterInfo *TRI = STI->getRegisterInfo();
1909+
Register ScratchReg =
1910+
TC->readsRegister(AArch64::X16, TRI) ? AArch64::X17 : AArch64::X16;
1911+
assert(!TC->readsRegister(ScratchReg, TRI) &&
1912+
"Neither x16 nor x17 is available as a scratch register");
1913+
AArch64PACKey::ID Key =
1914+
AArch64FI->shouldSignWithBKey() ? AArch64PACKey::IB : AArch64PACKey::IA;
1915+
emitPtrauthCheckAuthenticatedValue(
1916+
AArch64::LR, ScratchReg, Key, LRCheckMethod,
1917+
/*ShouldTrap=*/true, /*OnFailure=*/nullptr);
1918+
}
1919+
18421920
void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
18431921
const bool IsAUTPAC = MI->getOpcode() == AArch64::AUTPAC;
18441922

@@ -1850,7 +1928,7 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
18501928
// ; sign x16 (if AUTPAC)
18511929
// Lend: ; if not trapping on failure
18521930
//
1853-
// with the checking sequence chosen depending on whether we should check
1931+
// with the checking sequence chosen depending on whether/how we should check
18541932
// the pointer and whether we should trap on failure.
18551933

18561934
// By default, auth/resign sequences check for auth failures.
@@ -1910,6 +1988,7 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
19101988
EndSym = createTempSymbol("resign_end_");
19111989

19121990
emitPtrauthCheckAuthenticatedValue(AArch64::X16, AArch64::X17, AUTKey,
1991+
AArch64PAuth::AuthCheckMethod::XPAC,
19131992
ShouldTrap, EndSym);
19141993
}
19151994

@@ -2194,6 +2273,7 @@ void AArch64AsmPrinter::LowerMOVaddrPAC(const MachineInstr &MI) {
21942273
: AArch64PACKey::DA);
21952274

21962275
emitPtrauthCheckAuthenticatedValue(AArch64::X16, AArch64::X17, AuthKey,
2276+
AArch64PAuth::AuthCheckMethod::XPAC,
21972277
/*ShouldTrap=*/true,
21982278
/*OnFailure=*/nullptr);
21992279
}
@@ -2326,6 +2406,7 @@ void AArch64AsmPrinter::LowerLOADgotAUTH(const MachineInstr &MI) {
23262406
(AuthOpcode == AArch64::AUTIA ? AArch64PACKey::IA : AArch64PACKey::DA);
23272407

23282408
emitPtrauthCheckAuthenticatedValue(AuthResultReg, AArch64::X17, AuthKey,
2409+
AArch64PAuth::AuthCheckMethod::XPAC,
23292410
/*ShouldTrap=*/true,
23302411
/*OnFailure=*/nullptr);
23312412

@@ -2395,6 +2476,8 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
23952476
// Do any manual lowerings.
23962477
switch (MI->getOpcode()) {
23972478
default:
2479+
assert(!AArch64InstrInfo::isTailCallReturnInst(*MI) &&
2480+
"Unhandled tail call instruction");
23982481
break;
23992482
case AArch64::HINT: {
24002483
// CurrentPatchableFunctionEntrySym can be CurrentFnBegin only for
@@ -2538,6 +2621,8 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
25382621
? AArch64::X17
25392622
: AArch64::X16;
25402623

2624+
emitPtrauthTailCallHardening(MI);
2625+
25412626
unsigned DiscReg = AddrDisc;
25422627
if (Disc) {
25432628
if (AddrDisc != AArch64::NoRegister) {
@@ -2568,13 +2653,17 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
25682653
case AArch64::TCRETURNrix17:
25692654
case AArch64::TCRETURNrinotx16:
25702655
case AArch64::TCRETURNriALL: {
2656+
emitPtrauthTailCallHardening(MI);
2657+
25712658
MCInst TmpInst;
25722659
TmpInst.setOpcode(AArch64::BR);
25732660
TmpInst.addOperand(MCOperand::createReg(MI->getOperand(0).getReg()));
25742661
EmitToStreamer(*OutStreamer, TmpInst);
25752662
return;
25762663
}
25772664
case AArch64::TCRETURNdi: {
2665+
emitPtrauthTailCallHardening(MI);
2666+
25782667
MCOperand Dest;
25792668
MCInstLowering.lowerOperand(MI->getOperand(0), Dest);
25802669
MCInst TmpInst;

Diff for: llvm/lib/Target/AArch64/AArch64InstrInfo.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,19 @@ unsigned AArch64InstrInfo::getInstSizeInBytes(const MachineInstr &MI) const {
106106
unsigned NumBytes = 0;
107107
const MCInstrDesc &Desc = MI.getDesc();
108108

109+
if (!MI.isBundle() && isTailCallReturnInst(MI)) {
110+
NumBytes = Desc.getSize() ? Desc.getSize() : 4;
111+
112+
const auto *MFI = MF->getInfo<AArch64FunctionInfo>();
113+
if (!MFI->shouldSignReturnAddress(MF))
114+
return NumBytes;
115+
116+
const auto &STI = MF->getSubtarget<AArch64Subtarget>();
117+
auto Method = STI.getAuthenticatedLRCheckMethod(*MF);
118+
NumBytes += AArch64PAuth::getCheckerSizeInBytes(Method);
119+
return NumBytes;
120+
}
121+
109122
// Size should be preferably set in
110123
// llvm/lib/Target/AArch64/AArch64InstrInfo.td (default case).
111124
// Specific cases handle instructions of variable sizes

Diff for: llvm/lib/Target/AArch64/AArch64InstrInfo.td

+12-6
Original file line numberDiff line numberDiff line change
@@ -1964,30 +1964,36 @@ let Predicates = [HasPAuth] in {
19641964
}
19651965

19661966
// Size 16: 4 fixed + 8 variable, to compute discriminator.
1967+
// The size returned by getInstSizeInBytes() is incremented according
1968+
// to the variant of LR check.
1969+
// As the check requires either x16 or x17 as a scratch register and
1970+
// authenticated tail call instructions have two register operands,
1971+
// make sure at least one register is usable as a scratch one - for that
1972+
// purpose, use tcGPRnotx16x17 register class for one of the operands.
19671973
let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Size = 16,
19681974
Uses = [SP] in {
19691975
def AUTH_TCRETURN
1970-
: Pseudo<(outs), (ins tcGPR64:$dst, i32imm:$FPDiff, i32imm:$Key,
1976+
: Pseudo<(outs), (ins tcGPRnotx16x17:$dst, i32imm:$FPDiff, i32imm:$Key,
19711977
i64imm:$Disc, tcGPR64:$AddrDisc),
19721978
[]>, Sched<[WriteBrReg]>;
19731979
def AUTH_TCRETURN_BTI
19741980
: Pseudo<(outs), (ins tcGPRx16x17:$dst, i32imm:$FPDiff, i32imm:$Key,
1975-
i64imm:$Disc, tcGPR64:$AddrDisc),
1981+
i64imm:$Disc, tcGPRnotx16x17:$AddrDisc),
19761982
[]>, Sched<[WriteBrReg]>;
19771983
}
19781984

19791985
let Predicates = [TailCallAny] in
1980-
def : Pat<(AArch64authtcret tcGPR64:$dst, (i32 timm:$FPDiff), (i32 timm:$Key),
1986+
def : Pat<(AArch64authtcret tcGPRnotx16x17:$dst, (i32 timm:$FPDiff), (i32 timm:$Key),
19811987
(i64 timm:$Disc), tcGPR64:$AddrDisc),
1982-
(AUTH_TCRETURN tcGPR64:$dst, imm:$FPDiff, imm:$Key, imm:$Disc,
1988+
(AUTH_TCRETURN tcGPRnotx16x17:$dst, imm:$FPDiff, imm:$Key, imm:$Disc,
19831989
tcGPR64:$AddrDisc)>;
19841990

19851991
let Predicates = [TailCallX16X17] in
19861992
def : Pat<(AArch64authtcret tcGPRx16x17:$dst, (i32 timm:$FPDiff),
19871993
(i32 timm:$Key), (i64 timm:$Disc),
1988-
tcGPR64:$AddrDisc),
1994+
tcGPRnotx16x17:$AddrDisc),
19891995
(AUTH_TCRETURN_BTI tcGPRx16x17:$dst, imm:$FPDiff, imm:$Key,
1990-
imm:$Disc, tcGPR64:$AddrDisc)>;
1996+
imm:$Disc, tcGPRnotx16x17:$AddrDisc)>;
19911997
}
19921998

19931999
// v9.5-A pointer authentication extensions

0 commit comments

Comments
 (0)