@@ -1509,6 +1509,11 @@ static bool IsSVECalleeSave(MachineBasicBlock::iterator I) {
15091509 switch (I->getOpcode ()) {
15101510 default :
15111511 return false ;
1512+ case AArch64::PTRUE_C_B:
1513+ case AArch64::LD1B_2Z_IMM:
1514+ case AArch64::ST1B_2Z_IMM:
1515+ return I->getMF ()->getSubtarget <AArch64Subtarget>().hasSVE2p1 () ||
1516+ I->getMF ()->getSubtarget <AArch64Subtarget>().hasSME2 ();
15121517 case AArch64::STR_ZXI:
15131518 case AArch64::STR_PXI:
15141519 case AArch64::LDR_ZXI:
@@ -2782,6 +2787,16 @@ struct RegPairInfo {
27822787
27832788} // end anonymous namespace
27842789
2790+ unsigned findFreePredicateAsCounterReg (MachineFunction &MF) {
2791+ const MachineRegisterInfo &MRI = MF.getRegInfo ();
2792+ for (MCRegister PReg :
2793+ {AArch64::PN8, AArch64::PN9, AArch64::PN10, AArch64::PN11, AArch64::PN12,
2794+ AArch64::PN13, AArch64::PN14, AArch64::PN15}) {
2795+ if (!MRI.isReserved (PReg))
2796+ return PReg;
2797+ }
2798+ llvm_unreachable (" cannot find a free predicate" );
2799+ }
27852800static void computeCalleeSaveRegisterPairs (
27862801 MachineFunction &MF, ArrayRef<CalleeSavedInfo> CSI,
27872802 const TargetRegisterInfo *TRI, SmallVectorImpl<RegPairInfo> &RegPairs,
@@ -2792,6 +2807,7 @@ static void computeCalleeSaveRegisterPairs(
27922807
27932808 bool IsWindows = isTargetWindows (MF);
27942809 bool NeedsWinCFI = needsWinCFI (MF);
2810+ const auto &Subtarget = MF.getSubtarget <AArch64Subtarget>();
27952811 AArch64FunctionInfo *AFI = MF.getInfo <AArch64FunctionInfo>();
27962812 MachineFrameInfo &MFI = MF.getFrameInfo ();
27972813 CallingConv::ID CC = MF.getFunction ().getCallingConv ();
@@ -2860,7 +2876,11 @@ static void computeCalleeSaveRegisterPairs(
28602876 RPI.Reg2 = NextReg;
28612877 break ;
28622878 case RegPairInfo::PPR:
2879+ break ;
28632880 case RegPairInfo::ZPR:
2881+ if (Subtarget.hasSVE2p1 () || Subtarget.hasSME2 ())
2882+ if (((RPI.Reg1 - AArch64::Z0) & 1 ) == 0 && (NextReg == RPI.Reg1 + 1 ))
2883+ RPI.Reg2 = NextReg;
28642884 break ;
28652885 }
28662886 }
@@ -2905,7 +2925,7 @@ static void computeCalleeSaveRegisterPairs(
29052925 assert (OffsetPre % Scale == 0 );
29062926
29072927 if (RPI.isScalable ())
2908- ScalableByteOffset += StackFillDir * Scale;
2928+ ScalableByteOffset += StackFillDir * (RPI. isPaired () ? 2 * Scale : Scale) ;
29092929 else
29102930 ByteOffset += StackFillDir * (RPI.isPaired () ? 2 * Scale : Scale);
29112931
@@ -2916,9 +2936,6 @@ static void computeCalleeSaveRegisterPairs(
29162936 (IsWindows && RPI.Reg2 == AArch64::LR)))
29172937 ByteOffset += StackFillDir * 8 ;
29182938
2919- assert (!(RPI.isScalable () && RPI.isPaired ()) &&
2920- " Paired spill/fill instructions don't exist for SVE vectors" );
2921-
29222939 // Round up size of non-pair to pair size if we need to pad the
29232940 // callee-save area to ensure 16-byte alignment.
29242941 if (NeedGapToAlignStack && !NeedsWinCFI &&
@@ -3005,6 +3022,7 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
30053022 }
30063023 return true ;
30073024 }
3025+ bool PtrueCreated = false ;
30083026 for (const RegPairInfo &RPI : llvm::reverse (RegPairs)) {
30093027 unsigned Reg1 = RPI.Reg1 ;
30103028 unsigned Reg2 = RPI.Reg2 ;
@@ -3039,10 +3057,10 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
30393057 Alignment = Align (16 );
30403058 break ;
30413059 case RegPairInfo::ZPR:
3042- StrOpc = AArch64::STR_ZXI;
3043- Size = 16 ;
3044- Alignment = Align (16 );
3045- break ;
3060+ StrOpc = RPI. isPaired () ? AArch64::ST1B_2Z_IMM : AArch64::STR_ZXI;
3061+ Size = 16 ;
3062+ Alignment = Align (16 );
3063+ break ;
30463064 case RegPairInfo::PPR:
30473065 StrOpc = AArch64::STR_PXI;
30483066 Size = 2 ;
@@ -3066,19 +3084,37 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
30663084 std::swap (Reg1, Reg2);
30673085 std::swap (FrameIdxReg1, FrameIdxReg2);
30683086 }
3087+
3088+ unsigned PnReg;
3089+ unsigned PairRegs;
3090+ if (RPI.isPaired () && RPI.isScalable ()) {
3091+ PairRegs = AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0);
3092+ if (!PtrueCreated) {
3093+ PtrueCreated = true ;
3094+ PnReg = findFreePredicateAsCounterReg (MF);
3095+ BuildMI (MBB, MI, DL, TII.get (AArch64::PTRUE_C_B), PnReg)
3096+ .setMIFlags (MachineInstr::FrameDestroy);
3097+ }
3098+ }
30693099 MachineInstrBuilder MIB = BuildMI (MBB, MI, DL, TII.get (StrOpc));
30703100 if (!MRI.isReserved (Reg1))
30713101 MBB.addLiveIn (Reg1);
30723102 if (RPI.isPaired ()) {
30733103 if (!MRI.isReserved (Reg2))
30743104 MBB.addLiveIn (Reg2);
3075- MIB.addReg (Reg2, getPrologueDeath (MF, Reg2));
3105+ if (RPI.isScalable ())
3106+ MIB.addReg (PairRegs);
3107+ else
3108+ MIB.addReg (Reg2, getPrologueDeath (MF, Reg2));
30763109 MIB.addMemOperand (MF.getMachineMemOperand (
30773110 MachinePointerInfo::getFixedStack (MF, FrameIdxReg2),
30783111 MachineMemOperand::MOStore, Size, Alignment));
30793112 }
3080- MIB.addReg (Reg1, getPrologueDeath (MF, Reg1))
3081- .addReg (AArch64::SP)
3113+ if (RPI.isPaired () && RPI.isScalable ())
3114+ MIB.addReg (PnReg);
3115+ else
3116+ MIB.addReg (Reg1, getPrologueDeath (MF, Reg1));
3117+ MIB.addReg (AArch64::SP)
30823118 .addImm (RPI.Offset ) // [sp, #offset*scale],
30833119 // where factor*scale is implicit
30843120 .setMIFlag (MachineInstr::FrameSetup);
@@ -3090,8 +3126,11 @@ bool AArch64FrameLowering::spillCalleeSavedRegisters(
30903126
30913127 // Update the StackIDs of the SVE stack slots.
30923128 MachineFrameInfo &MFI = MF.getFrameInfo ();
3093- if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR)
3094- MFI.setStackID (RPI.FrameIdx , TargetStackID::ScalableVector);
3129+ if (RPI.Type == RegPairInfo::ZPR || RPI.Type == RegPairInfo::PPR) {
3130+ MFI.setStackID (FrameIdxReg1, TargetStackID::ScalableVector);
3131+ if (RPI.isPaired ())
3132+ MFI.setStackID (FrameIdxReg2, TargetStackID::ScalableVector);
3133+ }
30953134
30963135 }
30973136 return true ;
@@ -3111,7 +3150,8 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
31113150
31123151 computeCalleeSaveRegisterPairs (MF, CSI, TRI, RegPairs, hasFP (MF));
31133152
3114- auto EmitMI = [&](const RegPairInfo &RPI) -> MachineBasicBlock::iterator {
3153+ bool PtrueCreated = false ;
3154+ auto EmitMI = [&, PtrueCreated = false ](const RegPairInfo &RPI) mutable -> MachineBasicBlock::iterator {
31153155 unsigned Reg1 = RPI.Reg1 ;
31163156 unsigned Reg2 = RPI.Reg2 ;
31173157
@@ -3143,7 +3183,7 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
31433183 Alignment = Align (16 );
31443184 break ;
31453185 case RegPairInfo::ZPR:
3146- LdrOpc = AArch64::LDR_ZXI;
3186+ LdrOpc = RPI. isPaired () ? AArch64::LD1B_2Z_IMM : AArch64::LDR_ZXI;
31473187 Size = 16 ;
31483188 Alignment = Align (16 );
31493189 break ;
@@ -3168,15 +3208,31 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters(
31683208 std::swap (Reg1, Reg2);
31693209 std::swap (FrameIdxReg1, FrameIdxReg2);
31703210 }
3211+
3212+ unsigned PnReg;
3213+ unsigned PairRegs;
3214+ if (RPI.isPaired () && RPI.isScalable ()) {
3215+ PairRegs = AArch64::Z0_Z1 + (RPI.Reg1 - AArch64::Z0);
3216+ if (!PtrueCreated) {
3217+ PtrueCreated = true ;
3218+ PnReg = findFreePredicateAsCounterReg (MF);
3219+ BuildMI (MBB, MBBI, DL, TII.get (AArch64::PTRUE_C_B), PnReg)
3220+ .setMIFlags (MachineInstr::FrameDestroy);
3221+ }
3222+ }
3223+
31713224 MachineInstrBuilder MIB = BuildMI (MBB, MBBI, DL, TII.get (LdrOpc));
31723225 if (RPI.isPaired ()) {
3173- MIB.addReg (Reg2, getDefRegState (true ));
3226+ MIB.addReg (RPI. isScalable () ? PairRegs : Reg2, getDefRegState (true ));
31743227 MIB.addMemOperand (MF.getMachineMemOperand (
31753228 MachinePointerInfo::getFixedStack (MF, FrameIdxReg2),
31763229 MachineMemOperand::MOLoad, Size, Alignment));
31773230 }
3178- MIB.addReg (Reg1, getDefRegState (true ))
3179- .addReg (AArch64::SP)
3231+ if (RPI.isPaired () && RPI.isScalable ())
3232+ MIB.addReg (PnReg);
3233+ else
3234+ MIB.addReg (Reg1, getDefRegState (true ));
3235+ MIB.addReg (AArch64::SP)
31803236 .addImm (RPI.Offset ) // [sp, #offset*scale]
31813237 // where factor*scale is implicit
31823238 .setMIFlag (MachineInstr::FrameDestroy);
0 commit comments