@@ -47,10 +47,13 @@ class RISCVFoldMasks : public MachineFunctionPass {
4747 StringRef getPassName () const override { return " RISC-V Fold Masks" ; }
4848
4949private:
50- bool convertToUnmasked (MachineInstr &MI, MachineInstr *MaskDef ) const ;
51- bool convertVMergeToVMv (MachineInstr &MI, MachineInstr *MaskDef ) const ;
50+ bool convertToUnmasked (MachineInstr &MI) const ;
51+ bool convertVMergeToVMv (MachineInstr &MI) const ;
5252
53- bool isAllOnesMask (MachineInstr *MaskDef) const ;
53+ bool isAllOnesMask (const MachineInstr *MaskDef) const ;
54+
55+ // / Maps uses of V0 to the corresponding def of V0.
56+ DenseMap<const MachineInstr *, const MachineInstr *> V0Defs;
5457};
5558
5659} // namespace
@@ -59,10 +62,9 @@ char RISCVFoldMasks::ID = 0;
5962
6063INITIALIZE_PASS (RISCVFoldMasks, DEBUG_TYPE, " RISC-V Fold Masks" , false , false )
6164
62- bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) const {
63- if (!MaskDef)
64- return false ;
65- assert (MaskDef->isCopy () && MaskDef->getOperand (0 ).getReg () == RISCV::V0);
65+ bool RISCVFoldMasks::isAllOnesMask(const MachineInstr *MaskDef) const {
66+ assert (MaskDef && MaskDef->isCopy () &&
67+ MaskDef->getOperand (0 ).getReg () == RISCV::V0);
6668 Register SrcReg = TRI->lookThruCopyLike (MaskDef->getOperand (1 ).getReg (), MRI);
6769 if (!SrcReg.isVirtual ())
6870 return false ;
@@ -89,8 +91,7 @@ bool RISCVFoldMasks::isAllOnesMask(MachineInstr *MaskDef) const {
8991
9092// Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
9193// (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
92- bool RISCVFoldMasks::convertVMergeToVMv (MachineInstr &MI,
93- MachineInstr *V0Def) const {
94+ bool RISCVFoldMasks::convertVMergeToVMv (MachineInstr &MI) const {
9495#define CASE_VMERGE_TO_VMV (lmul ) \
9596 case RISCV::PseudoVMERGE_VVM_##lmul: \
9697 NewOpc = RISCV::PseudoVMV_V_V_##lmul; \
@@ -116,7 +117,7 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI,
116117 return false ;
117118
118119 assert (MI.getOperand (4 ).isReg () && MI.getOperand (4 ).getReg () == RISCV::V0);
119- if (!isAllOnesMask (V0Def ))
120+ if (!isAllOnesMask (V0Defs. lookup (&MI) ))
120121 return false ;
121122
122123 MI.setDesc (TII->get (NewOpc));
@@ -133,14 +134,13 @@ bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI,
133134 return true ;
134135}
135136
136- bool RISCVFoldMasks::convertToUnmasked (MachineInstr &MI,
137- MachineInstr *MaskDef) const {
137+ bool RISCVFoldMasks::convertToUnmasked (MachineInstr &MI) const {
138138 const RISCV::RISCVMaskedPseudoInfo *I =
139139 RISCV::getMaskedPseudoInfo (MI.getOpcode ());
140140 if (!I)
141141 return false ;
142142
143- if (!isAllOnesMask (MaskDef ))
143+ if (!isAllOnesMask (V0Defs. lookup (&MI) ))
144144 return false ;
145145
146146 // There are two classes of pseudos in the table - compares and
@@ -198,20 +198,29 @@ bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) {
198198 // $v0:vr = COPY %mask:vr
199199 // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
200200 //
201- // Because $v0 isn't in SSA, keep track of it so we can check the mask operand
202- // on each pseudo.
203- MachineInstr *CurrentV0Def;
204- for (MachineBasicBlock &MBB : MF) {
205- CurrentV0Def = nullptr ;
206- for (MachineInstr &MI : MBB) {
207- Changed |= convertToUnmasked (MI, CurrentV0Def);
208- Changed |= convertVMergeToVMv (MI, CurrentV0Def);
201+ // Because $v0 isn't in SSA, keep track of its definition at each use so we
202+ // can check mask operands.
203+ for (const MachineBasicBlock &MBB : MF) {
204+ const MachineInstr *CurrentV0Def = nullptr ;
205+ for (const MachineInstr &MI : MBB) {
206+ auto IsV0 = [](const auto &MO) {
207+ return MO.isReg () && MO.getReg () == RISCV::V0;
208+ };
209+ if (any_of (MI.uses (), IsV0))
210+ V0Defs[&MI] = CurrentV0Def;
209211
210212 if (MI.definesRegister (RISCV::V0, TRI))
211213 CurrentV0Def = &MI;
212214 }
213215 }
214216
217+ for (MachineBasicBlock &MBB : MF) {
218+ for (MachineInstr &MI : MBB) {
219+ Changed |= convertToUnmasked (MI);
220+ Changed |= convertVMergeToVMv (MI);
221+ }
222+ }
223+
215224 return Changed;
216225}
217226
0 commit comments