Skip to content

Commit

Permalink
Fix SPIR-V friendly LLVM IR for conversion functions (#899)
Browse files Browse the repository at this point in the history
In most cases convert functions are translated to set of LLVM IR
instructions but in case of saturated conversion or if conversion has
rounding mode, the translation should go through SPIR-V friendly LLVM IR.
  • Loading branch information
Fznamznon authored Feb 18, 2021
1 parent d1213fe commit 9f3c10f
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 43 deletions.
11 changes: 11 additions & 0 deletions lib/SPIRV/SPIRVInternal.h
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,17 @@ std::string mangleBuiltin(StringRef UniqName, ArrayRef<Type *> ArgTypes,
std::string getSPIRVFriendlyIRFunctionName(OCLExtOpKind ExtOpId,
ArrayRef<Type *> ArgTys);

/// Mangle a function in SPIR-V friendly IR manner
/// \param UniqName full unmangled name of the SPIR-V built-in function that
/// contains possible postfixes that depend not on opcode but on decorations or
/// return type, for example __spirv_UConvert_Rint_sat.
/// \param OC opcode of corresponding built-in instruction. Used to gather info
/// for unsigned/constant arguments.
/// \param Types of arguments of SPIR-V built-in function
/// \return IA64 mangled name.
std::string getSPIRVFriendlyIRFunctionName(const std::string &UniqName,
spv::Op OC, ArrayRef<Type *> ArgTys);

/// Remove cast from a value.
Value *removeCast(Value *V);

Expand Down
51 changes: 35 additions & 16 deletions lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2385,13 +2385,6 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
BV->getName(), BB));
}

case OpImageQuerySize:
case OpImageQuerySizeLod: {
return mapValue(
BV, transSPIRVBuiltinFromInst(static_cast<SPIRVInstruction *>(BV), BB,
/*AddRetTypePostfix=*/true));
}

