@@ -267,38 +267,77 @@ Cost InstCostVisitor::estimateBranchInst(BranchInst &I) {
267267 return estimateBasicBlocks (WorkList);
268268}
269269
270- void InstCostVisitor::discoverStronglyConnectedComponent (PHINode *PN,
271- unsigned Depth) {
272- if (Depth > MaxDiscoveryDepth)
273- return ;
270+ // This function is finding candidates for a PHINode is part of a chain or graph
271+ // of PHINodes that all link to each other. That means, if the original input to
272+ // the chain is a constant all the other values are also that constant.
273+ //
274+ // The caller of this function will later check that no other nodes are involved
275+ // that are non-constant, and discard it from the possible conversions.
276+ //
277+ // For example:
278+ //
279+ // %a = load %0
280+ // %c = phi [%a, %d]
281+ // %d = phi [%e, %c]
282+ // %e = phi [%c, %f]
283+ // %f = phi [%j, %h]
284+ // %j = phi [%h, %j]
285+ // %h = phi [%g, %c]
286+ //
287+ // This is only showing the PHINodes, not the branches that choose the
288+ // different paths.
289+ //
290+ // A depth limit is used to avoid extreme recurusion.
291+ // A max number of incoming phi values ensures that expensive searches
292+ // are avoided.
293+ //
294+ // Returns false if the discovery was aborted due to the above conditions.
295+ bool InstCostVisitor::discoverTransitivelyIncomngValues (
296+ DenseSet<PHINode *> &PHINodes, PHINode *PN, unsigned Depth) {
297+ if (Depth > MaxDiscoveryDepth) {
298+ LLVM_DEBUG (dbgs () << " FnSpecialization: Discover PHI nodes too deep ("
299+ << Depth << " >" << MaxDiscoveryDepth << " )\n " );
300+ return false ;
301+ }
274302
275- if (PN->getNumIncomingValues () > MaxIncomingPhiValues)
276- return ;
303+ if (PN->getNumIncomingValues () > MaxIncomingPhiValues) {
304+ LLVM_DEBUG (
305+ dbgs () << " FnSpecialization: Discover PHI nodes has too many values ("
306+ << PN->getNumIncomingValues () << " >" << MaxIncomingPhiValues
307+ << " )\n " );
308+ return false ;
309+ }
277310
278- if (!StronglyConnectedPHIs.insert (PN).second )
279- return ;
311+ // Already seen this, no more processing needed.
312+ if (!PHINodes.insert (PN).second )
313+ return true ;
280314
281315 for (unsigned I = 0 , E = PN->getNumIncomingValues (); I != E; ++I) {
282316 Value *V = PN->getIncomingValue (I);
283317 if (auto *Phi = dyn_cast<PHINode>(V)) {
284318 if (Phi == PN || DeadBlocks.contains (PN->getIncomingBlock (I)))
285319 continue ;
286- discoverStronglyConnectedComponent (Phi, Depth + 1 );
320+ if (!discoverTransitivelyIncomngValues (PHINodes, Phi, Depth + 1 ))
321+ return false ;
287322 }
288323 }
324+ return true ;
289325}
290326
291327Constant *InstCostVisitor::visitPHINode (PHINode &I) {
292328 if (I.getNumIncomingValues () > MaxIncomingPhiValues)
293329 return nullptr ;
294330
331+ // PHI nodes
332+ DenseSet<PHINode *> TransitivePHIs;
333+
295334 bool Inserted = VisitedPHIs.insert (&I).second ;
296- Constant *Const = nullptr ;
297335 SmallVector<PHINode *, 8 > UnknownIncomingValues;
298336
299- auto CanConstantFoldPhi = [&](PHINode *PN) -> bool {
300- UnknownIncomingValues. clear () ;
337+ auto canConstantFoldPhiTrivially = [&](PHINode *PN) -> Constant * {
338+ Constant *Const = nullptr ;
301339
340+ UnknownIncomingValues.clear ();
302341 for (unsigned I = 0 , E = PN->getNumIncomingValues (); I != E; ++I) {
303342 Value *V = PN->getIncomingValue (I);
304343
@@ -311,21 +350,22 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
311350 if (!Const)
312351 Const = C;
313352 // Not all incoming values are the same constant. Bail immediately.
314- else if (C != Const)
315- return false ;
316- } else if (auto *Phi = dyn_cast<PHINode>(V)) {
317- // It's not a strongly connected phi. Collect it and bail at the end.
318- if (!StronglyConnectedPHIs.contains (Phi))
319- UnknownIncomingValues.push_back (Phi);
320- } else {
321- // We can't reason about anything else.
322- return false ;
353+ if (C != Const)
354+ return nullptr ;
355+ continue ;
323356 }
357+ if (auto *Phi = dyn_cast<PHINode>(V)) {
358+ UnknownIncomingValues.push_back (Phi);
359+ continue ;
360+ }
361+
362+ // We can't reason about anything else.
363+ return nullptr ;
324364 }
325- return UnknownIncomingValues.empty ();
365+ return UnknownIncomingValues.empty () ? Const : nullptr ;
326366 };
327367
328- if (CanConstantFoldPhi (&I))
368+ if (Constant *Const = canConstantFoldPhiTrivially (&I))
329369 return Const;
330370
331371 if (Inserted) {
@@ -335,18 +375,59 @@ Constant *InstCostVisitor::visitPHINode(PHINode &I) {
335375 return nullptr ;
336376 }
337377
378+ // Try to see if we can collect a nest of transitive phis. Bail if
379+ // it's too complex.
338380 for (PHINode *Phi : UnknownIncomingValues)
339- discoverStronglyConnectedComponent (Phi, 1 );
381+ if (!discoverTransitivelyIncomngValues (TransitivePHIs, Phi, 1 ))
382+ return nullptr ;
383+
384+ // A nested set of PHINodes can be constantfolded if:
385+ // - It has a constant input.
386+ // - It is always the SAME constant.
387+ auto canConstantFoldNestedPhi = [&](PHINode *PN) -> Constant * {
388+ Constant *Const = nullptr ;
340389
341- bool CannotConstantFoldPhi = false ;
342- for (PHINode *Phi : StronglyConnectedPHIs) {
343- if (!CanConstantFoldPhi (Phi)) {
344- CannotConstantFoldPhi = true ;
345- break ;
390+ for (unsigned I = 0 , E = PN->getNumIncomingValues (); I != E; ++I) {
391+ Value *V = PN->getIncomingValue (I);
392+
393+ // Disregard self-references and dead incoming values.
394+ if (auto *Inst = dyn_cast<Instruction>(V))
395+ if (Inst == PN || DeadBlocks.contains (PN->getIncomingBlock (I)))
396+ continue ;
397+
398+ if (Constant *C = findConstantFor (V, KnownConstants)) {
399+ if (!Const)
400+ Const = C;
401+ // Not all incoming values are the same constant. Bail immediately.
402+ if (C != Const)
403+ return nullptr ;
404+ continue ;
405+ }
406+ if (auto *Phi = dyn_cast<PHINode>(V)) {
407+ // It's not a Transitive phi. Bail out.
408+ if (!TransitivePHIs.contains (Phi))
409+ return nullptr ;
410+ continue ;
411+ }
412+
413+ // We can't reason about anything else.
414+ return nullptr ;
415+ }
416+ return Const;
417+ };
418+
419+ // All TransitivePHIs have to be the SAME constant.
420+ Constant *Retval = nullptr ;
421+ for (PHINode *Phi : TransitivePHIs) {
422+ if (Constant *Const = canConstantFoldNestedPhi (Phi)) {
423+ if (!Retval)
424+ Retval = Const;
425+ else if (Retval != Const)
426+ return nullptr ;
346427 }
347428 }
348- StronglyConnectedPHIs. clear ();
349- return CannotConstantFoldPhi ? nullptr : Const ;
429+
430+ return Retval ;
350431}
351432
352433Constant *InstCostVisitor::visitFreezeInst (FreezeInst &I) {
@@ -871,37 +952,37 @@ bool FunctionSpecializer::findSpecializations(Function *F, unsigned FuncSize,
871952 unsigned FuncGrowth) -> bool {
872953 // No check required.
873954 if (ForceSpecialization) {
874- LLVM_DEBUG (dbgs () << " Force is on\n " );
955+ LLVM_DEBUG (dbgs () << " FnSpecialization: Force is on\n " );
875956 return true ;
876957 }
877958 // Minimum inlining bonus.
878959 if (Score > MinInliningBonus * FuncSize / 100 ) {
879960 LLVM_DEBUG (dbgs ()
880- << " FnSpecialization: Min inliningbous: Score = " << Score
881- << " > " << MinInliningBonus * FuncSize / 100 << " \n " );
961+ << " FnSpecialization: Sufficient inlining bonus ( " << Score
962+ << " > " << MinInliningBonus * FuncSize / 100 << " ) \n " );
882963 return true ;
883964 }
884965 // Minimum codesize savings.
885966 if (B.CodeSize < MinCodeSizeSavings * FuncSize / 100 ) {
886967 LLVM_DEBUG (dbgs ()
887- << " FnSpecialization: Min CodeSize Saving: CodeSize = "
968+ << " FnSpecialization: Insufficinet CodeSize Saving ( "
888969 << B.CodeSize << " > "
889- << MinCodeSizeSavings * FuncSize / 100 << " \n " );
970+ << MinCodeSizeSavings * FuncSize / 100 << " ) \n " );
890971 return false ;
891972 }
892973 // Minimum latency savings.
893974 if (B.Latency < MinLatencySavings * FuncSize / 100 ) {
894- LLVM_DEBUG (dbgs ()
895- << " FnSpecialization: Min Latency Saving: Latency = "
896- << B.Latency << " > " << MinLatencySavings * FuncSize / 100
897- << " \n " );
975+ LLVM_DEBUG (dbgs () << " FnSpecialization: Insufficinet Latency Saving ("
976+ << B.Latency << " > "
977+ << MinLatencySavings * FuncSize / 100 << " )\n " );
898978 return false ;
899979 }
900980 // Maximum codesize growth.
901981 if (FuncGrowth / FuncSize > MaxCodeSizeGrowth) {
902- LLVM_DEBUG (dbgs () << " FnSpecialization: Max Func Growth: CodeSize = "
903- << FuncGrowth / FuncSize << " > "
904- << MaxCodeSizeGrowth << " \n " );
982+ LLVM_DEBUG (dbgs ()
983+ << " FnSpecialization: Function Growth exceeds threshold ("
984+ << FuncGrowth / FuncSize << " > " << MaxCodeSizeGrowth
985+ << " )\n " );
905986 return false ;
906987 }
907988 return true ;
0 commit comments