Skip to content

Commit 82c6b8f

Browse files
authored
[AArch64][SME] Spill p-regs as z-regs when streaming hazards are possible (#123752)
This patch adds a new option `-aarch64-enable-zpr-predicate-spills` (which is disabled by default), this option replaces predicate spills with vector spills in streaming[-compatible] functions. For example: ``` str p8, [sp, #7, mul vl] // 2-byte Folded Spill // ... ldr p8, [sp, #7, mul vl] // 2-byte Folded Reload ``` Becomes: ``` mov z0.b, p8/z, #1 str z0, [sp] // 16-byte Folded Spill // ... ldr z0, [sp] // 16-byte Folded Reload ptrue p4.b cmpne p8.b, p4/z, z0.b, #0 ``` This is done to avoid streaming memory hazards between FPR/vector and predicate spills, which currently occupy the same stack area even when the `-aarch64-stack-hazard-size` flag is set. This is implemented with two new pseudos SPILL_PPR_TO_ZPR_SLOT_PSEUDO and FILL_PPR_FROM_ZPR_SLOT_PSEUDO. The expansion of these pseudos handles scavenging the required registers (z0 in the above example) and, in the worst case spilling a register to an emergency stack slot in the expansion. The condition flags are also preserved around the `cmpne` in case they are live at the expansion point.
1 parent e82f938 commit 82c6b8f

11 files changed

+1474
-12
lines changed

Diff for: llvm/lib/Target/AArch64/AArch64FrameLowering.cpp

+308-5
Original file line numberDiff line numberDiff line change
@@ -1634,6 +1634,9 @@ static bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
16341634
case AArch64::STR_PXI:
16351635
case AArch64::LDR_ZXI:
16361636
case AArch64::LDR_PXI:
1637+
case AArch64::PTRUE_B:
1638+
case AArch64::CPY_ZPzI_B:
1639+
case AArch64::CMPNE_PPzZI_B:
16371640
return I->getFlag(MachineInstr::FrameSetup) ||
16381641
I->getFlag(MachineInstr::FrameDestroy);
16391642
}
@@ -3265,7 +3268,8 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
32653268
StrOpc = RPI.isPaired() ? AArch64::ST1B_2Z_IMM : AArch64::STR_ZXI;
32663269
break;
32673270
case RegPairInfo::PPR:
3268-
StrOpc = AArch64::STR_PXI;
3271+
StrOpc =
3272+
Size == 16 ? AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO : AArch64::STR_PXI;
32693273
break;
32703274
case RegPairInfo::VG:
32713275
StrOpc = AArch64::STRXui;
@@ -3494,7 +3498,8 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
34943498
LdrOpc = RPI.isPaired() ? AArch64::LD1B_2Z_IMM : AArch64::LDR_ZXI;
34953499
break;
34963500
case RegPairInfo::PPR:
3497-
LdrOpc = AArch64::LDR_PXI;
3501+
LdrOpc = Size == 16 ? AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO
3502+
: AArch64::LDR_PXI;
34983503
break;
34993504
case RegPairInfo::VG:
35003505
continue;
@@ -3720,6 +3725,14 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
37203725
continue;
37213726
}
37223727

3728+
// Always save P4 when PPR spills are ZPR-sized and a predicate above p8 is
3729+
// spilled. If all of p0-p3 are used as return values p4 is must be free
3730+
// to reload p8-p15.
3731+
if (RegInfo->getSpillSize(AArch64::PPRRegClass) == 16 &&
3732+
AArch64::PPR_p8to15RegClass.contains(Reg)) {
3733+
SavedRegs.set(AArch64::P4);
3734+
}
3735+
37233736
// MachO's compact unwind format relies on all registers being stored in
37243737
// pairs.
37253738
// FIXME: the usual format is actually better if unwinding isn't needed.
@@ -4159,8 +4172,295 @@ int64_t AArch64FrameLowering::assignSVEStackObjectOffsets(
41594172
true);
41604173
}
41614174

