@@ -157,6 +157,9 @@ class AArch64AsmPrinter : public AsmPrinter {
157
157
bool ShouldTrap,
158
158
const MCSymbol *OnFailure);
159
159
160
+ // Check authenticated LR before tail calling.
161
+ void emitPtrauthTailCallHardening (const MachineInstr *TC);
162
+
160
163
// Emit the sequence for AUT or AUTPAC.
161
164
void emitPtrauthAuthResign (const MachineInstr *MI);
162
165
@@ -1870,6 +1873,30 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
1870
1873
OutStreamer->emitLabel (SuccessSym);
1871
1874
}
1872
1875
1876
+ // With Pointer Authentication, it may be needed to explicitly check the
1877
+ // authenticated value in LR before performing a tail call.
1878
+ // Otherwise, the callee may re-sign the invalid return address,
1879
+ // introducing a signing oracle.
1880
+ void AArch64AsmPrinter::emitPtrauthTailCallHardening (const MachineInstr *TC) {
1881
+ if (!AArch64FI->shouldSignReturnAddress (*MF))
1882
+ return ;
1883
+
1884
+ auto LRCheckMethod = STI->getAuthenticatedLRCheckMethod (*MF);
1885
+ if (LRCheckMethod == AArch64PAuth::AuthCheckMethod::None)
1886
+ return ;
1887
+
1888
+ const AArch64RegisterInfo *TRI = STI->getRegisterInfo ();
1889
+ Register ScratchReg =
1890
+ TC->readsRegister (AArch64::X16, TRI) ? AArch64::X17 : AArch64::X16;
1891
+ assert (!TC->readsRegister (ScratchReg, TRI) &&
1892
+ " Neither x16 nor x17 is available as a scratch register" );
1893
+ AArch64PACKey::ID Key =
1894
+ AArch64FI->shouldSignWithBKey () ? AArch64PACKey::IB : AArch64PACKey::IA;
1895
+ emitPtrauthCheckAuthenticatedValue (
1896
+ AArch64::LR, ScratchReg, Key, LRCheckMethod,
1897
+ /* ShouldTrap=*/ true , /* OnFailure=*/ nullptr );
1898
+ }
1899
+
1873
1900
void AArch64AsmPrinter::emitPtrauthAuthResign (const MachineInstr *MI) {
1874
1901
const bool IsAUTPAC = MI->getOpcode () == AArch64::AUTPAC;
1875
1902
@@ -2312,27 +2339,6 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
2312
2339
OutStreamer->emitLabel (LOHLabel);
2313
2340
}
2314
2341
2315
- // With Pointer Authentication, it may be needed to explicitly check the
2316
- // authenticated value in LR when performing a tail call.
2317
- // Otherwise, the callee may re-sign the invalid return address,
2318
- // introducing a signing oracle.
2319
- auto CheckLRInTailCall = [this ](Register CallDestinationReg) {
2320
- if (!AArch64FI->shouldSignReturnAddress (*MF))
2321
- return ;
2322
-
2323
- auto LRCheckMethod = STI->getAuthenticatedLRCheckMethod (*MF);
2324
- if (LRCheckMethod == AArch64PAuth::AuthCheckMethod::None)
2325
- return ;
2326
-
2327
- Register ScratchReg =
2328
- CallDestinationReg == AArch64::X16 ? AArch64::X17 : AArch64::X16;
2329
- AArch64PACKey::ID Key =
2330
- AArch64FI->shouldSignWithBKey () ? AArch64PACKey::IB : AArch64PACKey::IA;
2331
- emitPtrauthCheckAuthenticatedValue (
2332
- AArch64::LR, ScratchReg, Key, LRCheckMethod,
2333
- /* ShouldTrap=*/ true , /* OnFailure=*/ nullptr );
2334
- };
2335
-
2336
2342
AArch64TargetStreamer *TS =
2337
2343
static_cast <AArch64TargetStreamer *>(OutStreamer->getTargetStreamer ());
2338
2344
// Do any manual lowerings.
@@ -2479,7 +2485,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
2479
2485
? AArch64::X17
2480
2486
: AArch64::X16;
2481
2487
2482
- CheckLRInTailCall (MI-> getOperand ( 0 ). getReg () );
2488
+ emitPtrauthTailCallHardening (MI);
2483
2489
2484
2490
unsigned DiscReg = AddrDisc;
2485
2491
if (Disc) {
@@ -2511,7 +2517,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
2511
2517
case AArch64::TCRETURNrix17:
2512
2518
case AArch64::TCRETURNrinotx16:
2513
2519
case AArch64::TCRETURNriALL: {
2514
- CheckLRInTailCall (MI-> getOperand ( 0 ). getReg () );
2520
+ emitPtrauthTailCallHardening (MI);
2515
2521
2516
2522
MCInst TmpInst;
2517
2523
TmpInst.setOpcode (AArch64::BR);
@@ -2520,7 +2526,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
2520
2526
return ;
2521
2527
}
2522
2528
case AArch64::TCRETURNdi: {
2523
- CheckLRInTailCall (AArch64::NoRegister );
2529
+ emitPtrauthTailCallHardening (MI );
2524
2530
2525
2531
MCOperand Dest;
2526
2532
MCInstLowering.lowerOperand (MI->getOperand (0 ), Dest);
0 commit comments