case OpBitReverse: {
auto *BR = static_cast<SPIRVUnary *>(BV);
auto Ty = transType(BV->getType());
Expand Down Expand Up @@ -2670,7 +2663,7 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
auto BI = static_cast<SPIRVInstruction *>(BV);
Value *Inst = nullptr;
if (BI->hasFPRoundingMode() || BI->isSaturatedConversion())
Inst = transOCLBuiltinFromInst(BI, BB);
Inst = transSPIRVBuiltinFromInst(BI, BB);
else
Inst = transConvertInst(BV, F, BB);
return mapValue(BV, Inst);
Expand Down Expand Up @@ -3252,10 +3245,16 @@ Instruction *SPIRVToLLVM::transBuiltinFromInst(const std::string &FuncName,
HasFuncPtrArg = true;
}
}
if (!HasFuncPtrArg)
mangleOpenClBuiltin(FuncName, ArgTys, MangledName);
else
if (!HasFuncPtrArg) {
if (BM->getDesiredBIsRepresentation() != BIsRepresentation::SPIRVFriendlyIR)
mangleOpenClBuiltin(FuncName, ArgTys, MangledName);
else
MangledName =
getSPIRVFriendlyIRFunctionName(FuncName, BI->getOpCode(), ArgTys);

} else {
MangledName = decorateSPIRVFunction(FuncName);
}
Function *Func = M->getFunction(MangledName);
FunctionType *FT = FunctionType::get(RetTy, ArgTys, false);
// ToDo: Some intermediate functions have duplicate names with
Expand Down Expand Up @@ -3399,22 +3398,42 @@ std::string getSPIRVFuncSuffix(SPIRVInstruction *BI) {
break;
}
}
if (BI->hasDecorate(DecorationSaturatedConversion)) {
Suffix += kSPIRVPostfix::Divider;
Suffix += kSPIRVPostfix::Sat;
}
SPIRVFPRoundingModeKind Kind;
if (BI->hasFPRoundingMode(&Kind)) {
Suffix += kSPIRVPostfix::Divider;
Suffix += SPIRSPIRVFPRoundingModeMap::rmap(Kind);
}
return Suffix;
}

Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI,
BasicBlock *BB,
bool AddRetTypePostfix) {
BasicBlock *BB) {
assert(BB && "Invalid BB");
const auto OC = BI->getOpCode();
bool AddRetTypePostfix = false;
if (OC == OpImageQuerySizeLod || OC == OpImageQuerySize)
AddRetTypePostfix = true;

bool IsRetSigned = false;
if (isCvtOpCode(OC)) {
AddRetTypePostfix = true;
if (OC == OpConvertUToF || OC == OpSatConvertUToS)
IsRetSigned = true;
}

if (AddRetTypePostfix) {
const Type *RetTy =
BI->hasType() ? transType(BI->getType()) : Type::getVoidTy(*Context);
return transBuiltinFromInst(getSPIRVFuncName(BI->getOpCode(), RetTy) +
return transBuiltinFromInst(getSPIRVFuncName(OC, RetTy, IsRetSigned) +
getSPIRVFuncSuffix(BI),
BI, BB);
}
return transBuiltinFromInst(
getSPIRVFuncName(BI->getOpCode(), getSPIRVFuncSuffix(BI)), BI, BB);
return transBuiltinFromInst(getSPIRVFuncName(OC, getSPIRVFuncSuffix(BI)), BI,
BB);
}

bool SPIRVToLLVM::translate() {
Expand Down
3 changes: 1 addition & 2 deletions lib/SPIRV/SPIRVReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,7 @@ class SPIRVToLLVM {
Instruction *transBuiltinFromInst(const std::string &FuncName,
SPIRVInstruction *BI, BasicBlock *BB);
Instruction *transOCLBuiltinFromInst(SPIRVInstruction *BI, BasicBlock *BB);
Instruction *transSPIRVBuiltinFromInst(SPIRVInstruction *BI, BasicBlock *BB,
bool AddRetTypePostfix = false);
Instruction *transSPIRVBuiltinFromInst(SPIRVInstruction *BI, BasicBlock *BB);
void transOCLVectorLoadStore(std::string &UnmangledName,
std::vector<SPIRVWord> &BArgs);

Expand Down
31 changes: 31 additions & 0 deletions lib/SPIRV/SPIRVToOCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ void SPIRVToOCL::visitCallInst(CallInst &CI) {
visitCallSPIRVImageMediaBlockBuiltin(&CI, OC);
return;
}
if (isCvtOpCode(OC)) {
visitCallSPIRVCvtBuiltin(&CI, OC, DemangledName);
return;
}
if (OCLSPIRVBuiltinMap::rfind(OC))
visitCallSPIRVBuiltin(&CI, OC);
}
Expand Down Expand Up @@ -498,6 +502,33 @@ void SPIRVToOCL::visitCallSPIRVImageMediaBlockBuiltin(CallInst *CI, Op OC) {
&Attrs);
}

void SPIRVToOCL::visitCallSPIRVCvtBuiltin(CallInst *CI, Op OC,
StringRef DemangledName) {
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
mutateCallInstOCL(
M, CI,
[=](CallInst *Call, std::vector<Value *> &Args) {
std::string CastBuiltInName;
if (isCvtFromUnsignedOpCode(OC))
CastBuiltInName = "u";
CastBuiltInName += kOCLBuiltinName::ConvertPrefix;
Type *DstTy = Call->getType();
CastBuiltInName +=
mapLLVMTypeToOCLType(DstTy, !isCvtToUnsignedOpCode(OC));
if (DemangledName.find("_sat") != StringRef::npos || isSatCvtOpCode(OC))
CastBuiltInName += "_sat";
Value *Src = Call->getOperand(0);
assert(Src && "Invalid SPIRV convert builtin call");
Type *SrcTy = Src->getType();
auto Loc = DemangledName.find("_rt");
if (Loc != StringRef::npos &&
!(isa<IntegerType>(SrcTy) && isa<IntegerType>(DstTy)))
CastBuiltInName += DemangledName.substr(Loc, 4).str();
return CastBuiltInName;
},
&Attrs);
}

void SPIRVToOCL::visitCallSPIRVBuiltin(CallInst *CI, Op OC) {
AttributeList Attrs = CI->getCalledFunction()->getAttributes();
mutateCallInstOCL(
Expand Down
6 changes: 6 additions & 0 deletions lib/SPIRV/SPIRVToOCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ class SPIRVToOCL : public ModulePass, public InstVisitor<SPIRVToOCL> {
/// intel_sub_group_media_block_write
void visitCallSPIRVImageMediaBlockBuiltin(CallInst *CI, Op OC);

/// Transform __spirv_*Convert_R{ReturnType}{_sat}{_rtp|_rtn|_rtz|_rte} to
/// convert_{ReturnType}_{sat}{_rtp|_rtn|_rtz|_rte}
/// example: <2 x i8> __spirv_SatConvertUToS(<2 x i32>) =>
/// convert_uchar2_sat(int2)
void visitCallSPIRVCvtBuiltin(CallInst *CI, Op OC, StringRef DemangledName);

/// Transform __spirv_* builtins to OCL 2.0 builtins.
/// No change with arguments.
void visitCallSPIRVBuiltin(CallInst *CI, Op OC);
Expand Down
55 changes: 54 additions & 1 deletion lib/SPIRV/SPIRVUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ std::string getSPIRVFuncName(Op OC, StringRef PostFix) {

std::string getSPIRVFuncName(Op OC, const Type *PRetTy, bool IsSigned) {
return prefixSPIRVName(getName(OC) + kSPIRVPostfix::Divider +
getPostfixForReturnType(PRetTy, false));
getPostfixForReturnType(PRetTy, IsSigned));
}

std::string getSPIRVExtFuncName(SPIRVExtInstSetKind Set, unsigned ExtOp,
Expand Down Expand Up @@ -1597,6 +1597,52 @@ bool checkTypeForSPIRVExtendedInstLowering(IntrinsicInst *II, SPIRVModule *BM) {
} // namespace SPIRV

namespace {
class SPIRVFriendlyIRMangleInfo : public BuiltinFuncMangleInfo {
public:
SPIRVFriendlyIRMangleInfo(spv::Op OC, ArrayRef<Type *> ArgTys)
: OC(OC), ArgTys(ArgTys) {}

void init(StringRef UniqUnmangledName) override {
UnmangledName = UniqUnmangledName.str();
switch (OC) {
case OpConvertUToF:
LLVM_FALLTHROUGH;
case OpUConvert:
LLVM_FALLTHROUGH;
case OpSatConvertUToS:
// Treat all arguments as unsigned
addUnsignedArg(-1);
break;
case OpSubgroupShuffleINTEL:
LLVM_FALLTHROUGH;
case OpSubgroupShuffleXorINTEL:
addUnsignedArg(1);
break;
case OpSubgroupShuffleDownINTEL:
LLVM_FALLTHROUGH;
case OpSubgroupShuffleUpINTEL:
addUnsignedArg(2);
break;
case OpSubgroupBlockWriteINTEL:
addUnsignedArg(0);
addUnsignedArg(1);
break;
case OpSubgroupImageBlockWriteINTEL:
addUnsignedArg(2);
break;
case OpSubgroupBlockReadINTEL:
setArgAttr(0, SPIR::ATTR_CONST);
addUnsignedArg(0);
break;
default:;
// No special handling is needed
}
}

private:
spv::Op OC;
ArrayRef<Type *> ArgTys;
};
class OpenCLStdToSPIRVFriendlyIRMangleInfo : public BuiltinFuncMangleInfo {
public:
OpenCLStdToSPIRVFriendlyIRMangleInfo(OCLExtOpKind ExtOpId,
Expand Down Expand Up @@ -1660,4 +1706,11 @@ std::string getSPIRVFriendlyIRFunctionName(OCLExtOpKind ExtOpId,
return mangleBuiltin(MangleInfo.getUnmangledName(), ArgTys, &MangleInfo);
}

std::string getSPIRVFriendlyIRFunctionName(const std::string &UniqName,
spv::Op OC,
ArrayRef<Type *> ArgTys) {
SPIRVFriendlyIRMangleInfo MangleInfo(OC, ArgTys);
return mangleBuiltin(UniqName, ArgTys, &MangleInfo);
}

} // namespace SPIRV
4 changes: 4 additions & 0 deletions lib/SPIRV/libSPIRV/SPIRVOpCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ inline bool isCvtFromUnsignedOpCode(Op OpCode) {
OpCode == OpSatConvertUToS;
}

inline bool isSatCvtOpCode(Op OpCode) {
return OpCode == OpSatConvertUToS || OpCode == OpSatConvertSToU;
}

inline bool isOpaqueGenericTypeOpCode(Op OpCode) {
return ((unsigned)OpCode >= OpTypeEvent && (unsigned)OpCode <= OpTypeQueue) ||
OpCode == OpTypeSampler;
Expand Down
24 changes: 0 additions & 24 deletions test/transcoding/SatConvert.cl

This file was deleted.

86 changes: 86 additions & 0 deletions test/transcoding/explicit-conversions.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// RUN: %clang_cc1 -triple spir-unknown-unknown -O1 -cl-std=CL2.0 -fdeclare-opencl-builtins -finclude-default-header -emit-llvm-bc %s -o %t.bc
// RUN: llvm-spirv %t.bc -spirv-text -o - | FileCheck %s --check-prefix=CHECK-SPIRV
// RUN: llvm-spirv %t.bc -o %t.spv
// RUN: spirv-val %t.spv
// RUN: llvm-spirv -r %t.spv -o %t.rev.bc
// RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
// RUN: llvm-spirv -r --spirv-target-env=SPV-IR %t.spv -o %t.rev.bc
// RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-SPV-IR

// CHECK-SPIRV: SatConvertSToU

// CHECK-LLVM-LABEL: @testSToU
// CHECK-LLVM: call spir_func <2 x i8> @_Z18convert_uchar2_satDv2_i

// CHECK-SPV-IR-LABEL: @testSToU
// CHECK-SPV-IR: call spir_func <2 x i8> @_Z30__spirv_SatConvertSToU_Ruchar2Dv2_i

kernel void testSToU(global int2 *a, global uchar2 *res) {
res[0] = convert_uchar2_sat(*a);
}

// CHECK-SPIRV: SatConvertUToS

// CHECK-LLVM-LABEL: @testUToS
// CHECK-LLVM: call spir_func <2 x i8> @_Z17convert_char2_satDv2_j

// CHECK-SPV-IR-LABEL: @testUToS
// CHECK-SPV-IR: call spir_func <2 x i8> @_Z29__spirv_SatConvertUToS_Rchar2Dv2_j
kernel void testUToS(global uint2 *a, global char2 *res) {
res[0] = convert_char2_sat(*a);
}

// CHECK-SPIRV: ConvertUToF

// CHECK-LLVM-LABEL: @testUToF
// CHECK-LLVM: call spir_func <2 x float> @_Z18convert_float2_rtzDv2_j

// CHECK-SPV-IR-LABEL: @testUToF
// CHECK-SPV-IR: call spir_func <2 x float> @_Z31__spirv_ConvertUToF_Rfloat2_rtzDv2_j
kernel void testUToF(global uint2 *a, global float2 *res) {
res[0] = convert_float2_rtz(*a);
}

// CHECK-SPIRV: ConvertFToU

// CHECK-LLVM-LABEL: @testFToUSat
// CHECK-LLVM: call spir_func <2 x i32> @_Z21convert_uint2_sat_rtnDv2_f

// CHECK-SPV-IR-LABEL: @testFToUSat
// CHECK-SPV-IR: call spir_func <2 x i32> @_Z34__spirv_ConvertFToU_Ruint2_sat_rtnDv2_f
kernel void testFToUSat(global float2 *a, global uint2 *res) {
res[0] = convert_uint2_sat_rtn(*a);
}

// CHECK-SPIRV: UConvert

// CHECK-LLVM-LABEL: @testUToUSat
// CHECK-LLVM: call spir_func i32 @_Z16convert_uint_sath

// CHECK-SPV-IR-LABEL: @testUToUSat
// CHECK-SPV-IR: call spir_func i32 @_Z26__spirv_UConvert_Ruint_sath
kernel void testUToUSat(global uchar *a, global uint *res) {
res[0] = convert_uint_sat(*a);
}

// CHECK-SPIRV: UConvert

// CHECK-LLVM-LABEL: @testUToUSat1
// CHECK-LLVM: call spir_func i8 @_Z17convert_uchar_satj

// CHECK-SPV-IR-LABEL: @testUToUSat1
// CHECK-SPV-IR: call spir_func i8 @_Z27__spirv_UConvert_Ruchar_satj
kernel void testUToUSat1(global uint *a, global uchar *res) {
res[0] = convert_uchar_sat(*a);
}

// CHECK-SPIRV: ConvertFToU

// CHECK-LLVM-LABEL: @testFToU
// CHECK-LLVM: call spir_func <3 x i32> @_Z17convert_uint3_rtpDv3_f

// CHECK-SPV-IR-LABEL: @testFToU
// CHECK-SPV-IR: call spir_func <3 x i32> @_Z30__spirv_ConvertFToU_Ruint3_rtpDv3_f
kernel void testFToU(global float3 *a, global uint3 *res) {
res[0] = convert_uint3_rtp(*a);
}

0 comments on commit 9f3c10f

Please sign in to comment.