4175+
/// Attempts to scavenge a register from \p ScavengeableRegs given the used
4176+
/// registers in \p UsedRegs.
4177+
static Register tryScavengeRegister(LiveRegUnits const &UsedRegs,
4178+
BitVector const &ScavengeableRegs) {
4179+
for (auto Reg : ScavengeableRegs.set_bits()) {
4180+
if (UsedRegs.available(Reg))
4181+
return Reg;
4182+
}
4183+
return AArch64::NoRegister;
4184+
}
4185+
4186+
/// Propagates frame-setup/destroy flags from \p SourceMI to all instructions in
4187+
/// \p MachineInstrs.
4188+
static void propagateFrameFlags(MachineInstr &SourceMI,
4189+
ArrayRef<MachineInstr *> MachineInstrs) {
4190+
for (MachineInstr *MI : MachineInstrs) {
4191+
if (SourceMI.getFlag(MachineInstr::FrameSetup))
4192+
MI->setFlag(MachineInstr::FrameSetup);
4193+
if (SourceMI.getFlag(MachineInstr::FrameDestroy))
4194+
MI->setFlag(MachineInstr::FrameDestroy);
4195+
}
4196+
}
4197+
4198+
/// RAII helper class for scavenging or spilling a register. On construction
4199+
/// attempts to find a free register of class \p RC (given \p UsedRegs and \p
4200+
/// AllocatableRegs), if no register can be found spills \p SpillCandidate to \p
4201+
/// MaybeSpillFI to free a register. The free'd register is returned via the \p
4202+
/// FreeReg output parameter. On destruction, if there is a spill, its previous
4203+
/// value is reloaded. The spilling and scavenging is only valid at the
4204+
/// insertion point \p MBBI, this class should _not_ be used in places that
4205+
/// create or manipulate basic blocks, moving the expected insertion point.
4206+
struct ScopedScavengeOrSpill {
4207+
ScopedScavengeOrSpill(const ScopedScavengeOrSpill &) = delete;
4208+
ScopedScavengeOrSpill(ScopedScavengeOrSpill &&) = delete;
4209+
4210+
ScopedScavengeOrSpill(MachineFunction &MF, MachineBasicBlock &MBB,
4211+
MachineBasicBlock::iterator MBBI,
4212+
Register SpillCandidate, const TargetRegisterClass &RC,
4213+
LiveRegUnits const &UsedRegs,
4214+
BitVector const &AllocatableRegs,
4215+
std::optional<int> *MaybeSpillFI)
4216+
: MBB(MBB), MBBI(MBBI), RC(RC), TII(static_cast<const AArch64InstrInfo &>(
4217+
*MF.getSubtarget().getInstrInfo())),
4218+
TRI(*MF.getSubtarget().getRegisterInfo()) {
4219+
FreeReg = tryScavengeRegister(UsedRegs, AllocatableRegs);
4220+
if (FreeReg != AArch64::NoRegister)
4221+
return;
4222+
assert(MaybeSpillFI && "Expected emergency spill slot FI information "
4223+
"(attempted to spill in prologue/epilogue?)");
4224+
if (!MaybeSpillFI->has_value()) {
4225+
MachineFrameInfo &MFI = MF.getFrameInfo();
4226+
*MaybeSpillFI = MFI.CreateSpillStackObject(TRI.getSpillSize(RC),
4227+
TRI.getSpillAlign(RC));
4228+
}
4229+
FreeReg = SpillCandidate;
4230+
SpillFI = MaybeSpillFI->value();
4231+
TII.storeRegToStackSlot(MBB, MBBI, FreeReg, false, *SpillFI, &RC, &TRI,
4232+
Register());
4233+
}
4234+
4235+
bool hasSpilled() const { return SpillFI.has_value(); }
4236+
4237+
/// Returns the free register (found from scavenging or spilling a register).
4238+
Register freeRegister() const { return FreeReg; }
4239+
4240+
Register operator*() const { return freeRegister(); }
4241+
4242+
~ScopedScavengeOrSpill() {
4243+
if (hasSpilled())
4244+
TII.loadRegFromStackSlot(MBB, MBBI, FreeReg, *SpillFI, &RC, &TRI,
4245+
Register());
4246+
}
4247+
4248+
private:
4249+
MachineBasicBlock &MBB;
4250+
MachineBasicBlock::iterator MBBI;
4251+
const TargetRegisterClass &RC;
4252+
const AArch64InstrInfo &TII;
4253+
const TargetRegisterInfo &TRI;
4254+
Register FreeReg = AArch64::NoRegister;
4255+
std::optional<int> SpillFI;
4256+
};
4257+
4258+
/// Emergency stack slots for expanding SPILL_PPR_TO_ZPR_SLOT_PSEUDO and
4259+
/// FILL_PPR_FROM_ZPR_SLOT_PSEUDO.
4260+
struct EmergencyStackSlots {
4261+
std::optional<int> ZPRSpillFI;
4262+
std::optional<int> PPRSpillFI;
4263+
std::optional<int> GPRSpillFI;
4264+
};
4265+
4266+
/// Registers available for scavenging (ZPR, PPR3b, GPR).
4267+
struct ScavengeableRegs {
4268+
BitVector ZPRRegs;
4269+
BitVector PPR3bRegs;
4270+
BitVector GPRRegs;
4271+
};
4272+
4273+
static bool isInPrologueOrEpilogue(const MachineInstr &MI) {
4274+
return MI.getFlag(MachineInstr::FrameSetup) ||
4275+
MI.getFlag(MachineInstr::FrameDestroy);
4276+
}
4277+
4278+
/// Expands:
4279+
/// ```
4280+
/// SPILL_PPR_TO_ZPR_SLOT_PSEUDO $p0, %stack.0, 0
4281+
/// ```
4282+
/// To:
4283+
/// ```
4284+
/// $z0 = CPY_ZPzI_B $p0, 1, 0
4285+
/// STR_ZXI $z0, $stack.0, 0
4286+
/// ```
4287+
/// While ensuring a ZPR ($z0 in this example) is free for the predicate (
4288+
/// spilling if necessary).
4289+
static void expandSpillPPRToZPRSlotPseudo(MachineBasicBlock &MBB,
4290+
MachineInstr &MI,
4291+
const TargetRegisterInfo &TRI,
4292+
LiveRegUnits const &UsedRegs,
4293+
ScavengeableRegs const &SR,
4294+
EmergencyStackSlots &SpillSlots) {
4295+
MachineFunction &MF = *MBB.getParent();
4296+
auto *TII =
4297+
static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());
4298+
4299+
ScopedScavengeOrSpill ZPredReg(
4300+
MF, MBB, MI, AArch64::Z0, AArch64::ZPRRegClass, UsedRegs, SR.ZPRRegs,
4301+
isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.ZPRSpillFI);
4302+
4303+
SmallVector<MachineInstr *, 2> MachineInstrs;
4304+
const DebugLoc &DL = MI.getDebugLoc();
4305+
MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::CPY_ZPzI_B))
4306+
.addReg(*ZPredReg, RegState::Define)
4307+
.add(MI.getOperand(0))
4308+
.addImm(1)
4309+
.addImm(0)
4310+
.getInstr());
4311+
MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::STR_ZXI))
4312+
.addReg(*ZPredReg)
4313+
.add(MI.getOperand(1))
4314+
.addImm(MI.getOperand(2).getImm())
4315+
.setMemRefs(MI.memoperands())
4316+
.getInstr());
4317+
propagateFrameFlags(MI, MachineInstrs);
4318+
}
4319+
4320+
/// Expands:
4321+
/// ```
4322+
/// $p0 = FILL_PPR_FROM_ZPR_SLOT_PSEUDO %stack.0, 0
4323+
/// ```
4324+
/// To:
4325+
/// ```
4326+
/// $z0 = LDR_ZXI %stack.0, 0
4327+
/// $p0 = PTRUE_B 31, implicit $vg
4328+
/// $p0 = CMPNE_PPzZI_B $p0, $z0, 0, implicit-def $nzcv, implicit-def $nzcv
4329+
/// ```
4330+
/// While ensuring a ZPR ($z0 in this example) is free for the predicate (
4331+
/// spilling if necessary). If the status flags are in use at the point of
4332+
/// expansion they are preserved (by moving them to/from a GPR). This may cause
4333+
/// an additional spill if no GPR is free at the expansion point.
4334+
static bool expandFillPPRFromZPRSlotPseudo(MachineBasicBlock &MBB,
4335+
MachineInstr &MI,
4336+
const TargetRegisterInfo &TRI,
4337+
LiveRegUnits const &UsedRegs,
4338+
ScavengeableRegs const &SR,
4339+
EmergencyStackSlots &SpillSlots) {
4340+
MachineFunction &MF = *MBB.getParent();
4341+
auto *TII =
4342+
static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());
4343+
4344+
ScopedScavengeOrSpill ZPredReg(
4345+
MF, MBB, MI, AArch64::Z0, AArch64::ZPRRegClass, UsedRegs, SR.ZPRRegs,
4346+
isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.ZPRSpillFI);
4347+
4348+
ScopedScavengeOrSpill PredReg(
4349+
MF, MBB, MI, AArch64::P0, AArch64::PPR_3bRegClass, UsedRegs, SR.PPR3bRegs,
4350+
isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.PPRSpillFI);
4351+
4352+
// Elide NZCV spills if we know it is not used.
4353+
bool IsNZCVUsed = !UsedRegs.available(AArch64::NZCV);
4354+
std::optional<ScopedScavengeOrSpill> NZCVSaveReg;
4355+
if (IsNZCVUsed)
4356+
NZCVSaveReg.emplace(
4357+
MF, MBB, MI, AArch64::X0, AArch64::GPR64RegClass, UsedRegs, SR.GPRRegs,
4358+
isInPrologueOrEpilogue(MI) ? nullptr : &SpillSlots.GPRSpillFI);
4359+
SmallVector<MachineInstr *, 4> MachineInstrs;
4360+
const DebugLoc &DL = MI.getDebugLoc();
4361+
MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::LDR_ZXI))
4362+
.addReg(*ZPredReg, RegState::Define)
4363+
.add(MI.getOperand(1))
4364+
.addImm(MI.getOperand(2).getImm())
4365+
.setMemRefs(MI.memoperands())
4366+
.getInstr());
4367+
if (IsNZCVUsed)
4368+
MachineInstrs.push_back(
4369+
BuildMI(MBB, MI, DL, TII->get(AArch64::MRS))
4370+
.addReg(NZCVSaveReg->freeRegister(), RegState::Define)
4371+
.addImm(AArch64SysReg::NZCV)
4372+
.addReg(AArch64::NZCV, RegState::Implicit)
4373+
.getInstr());
4374+
MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::PTRUE_B))
4375+
.addReg(*PredReg, RegState::Define)
4376+
.addImm(31));
4377+
MachineInstrs.push_back(
4378+
BuildMI(MBB, MI, DL, TII->get(AArch64::CMPNE_PPzZI_B))
4379+
.addReg(MI.getOperand(0).getReg(), RegState::Define)
4380+
.addReg(*PredReg)
4381+
.addReg(*ZPredReg)
4382+
.addImm(0)
4383+
.addReg(AArch64::NZCV, RegState::ImplicitDefine)
4384+
.getInstr());
4385+
if (IsNZCVUsed)
4386+
MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::MSR))
4387+
.addImm(AArch64SysReg::NZCV)
4388+
.addReg(NZCVSaveReg->freeRegister())
4389+
.addReg(AArch64::NZCV, RegState::ImplicitDefine)
4390+
.getInstr());
4391+
4392+
propagateFrameFlags(MI, MachineInstrs);
4393+
return PredReg.hasSpilled();
4394+
}
4395+
4396+
/// Expands all FILL_PPR_FROM_ZPR_SLOT_PSEUDO and SPILL_PPR_TO_ZPR_SLOT_PSEUDO
4397+
/// operations within the MachineBasicBlock \p MBB.
4398+
static bool expandSMEPPRToZPRSpillPseudos(MachineBasicBlock &MBB,
4399+
const TargetRegisterInfo &TRI,
4400+
ScavengeableRegs const &SR,
4401+
EmergencyStackSlots &SpillSlots) {
4402+
LiveRegUnits UsedRegs(TRI);
4403+
UsedRegs.addLiveOuts(MBB);
4404+
bool HasPPRSpills = false;
4405+
for (MachineInstr &MI : make_early_inc_range(reverse(MBB))) {
4406+
UsedRegs.stepBackward(MI);
4407+
switch (MI.getOpcode()) {
4408+
case AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO:
4409+
HasPPRSpills |= expandFillPPRFromZPRSlotPseudo(MBB, MI, TRI, UsedRegs, SR,
4410+
SpillSlots);
4411+
MI.eraseFromParent();
4412+
break;
4413+
case AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO:
4414+
expandSpillPPRToZPRSlotPseudo(MBB, MI, TRI, UsedRegs, SR, SpillSlots);
4415+
MI.eraseFromParent();
4416+
break;
4417+
default:
4418+
break;
4419+
}
4420+
}
4421+
4422+
return HasPPRSpills;
4423+
}
4424+
41624425
void AArch64FrameLowering::processFunctionBeforeFrameFinalized(
41634426
MachineFunction &MF, RegScavenger *RS) const {
4427+
4428+
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
4429+
const TargetSubtargetInfo &TSI = MF.getSubtarget();
4430+
const TargetRegisterInfo &TRI = *TSI.getRegisterInfo();
4431+
4432+
// If predicates spills are 16-bytes we may need to expand
4433+
// SPILL_PPR_TO_ZPR_SLOT_PSEUDO/FILL_PPR_FROM_ZPR_SLOT_PSEUDO.
4434+
if (AFI->hasStackFrame() && TRI.getSpillSize(AArch64::PPRRegClass) == 16) {
4435+
auto ComputeScavengeableRegisters = [&](unsigned RegClassID) {
4436+
BitVector Regs = TRI.getAllocatableSet(MF, TRI.getRegClass(RegClassID));
4437+
assert(Regs.count() > 0 && "Expected scavengeable registers");
4438+
return Regs;
4439+
};
4440+
4441+
ScavengeableRegs SR{};
4442+
SR.ZPRRegs = ComputeScavengeableRegisters(AArch64::ZPRRegClassID);
4443+
// Only p0-7 are possible as the second operand of cmpne (needed for fills).
4444+
SR.PPR3bRegs = ComputeScavengeableRegisters(AArch64::PPR_3bRegClassID);
4445+
SR.GPRRegs = ComputeScavengeableRegisters(AArch64::GPR64RegClassID);
4446+
4447+
EmergencyStackSlots SpillSlots;
4448+
for (MachineBasicBlock &MBB : MF) {
4449+
// In the case we had to spill a predicate (in the range p0-p7) to reload
4450+
// a predicate (>= p8), additional spill/fill pseudos will be created.
4451+
// These need an additional expansion pass. Note: There will only be at
4452+
// most two expansion passes, as spilling/filling a predicate in the range
4453+
// p0-p7 never requires spilling another predicate.
4454+
for (int Pass = 0; Pass < 2; Pass++) {
4455+
bool HasPPRSpills =
4456+
expandSMEPPRToZPRSpillPseudos(MBB, TRI, SR, SpillSlots);
4457+
assert((Pass == 0 || !HasPPRSpills) && "Did not expect PPR spills");
4458+
if (!HasPPRSpills)
4459+
break;
4460+
}
4461+
}
4462+
}
4463+
41644464
MachineFrameInfo &MFI = MF.getFrameInfo();
41654465

41664466
assert(getStackGrowthDirection() == TargetFrameLowering::StackGrowsDown &&
@@ -4170,7 +4470,6 @@ void AArch64FrameLowering::processFunctionBeforeFrameFinalized(
41704470
int64_t SVEStackSize =
41714471
assignSVEStackObjectOffsets(MFI, MinCSFrameIndex, MaxCSFrameIndex);
41724472

4173-
AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
41744473
AFI->setStackSizeSVE(alignTo(SVEStackSize, 16U));
41754474
AFI->setMinMaxSVECSFrameIndex(MinCSFrameIndex, MaxCSFrameIndex);
41764475

@@ -5204,9 +5503,13 @@ void AArch64FrameLowering::emitRemarks(
52045503

52055504
unsigned RegTy = StackAccess::AccessType::GPR;
52065505
if (MFI.getStackID(FrameIdx) == TargetStackID::ScalableVector) {
5207-
if (AArch64::PPRRegClass.contains(MI.getOperand(0).getReg()))
5506+
// SPILL_PPR_TO_ZPR_SLOT_PSEUDO and FILL_PPR_FROM_ZPR_SLOT_PSEUDO
5507+
// spill/fill the predicate as a data vector (so are an FPR acess).
5508+
if (MI.getOpcode() != AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO &&
5509+
MI.getOpcode() != AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO &&
5510+
AArch64::PPRRegClass.contains(MI.getOperand(0).getReg())) {
52085511
RegTy = StackAccess::PPR;
5209-
else
5512+
} else
52105513
RegTy = StackAccess::FPR;
52115514
} else if (AArch64InstrInfo::isFpOrNEON(MI)) {
52125515
RegTy = StackAccess::FPR;

0 commit comments

Comments
 (0)