Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64][SME] Spill p-regs as z-regs when streaming hazards are possible #123752

Merged
merged 8 commits into from
Feb 3, 2025

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Jan 21, 2025

This patch adds a new option -aarch64-enable-zpr-predicate-spills (which is disabled by default), this option replaces predicate spills with vector spills in streaming[-compatible] functions.

For example:

str	p8, [sp, #7, mul vl]            // 2-byte Folded Spill
// ...
ldr	p8, [sp, #7, mul vl]            // 2-byte Folded Reload

Becomes:

mov	z0.b, p8/z, #1
str	z0, [sp]                        // 16-byte Folded Spill
// ...
ldr	z0, [sp]                        // 16-byte Folded Reload
ptrue	p4.b
cmpne	p8.b, p4/z, z0.b, #0

This is done to avoid streaming memory hazards between FPR/vector and predicate spills, which currently occupy the same stack area even when the -aarch64-stack-hazard-size flag is set.

This is implemented with two new pseudos SPILL_PPR_TO_ZPR_SLOT_PSEUDO and FILL_PPR_FROM_ZPR_SLOT_PSEUDO. The expansion of these pseudos handles scavenging the required registers (z0 in the above example) and, in the worst case spilling a register to an emergency stack slot in the expansion. The condition flags are also preserved around the cmpne in case they are live at the expansion point.

@llvmbot
Copy link
Member

llvmbot commented Jan 21, 2025

@llvm/pr-subscribers-tablegen

@llvm/pr-subscribers-backend-aarch64

Author: Benjamin Maxwell (MacDue)

Changes

This patch adds a new option -aarch64-enable-zpr-predicate-spills (which is disabled by default), this option replaces predicate spills with vector spills in streaming[-compatible] functions.

For example:

str	p8, [sp, #<!-- -->7, mul vl]            // 2-byte Folded Spill
// ...
ldr	p8, [sp, #<!-- -->7, mul vl]            // 2-byte Folded Reload

Becomes:

mov	z0.b, p8/z, #<!-- -->1
str	z0, [sp]                        // 16-byte Folded Spill
// ...
ldr	z0, [sp]                        // 16-byte Folded Reload
ptrue	p4.b
cmpne	p8.b, p4/z, z0.b, #<!-- -->0

This is done to avoid streaming memory hazards between FPR/vector and predicate spills, which currently occupy the same stack area even when the -aarch64-stack-hazard-size flag is set.

This is implemented with two new pseudos SPILL_PPR_TO_ZPR_SLOT_PSEUDO and FILL_PPR_FROM_ZPR_SLOT_PSEUDO. The expansion of these pseudos handles scavenging the required registers (z0 in the above example) and, in the worst case spilling a register to an emergency stack slot in the expansion. The condition flags are also preserved around the cmpne in case they are live at the expansion point.


Patch is 83.42 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123752.diff

10 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64FrameLowering.cpp (+331-4)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.cpp (+15-1)
  • (modified) llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp (+2-2)
  • (modified) llvm/lib/Target/AArch64/AArch64RegisterInfo.h (+1-1)
  • (modified) llvm/lib/Target/AArch64/AArch64RegisterInfo.td (+10-1)
  • (modified) llvm/lib/Target/AArch64/AArch64Subtarget.cpp (+22)
  • (modified) llvm/lib/Target/AArch64/AArch64Subtarget.h (+2)
  • (modified) llvm/lib/Target/AArch64/SMEInstrFormats.td (+14)
  • (added) llvm/test/CodeGen/AArch64/spill-fill-zpr-predicates.mir (+1035)
  • (modified) llvm/test/CodeGen/AArch64/ssve-stack-hazard-remarks.ll (+12-1)
diff --git a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
index eabe64361938b4..64c3ecaf21ea31 100644
--- a/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64FrameLowering.cpp
@@ -1630,6 +1630,9 @@ static bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
   case AArch64::STR_PXI:
   case AArch64::LDR_ZXI:
   case AArch64::LDR_PXI:
+  case AArch64::PTRUE_B:
+  case AArch64::CPY_ZPzI_B:
+  case AArch64::CMPNE_PPzZI_B:
     return I->getFlag(MachineInstr::FrameSetup) ||
            I->getFlag(MachineInstr::FrameDestroy);
   }
@@ -3261,7 +3264,8 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
       StrOpc = RPI.isPaired() ? AArch64::ST1B_2Z_IMM : AArch64::STR_ZXI;
       break;
     case RegPairInfo::PPR:
-      StrOpc = AArch64::STR_PXI;
+      StrOpc =
+          Size == 16 ? AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO : AArch64::STR_PXI;
       break;
     case RegPairInfo::VG:
       StrOpc = AArch64::STRXui;
@@ -3490,7 +3494,8 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
       LdrOpc = RPI.isPaired() ? AArch64::LD1B_2Z_IMM : AArch64::LDR_ZXI;
       break;
     case RegPairInfo::PPR:
-      LdrOpc = AArch64::LDR_PXI;
+      LdrOpc = Size == 16 ? AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO
+                          : AArch64::LDR_PXI;
       break;
     case RegPairInfo::VG:
       continue;
@@ -3716,6 +3721,14 @@ void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF,
       continue;
     }
 
+    // Always save P4 when PPR spills are ZPR-sized and a predicate above p8 is
+    // spilled. If all of p0-p3 are used as return values p4 is must be free
+    // to reload p8-p15.
+    if (RegInfo->getSpillSize(AArch64::PPRRegClass) == 16 &&
+        AArch64::PPR_p8to15RegClass.contains(Reg)) {
+      SavedRegs.set(AArch64::P4);
+    }
+
     // MachO's compact unwind format relies on all registers being stored in
     // pairs.
     // FIXME: the usual format is actually better if unwinding isn't needed.
@@ -4155,8 +4168,318 @@ int64_t AArch64FrameLowering::assignSVEStackObjectOffsets(
                                         true);
 }
 
+/// Attempts to scavenge a register from \p ScavengeableRegs given the used
+/// registers in \p UsedRegs.
+static Register tryScavengeRegister(LiveRegUnits const &UsedRegs,
+                                    BitVector const &ScavengeableRegs) {
+  for (auto Reg : ScavengeableRegs.set_bits()) {
+    if (UsedRegs.available(Reg))
+      return Reg;
+  }
+  return AArch64::NoRegister;
+}
+
+/// Propagates frame-setup/destroy flags from \p SourceMI to all instructions in
+/// \p MachineInstrs.
+static void propagateFrameFlags(MachineInstr &SourceMI,
+                                ArrayRef<MachineInstr *> MachineInstrs) {
+  for (MachineInstr *MI : MachineInstrs) {
+    if (SourceMI.getFlag(MachineInstr::FrameSetup))
+      MI->setFlag(MachineInstr::FrameSetup);
+    if (SourceMI.getFlag(MachineInstr::FrameDestroy))
+      MI->setFlag(MachineInstr::FrameDestroy);
+  }
+}
+
+/// RAII helper class for scavenging or spilling a register. On construction
+/// attempts to find a free register of class \p RC (given \p UsedRegs and \p
+/// AllocatableRegs), if no register can be found spills \p SpillCandidate to \p
+/// MaybeSpillFI to free a register. The free'd register is returned via the \p
+/// FreeReg output parameter. On destruction, if there is a spill, its previous
+/// value is reloaded. The spilling and scavenging is only valid at the
+/// insertion point \p MBBI, this class should _not_ be used in places that
+/// create or manipulate basic blocks, moving the expected insertion point.
+struct ScopedScavengeOrSpill {
+  ScopedScavengeOrSpill(const ScopedScavengeOrSpill &) = delete;
+  ScopedScavengeOrSpill(ScopedScavengeOrSpill &&) = delete;
+
+  ScopedScavengeOrSpill(MachineFunction &MF, MachineBasicBlock &MBB,
+                        MachineBasicBlock::iterator MBBI, Register &FreeReg,
+                        Register SpillCandidate, const TargetRegisterClass &RC,
+                        LiveRegUnits const &UsedRegs,
+                        BitVector const &AllocatableRegs,
+                        std::optional<int> &MaybeSpillFI)
+      : MBB(MBB), MBBI(MBBI), RC(RC), TII(static_cast<const AArch64InstrInfo &>(
+                                          *MF.getSubtarget().getInstrInfo())),
+        TRI(*MF.getSubtarget().getRegisterInfo()) {
+    FreeReg = tryScavengeRegister(UsedRegs, AllocatableRegs);
+    if (FreeReg != AArch64::NoRegister)
+      return;
+    if (!MaybeSpillFI) {
+      MachineFrameInfo &MFI = MF.getFrameInfo();
+      MaybeSpillFI = MFI.CreateSpillStackObject(TRI.getSpillSize(RC),
+                                                TRI.getSpillAlign(RC));
+    }
+    FreeReg = SpilledReg = SpillCandidate;
+    SpillFI = *MaybeSpillFI;
+    TII.storeRegToStackSlot(MBB, MBBI, SpilledReg, false, SpillFI, &RC, &TRI,
+                            Register());
+  }
+
+  bool hasSpilled() const { return SpilledReg != AArch64::NoRegister; }
+
+  ~ScopedScavengeOrSpill() {
+    if (hasSpilled())
+      TII.loadRegFromStackSlot(MBB, MBBI, SpilledReg, SpillFI, &RC, &TRI,
+                               Register());
+  }
+
+private:
+  MachineBasicBlock &MBB;
+  MachineBasicBlock::iterator MBBI;
+  const TargetRegisterClass &RC;
+  const AArch64InstrInfo &TII;
+  const TargetRegisterInfo &TRI;
+  Register SpilledReg = AArch64::NoRegister;
+  int SpillFI = -1;
+};
+
+/// Emergency stack slots for expanding SPILL_PPR_TO_ZPR_SLOT_PSEUDO and
+/// FILL_PPR_FROM_ZPR_SLOT_PSEUDO.
+struct EmergencyStackSlots {
+  std::optional<int> ZPRSpillFI;
+  std::optional<int> PPRSpillFI;
+  std::optional<int> GPRSpillFI;
+};
+
+/// Expands:
+/// ```
+/// SPILL_PPR_TO_ZPR_SLOT_PSEUDO $p0, %stack.0, 0
+/// ```
+/// To:
+/// ```
+/// $z0 = CPY_ZPzI_B $p0, 1, 0
+/// STR_ZXI $z0, $stack.0, 0
+/// ```
+/// While ensuring a ZPR ($z0 in this example) is free for the predicate (
+/// spilling if necessary).
+static void expandSpillPPRToZPRSlotPseudo(MachineBasicBlock &MBB,
+                                          MachineInstr &MI,
+                                          const TargetRegisterInfo &TRI,
+                                          LiveRegUnits const &UsedRegs,
+                                          BitVector const &ZPRRegs,
+                                          EmergencyStackSlots &SpillSlots) {
+  MachineFunction &MF = *MBB.getParent();
+  auto *TII =
+      static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());
+
+  Register ZPredReg = AArch64::NoRegister;
+  ScopedScavengeOrSpill FindZPRReg(MF, MBB, MachineBasicBlock::iterator(MI),
+                                   ZPredReg, AArch64::Z0, AArch64::ZPRRegClass,
+                                   UsedRegs, ZPRRegs, SpillSlots.ZPRSpillFI);
+
+#ifndef NDEBUG
+  bool InPrologueOrEpilogue = MI.getFlag(MachineInstr::FrameSetup) ||
+                              MI.getFlag(MachineInstr::FrameDestroy);
+  assert((!FindZPRReg.hasSpilled() || !InPrologueOrEpilogue) &&
+         "SPILL_PPR_TO_ZPR_SLOT_PSEUDO expansion should not spill in prologue "
+         "or epilogue");
+#endif
+
+  SmallVector<MachineInstr *, 2> MachineInstrs;
+  const DebugLoc &DL = MI.getDebugLoc();
+  MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::CPY_ZPzI_B))
+                              .addReg(ZPredReg, RegState::Define)
+                              .add(MI.getOperand(0))
+                              .addImm(1)
+                              .addImm(0)
+                              .getInstr());
+  MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::STR_ZXI))
+                              .addReg(ZPredReg)
+                              .add(MI.getOperand(1))
+                              .addImm(MI.getOperand(2).getImm())
+                              .setMemRefs(MI.memoperands())
+                              .getInstr());
+  propagateFrameFlags(MI, MachineInstrs);
+}
+
+/// Expands:
+/// ```
+/// $p0 = FILL_PPR_FROM_ZPR_SLOT_PSEUDO %stack.0, 0
+/// ```
+/// To:
+/// ```
+/// $z0 = LDR_ZXI %stack.0, 0
+/// $p0 = PTRUE_B 31, implicit $vg
+/// $p0 = CMPNE_PPzZI_B $p0, $z0, 0, implicit-def $nzcv, implicit-def $nzcv
+/// ```
+/// While ensuring a ZPR ($z0 in this example) is free for the predicate (
+/// spilling if necessary). If the status flags are in use at the point of
+/// expansion they are preserved (by moving them to/from a GPR). This may cause
+/// an additional spill if no GPR is free at the expansion point.
+static bool expandFillPPRFromZPRSlotPseudo(
+    MachineBasicBlock &MBB, MachineInstr &MI, const TargetRegisterInfo &TRI,
+    LiveRegUnits const &UsedRegs, BitVector const &ZPRRegs,
+    BitVector const &PPR3bRegs, BitVector const &GPRRegs,
+    EmergencyStackSlots &SpillSlots) {
+  MachineFunction &MF = *MBB.getParent();
+  auto *TII =
+      static_cast<const AArch64InstrInfo *>(MF.getSubtarget().getInstrInfo());
+
+  Register ZPredReg = AArch64::NoRegister;
+  ScopedScavengeOrSpill FindZPRReg(MF, MBB, MachineBasicBlock::iterator(MI),
+                                   ZPredReg, AArch64::Z0, AArch64::ZPRRegClass,
+                                   UsedRegs, ZPRRegs, SpillSlots.ZPRSpillFI);
+
+  Register PredReg = AArch64::NoRegister;
+  std::optional<ScopedScavengeOrSpill> FindPPR3bReg;
+  if (AArch64::PPR_3bRegClass.contains(MI.getOperand(0).getReg()))
+    PredReg = MI.getOperand(0).getReg();
+  else
+    FindPPR3bReg.emplace(MF, MBB, MachineBasicBlock::iterator(MI), PredReg,
+                         AArch64::P0, AArch64::PPR_3bRegClass, UsedRegs,
+                         PPR3bRegs, SpillSlots.PPRSpillFI);
+
+  // Elide NZCV spills if we know it is not used.
+  Register NZCVSaveReg = AArch64::NoRegister;
+  bool IsNZCVUsed = !UsedRegs.available(AArch64::NZCV);
+  std::optional<ScopedScavengeOrSpill> FindGPRReg;
+  if (IsNZCVUsed)
+    FindGPRReg.emplace(MF, MBB, MachineBasicBlock::iterator(MI), NZCVSaveReg,
+                       AArch64::X0, AArch64::GPR64RegClass, UsedRegs, GPRRegs,
+                       SpillSlots.GPRSpillFI);
+
+#ifndef NDEBUG
+  bool Spilled = FindZPRReg.hasSpilled() ||
+                 (FindPPR3bReg && FindPPR3bReg->hasSpilled()) ||
+                 (FindGPRReg && FindGPRReg->hasSpilled());
+  bool InPrologueOrEpilogue = MI.getFlag(MachineInstr::FrameSetup) ||
+                              MI.getFlag(MachineInstr::FrameDestroy);
+  assert((!Spilled || !InPrologueOrEpilogue) &&
+         "FILL_PPR_FROM_ZPR_SLOT_PSEUDO expansion should not spill in prologue "
+         "or epilogue");
+#endif
+
+  SmallVector<MachineInstr *, 4> MachineInstrs;
+  const DebugLoc &DL = MI.getDebugLoc();
+  MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::LDR_ZXI))
+                              .addReg(ZPredReg, RegState::Define)
+                              .add(MI.getOperand(1))
+                              .addImm(MI.getOperand(2).getImm())
+                              .setMemRefs(MI.memoperands())
+                              .getInstr());
+  if (IsNZCVUsed)
+    MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::MRS))
+                                .addReg(NZCVSaveReg, RegState::Define)
+                                .addImm(AArch64SysReg::NZCV)
+                                .addReg(AArch64::NZCV, RegState::Implicit)
+                                .getInstr());
+  MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::PTRUE_B))
+                              .addReg(PredReg, RegState::Define)
+                              .addImm(31));
+  MachineInstrs.push_back(
+      BuildMI(MBB, MI, DL, TII->get(AArch64::CMPNE_PPzZI_B))
+          .addReg(MI.getOperand(0).getReg(), RegState::Define)
+          .addReg(PredReg)
+          .addReg(ZPredReg)
+          .addImm(0)
+          .addReg(AArch64::NZCV, RegState::ImplicitDefine)
+          .getInstr());
+  if (IsNZCVUsed)
+    MachineInstrs.push_back(BuildMI(MBB, MI, DL, TII->get(AArch64::MSR))
+                                .addImm(AArch64SysReg::NZCV)
+                                .addReg(NZCVSaveReg)
+                                .addReg(AArch64::NZCV, RegState::ImplicitDefine)
+                                .getInstr());
+
+  propagateFrameFlags(MI, MachineInstrs);
+  return FindPPR3bReg && FindPPR3bReg->hasSpilled();
+}
+
+/// Expands all FILL_PPR_FROM_ZPR_SLOT_PSEUDO and SPILL_PPR_TO_ZPR_SLOT_PSEUDO
+/// operations within the MachineBasicBlock \p MBB.
+static bool expandSMEPPRToZPRSpillPseudos(MachineBasicBlock &MBB,
+                                          const TargetRegisterInfo &TRI,
+                                          BitVector const &ZPRRegs,
+                                          BitVector const &PPR3bRegs,
+                                          BitVector const &GPRRegs,
+                                          EmergencyStackSlots &SpillSlots) {
+  LiveRegUnits UsedRegs(TRI);
+  UsedRegs.addLiveOuts(MBB);
+  bool HasPPRSpills = false;
+  for (MachineInstr &MI : make_early_inc_range(reverse(MBB))) {
+    UsedRegs.stepBackward(MI);
+    switch (MI.getOpcode()) {
+    case AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO:
+      HasPPRSpills |= expandFillPPRFromZPRSlotPseudo(
+          MBB, MI, TRI, UsedRegs, ZPRRegs, PPR3bRegs, GPRRegs, SpillSlots);
+      MI.eraseFromParent();
+      break;
+    case AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO:
+      expandSpillPPRToZPRSlotPseudo(MBB, MI, TRI, UsedRegs, ZPRRegs,
+                                    SpillSlots);
+      MI.eraseFromParent();
+      break;
+    default:
+      break;
+    }
+  }
+
+  return HasPPRSpills;
+}
+
 void AArch64FrameLowering::processFunctionBeforeFrameFinalized(
     MachineFunction &MF, RegScavenger *RS) const {
+
+  AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
+  const TargetSubtargetInfo &TSI = MF.getSubtarget();
+  const TargetRegisterInfo &TRI = *TSI.getRegisterInfo();
+  if (AFI->hasStackFrame() && TRI.getSpillSize(AArch64::PPRRegClass) == 16) {
+    const uint32_t *CSRMask =
+        TRI.getCallPreservedMask(MF, MF.getFunction().getCallingConv());
+    const MachineFrameInfo &MFI = MF.getFrameInfo();
+    assert(MFI.isCalleeSavedInfoValid());
+
+    auto ComputeScavengeableRegisters = [&](unsigned RegClassID) {
+      BitVector ScavengeableRegs =
+          TRI.getAllocatableSet(MF, TRI.getRegClass(RegClassID));
+      if (CSRMask)
+        ScavengeableRegs.clearBitsInMask(CSRMask);
+      // TODO: Allow reusing callee-saved registers that have been saved.
+      return ScavengeableRegs;
+    };
+
+    // If predicates spills are 16-bytes we may need to expand
+    // SPILL_PPR_TO_ZPR_SLOT_PSEUDO/FILL_PPR_FROM_ZPR_SLOT_PSEUDO.
+    // These are handled separately as we need to compute register liveness to
+    // scavenge a ZPR and PPR during the expansion.
+    BitVector ZPRRegs = ComputeScavengeableRegisters(AArch64::ZPRRegClassID);
+    // Only p0-7 are possible as the second operand of cmpne (needed for fills).
+    BitVector PPR3bRegs =
+        ComputeScavengeableRegisters(AArch64::PPR_3bRegClassID);
+    BitVector GPRRegs = ComputeScavengeableRegisters(AArch64::GPR64RegClassID);
+
+    bool SpillsAboveP7 =
+        any_of(MFI.getCalleeSavedInfo(), [](const CalleeSavedInfo &CSI) {
+          return AArch64::PPR_p8to15RegClass.contains(CSI.getReg());
+        });
+    // We spill p4 in determineCalleeSaves() if a predicate above p8 is spilled,
+    // as it may be needed to reload callee saves (if p0-p3 are used as
+    // returns).
+    if (SpillsAboveP7)
+      PPR3bRegs.set(AArch64::P4);
+
+    EmergencyStackSlots SpillSlots;
+    for (MachineBasicBlock &MBB : MF) {
+      for (int Pass = 0; Pass < 2; Pass++) {
+        bool HasPPRSpills = expandSMEPPRToZPRSpillPseudos(
+            MBB, TRI, ZPRRegs, PPR3bRegs, GPRRegs, SpillSlots);
+        if (!HasPPRSpills)
+          break;
+      }
+    }
+  }
+
   MachineFrameInfo &MFI = MF.getFrameInfo();
 
   assert(getStackGrowthDirection() == TargetFrameLowering::StackGrowsDown &&
@@ -4166,7 +4489,6 @@ void AArch64FrameLowering::processFunctionBeforeFrameFinalized(
   int64_t SVEStackSize =
       assignSVEStackObjectOffsets(MFI, MinCSFrameIndex, MaxCSFrameIndex);
 
-  AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
   AFI->setStackSizeSVE(alignTo(SVEStackSize, 16U));
   AFI->setMinMaxSVECSFrameIndex(MinCSFrameIndex, MaxCSFrameIndex);
 
@@ -5200,7 +5522,12 @@ void AArch64FrameLowering::emitRemarks(
 
           unsigned RegTy = StackAccess::AccessType::GPR;
           if (MFI.getStackID(FrameIdx) == TargetStackID::ScalableVector) {
-            if (AArch64::PPRRegClass.contains(MI.getOperand(0).getReg()))
+            // SPILL_PPR_TO_ZPR_SLOT_PSEUDO and FILL_PPR_FROM_ZPR_SLOT_PSEUDO
+            // spill/fill the predicate as a data vector (so are an FPR acess).
+            if (!is_contained({AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO,
+                               AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO},
+                              MI.getOpcode()) &&
+                AArch64::PPRRegClass.contains(MI.getOperand(0).getReg()))
               RegTy = StackAccess::PPR;
             else
               RegTy = StackAccess::FPR;
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
index 6b8a7e9559e005..8d31760a6fb75d 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
@@ -81,7 +81,7 @@ static cl::opt<unsigned>
 AArch64InstrInfo::AArch64InstrInfo(const AArch64Subtarget &STI)
     : AArch64GenInstrInfo(AArch64::ADJCALLSTACKDOWN, AArch64::ADJCALLSTACKUP,
                           AArch64::CATCHRET),
-      RI(STI.getTargetTriple()), Subtarget(STI) {}
+      RI(STI.getTargetTriple(), STI.getHwMode()), Subtarget(STI) {}
 
 /// GetInstSize - Return the number of bytes of code the specified
 /// instruction may be.  This returns the maximum number of bytes.
@@ -2438,6 +2438,8 @@ unsigned AArch64InstrInfo::getLoadStoreImmIdx(unsigned Opc) {
   case AArch64::STZ2Gi:
   case AArch64::STZGi:
   case AArch64::TAGPstack:
+  case AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO:
+  case AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO:
     return 2;
   case AArch64::LD1B_D_IMM:
   case AArch64::LD1B_H_IMM:
@@ -4223,6 +4225,8 @@ bool AArch64InstrInfo::getMemOpInfo(unsigned Opcode, TypeSize &Scale,
     MinOffset = -256;
     MaxOffset = 254;
     break;
+  case AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO:
+  case AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO:
   case AArch64::LDR_ZXI:
   case AArch64::STR_ZXI:
     Scale = TypeSize::getScalable(16);
@@ -5354,6 +5358,11 @@ void AArch64InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
              "Unexpected register store without SVE store instructions");
       Opc = AArch64::STR_ZXI;
       StackID = TargetStackID::ScalableVector;
+    } else if (AArch64::PPRRegClass.hasSubClassEq(RC)) {
+      assert(Subtarget.isSVEorStreamingSVEAvailable() &&
+             "Unexpected predicate store without SVE store instructions");
+      Opc = AArch64::SPILL_PPR_TO_ZPR_SLOT_PSEUDO;
+      StackID = TargetStackID::ScalableVector;
     }
     break;
   case 24:
@@ -5528,6 +5537,11 @@ void AArch64InstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
              "Unexpected register load without SVE load instructions");
       Opc = AArch64::LDR_ZXI;
       StackID = TargetStackID::ScalableVector;
+    } else if (AArch64::PPRRegClass.hasSubClassEq(RC)) {
+      assert(Subtarget.isSVEorStreamingSVEAvailable() &&
+             "Unexpected predicate load without SVE load instructions");
+      Opc = AArch64::FILL_PPR_FROM_ZPR_SLOT_PSEUDO;
+      StackID = TargetStackID::ScalableVector;
     }
     break;
   case 24:
diff --git a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
index 5973b63b5a8024..e9730348ba58e5 100644
--- a/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
@@ -38,8 +38,8 @@ using namespace llvm;
 #define GET_REGINFO_TARGET_DESC
 #include "AArch64GenRegisterInfo.inc"
 
-AArch64RegisterInfo::AArch64RegisterInfo(const Triple &TT)
-    : AArch64GenRegisterInfo(AArch64::LR), TT(TT) {
+AArch64RegisterInfo::AArch64RegisterInfo(const Triple &TT, unsigned HwMode)
+    : AArch64GenRegisterInfo(AArch64::LR, 0, 0, 0, Hw...
[truncated]

; EXPAND-NEXT: $p6 = frame-destroy PTRUE_B 31, implicit $vg
; EXPAND-NEXT: $p6 = frame-destroy CMPNE_PPzZI_B $p6, $z0, 0, implicit-def $nzcv, implicit-def $nzcv
; EXPAND-NEXT: $z0 = frame-destroy LDR_ZXI $sp, 10 :: (load (s128) from %stack.2)
; EXPAND-NEXT: $p5 = frame-destroy PTRUE_B 31, implicit $vg
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I have a small follow-up patch ready that reduces the amount of ptrue instructions created.


// Note: This hardware mode is enabled in AArch64Subtarget::getHwModeSet()
// (without the use of the table-gen'd predicates).
def SMEWithStreamingMemoryHazards : HwMode<"", [Predicate<"false">]>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this not have to check the bitmask to verify bit 0 is set? (as you've set it in AArch64Subtarget)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The predicate and feature flags of the hardware mode are not used, the implementation in AArch64Subtarget override the default implementation (which only checks CPU features). The predicate is only used for hardware-mode specific DAG patterns (of which we have none (https://reviews.llvm.org/D146012).

Copy link
Member Author

@MacDue MacDue Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I initially attempted to use the predicate in table-gen to enable this mode, but was surprised to find out it's not actually used to enable the hardware mode).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The confusing part to me is that I don't see how the value of 1 << 0 then relates to RegInfo<16, 16, 16>.
What if AArch64Subtarget::getHwModeSet would set 1 << 1 instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can remove the << 0 (I was just doing like table-gen does it). But really, this is just selecting betweem hardware mode 0 (the default with 2 x vscale predicate spills) and hardware mode 1 (with 16 x vscale predicate predicate spills).

I think getHwModeSet returns a bitset (no bits set = default), bit 0 set = mode 1, bit 1 = mode 2 (and I think multiple bits can be set). The bits are chosen by table-gen, which does not seem to give them names.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So:

What if AArch64Subtarget::getHwModeSet would set 1 << 1 instead?

That'd active mode 2 and something would crash, because that does not exist.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I belief what you're saying, but it's odd to me that TableGen doesn't generate an enum for this, because this means we need to make the implicit assumption that SMEWithStreamingMemoryHazards == 1, even though this is not expressed anywhere.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated table-gen to make a AArch64HwModeBits::SMEWithZPRPredicateSpills enum automatically, which makes this: return to_underlying(AArch64HwModeBits::SMEWithZPRPredicateSpills);, which is much less magic :)

@MacDue MacDue force-pushed the predicate_z_regs_p1 branch from 7c72d0e to 2a74bc6 Compare January 24, 2025 13:08

// Note: This hardware mode is enabled in AArch64Subtarget::getHwModeSet()
// (without the use of the table-gen'd predicates).
def SMEWithStreamingMemoryHazards : HwMode<"", [Predicate<"false">]>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The confusing part to me is that I don't see how the value of 1 << 0 then relates to RegInfo<16, 16, 16>.
What if AArch64Subtarget::getHwModeSet would set 1 << 1 instead?

Copy link
Collaborator

@sdesmalen-arm sdesmalen-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @MacDue

…ible

This patch adds a new option `-aarch64-enable-zpr-predicate-spills`
(which is disabled by default), this option replaces predicate spills
with vector spills in streaming[-compatible] functions.

For example:

```
str	p8, [sp, llvm#7, mul vl]            // 2-byte Folded Spill
// ...
ldr	p8, [sp, llvm#7, mul vl]            // 2-byte Folded Reload
```

Becomes:

```
mov	z0.b, p8/z, llvm#1
str	z0, [sp]                        // 16-byte Folded Spill
// ...
ldr	z0, [sp]                        // 16-byte Folded Reload
ptrue	p4.b
cmpne	p8.b, p4/z, z0.b, #0
```

This is done to avoid streaming memory hazards between FPR/vector and
predicate spills, which currently occupy the same stack area even when
the `-aarch64-stack-hazard-size` flag is set.

This is implemented with two new pseudos SPILL_PPR_TO_ZPR_SLOT_PSEUDO
and FILL_PPR_FROM_ZPR_SLOT_PSEUDO. The expansion of these pseudos
handles scavenging the required registers (z0 in the above example) and,
in the worst case spilling a register to an emergency stack slot in the
expansion. The condition flags are also preserved around the `cmpne`
in case they are live at the expansion point.
Turns out LiveRegUnits will already reserve unsaved callee-saved
registers so we don't need to worry about doing this here.
@MacDue MacDue force-pushed the predicate_z_regs_p1 branch from 8ffa920 to 5110d11 Compare February 3, 2025 10:17
@MacDue MacDue merged commit 82c6b8f into llvm:main Feb 3, 2025
6 of 8 checks passed
@MacDue MacDue deleted the predicate_z_regs_p1 branch February 3, 2025 12:04
MacDue added a commit to MacDue/llvm-project that referenced this pull request Feb 10, 2025
…ible (llvm#123752)

This patch adds a new option `-aarch64-enable-zpr-predicate-spills`
(which is disabled by default), this option replaces predicate spills
with vector spills in streaming[-compatible] functions.

For example:

```
str	p8, [sp, llvm#7, mul vl]            // 2-byte Folded Spill
// ...
ldr	p8, [sp, llvm#7, mul vl]            // 2-byte Folded Reload
```

Becomes:

```
mov	z0.b, p8/z, llvm#1
str	z0, [sp]                        // 16-byte Folded Spill
// ...
ldr	z0, [sp]                        // 16-byte Folded Reload
ptrue	p4.b
cmpne	p8.b, p4/z, z0.b, #0
```

This is done to avoid streaming memory hazards between FPR/vector and
predicate spills, which currently occupy the same stack area even when
the `-aarch64-stack-hazard-size` flag is set.

This is implemented with two new pseudos SPILL_PPR_TO_ZPR_SLOT_PSEUDO
and FILL_PPR_FROM_ZPR_SLOT_PSEUDO. The expansion of these pseudos
handles scavenging the required registers (z0 in the above example) and,
in the worst case spilling a register to an emergency stack slot in the
expansion. The condition flags are also preserved around the `cmpne` in
case they are live at the expansion point.
swift-ci pushed a commit to swiftlang/llvm-project that referenced this pull request Feb 11, 2025
…ible (llvm#123752)

This patch adds a new option `-aarch64-enable-zpr-predicate-spills`
(which is disabled by default), this option replaces predicate spills
with vector spills in streaming[-compatible] functions.

For example:

```
str	p8, [sp, #7, mul vl]            // 2-byte Folded Spill
// ...
ldr	p8, [sp, #7, mul vl]            // 2-byte Folded Reload
```

Becomes:

```
mov	z0.b, p8/z, #1
str	z0, [sp]                        // 16-byte Folded Spill
// ...
ldr	z0, [sp]                        // 16-byte Folded Reload
ptrue	p4.b
cmpne	p8.b, p4/z, z0.b, #0
```

This is done to avoid streaming memory hazards between FPR/vector and
predicate spills, which currently occupy the same stack area even when
the `-aarch64-stack-hazard-size` flag is set.

This is implemented with two new pseudos SPILL_PPR_TO_ZPR_SLOT_PSEUDO
and FILL_PPR_FROM_ZPR_SLOT_PSEUDO. The expansion of these pseudos
handles scavenging the required registers (z0 in the above example) and,
in the worst case spilling a register to an emergency stack slot in the
expansion. The condition flags are also preserved around the `cmpne` in
case they are live at the expansion point.
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…ible (llvm#123752)

This patch adds a new option `-aarch64-enable-zpr-predicate-spills`
(which is disabled by default), this option replaces predicate spills
with vector spills in streaming[-compatible] functions.

For example:

```
str	p8, [sp, llvm#7, mul vl]            // 2-byte Folded Spill
// ...
ldr	p8, [sp, llvm#7, mul vl]            // 2-byte Folded Reload
```

Becomes:

```
mov	z0.b, p8/z, llvm#1
str	z0, [sp]                        // 16-byte Folded Spill
// ...
ldr	z0, [sp]                        // 16-byte Folded Reload
ptrue	p4.b
cmpne	p8.b, p4/z, z0.b, #0
```

This is done to avoid streaming memory hazards between FPR/vector and
predicate spills, which currently occupy the same stack area even when
the `-aarch64-stack-hazard-size` flag is set.

This is implemented with two new pseudos SPILL_PPR_TO_ZPR_SLOT_PSEUDO
and FILL_PPR_FROM_ZPR_SLOT_PSEUDO. The expansion of these pseudos
handles scavenging the required registers (z0 in the above example) and,
in the worst case spilling a register to an emergency stack slot in the
expansion. The condition flags are also preserved around the `cmpne` in
case they are live at the expansion point.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants