Skip to content
Merged
62 changes: 60 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1154,10 +1154,63 @@ static unsigned getNumSizeComponents(SPIRVType *imgType) {
return arrayed ? numComps + 1 : numComps;
}

static bool builtinMayNeedPromotionToVec(uint32_t BuiltinNumber) {
switch (BuiltinNumber) {
case SPIRV::OpenCLExtInst::s_min:
case SPIRV::OpenCLExtInst::u_min:
case SPIRV::OpenCLExtInst::s_max:
case SPIRV::OpenCLExtInst::u_max:
case SPIRV::OpenCLExtInst::fmax:
case SPIRV::OpenCLExtInst::fmin:
case SPIRV::OpenCLExtInst::fmax_common:
case SPIRV::OpenCLExtInst::fmin_common:
case SPIRV::OpenCLExtInst::s_clamp:
case SPIRV::OpenCLExtInst::fclamp:
case SPIRV::OpenCLExtInst::u_clamp:
case SPIRV::OpenCLExtInst::mix:
case SPIRV::OpenCLExtInst::step:
case SPIRV::OpenCLExtInst::smoothstep:
return true;
default:
break;
}
return false;
}

//===----------------------------------------------------------------------===//
// Implementation functions for each builtin group
//===----------------------------------------------------------------------===//

static SmallVector<Register>
getBuiltinCallArguments(const SPIRV::IncomingCall *Call, uint32_t BuiltinNumber,
MachineIRBuilder &MIRBuilder, SPIRVGlobalRegistry *GR) {

Register ReturnTypeId = GR->getSPIRVTypeID(Call->ReturnType);
unsigned ResultElementCount =
GR->getScalarOrVectorComponentCount(ReturnTypeId);
bool MayNeedPromotionToVec =
builtinMayNeedPromotionToVec(BuiltinNumber) && ResultElementCount > 1;

if (!MayNeedPromotionToVec)
return {Call->Arguments.begin(), Call->Arguments.end()};

SmallVector<Register> Arguments;
for (Register Argument : Call->Arguments) {
Register VecArg = Argument;
SPIRVType *ArgumentType = GR->getSPIRVTypeForVReg(Argument);
if (ArgumentType != Call->ReturnType) {
VecArg = createVirtualRegister(Call->ReturnType, GR, MIRBuilder);
auto VecSplat = MIRBuilder.buildInstr(SPIRV::OpCompositeConstruct)
.addDef(VecArg)
.addUse(ReturnTypeId);
for (unsigned I = 0; I != ResultElementCount; ++I)
VecSplat.addUse(Argument);
}
Arguments.push_back(VecArg);
}
return Arguments;
}

