Skip to content

Commit edaae6a

Browse files
committed
Check both register operands of AUTH_TCRETURN*
1 parent 089cc13 commit edaae6a

File tree

4 files changed

+42
-28
lines changed

4 files changed

+42
-28
lines changed

llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp

+30-24
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,9 @@ class AArch64AsmPrinter : public AsmPrinter {
157157
bool ShouldTrap,
158158
const MCSymbol *OnFailure);
159159

160+
// Check authenticated LR before tail calling.
161+
void emitPtrauthTailCallHardening(const MachineInstr *TC);
162+
160163
// Emit the sequence for AUT or AUTPAC.
161164
void emitPtrauthAuthResign(const MachineInstr *MI);
162165

@@ -1870,6 +1873,30 @@ void AArch64AsmPrinter::emitPtrauthCheckAuthenticatedValue(
18701873
OutStreamer->emitLabel(SuccessSym);
18711874
}
18721875

1876+
// With Pointer Authentication, it may be needed to explicitly check the
1877+
// authenticated value in LR before performing a tail call.
1878+
// Otherwise, the callee may re-sign the invalid return address,
1879+
// introducing a signing oracle.
1880+
void AArch64AsmPrinter::emitPtrauthTailCallHardening(const MachineInstr *TC) {
1881+
if (!AArch64FI->shouldSignReturnAddress(*MF))
1882+
return;
1883+
1884+
auto LRCheckMethod = STI->getAuthenticatedLRCheckMethod(*MF);
1885+
if (LRCheckMethod == AArch64PAuth::AuthCheckMethod::None)
1886+
return;
1887+
1888+
const AArch64RegisterInfo *TRI = STI->getRegisterInfo();
1889+
Register ScratchReg =
1890+
TC->readsRegister(AArch64::X16, TRI) ? AArch64::X17 : AArch64::X16;
1891+
assert(!TC->readsRegister(ScratchReg, TRI) &&
1892+
"Neither x16 nor x17 is available as a scratch register");
1893+
AArch64PACKey::ID Key =
1894+
AArch64FI->shouldSignWithBKey() ? AArch64PACKey::IB : AArch64PACKey::IA;
1895+
emitPtrauthCheckAuthenticatedValue(
1896+
AArch64::LR, ScratchReg, Key, LRCheckMethod,
1897+
/*ShouldTrap=*/true, /*OnFailure=*/nullptr);
1898+
}
1899+
18731900
void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
18741901
const bool IsAUTPAC = MI->getOpcode() == AArch64::AUTPAC;
18751902

@@ -2312,27 +2339,6 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
23122339
OutStreamer->emitLabel(LOHLabel);
23132340
}
23142341

2315-
// With Pointer Authentication, it may be needed to explicitly check the
2316-
// authenticated value in LR when performing a tail call.
2317-
// Otherwise, the callee may re-sign the invalid return address,
2318-
// introducing a signing oracle.
2319-
auto CheckLRInTailCall = [this](Register CallDestinationReg) {
2320-
if (!AArch64FI->shouldSignReturnAddress(*MF))
2321-
return;
2322-
2323-
auto LRCheckMethod = STI->getAuthenticatedLRCheckMethod(*MF);
2324-
if (LRCheckMethod == AArch64PAuth::AuthCheckMethod::None)
2325-
return;
2326-
2327-
Register ScratchReg =
2328-
CallDestinationReg == AArch64::X16 ? AArch64::X17 : AArch64::X16;
2329-
AArch64PACKey::ID Key =
2330-
AArch64FI->shouldSignWithBKey() ? AArch64PACKey::IB : AArch64PACKey::IA;
2331-
emitPtrauthCheckAuthenticatedValue(
2332-
AArch64::LR, ScratchReg, Key, LRCheckMethod,
2333-
/*ShouldTrap=*/true, /*OnFailure=*/nullptr);
2334-
};
2335-
23362342
AArch64TargetStreamer *TS =
23372343
static_cast<AArch64TargetStreamer *>(OutStreamer->getTargetStreamer());
23382344
// Do any manual lowerings.
@@ -2479,7 +2485,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
24792485
? AArch64::X17
24802486
: AArch64::X16;
24812487

2482-
CheckLRInTailCall(MI->getOperand(0).getReg());
2488+
emitPtrauthTailCallHardening(MI);
24832489

