13
13
14
14
#include " InstCombineInternal.h"
15
15
#include " llvm/ADT/APInt.h"
16
+ #include " llvm/ADT/SmallPtrSet.h"
16
17
#include " llvm/ADT/SmallVector.h"
17
18
#include " llvm/Analysis/InstructionSimplify.h"
18
19
#include " llvm/Analysis/ValueTracking.h"
@@ -657,6 +658,94 @@ Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) {
657
658
return nullptr ;
658
659
}
659
660
661
+ // If we have the following pattern,
662
+ // X = 1.0/sqrt(a)
663
+ // R1 = X * X
664
+ // R2 = a/sqrt(a)
665
+ // then this method collects all the instructions that match R1 and R2.
666
+ static bool getFSqrtDivOptPattern (Instruction *Div,
667
+ SmallPtrSetImpl<Instruction *> &R1,
668
+ SmallPtrSetImpl<Instruction *> &R2) {
669
+ Value *A;
670
+ if (match (Div, m_FDiv (m_FPOne (), m_Sqrt (m_Value (A)))) ||
671
+ match (Div, m_FDiv (m_SpecificFP (-1.0 ), m_Sqrt (m_Value (A))))) {
672
+ for (User *U : Div->users ()) {
673
+ Instruction *I = cast<Instruction>(U);
674
+ if (match (I, m_FMul (m_Specific (Div), m_Specific (Div))))
675
+ R1.insert (I);
676
+ }
677
+
678
+ CallInst *CI = cast<CallInst>(Div->getOperand (1 ));
679
+ for (User *U : CI->users ()) {
680
+ Instruction *I = cast<Instruction>(U);
681
+ if (match (I, m_FDiv (m_Specific (A), m_Sqrt (m_Specific (A)))))
682
+ R2.insert (I);
683
+ }
684
+ }
685
+ return !R1.empty () && !R2.empty ();
686
+ }
687
+
688
+ // Check legality for transforming
689
+ // x = 1.0/sqrt(a)
690
+ // r1 = x * x;
691
+ // r2 = a/sqrt(a);
692
+ //
693
+ // TO
694
+ //
695
+ // r1 = 1/a
696
+ // r2 = sqrt(a)
697
+ // x = r1 * r2
698
+ // This transform works only when 'a' is known positive.
699
+ static bool isFSqrtDivToFMulLegal (Instruction *X,
700
+ SmallPtrSetImpl<Instruction *> &R1,
701
+ SmallPtrSetImpl<Instruction *> &R2) {
702
+ // Check if the required pattern for the transformation exists.
703
+ if (!getFSqrtDivOptPattern (X, R1, R2))
704
+ return false ;
705
+
706
+ BasicBlock *BBx = X->getParent ();
707
+ BasicBlock *BBr1 = (*R1.begin ())->getParent ();
708
+ BasicBlock *BBr2 = (*R2.begin ())->getParent ();
709
+
710
+ CallInst *FSqrt = cast<CallInst>(X->getOperand (1 ));
711
+ if (!FSqrt->hasAllowReassoc () || !FSqrt->hasNoNaNs () ||
712
+ !FSqrt->hasNoSignedZeros () || !FSqrt->hasNoInfs ())
713
+ return false ;
714
+
715
+ // We change x = 1/sqrt(a) to x = sqrt(a) * 1/a . This change isn't allowed
716
+ // by recip fp as it is strictly meant to transform ops of type a/b to
717
+ // a * 1/b. So, this can be considered as algebraic rewrite and reassoc flag
718
+ // has been used(rather abused)in the past for algebraic rewrites.
719
+ if (!X->hasAllowReassoc () || !X->hasAllowReciprocal () || !X->hasNoInfs ())
720
+ return false ;
721
+
722
+ // Check the constraints on X, R1 and R2 combined.
723
+ // fdiv instruction and one of the multiplications must reside in the same
724
+ // block. If not, the optimized code may execute more ops than before and
725
+ // this may hamper the performance.
726
+ if (BBx != BBr1 && BBx != BBr2)
727
+ return false ;
728
+
729
+ // Check the constraints on instructions in R1.
730
+ if (any_of (R1, [BBr1](Instruction *I) {
731
+ // When you have multiple instructions residing in R1 and R2
732
+ // respectively, it's difficult to generate combinations of (R1,R2) and
733
+ // then check if we have the required pattern. So, for now, just be
734
+ // conservative.
735
+ return (I->getParent () != BBr1 || !I->hasAllowReassoc ());
736
+ }))
737
+ return false ;
738
+
739
+ // Check the constraints on instructions in R2.
740
+ return all_of (R2, [BBr2](Instruction *I) {
741
+ // When you have multiple instructions residing in R1 and R2
742
+ // respectively, it's difficult to generate combination of (R1,R2) and
743
+ // then check if we have the required pattern. So, for now, just be
744
+ // conservative.
745
+ return (I->getParent () == BBr2 && I->hasAllowReassoc ());
746
+ });
747
+ }
748
+
660
749
Instruction *InstCombinerImpl::foldFMulReassoc (BinaryOperator &I) {
661
750
Value *Op0 = I.getOperand (0 );
662
751
Value *Op1 = I.getOperand (1 );
@@ -1913,6 +2002,75 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
1913
2002
return BinaryOperator::CreateFMulFMF (Op0, NewSqrt, &I);
1914
2003
}
1915
2004
2005
+ // Change
2006
+ // X = 1/sqrt(a)
2007
+ // R1 = X * X
2008
+ // R2 = a * X
2009
+ //
2010
+ // TO
2011
+ //
2012
+ // FDiv = 1/a
2013
+ // FSqrt = sqrt(a)
2014
+ // FMul = FDiv * FSqrt
2015
+ // Replace Uses Of R1 With FDiv
2016
+ // Replace Uses Of R2 With FSqrt
2017
+ // Replace Uses Of X With FMul
2018
+ static Instruction *
2019
+ convertFSqrtDivIntoFMul (CallInst *CI, Instruction *X,
2020
+ const SmallPtrSetImpl<Instruction *> &R1,
2021
+ const SmallPtrSetImpl<Instruction *> &R2,
2022
+ InstCombiner::BuilderTy &B, InstCombinerImpl *IC) {
2023
+
2024
+ B.SetInsertPoint (X);
2025
+
2026
+ // Have an instruction that is representative of all of instructions in R1 and
2027
+ // get the most common fpmath metadata and fast-math flags on it.
2028
+ Value *SqrtOp = CI->getArgOperand (0 );
2029
+ auto *FDiv = cast<Instruction>(
2030
+ B.CreateFDiv (ConstantFP::get (X->getType (), 1.0 ), SqrtOp));
2031
+ auto *R1FPMathMDNode = (*R1.begin ())->getMetadata (LLVMContext::MD_fpmath);
2032
+ FastMathFlags R1FMF = (*R1.begin ())->getFastMathFlags (); // Common FMF
2033
+ for (Instruction *I : R1) {
2034
+ R1FPMathMDNode = MDNode::getMostGenericFPMath (
2035
+ R1FPMathMDNode, I->getMetadata (LLVMContext::MD_fpmath));
2036
+ R1FMF &= I->getFastMathFlags ();
2037
+ IC->replaceInstUsesWith (*I, FDiv);
2038
+ IC->eraseInstFromFunction (*I);
2039
+ }
2040
+ FDiv->setMetadata (LLVMContext::MD_fpmath, R1FPMathMDNode);
2041
+ FDiv->copyFastMathFlags (R1FMF);
2042
+
2043
+ // Have a single sqrt call instruction that is representative of all of
2044
+ // instructions in R2 and get the most common fpmath metadata and fast-math
2045
+ // flags on it.
2046
+ auto *FSqrt = cast<CallInst>(CI->clone ());
2047
+ FSqrt->insertBefore (CI);
2048
+ auto *R2FPMathMDNode = (*R2.begin ())->getMetadata (LLVMContext::MD_fpmath);
2049
+ FastMathFlags R2FMF = (*R2.begin ())->getFastMathFlags (); // Common FMF
2050
+ for (Instruction *I : R2) {
2051
+ R2FPMathMDNode = MDNode::getMostGenericFPMath (
2052
+ R2FPMathMDNode, I->getMetadata (LLVMContext::MD_fpmath));
2053
+ R2FMF &= I->getFastMathFlags ();
2054
+ IC->replaceInstUsesWith (*I, FSqrt);
2055
+ IC->eraseInstFromFunction (*I);
2056
+ }
2057
+ FSqrt->setMetadata (LLVMContext::MD_fpmath, R2FPMathMDNode);
2058
+ FSqrt->copyFastMathFlags (R2FMF);
2059
+
2060
+ Instruction *FMul;
2061
+ // If X = -1/sqrt(a) initially,then FMul = -(FDiv * FSqrt)
2062
+ if (match (X, m_FDiv (m_SpecificFP (-1.0 ), m_Specific (CI)))) {
2063
+ Value *Mul = B.CreateFMul (FDiv, FSqrt);
2064
+ FMul = cast<Instruction>(B.CreateFNeg (Mul));
2065
+ } else
2066
+ FMul = cast<Instruction>(B.CreateFMul (FDiv, FSqrt));
2067
+ FMul->copyMetadata (*X);
2068
+ FMul->copyFastMathFlags (FastMathFlags::intersectRewrite (R1FMF, R2FMF) |
2069
+ FastMathFlags::unionValue (R1FMF, R2FMF));
2070
+ IC->replaceInstUsesWith (*X, FMul);
2071
+ return IC->eraseInstFromFunction (*X);
2072
+ }
2073
+
1916
2074
Instruction *InstCombinerImpl::visitFDiv (BinaryOperator &I) {
1917
2075
Module *M = I.getModule ();
1918
2076
@@ -1937,6 +2095,24 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
1937
2095
return R;
1938
2096
1939
2097
Value *Op0 = I.getOperand (0 ), *Op1 = I.getOperand (1 );
2098
+
2099
+ // Convert
2100
+ // x = 1.0/sqrt(a)
2101
+ // r1 = x * x;
2102
+ // r2 = a/sqrt(a);
2103
+ //
2104
+ // TO
2105
+ //
2106
+ // r1 = 1/a
2107
+ // r2 = sqrt(a)
2108
+ // x = r1 * r2
2109
+ SmallPtrSet<Instruction *, 2 > R1, R2;
2110
+ if (isFSqrtDivToFMulLegal (&I, R1, R2)) {
2111
+ CallInst *CI = cast<CallInst>(I.getOperand (1 ));
2112
+ if (Instruction *D = convertFSqrtDivIntoFMul (CI, &I, R1, R2, Builder, this ))
2113
+ return D;
2114
+ }
2115
+
1940
2116
if (isa<Constant>(Op0))
1941
2117
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
1942
2118
if (Instruction *R = FoldOpIntoSelect (I, SI))
0 commit comments