static bool generateExtInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR, const CallBase &CB) {
Expand All @@ -1179,16 +1232,21 @@ static bool generateExtInst(const SPIRV::IncomingCall *Call,
: SPIRV::OpenCLExtInst::fmax;
}

Register ReturnTypeId = GR->getSPIRVTypeID(Call->ReturnType);
SmallVector<Register> Arguments =
getBuiltinCallArguments(Call, Number, MIRBuilder, GR);

// Build extended instruction.
auto MIB =
MIRBuilder.buildInstr(SPIRV::OpExtInst)
.addDef(Call->ReturnRegister)
.addUse(GR->getSPIRVTypeID(Call->ReturnType))
.addUse(ReturnTypeId)
.addImm(static_cast<uint32_t>(SPIRV::InstructionSet::OpenCL_std))
.addImm(Number);

for (auto Argument : Call->Arguments)
for (Register Argument : Arguments)
MIB.addUse(Argument);

MIB.getInstr()->copyIRFlags(CB);
if (OrigNumber == SPIRV::OpenCLExtInst::fmin_common ||
OrigNumber == SPIRV::OpenCLExtInst::fmax_common) {
Expand Down
102 changes: 3 additions & 99 deletions llvm/lib/Target/SPIRV/SPIRVRegularizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
//===----------------------------------------------------------------------===//

#include "SPIRV.h"
#include "llvm/Demangle/Demangle.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/InstVisitor.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/PassManager.h"
#include "llvm/Transforms/Utils/Cloning.h"

#include <list>

Expand All @@ -25,9 +24,7 @@
using namespace llvm;

namespace {
struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> {
DenseMap<Function *, Function *> Old2NewFuncs;

struct SPIRVRegularizer : public FunctionPass {
public:
static char ID;
SPIRVRegularizer() : FunctionPass(ID) {}
Expand All @@ -37,11 +34,8 @@ struct SPIRVRegularizer : public FunctionPass, InstVisitor<SPIRVRegularizer> {
void getAnalysisUsage(AnalysisUsage &AU) const override {
FunctionPass::getAnalysisUsage(AU);
}
void visitCallInst(CallInst &CI);

private:
void visitCallScalToVec(CallInst *CI, StringRef MangledName,
StringRef DemangledName);
void runLowerConstExpr(Function &F);
};
} // namespace
Expand Down Expand Up @@ -157,98 +151,8 @@ void SPIRVRegularizer::runLowerConstExpr(Function &F) {
}
}

// It fixes calls to OCL builtins that accept vector arguments and one of them
// is actually a scalar splat.
void SPIRVRegularizer::visitCallInst(CallInst &CI) {
auto F = CI.getCalledFunction();
if (!F)
return;

auto MangledName = F->getName();
char *NameStr = itaniumDemangle(F->getName().data());
if (!NameStr)
return;
StringRef DemangledName(NameStr);

// TODO: add support for other builtins.
if (DemangledName.starts_with("fmin") || DemangledName.starts_with("fmax") ||
DemangledName.starts_with("min") || DemangledName.starts_with("max"))
visitCallScalToVec(&CI, MangledName, DemangledName);
free(NameStr);
}

void SPIRVRegularizer::visitCallScalToVec(CallInst *CI, StringRef MangledName,
StringRef DemangledName) {
// Check if all arguments have the same type - it's simple case.
auto Uniform = true;
Type *Arg0Ty = CI->getOperand(0)->getType();
auto IsArg0Vector = isa<VectorType>(Arg0Ty);
for (unsigned I = 1, E = CI->arg_size(); Uniform && (I != E); ++I)
Uniform = isa<VectorType>(CI->getOperand(I)->getType()) == IsArg0Vector;
if (Uniform)
return;

auto *OldF = CI->getCalledFunction();
Function *NewF = nullptr;
auto [It, Inserted] = Old2NewFuncs.try_emplace(OldF);
if (Inserted) {
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
SmallVector<Type *, 2> ArgTypes = {OldF->getArg(0)->getType(), Arg0Ty};
auto *NewFTy =
FunctionType::get(OldF->getReturnType(), ArgTypes, OldF->isVarArg());
NewF = Function::Create(NewFTy, OldF->getLinkage(), OldF->getName(),
*OldF->getParent());
ValueToValueMapTy VMap;
auto NewFArgIt = NewF->arg_begin();
for (auto &Arg : OldF->args()) {
auto ArgName = Arg.getName();
NewFArgIt->setName(ArgName);
VMap[&Arg] = &(*NewFArgIt++);
}
SmallVector<ReturnInst *, 8> Returns;
CloneFunctionInto(NewF, OldF, VMap,
CloneFunctionChangeType::LocalChangesOnly, Returns);
NewF->setAttributes(Attrs);
It->second = NewF;
} else {
NewF = It->second;
}
assert(NewF);

// This produces an instruction sequence that implements a splat of
// CI->getOperand(1) to a vector Arg0Ty. However, we use InsertElementInst
// and ShuffleVectorInst to generate the same code as the SPIR-V translator.
// For instance (transcoding/OpMin.ll), this call
// call spir_func <2 x i32> @_Z3minDv2_ii(<2 x i32> <i32 1, i32 10>, i32 5)
// is translated to
// %8 = OpUndef %v2uint
// %14 = OpConstantComposite %v2uint %uint_1 %uint_10
// ...
// %10 = OpCompositeInsert %v2uint %uint_5 %8 0
// %11 = OpVectorShuffle %v2uint %10 %8 0 0
// %call = OpExtInst %v2uint %1 s_min %14 %11
auto ConstInt = ConstantInt::get(IntegerType::get(CI->getContext(), 32), 0);
PoisonValue *PVal = PoisonValue::get(Arg0Ty);
Instruction *Inst = InsertElementInst::Create(
PVal, CI->getOperand(1), ConstInt, "", CI->getIterator());
ElementCount VecElemCount = cast<VectorType>(Arg0Ty)->getElementCount();
Constant *ConstVec = ConstantVector::getSplat(VecElemCount, ConstInt);
Value *NewVec =
new ShuffleVectorInst(Inst, PVal, ConstVec, "", CI->getIterator());
CI->setOperand(1, NewVec);
CI->replaceUsesOfWith(OldF, NewF);
CI->mutateFunctionType(NewF->getFunctionType());
}

bool SPIRVRegularizer::runOnFunction(Function &F) {
runLowerConstExpr(F);
visit(F);
for (auto &OldNew : Old2NewFuncs) {
Function *OldF = OldNew.first;
Function *NewF = OldNew.second;
NewF->takeName(OldF);
OldF->eraseFromParent();
}
return true;
}

Expand Down
Loading