Skip to content

Commit

Permalink
Addressing review.
Browse files Browse the repository at this point in the history
getTLIFunction is no longer an optional. It accepts a pointer for
ScalarFunc
  • Loading branch information
paschalis-mpeis committed Dec 18, 2023
1 parent 4332484 commit 8082f46
Showing 1 changed file with 20 additions and 21 deletions.
41 changes: 20 additions & 21 deletions llvm/lib/CodeGen/ReplaceWithVeclib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,24 @@ STATISTIC(NumFuncUsedAdded,
"Number of functions added to `llvm.compiler.used`");

/// Returns a vector Function that it adds to the Module \p M. When an \p
/// OptOldFunc is given, it copies its attributes to the newly created Function.
/// ScalarFunc is not null, it copies its attributes to the newly created
/// Function.
Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
std::optional<Function *> OptOldFunc,
const StringRef TLIName) {
Function *ScalarFunc, const StringRef TLIName) {
Function *TLIFunc = M->getFunction(TLIName);
if (!TLIFunc) {
TLIFunc =
Function::Create(VectorFTy, Function::ExternalLinkage, TLIName, *M);
if (OptOldFunc)
TLIFunc->copyAttributesFrom(*OptOldFunc);
if (ScalarFunc)
TLIFunc->copyAttributesFrom(ScalarFunc);

LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
<< TLIName << "` of type `" << *(TLIFunc->getType())
<< "` to module.\n");

++NumTLIFuncDeclAdded;
// Add the freshly created function to llvm.compiler.used, similar to as it
// is done in InjectTLIMappings
// is done in InjectTLIMappings.
appendToCompilerUsed(*M, {TLIFunc});
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
<< "` to `@llvm.compiler.used`.\n");
Expand All @@ -72,11 +72,11 @@ Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
/// Replace the call to the vector intrinsic ( \p FuncToReplace ) with a call to
/// the corresponding function from the vector library ( \p TLIFunc ).
static void replaceWithTLIFunction(CallInst &CI, VFInfo &Info,
Function *TLIFunc, FunctionType *VecFTy) {
Function *TLIVecFunc) {
IRBuilder<> IRBuilder(&CI);
SmallVector<Value *> Args(CI.args());
if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
if (Args.size() == VecFTy->getNumParams())
if (Args.size() == TLIVecFunc->getFunctionType()->getNumParams())
static_assert(true && "mask was already in place");

auto *MaskTy =
Expand All @@ -88,9 +88,7 @@ static void replaceWithTLIFunction(CallInst &CI, VFInfo &Info,
// Preserve the operand bundles.
SmallVector<OperandBundleDef, 1> OpBundles;
CI.getOperandBundlesAsDefs(OpBundles);
CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles);
assert(VecFTy == TLIFunc->getFunctionType() &&
"Expecting function types to be identical");
CallInst *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
CI.replaceAllUsesWith(Replacement);
// Preserve fast math flags for FP math.
if (isa<FPMathOperator>(Replacement))
Expand All @@ -102,10 +100,10 @@ static void replaceWithTLIFunction(CallInst &CI, VFInfo &Info,
static std::optional<const VecDesc *> getVecDesc(const TargetLibraryInfo &TLI,
const StringRef &ScalarName,
const ElementCount &VF) {
if (auto *VDMasked = TLI.getVectorMappingInfo(ScalarName, VF, true))
return VDMasked;
if (auto *VDNoMask = TLI.getVectorMappingInfo(ScalarName, VF, false))
return VDNoMask;
if (auto *VDMasked = TLI.getVectorMappingInfo(ScalarName, VF, true))
return VDMasked;
return std::nullopt;
}

Expand All @@ -117,20 +115,20 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
return false;

auto IntrinsicID = CI.getCalledFunction()->getIntrinsicID();
// Replacement is only performed for intrinsic functions
// Replacement is only performed for intrinsic functions.
if (IntrinsicID == Intrinsic::not_intrinsic)
return false;

// Convert vector arguments to scalar type and check that all vector operands
// have identical vector width.
ElementCount VF = ElementCount::getFixed(0);
SmallVector<Type *> ScalarTypes;
SmallVector<Type *> ScalarArgTypes;
for (auto Arg : enumerate(CI.args())) {
auto *ArgTy = Arg.value()->getType();
if (isVectorIntrinsicWithScalarOpAtArg(IntrinsicID, Arg.index())) {
ScalarTypes.push_back(ArgTy);
ScalarArgTypes.push_back(ArgTy);
} else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
ScalarTypes.push_back(ArgTy->getScalarType());
ScalarArgTypes.push_back(ArgTy->getScalarType());
// Disallow vector arguments with different VFs. When processing the first
// vector argument, store it's VF, and for the rest ensure that they match
// it.
Expand All @@ -139,15 +137,15 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
else if (VF != VectorArgTy->getElementCount())
return false;
} else
// enters when it is supposed to be a vector argument but it isn't.
// Exit when it is supposed to be a vector argument but it isn't.
return false;
}

// Try to reconstruct the name for the scalar version of this intrinsic using
// the intrinsic ID and the argument types converted to scalar above.
std::string ScalarName =
(Intrinsic::isOverloaded(IntrinsicID)
? Intrinsic::getName(IntrinsicID, ScalarTypes, CI.getModule())
? Intrinsic::getName(IntrinsicID, ScalarArgTypes, CI.getModule())
: Intrinsic::getName(IntrinsicID).str());

// The TargetLibraryInfo does not contain a vectorized version of the scalar
Expand All @@ -169,7 +167,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
// Replace the call to the intrinsic with a call to the vector library
// function.
Type *ScalarRetTy = CI.getType()->getScalarType();
FunctionType *ScalarFTy = FunctionType::get(ScalarRetTy, ScalarTypes, false);
FunctionType *ScalarFTy =
FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false);
const std::string MangledName = VD->getVectorFunctionABIVariantString();
auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
if (!OptInfo)
Expand All @@ -182,7 +181,7 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
Function *FuncToReplace = CI.getCalledFunction();
Function *TLIFunc = getTLIFunction(CI.getModule(), VectorFTy, FuncToReplace,
VD->getVectorFnName());
replaceWithTLIFunction(CI, *OptInfo, TLIFunc, VectorFTy);
replaceWithTLIFunction(CI, *OptInfo, TLIFunc);

LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `"
<< FuncToReplace->getName() << "` with call to `"
Expand Down

0 comments on commit 8082f46

Please sign in to comment.