24842490
unsigned DiscReg = AddrDisc;
24852491
if (Disc) {
@@ -2511,7 +2517,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
25112517
case AArch64::TCRETURNrix17:
25122518
case AArch64::TCRETURNrinotx16:
25132519
case AArch64::TCRETURNriALL: {
2514-
CheckLRInTailCall(MI->getOperand(0).getReg());
2520+
emitPtrauthTailCallHardening(MI);
25152521

25162522
MCInst TmpInst;
25172523
TmpInst.setOpcode(AArch64::BR);
@@ -2520,7 +2526,7 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
25202526
return;
25212527
}
25222528
case AArch64::TCRETURNdi: {
2523-
CheckLRInTailCall(AArch64::NoRegister);
2529+
emitPtrauthTailCallHardening(MI);
25242530

25252531
MCOperand Dest;
25262532
MCInstLowering.lowerOperand(MI->getOperand(0), Dest);

llvm/lib/Target/AArch64/AArch64InstrInfo.td

+6-2
Original file line numberDiff line numberDiff line change
@@ -1905,15 +1905,19 @@ let Predicates = [HasPAuth] in {
19051905
// Size 16: 4 fixed + 8 variable, to compute discriminator.
19061906
// The size returned by getInstSizeInBytes() is incremented according
19071907
// to the variant of LR check.
1908+
// As the check requires either x16 or x17 as a scratch register and
1909+
// authenticated tail call instructions have two register operands,
1910+
// make sure at least one register is usable as a scratch one - for that
1911+
// purpose, use tcGPRnotx16x17 register class for the second operand.
19081912
let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1, Size = 16,
19091913
Uses = [SP] in {
19101914
def AUTH_TCRETURN
19111915
: Pseudo<(outs), (ins tcGPR64:$dst, i32imm:$FPDiff, i32imm:$Key,
1912-
i64imm:$Disc, tcGPR64:$AddrDisc),
1916+
i64imm:$Disc, tcGPRnotx16x17:$AddrDisc),
19131917
[]>, Sched<[WriteBrReg]>;
19141918
def AUTH_TCRETURN_BTI
19151919
: Pseudo<(outs), (ins tcGPRx16x17:$dst, i32imm:$FPDiff, i32imm:$Key,
1916-
i64imm:$Disc, tcGPR64:$AddrDisc),
1920+
i64imm:$Disc, tcGPRnotx16x17:$AddrDisc),
19171921
[]>, Sched<[WriteBrReg]>;
19181922
}
19191923

llvm/lib/Target/AArch64/AArch64RegisterInfo.td

+4
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,10 @@ def tcGPR64 : RegisterClass<"AArch64", [i64], 64, (sub GPR64common, X19, X20, X2
238238
def tcGPRx17 : RegisterClass<"AArch64", [i64], 64, (add X17)>;
239239
def tcGPRx16x17 : RegisterClass<"AArch64", [i64], 64, (add X16, X17)>;
240240
def tcGPRnotx16 : RegisterClass<"AArch64", [i64], 64, (sub tcGPR64, X16)>;
241+
// LR checking code expects either x16 or x17 to be available as a scratch
242+
// register - for that reason restrict one of two register operands of
243+
// AUTH_TCRETURN* pseudos.
244+
def tcGPRnotx16x17 : RegisterClass<"AArch64", [i64], 64, (sub tcGPR64, X16, X17)>;
241245

242246
// Register set that excludes registers that are reserved for procedure calls.
243247
// This is used for pseudo-instructions that are actually implemented using a

llvm/test/CodeGen/AArch64/ptrauth-call.ll

+2-2
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,9 @@ define void @test_tailcall_omit_mov_x16_x16(ptr %objptr) #0 {
173173
; CHECK: mov x17, x0
174174
; CHECK: movk x17, #6503, lsl #48
175175
; CHECK: autda x16, x17
176-
; CHECK: ldr x1, [x16]
176+
; CHECK: ldr x2, [x16]
177177
; CHECK: movk x16, #54167, lsl #48
178-
; CHECK: braa x1, x16
178+
; CHECK: braa x2, x16
179179
%vtable.signed = load ptr, ptr %objptr, align 8
180180
%objptr.int = ptrtoint ptr %objptr to i64
181181
%vtable.discr = tail call i64 @llvm.ptrauth.blend(i64 %objptr.int, i64 6503)

0 commit comments

Comments
 (0)