diff --git a/llvm/include/llvm/CodeGen/RDFLiveness.h b/llvm/include/llvm/CodeGen/RDFLiveness.h index ea48902717266..d39d3585e7bd5 100644 --- a/llvm/include/llvm/CodeGen/RDFLiveness.h +++ b/llvm/include/llvm/CodeGen/RDFLiveness.h @@ -18,6 +18,8 @@ #include "llvm/MC/LaneBitmask.h" #include #include +#include +#include #include namespace llvm { @@ -28,6 +30,30 @@ class MachineDominatorTree; class MachineRegisterInfo; class TargetRegisterInfo; +} // namespace llvm + +namespace llvm { +namespace rdf { +namespace detail { + +using NodeRef = std::pair; + +} // namespace detail +} // namespace rdf +} // namespace llvm + +namespace std { + +template <> struct hash { + std::size_t operator()(llvm::rdf::detail::NodeRef R) const { + return std::hash{}(R.first) ^ + std::hash{}(R.second.getAsInteger()); + } +}; + +} // namespace std + +namespace llvm { namespace rdf { struct Liveness { @@ -46,10 +72,9 @@ namespace rdf { std::map Map; }; - using NodeRef = std::pair; - using NodeRefSet = std::set; - // RegisterId in RefMap must be normalized. - using RefMap = std::map; + using NodeRef = detail::NodeRef; + using NodeRefSet = std::unordered_set; + using RefMap = std::unordered_map; Liveness(MachineRegisterInfo &mri, const DataFlowGraph &g) : DFG(g), TRI(g.getTRI()), PRI(g.getPRI()), MDT(g.getDT()), @@ -110,15 +135,14 @@ namespace rdf { // Cache of mapping from node ids (for RefNodes) to the containing // basic blocks. Not computing it each time for each node reduces // the liveness calculation time by a large fraction. - using NodeBlockMap = DenseMap; - NodeBlockMap NBMap; + DenseMap NBMap; // Phi information: // // RealUseMap // map: NodeId -> (map: RegisterId -> NodeRefSet) // phi id -> (map: register -> set of reached non-phi uses) - std::map RealUseMap; + DenseMap RealUseMap; // Inverse iterated dominance frontier. std::map> IIDF; diff --git a/llvm/include/llvm/CodeGen/RDFRegisters.h b/llvm/include/llvm/CodeGen/RDFRegisters.h index 4afaf80e46595..abeab62af3fa6 100644 --- a/llvm/include/llvm/CodeGen/RDFRegisters.h +++ b/llvm/include/llvm/CodeGen/RDFRegisters.h @@ -91,6 +91,11 @@ namespace rdf { bool operator< (const RegisterRef &RR) const { return Reg < RR.Reg || (Reg == RR.Reg && Mask < RR.Mask); } + + size_t hash() const { + return std::hash{}(Reg) ^ + std::hash{}(Mask.getAsInteger()); + } }; @@ -110,7 +115,11 @@ namespace rdf { return RegMasks.get(Register::stackSlot2Index(R)); } - RegisterRef normalize(RegisterRef RR) const; + LLVM_ATTRIBUTE_DEPRECATED(RegisterRef normalize(RegisterRef RR), + "This function is now an identity function"); + RegisterRef normalize(RegisterRef RR) const { + return RR; + } bool alias(RegisterRef RA, RegisterRef RB) const { if (!isRegMaskId(RA.Reg)) @@ -128,6 +137,10 @@ namespace rdf { return MaskInfos[Register::stackSlot2Index(MaskId)].Units; } + const BitVector &getUnitAliases(uint32_t U) const { + return AliasInfos[U].Regs; + } + RegisterRef mapTo(RegisterRef RR, unsigned R) const; const TargetRegisterInfo &getTRI() const { return TRI; } @@ -142,12 +155,16 @@ namespace rdf { struct MaskInfo { BitVector Units; }; + struct AliasInfo { + BitVector Regs; + }; const TargetRegisterInfo &TRI; IndexedSet RegMasks; std::vector RegInfos; std::vector UnitInfos; std::vector MaskInfos; + std::vector AliasInfos; bool aliasRR(RegisterRef RA, RegisterRef RB) const; bool aliasRM(RegisterRef RR, RegisterRef RM) const; @@ -159,10 +176,15 @@ namespace rdf { : Units(pri.getTRI().getNumRegUnits()), PRI(pri) {} RegisterAggr(const RegisterAggr &RG) = default; + unsigned count() const { return Units.count(); } bool empty() const { return Units.none(); } bool hasAliasOf(RegisterRef RR) const; bool hasCoverOf(RegisterRef RR) const; + bool operator==(const RegisterAggr &A) const { + return DenseMapInfo::isEqual(Units, A.Units); + } + static bool isCoverOf(RegisterRef RA, RegisterRef RB, const PhysicalRegisterInfo &PRI) { return RegisterAggr(PRI).insert(RA).hasCoverOf(RB); @@ -179,6 +201,10 @@ namespace rdf { RegisterRef clearIn(RegisterRef RR) const; RegisterRef makeRegRef() const; + size_t hash() const { + return DenseMapInfo::getHashValue(Units); + } + void print(raw_ostream &OS) const; struct rr_iterator { @@ -232,9 +258,26 @@ namespace rdf { LaneBitmask Mask; }; raw_ostream &operator<< (raw_ostream &OS, const PrintLaneMaskOpt &P); - } // end namespace rdf } // end namespace llvm +namespace std { + template <> struct hash { + size_t operator()(llvm::rdf::RegisterRef A) const { + return A.hash(); + } + }; + template <> struct hash { + size_t operator()(const llvm::rdf::RegisterAggr &A) const { + return A.hash(); + } + }; + template <> struct equal_to { + bool operator()(const llvm::rdf::RegisterAggr &A, + const llvm::rdf::RegisterAggr &B) const { + return A == B; + } + }; +} #endif // LLVM_LIB_TARGET_HEXAGON_RDFREGISTERS_H diff --git a/llvm/lib/CodeGen/RDFGraph.cpp b/llvm/lib/CodeGen/RDFGraph.cpp index 437a6b0300967..cebb902f0a4a0 100644 --- a/llvm/lib/CodeGen/RDFGraph.cpp +++ b/llvm/lib/CodeGen/RDFGraph.cpp @@ -984,11 +984,6 @@ RegisterRef DataFlowGraph::restrictRef(RegisterRef AR, RegisterRef BR) const { LaneBitmask M = AR.Mask & BR.Mask; return M.any() ? RegisterRef(AR.Reg, M) : RegisterRef(); } -#ifndef NDEBUG -// RegisterRef NAR = PRI.normalize(AR); -// RegisterRef NBR = PRI.normalize(BR); -// assert(NAR.Reg != NBR.Reg); -#endif // This isn't strictly correct, because the overlap may happen in the // part masked out. if (PRI.alias(AR, BR)) diff --git a/llvm/lib/CodeGen/RDFLiveness.cpp b/llvm/lib/CodeGen/RDFLiveness.cpp index 0bcd27f8ea452..b2a29bf451a2a 100644 --- a/llvm/lib/CodeGen/RDFLiveness.cpp +++ b/llvm/lib/CodeGen/RDFLiveness.cpp @@ -23,8 +23,10 @@ // <10.1145/2086696.2086706>. // #include "llvm/ADT/BitVector.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/CodeGen/MachineBasicBlock.h" #include "llvm/CodeGen/MachineDominanceFrontier.h" #include "llvm/CodeGen/MachineDominators.h" @@ -45,6 +47,7 @@ #include #include #include +#include #include #include @@ -108,7 +111,7 @@ NodeList Liveness::getAllReachingDefs(RegisterRef RefRR, const RegisterAggr &DefRRs) { NodeList RDefs; // Return value. SetVector DefQ; - SetVector Owners; + DenseMap OrdMap; // Dead defs will be treated as if they were live, since they are actually // on the data-flow path. They cannot be ignored because even though they @@ -151,18 +154,9 @@ NodeList Liveness::getAllReachingDefs(RegisterRef RefRR, for (auto S : DFG.getRelatedRefs(TA.Addr->getOwner(DFG), TA)) if (NodeId RD = NodeAddr(S).Addr->getReachingDef()) DefQ.insert(RD); - } - - // Remove all non-phi defs that are not aliased to RefRR, and collect - // the owners of the remaining defs. - SetVector Defs; - for (NodeId N : DefQ) { - auto TA = DFG.addr(N); - bool IsPhi = TA.Addr->getFlags() & NodeAttrs::PhiRef; - if (!IsPhi && !PRI.alias(RefRR, TA.Addr->getRegRef(DFG))) - continue; - Defs.insert(TA.Id); - Owners.insert(TA.Addr->getOwner(DFG).Id); + // Don't visit sibling defs. They share the same reaching def (which + // will be visited anyway), but they define something not aliased to + // this ref. } // Return the MachineBasicBlock containing a given instruction. @@ -174,38 +168,81 @@ NodeList Liveness::getAllReachingDefs(RegisterRef RefRR, NodeAddr BA = PA.Addr->getOwner(DFG); return BA.Addr->getCode(); }; - // Less(A,B) iff instruction A is further down in the dominator tree than B. - auto Less = [&Block,this] (NodeId A, NodeId B) -> bool { + + SmallSet Defs; + + // Remove all non-phi defs that are not aliased to RefRR, and segregate + // the the remaining defs into buckets for containing blocks. + std::map> Owners; + std::map> Blocks; + for (NodeId N : DefQ) { + auto TA = DFG.addr(N); + bool IsPhi = TA.Addr->getFlags() & NodeAttrs::PhiRef; + if (!IsPhi && !PRI.alias(RefRR, TA.Addr->getRegRef(DFG))) + continue; + Defs.insert(TA.Id); + NodeAddr IA = TA.Addr->getOwner(DFG); + Owners[TA.Id] = IA; + Blocks[Block(IA)].push_back(IA.Id); + } + + auto Precedes = [this,&OrdMap] (NodeId A, NodeId B) { if (A == B) return false; - auto OA = DFG.addr(A), OB = DFG.addr(B); - MachineBasicBlock *BA = Block(OA), *BB = Block(OB); - if (BA != BB) - return MDT.dominates(BB, BA); - // They are in the same block. + NodeAddr OA = DFG.addr(A); + NodeAddr OB = DFG.addr(B); bool StmtA = OA.Addr->getKind() == NodeAttrs::Stmt; bool StmtB = OB.Addr->getKind() == NodeAttrs::Stmt; - if (StmtA) { - if (!StmtB) // OB is a phi and phis dominate statements. - return true; - MachineInstr *CA = NodeAddr(OA).Addr->getCode(); - MachineInstr *CB = NodeAddr(OB).Addr->getCode(); - // The order must be linear, so tie-break such equalities. - if (CA == CB) - return A < B; - return MDT.dominates(CB, CA); - } else { - // OA is a phi. - if (StmtB) - return false; - // Both are phis. There is no ordering between phis (in terms of - // the data-flow), so tie-break this via node id comparison. + if (StmtA && StmtB) { + const MachineInstr *InA = NodeAddr(OA).Addr->getCode(); + const MachineInstr *InB = NodeAddr(OB).Addr->getCode(); + assert(InA->getParent() == InB->getParent()); + auto FA = OrdMap.find(InA); + if (FA != OrdMap.end()) + return FA->second < OrdMap.find(InB)->second; + const MachineBasicBlock *BB = InA->getParent(); + for (auto It = BB->begin(), E = BB->end(); It != E; ++It) { + if (It == InA->getIterator()) + return true; + if (It == InB->getIterator()) + return false; + } + llvm_unreachable("InA and InB should be in the same block"); + } + // One of them is a phi node. + if (!StmtA && !StmtB) { + // Both are phis, which are unordered. Break the tie by id numbers. return A < B; } + // Only one of them is a phi. Phis always precede statements. + return !StmtA; }; - std::vector Tmp(Owners.begin(), Owners.end()); - llvm::sort(Tmp, Less); + auto GetOrder = [&OrdMap] (MachineBasicBlock &B) { + uint32_t Pos = 0; + for (MachineInstr &In : B) + OrdMap.insert({&In, ++Pos}); + }; + + // For each block, sort the nodes in it. + std::vector TmpBB; + for (auto &Bucket : Blocks) { + TmpBB.push_back(Bucket.first); + if (Bucket.second.size() > 2) + GetOrder(*Bucket.first); + std::sort(Bucket.second.begin(), Bucket.second.end(), Precedes); + } + + // Sort the blocks with respect to dominance. + std::sort(TmpBB.begin(), TmpBB.end(), [this](auto A, auto B) { + return MDT.dominates(A, B); + }); + + std::vector TmpInst; + for (auto I = TmpBB.rbegin(), E = TmpBB.rend(); I != E; ++I) { + auto &Bucket = Blocks[*I]; + TmpInst.insert(TmpInst.end(), Bucket.rbegin(), Bucket.rend()); + } // The vector is a list of instructions, so that defs coming from // the same instruction don't need to be artificially ordered. @@ -220,6 +257,9 @@ NodeList Liveness::getAllReachingDefs(RegisterRef RefRR, // *d3 If A \incl BuC, and B \incl AuC, then *d2 would be // covered if we added A first, and A would be covered // if we added B first. + // In this example we want both A and B, because we don't want to give + // either one priority over the other, since they belong to the same + // statement. RegisterAggr RRs(DefRRs); @@ -227,7 +267,8 @@ NodeList Liveness::getAllReachingDefs(RegisterRef RefRR, return TA.Addr->getKind() == NodeAttrs::Def && Defs.count(TA.Id); }; - for (NodeId T : Tmp) { + + for (NodeId T : TmpInst) { if (!FullChain && RRs.hasCoverOf(RefRR)) break; auto TA = DFG.addr(T); @@ -436,7 +477,7 @@ void Liveness::computePhiInfo() { // phi use -> (map: reaching phi -> set of registers defined in between) std::map> PhiUp; std::vector PhiUQ; // Work list of phis for upward propagation. - std::map PhiDRs; // Phi -> registers defined by it. + std::unordered_map PhiDRs; // Phi -> registers defined by it. // Go over all phis. for (NodeAddr PhiA : Phis) { @@ -474,7 +515,7 @@ void Liveness::computePhiInfo() { NodeAddr A = DFG.addr(UN); uint16_t F = A.Addr->getFlags(); if ((F & (NodeAttrs::Undef | NodeAttrs::PhiRef)) == 0) { - RegisterRef R = PRI.normalize(A.Addr->getRegRef(DFG)); + RegisterRef R = A.Addr->getRegRef(DFG); RealUses[R.Reg].insert({A.Id,R.Mask}); } UN = A.Addr->getSibling(); @@ -612,6 +653,23 @@ void Liveness::computePhiInfo() { // is covered, or until reaching the final phi. Only assume that the // reference reaches the phi in the latter case. + // The operation "clearIn" can be expensive. For a given set of intervening + // defs, cache the result of subtracting these defs from a given register + // ref. + using SubMap = std::unordered_map; + std::unordered_map Subs; + auto ClearIn = [] (RegisterRef RR, const RegisterAggr &Mid, SubMap &SM) { + if (Mid.empty()) + return RR; + auto F = SM.find(RR); + if (F != SM.end()) + return F->second; + RegisterRef S = Mid.clearIn(RR); + SM.insert({RR, S}); + return S; + }; + + // Go over all phis. for (unsigned i = 0; i < PhiUQ.size(); ++i) { auto PA = DFG.addr(PhiUQ[i]); NodeList PUs = PA.Addr->members_if(DFG.IsRef, DFG); @@ -619,17 +677,17 @@ void Liveness::computePhiInfo() { for (NodeAddr UA : PUs) { std::map &PUM = PhiUp[UA.Id]; - RegisterRef UR = PRI.normalize(UA.Addr->getRegRef(DFG)); + RegisterRef UR = UA.Addr->getRegRef(DFG); for (const std::pair &P : PUM) { bool Changed = false; const RegisterAggr &MidDefs = P.second; - // Collect the set PropUp of uses that are reached by the current // phi PA, and are not covered by any intervening def between the // currently visited use UA and the upward phi P. if (MidDefs.hasCoverOf(UR)) continue; + SubMap &SM = Subs[MidDefs]; // General algorithm: // for each (R,U) : U is use node of R, U is reached by PA @@ -649,7 +707,7 @@ void Liveness::computePhiInfo() { LaneBitmask M = R.Mask & V.second; if (M.none()) continue; - if (RegisterRef SS = MidDefs.clearIn(RegisterRef(R.Reg, M))) { + if (RegisterRef SS = ClearIn(RegisterRef(R.Reg, M), MidDefs, SM)) { NodeRefSet &RS = RealUseMap[P.first][SS.Reg]; Changed |= RS.insert({V.first,SS.Mask}).second; } @@ -1073,7 +1131,7 @@ void Liveness::traverse(MachineBasicBlock *B, RefMap &LiveIn) { for (NodeAddr UA : IA.Addr->members_if(DFG.IsUse, DFG)) { if (UA.Addr->getFlags() & NodeAttrs::Undef) continue; - RegisterRef RR = PRI.normalize(UA.Addr->getRegRef(DFG)); + RegisterRef RR = UA.Addr->getRegRef(DFG); for (NodeAddr D : getAllReachingDefs(UA)) if (getBlockWithRef(D.Id) != B) LiveIn[RR.Reg].insert({D.Id,RR.Mask}); diff --git a/llvm/lib/CodeGen/RDFRegisters.cpp b/llvm/lib/CodeGen/RDFRegisters.cpp index bd8661816e718..c76447d95444a 100644 --- a/llvm/lib/CodeGen/RDFRegisters.cpp +++ b/llvm/lib/CodeGen/RDFRegisters.cpp @@ -92,10 +92,15 @@ PhysicalRegisterInfo::PhysicalRegisterInfo(const TargetRegisterInfo &tri, } MaskInfos[M].Units = PU.flip(); } -} -RegisterRef PhysicalRegisterInfo::normalize(RegisterRef RR) const { - return RR; + AliasInfos.resize(TRI.getNumRegUnits()); + for (uint32_t U = 0, NU = TRI.getNumRegUnits(); U != NU; ++U) { + BitVector AS(TRI.getNumRegs()); + for (MCRegUnitRootIterator R(U, &TRI); R.isValid(); ++R) + for (MCSuperRegIterator S(*R, &TRI, true); S.isValid(); ++S) + AS.set(*S); + AliasInfos[U].Regs = AS; + } } std::set PhysicalRegisterInfo::getAliasSet(RegisterId Reg) const { @@ -321,26 +326,17 @@ RegisterRef RegisterAggr::makeRegRef() const { if (U < 0) return RegisterRef(); - auto AliasedRegs = [this] (uint32_t Unit, BitVector &Regs) { - for (MCRegUnitRootIterator R(Unit, &PRI.getTRI()); R.isValid(); ++R) - for (MCSuperRegIterator S(*R, &PRI.getTRI(), true); S.isValid(); ++S) - Regs.set(*S); - }; - // Find the set of all registers that are aliased to all the units // in this aggregate. // Get all the registers aliased to the first unit in the bit vector. - BitVector Regs(PRI.getTRI().getNumRegs()); - AliasedRegs(U, Regs); + BitVector Regs = PRI.getUnitAliases(U); U = Units.find_next(U); // For each other unit, intersect it with the set of all registers // aliased that unit. while (U >= 0) { - BitVector AR(PRI.getTRI().getNumRegs()); - AliasedRegs(U, AR); - Regs &= AR; + Regs &= PRI.getUnitAliases(U); U = Units.find_next(U); } diff --git a/llvm/lib/Target/Hexagon/HexagonOptAddrMode.cpp b/llvm/lib/Target/Hexagon/HexagonOptAddrMode.cpp index c718e5f2d9fbe..2cdfbe7845b63 100644 --- a/llvm/lib/Target/Hexagon/HexagonOptAddrMode.cpp +++ b/llvm/lib/Target/Hexagon/HexagonOptAddrMode.cpp @@ -246,7 +246,7 @@ void HexagonOptAddrMode::getAllRealUses(NodeAddr SA, for (NodeAddr DA : SA.Addr->members_if(DFG->IsDef, *DFG)) { LLVM_DEBUG(dbgs() << "\t\t[DefNode]: " << Print>(DA, *DFG) << "\n"); - RegisterRef DR = DFG->getPRI().normalize(DA.Addr->getRegRef(*DFG)); + RegisterRef DR = DA.Addr->getRegRef(*DFG); auto UseSet = LV->getAllReachedUses(DR, DA); diff --git a/llvm/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp b/llvm/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp index 50f8b3477acce..12aaabcc79645 100644 --- a/llvm/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp +++ b/llvm/lib/Target/X86/X86LoadValueInjectionLoadHardening.cpp @@ -42,6 +42,7 @@ #include "X86TargetMachine.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" @@ -104,9 +105,9 @@ static cl::opt EmitDotVerify( cl::init(false), cl::Hidden); static llvm::sys::DynamicLibrary OptimizeDL; -typedef int (*OptimizeCutT)(unsigned int *nodes, unsigned int nodes_size, - unsigned int *edges, int *edge_values, - int *cut_edges /* out */, unsigned int edges_size); +typedef int (*OptimizeCutT)(unsigned int *Nodes, unsigned int NodesSize, + unsigned int *Edges, int *EdgeValues, + int *CutEdges /* out */, unsigned int EdgesSize); static OptimizeCutT OptimizeCut = nullptr; namespace { @@ -148,9 +149,10 @@ class X86LoadValueInjectionLoadHardeningPass : public MachineFunctionPass { private: using GraphBuilder = ImmutableGraphBuilder; + using Edge = MachineGadgetGraph::Edge; + using Node = MachineGadgetGraph::Node; using EdgeSet = MachineGadgetGraph::EdgeSet; using NodeSet = MachineGadgetGraph::NodeSet; - using Gadget = std::pair; const X86Subtarget *STI; const TargetInstrInfo *TII; @@ -162,8 +164,8 @@ class X86LoadValueInjectionLoadHardeningPass : public MachineFunctionPass { const MachineDominanceFrontier &MDF) const; int hardenLoadsWithPlugin(MachineFunction &MF, std::unique_ptr Graph) const; - int hardenLoadsWithGreedyHeuristic( - MachineFunction &MF, std::unique_ptr Graph) const; + int hardenLoadsWithHeuristic(MachineFunction &MF, + std::unique_ptr Graph) const; int elimMitigatedEdgesAndNodes(MachineGadgetGraph &G, EdgeSet &ElimEdges /* in, out */, NodeSet &ElimNodes /* in, out */) const; @@ -198,7 +200,7 @@ struct DOTGraphTraits : DefaultDOTGraphTraits { using ChildIteratorType = typename Traits::ChildIteratorType; using ChildEdgeIteratorType = typename Traits::ChildEdgeIteratorType; - DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {} + DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {} std::string getNodeLabel(NodeRef Node, GraphType *) { if (Node->getValue() == MachineGadgetGraph::ArgNodeSentinel) @@ -243,7 +245,7 @@ void X86LoadValueInjectionLoadHardeningPass::getAnalysisUsage( AU.setPreservesCFG(); } -static void WriteGadgetGraph(raw_ostream &OS, MachineFunction &MF, +static void writeGadgetGraph(raw_ostream &OS, MachineFunction &MF, MachineGadgetGraph *G) { WriteGraph(OS, G, /*ShortNames*/ false, "Speculative gadgets for \"" + MF.getName() + "\" function"); @@ -279,7 +281,7 @@ bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction( return false; // didn't find any gadgets if (EmitDotVerify) { - WriteGadgetGraph(outs(), MF, Graph.get()); + writeGadgetGraph(outs(), MF, Graph.get()); return false; } @@ -292,7 +294,7 @@ bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction( raw_fd_ostream FileOut(FileName, FileError); if (FileError) errs() << FileError.message(); - WriteGadgetGraph(FileOut, MF, Graph.get()); + writeGadgetGraph(FileOut, MF, Graph.get()); FileOut.close(); LLVM_DEBUG(dbgs() << "Emitting gadget graph... Done\n"); if (EmitDotOnly) @@ -313,7 +315,7 @@ bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction( } FencesInserted = hardenLoadsWithPlugin(MF, std::move(Graph)); } else { // Use the default greedy heuristic - FencesInserted = hardenLoadsWithGreedyHeuristic(MF, std::move(Graph)); + FencesInserted = hardenLoadsWithHeuristic(MF, std::move(Graph)); } if (FencesInserted > 0) @@ -367,7 +369,7 @@ X86LoadValueInjectionLoadHardeningPass::getGadgetGraph( // Use RDF to find all the uses of `Def` rdf::NodeSet Uses; - RegisterRef DefReg = DFG.getPRI().normalize(Def.Addr->getRegRef(DFG)); + RegisterRef DefReg = Def.Addr->getRegRef(DFG); for (auto UseID : L.getAllReachedUses(DefReg, Def)) { auto Use = DFG.addr(UseID); if (Use.Addr->getFlags() & NodeAttrs::PhiRef) { // phi node @@ -540,17 +542,17 @@ X86LoadValueInjectionLoadHardeningPass::getGadgetGraph( // Returns the number of remaining gadget edges that could not be eliminated int X86LoadValueInjectionLoadHardeningPass::elimMitigatedEdgesAndNodes( - MachineGadgetGraph &G, MachineGadgetGraph::EdgeSet &ElimEdges /* in, out */, - MachineGadgetGraph::NodeSet &ElimNodes /* in, out */) const { + MachineGadgetGraph &G, EdgeSet &ElimEdges /* in, out */, + NodeSet &ElimNodes /* in, out */) const { if (G.NumFences > 0) { // Eliminate fences and CFG edges that ingress and egress the fence, as // they are trivially mitigated. - for (const auto &E : G.edges()) { - const MachineGadgetGraph::Node *Dest = E.getDest(); + for (const Edge &E : G.edges()) { + const Node *Dest = E.getDest(); if (isFence(Dest->getValue())) { ElimNodes.insert(*Dest); ElimEdges.insert(E); - for (const auto &DE : Dest->edges()) + for (const Edge &DE : Dest->edges()) ElimEdges.insert(DE); } } @@ -558,29 +560,28 @@ int X86LoadValueInjectionLoadHardeningPass::elimMitigatedEdgesAndNodes( // Find and eliminate gadget edges that have been mitigated. int MitigatedGadgets = 0, RemainingGadgets = 0; - MachineGadgetGraph::NodeSet ReachableNodes{G}; - for (const auto &RootN : G.nodes()) { + NodeSet ReachableNodes{G}; + for (const Node &RootN : G.nodes()) { if (llvm::none_of(RootN.edges(), MachineGadgetGraph::isGadgetEdge)) continue; // skip this node if it isn't a gadget source // Find all of the nodes that are CFG-reachable from RootN using DFS ReachableNodes.clear(); - std::function - FindReachableNodes = - [&](const MachineGadgetGraph::Node *N, bool FirstNode) { - if (!FirstNode) - ReachableNodes.insert(*N); - for (const auto &E : N->edges()) { - const MachineGadgetGraph::Node *Dest = E.getDest(); - if (MachineGadgetGraph::isCFGEdge(E) && - !ElimEdges.contains(E) && !ReachableNodes.contains(*Dest)) - FindReachableNodes(Dest, false); - } - }; + std::function FindReachableNodes = + [&](const Node *N, bool FirstNode) { + if (!FirstNode) + ReachableNodes.insert(*N); + for (const Edge &E : N->edges()) { + const Node *Dest = E.getDest(); + if (MachineGadgetGraph::isCFGEdge(E) && !ElimEdges.contains(E) && + !ReachableNodes.contains(*Dest)) + FindReachableNodes(Dest, false); + } + }; FindReachableNodes(&RootN, true); // Any gadget whose sink is unreachable has been mitigated - for (const auto &E : RootN.edges()) { + for (const Edge &E : RootN.edges()) { if (MachineGadgetGraph::isGadgetEdge(E)) { if (ReachableNodes.contains(*E.getDest())) { // This gadget's sink is reachable @@ -598,8 +599,8 @@ int X86LoadValueInjectionLoadHardeningPass::elimMitigatedEdgesAndNodes( std::unique_ptr X86LoadValueInjectionLoadHardeningPass::trimMitigatedEdges( std::unique_ptr Graph) const { - MachineGadgetGraph::NodeSet ElimNodes{*Graph}; - MachineGadgetGraph::EdgeSet ElimEdges{*Graph}; + NodeSet ElimNodes{*Graph}; + EdgeSet ElimEdges{*Graph}; int RemainingGadgets = elimMitigatedEdgesAndNodes(*Graph, ElimEdges, ElimNodes); if (ElimEdges.empty() && ElimNodes.empty()) { @@ -630,11 +631,11 @@ int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithPlugin( auto Edges = std::make_unique(Graph->edges_size()); auto EdgeCuts = std::make_unique(Graph->edges_size()); auto EdgeValues = std::make_unique(Graph->edges_size()); - for (const auto &N : Graph->nodes()) { + for (const Node &N : Graph->nodes()) { Nodes[Graph->getNodeIndex(N)] = Graph->getEdgeIndex(*N.edges_begin()); } Nodes[Graph->nodes_size()] = Graph->edges_size(); // terminator node - for (const auto &E : Graph->edges()) { + for (const Edge &E : Graph->edges()) { Edges[Graph->getEdgeIndex(E)] = Graph->getNodeIndex(*E.getDest()); EdgeValues[Graph->getEdgeIndex(E)] = E.getValue(); } @@ -651,74 +652,67 @@ int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithPlugin( LLVM_DEBUG(dbgs() << "Inserting LFENCEs... Done\n"); LLVM_DEBUG(dbgs() << "Inserted " << FencesInserted << " fences\n"); - Graph = GraphBuilder::trim(*Graph, MachineGadgetGraph::NodeSet{*Graph}, - CutEdges); + Graph = GraphBuilder::trim(*Graph, NodeSet{*Graph}, CutEdges); } while (true); return FencesInserted; } -int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithGreedyHeuristic( +int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithHeuristic( MachineFunction &MF, std::unique_ptr Graph) const { - LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n"); - Graph = trimMitigatedEdges(std::move(Graph)); - LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n"); + // If `MF` does not have any fences, then no gadgets would have been + // mitigated at this point. + if (Graph->NumFences > 0) { + LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n"); + Graph = trimMitigatedEdges(std::move(Graph)); + LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n"); + } + if (Graph->NumGadgets == 0) return 0; LLVM_DEBUG(dbgs() << "Cutting edges...\n"); - MachineGadgetGraph::NodeSet ElimNodes{*Graph}, GadgetSinks{*Graph}; - MachineGadgetGraph::EdgeSet ElimEdges{*Graph}, CutEdges{*Graph}; - auto IsCFGEdge = [&ElimEdges, &CutEdges](const MachineGadgetGraph::Edge &E) { - return !ElimEdges.contains(E) && !CutEdges.contains(E) && - MachineGadgetGraph::isCFGEdge(E); - }; - auto IsGadgetEdge = [&ElimEdges, - &CutEdges](const MachineGadgetGraph::Edge &E) { - return !ElimEdges.contains(E) && !CutEdges.contains(E) && - MachineGadgetGraph::isGadgetEdge(E); - }; - - // FIXME: this is O(E^2), we could probably do better. - do { - // Find the cheapest CFG edge that will eliminate a gadget (by being - // egress from a SOURCE node or ingress to a SINK node), and cut it. - const MachineGadgetGraph::Edge *CheapestSoFar = nullptr; - - // First, collect all gadget source and sink nodes. - MachineGadgetGraph::NodeSet GadgetSources{*Graph}, GadgetSinks{*Graph}; - for (const auto &N : Graph->nodes()) { - if (ElimNodes.contains(N)) + EdgeSet CutEdges{*Graph}; + + // Begin by collecting all ingress CFG edges for each node + DenseMap> IngressEdgeMap; + for (const Edge &E : Graph->edges()) + if (MachineGadgetGraph::isCFGEdge(E)) + IngressEdgeMap[E.getDest()].push_back(&E); + + // For each gadget edge, make cuts that guarantee the gadget will be + // mitigated. A computationally efficient way to achieve this is to either: + // (a) cut all egress CFG edges from the gadget source, or + // (b) cut all ingress CFG edges to the gadget sink. + // + // Moreover, the algorithm tries not to make a cut into a loop by preferring + // to make a (b)-type cut if the gadget source resides at a greater loop depth + // than the gadget sink, or an (a)-type cut otherwise. + for (const Node &N : Graph->nodes()) { + for (const Edge &E : N.edges()) { + if (!MachineGadgetGraph::isGadgetEdge(E)) continue; - for (const auto &E : N.edges()) { - if (IsGadgetEdge(E)) { - GadgetSources.insert(N); - GadgetSinks.insert(*E.getDest()); - } - } - } - // Next, look for the cheapest CFG edge which, when cut, is guaranteed to - // mitigate at least one gadget by either: - // (a) being egress from a gadget source, or - // (b) being ingress to a gadget sink. - for (const auto &N : Graph->nodes()) { - if (ElimNodes.contains(N)) - continue; - for (const auto &E : N.edges()) { - if (IsCFGEdge(E)) { - if (GadgetSources.contains(N) || GadgetSinks.contains(*E.getDest())) { - if (!CheapestSoFar || E.getValue() < CheapestSoFar->getValue()) - CheapestSoFar = &E; - } - } - } + SmallVector EgressEdges; + SmallVector &IngressEdges = IngressEdgeMap[E.getDest()]; + for (const Edge &EgressEdge : N.edges()) + if (MachineGadgetGraph::isCFGEdge(EgressEdge)) + EgressEdges.push_back(&EgressEdge); + + int EgressCutCost = 0, IngressCutCost = 0; + for (const Edge *EgressEdge : EgressEdges) + if (!CutEdges.contains(*EgressEdge)) + EgressCutCost += EgressEdge->getValue(); + for (const Edge *IngressEdge : IngressEdges) + if (!CutEdges.contains(*IngressEdge)) + IngressCutCost += IngressEdge->getValue(); + + auto &EdgesToCut = + IngressCutCost < EgressCutCost ? IngressEdges : EgressEdges; + for (const Edge *E : EdgesToCut) + CutEdges.insert(*E); } - - assert(CheapestSoFar && "Failed to cut an edge"); - CutEdges.insert(*CheapestSoFar); - ElimEdges.insert(*CheapestSoFar); - } while (elimMitigatedEdgesAndNodes(*Graph, ElimEdges, ElimNodes)); + } LLVM_DEBUG(dbgs() << "Cutting edges... Done\n"); LLVM_DEBUG(dbgs() << "Cut " << CutEdges.count() << " edges\n"); @@ -734,8 +728,8 @@ int X86LoadValueInjectionLoadHardeningPass::insertFences( MachineFunction &MF, MachineGadgetGraph &G, EdgeSet &CutEdges /* in, out */) const { int FencesInserted = 0; - for (const auto &N : G.nodes()) { - for (const auto &E : N.edges()) { + for (const Node &N : G.nodes()) { + for (const Edge &E : N.edges()) { if (CutEdges.contains(E)) { MachineInstr *MI = N.getValue(), *Prev; MachineBasicBlock *MBB; // Insert an LFENCE in this MBB @@ -751,7 +745,7 @@ int X86LoadValueInjectionLoadHardeningPass::insertFences( Prev = MI->getPrevNode(); // Remove all egress CFG edges from this branch because the inserted // LFENCE prevents gadgets from crossing the branch. - for (const auto &E : N.edges()) { + for (const Edge &E : N.edges()) { if (MachineGadgetGraph::isCFGEdge(E)) CutEdges.insert(E); }