Skip to content

Commit 7253c6f

Browse files
authored
[InstCombine] Transform high latency, dependent FSQRT/FDIV into FMUL (#87474)
The proposed patch, in general, tries to transform the below code sequence: x = 1.0 / sqrt (a); r1 = x * x; // same as 1.0 / a r2 = a / sqrt(a); // same as sqrt (a) TO (If x, r1 and r2 are all used further in the code) r1 = 1.0 / a r2 = sqrt (a) x = r1 * r2 The transform tries to make high latency sqrt and div operations independent and also saves on one multiplication. The patch was tested with SPEC17 suite with cpu=neoverse-v2. The performance uplift achieved was: 544.nab_r ~4% No other regressions were observed. Also, no compile time differences were observed with the patch. Closes #54652
1 parent 263fed7 commit 7253c6f

File tree

2 files changed

+807
-0
lines changed

2 files changed

+807
-0
lines changed

llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp

+176
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "InstCombineInternal.h"
1515
#include "llvm/ADT/APInt.h"
16+
#include "llvm/ADT/SmallPtrSet.h"
1617
#include "llvm/ADT/SmallVector.h"
1718
#include "llvm/Analysis/InstructionSimplify.h"
1819
#include "llvm/Analysis/ValueTracking.h"
@@ -657,6 +658,94 @@ Instruction *InstCombinerImpl::foldPowiReassoc(BinaryOperator &I) {
657658
return nullptr;
658659
}
659660

661+
// If we have the following pattern,
662+
// X = 1.0/sqrt(a)
663+
// R1 = X * X
664+
// R2 = a/sqrt(a)
665+
// then this method collects all the instructions that match R1 and R2.
666+
static bool getFSqrtDivOptPattern(Instruction *Div,
667+
SmallPtrSetImpl<Instruction *> &R1,
668+
SmallPtrSetImpl<Instruction *> &R2) {
669+
Value *A;
670+
if (match(Div, m_FDiv(m_FPOne(), m_Sqrt(m_Value(A)))) ||
671+
match(Div, m_FDiv(m_SpecificFP(-1.0), m_Sqrt(m_Value(A))))) {
672+
for (User *U : Div->users()) {
673+
Instruction *I = cast<Instruction>(U);
674+
if (match(I, m_FMul(m_Specific(Div), m_Specific(Div))))
675+
R1.insert(I);
676+
}
677+
678+
CallInst *CI = cast<CallInst>(Div->getOperand(1));
679+
for (User *U : CI->users()) {
680+
Instruction *I = cast<Instruction>(U);
681+
if (match(I, m_FDiv(m_Specific(A), m_Sqrt(m_Specific(A)))))
682+
R2.insert(I);
683+
}
684+
}
685+
return !R1.empty() && !R2.empty();
686+
}
687+
688+
// Check legality for transforming
689+
// x = 1.0/sqrt(a)
690+
// r1 = x * x;
691+
// r2 = a/sqrt(a);
692+
//
693+
// TO
694+
//
695+
// r1 = 1/a
696+
// r2 = sqrt(a)
697+
// x = r1 * r2
698+
// This transform works only when 'a' is known positive.
699+
static bool isFSqrtDivToFMulLegal(Instruction *X,
700+
SmallPtrSetImpl<Instruction *> &R1,
701+
SmallPtrSetImpl<Instruction *> &R2) {
702+
// Check if the required pattern for the transformation exists.
703+
if (!getFSqrtDivOptPattern(X, R1, R2))
704+
return false;
705+
706+
BasicBlock *BBx = X->getParent();
707+
BasicBlock *BBr1 = (*R1.begin())->getParent();
708+
BasicBlock *BBr2 = (*R2.begin())->getParent();
709+
710+
CallInst *FSqrt = cast<CallInst>(X->getOperand(1));
711+
if (!FSqrt->hasAllowReassoc() || !FSqrt->hasNoNaNs() ||
712+
!FSqrt->hasNoSignedZeros() || !FSqrt->hasNoInfs())
713+
return false;
714+
715+
// We change x = 1/sqrt(a) to x = sqrt(a) * 1/a . This change isn't allowed
716+
// by recip fp as it is strictly meant to transform ops of type a/b to
717+
// a * 1/b. So, this can be considered as algebraic rewrite and reassoc flag
718+
// has been used(rather abused)in the past for algebraic rewrites.
719+
if (!X->hasAllowReassoc() || !X->hasAllowReciprocal() || !X->hasNoInfs())
720+
return false;
721+
722+
// Check the constraints on X, R1 and R2 combined.
723+
// fdiv instruction and one of the multiplications must reside in the same
724+
// block. If not, the optimized code may execute more ops than before and
725+
// this may hamper the performance.
726+
if (BBx != BBr1 && BBx != BBr2)
727+
return false;
728+
729+
// Check the constraints on instructions in R1.
730+
if (any_of(R1, [BBr1](Instruction *I) {
731+
// When you have multiple instructions residing in R1 and R2
732+
// respectively, it's difficult to generate combinations of (R1,R2) and
733+
// then check if we have the required pattern. So, for now, just be
734+
// conservative.
735+
return (I->getParent() != BBr1 || !I->hasAllowReassoc());
736+
}))
737+
return false;
738+
739+
// Check the constraints on instructions in R2.
740+
return all_of(R2, [BBr2](Instruction *I) {
741+
// When you have multiple instructions residing in R1 and R2
742+
// respectively, it's difficult to generate combination of (R1,R2) and
743+
// then check if we have the required pattern. So, for now, just be
744+
// conservative.
745+
return (I->getParent() == BBr2 && I->hasAllowReassoc());
746+
});
747+
}
748+
660749
Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) {
661750
Value *Op0 = I.getOperand(0);
662751
Value *Op1 = I.getOperand(1);
@@ -1913,6 +2002,75 @@ static Instruction *foldFDivSqrtDivisor(BinaryOperator &I,
19132002
return BinaryOperator::CreateFMulFMF(Op0, NewSqrt, &I);
19142003
}
19152004

2005+
// Change
2006+
// X = 1/sqrt(a)
2007+
// R1 = X * X
2008+
// R2 = a * X
2009+
//
2010+
// TO
2011+
//
2012+
// FDiv = 1/a
2013+
// FSqrt = sqrt(a)
2014+
// FMul = FDiv * FSqrt
2015+
// Replace Uses Of R1 With FDiv
2016+
// Replace Uses Of R2 With FSqrt
2017+
// Replace Uses Of X With FMul
2018+
static Instruction *
2019+
convertFSqrtDivIntoFMul(CallInst *CI, Instruction *X,
2020+
const SmallPtrSetImpl<Instruction *> &R1,
2021+
const SmallPtrSetImpl<Instruction *> &R2,
2022+
InstCombiner::BuilderTy &B, InstCombinerImpl *IC) {
2023+
2024+
B.SetInsertPoint(X);
2025+
2026+
// Have an instruction that is representative of all of instructions in R1 and
2027+
// get the most common fpmath metadata and fast-math flags on it.
2028+
Value *SqrtOp = CI->getArgOperand(0);
2029+
auto *FDiv = cast<Instruction>(
2030+
B.CreateFDiv(ConstantFP::get(X->getType(), 1.0), SqrtOp));
2031+
auto *R1FPMathMDNode = (*R1.begin())->getMetadata(LLVMContext::MD_fpmath);
2032+
FastMathFlags R1FMF = (*R1.begin())->getFastMathFlags(); // Common FMF
2033+
for (Instruction *I : R1) {
2034+
R1FPMathMDNode = MDNode::getMostGenericFPMath(
2035+
R1FPMathMDNode, I->getMetadata(LLVMContext::MD_fpmath));
2036+
R1FMF &= I->getFastMathFlags();
2037+
IC->replaceInstUsesWith(*I, FDiv);
2038+
IC->eraseInstFromFunction(*I);
2039+
}
2040+
FDiv->setMetadata(LLVMContext::MD_fpmath, R1FPMathMDNode);
2041+
FDiv->copyFastMathFlags(R1FMF);
2042+
2043+
// Have a single sqrt call instruction that is representative of all of
2044+
// instructions in R2 and get the most common fpmath metadata and fast-math
2045+
// flags on it.
2046+
auto *FSqrt = cast<CallInst>(CI->clone());
2047+
FSqrt->insertBefore(CI);
2048+
auto *R2FPMathMDNode = (*R2.begin())->getMetadata(LLVMContext::MD_fpmath);
2049+
FastMathFlags R2FMF = (*R2.begin())->getFastMathFlags(); // Common FMF
2050+
for (Instruction *I : R2) {
2051+
R2FPMathMDNode = MDNode::getMostGenericFPMath(
2052+
R2FPMathMDNode, I->getMetadata(LLVMContext::MD_fpmath));
2053+
R2FMF &= I->getFastMathFlags();
2054+
IC->replaceInstUsesWith(*I, FSqrt);
2055+
IC->eraseInstFromFunction(*I);
2056+
}
2057+
FSqrt->setMetadata(LLVMContext::MD_fpmath, R2FPMathMDNode);
2058+
FSqrt->copyFastMathFlags(R2FMF);
2059+
2060+
Instruction *FMul;
2061+
// If X = -1/sqrt(a) initially,then FMul = -(FDiv * FSqrt)
2062+
if (match(X, m_FDiv(m_SpecificFP(-1.0), m_Specific(CI)))) {
2063+
Value *Mul = B.CreateFMul(FDiv, FSqrt);
2064+
FMul = cast<Instruction>(B.CreateFNeg(Mul));
2065+
} else
2066+
FMul = cast<Instruction>(B.CreateFMul(FDiv, FSqrt));
2067+
FMul->copyMetadata(*X);
2068+
FMul->copyFastMathFlags(FastMathFlags::intersectRewrite(R1FMF, R2FMF) |
2069+
FastMathFlags::unionValue(R1FMF, R2FMF));
2070+
IC->replaceInstUsesWith(*X, FMul);
2071+
return IC->eraseInstFromFunction(*X);
2072+
}
2073+
19162074
Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
19172075
Module *M = I.getModule();
19182076

@@ -1937,6 +2095,24 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
19372095
return R;
19382096

19392097
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
2098+
2099+
// Convert
2100+
// x = 1.0/sqrt(a)
2101+
// r1 = x * x;
2102+
// r2 = a/sqrt(a);
2103+
//
2104+
// TO
2105+
//
2106+
// r1 = 1/a
2107+
// r2 = sqrt(a)
2108+
// x = r1 * r2
2109+
SmallPtrSet<Instruction *, 2> R1, R2;
2110+
if (isFSqrtDivToFMulLegal(&I, R1, R2)) {
2111+
CallInst *CI = cast<CallInst>(I.getOperand(1));
2112+
if (Instruction *D = convertFSqrtDivIntoFMul(CI, &I, R1, R2, Builder, this))
2113+
return D;
2114+
}
2115+
19402116
if (isa<Constant>(Op0))
19412117
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
19422118
if (Instruction *R = FoldOpIntoSelect(I, SI))

0 commit comments

Comments
 (0)