Skip to content

Commit

Permalink
codegen: explicitly handle Float16 intrinsics
Browse files Browse the repository at this point in the history
Fixes #44829, until llvm fixes the support for these intrinsics itself
  • Loading branch information
vtjnash committed May 9, 2022
1 parent 5c557b2 commit f8b2c46
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 87 deletions.
4 changes: 2 additions & 2 deletions base/float.jl
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ function unsafe_trunc(::Type{Int128}, x::Float32)
copysign(unsafe_trunc(UInt128,x) % Int128, x)
end

unsafe_trunc(::Type{UInt128}, x::Float16) = unsafe_trunc(UInt128, Float32(x))
unsafe_trunc(::Type{Int128}, x::Float16) = unsafe_trunc(Int128, Float32(x))
unsafe_trunc(::Type{UInt128}, x::Float16) = fptoui(UInt128, x)
unsafe_trunc(::Type{Int128}, x::Float16) = fptosi(Int128, x)

# matches convert methods
# also determines floor, ceil, round
Expand Down
6 changes: 3 additions & 3 deletions src/APInt-C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ void LLVMByteSwap(unsigned numbits, integerPart *pa, integerPart *pr) {
void LLVMFPtoInt(unsigned numbits, void *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) {
double Val;
if (numbits == 16)
Val = __gnu_h2f_ieee(*(uint16_t*)pa);
Val = julia__gnu_h2f_ieee(*(uint16_t*)pa);
else if (numbits == 32)
Val = *(float*)pa;
else if (numbits == 64)
Expand Down Expand Up @@ -391,7 +391,7 @@ void LLVMSItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
val = a.roundToDouble(true);
}
if (onumbits == 16)
*(uint16_t*)pr = __gnu_f2h_ieee(val);
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
Expand All @@ -408,7 +408,7 @@ void LLVMUItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
val = a.roundToDouble(false);
}
if (onumbits == 16)
*(uint16_t*)pr = __gnu_f2h_ieee(val);
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
Expand Down
6 changes: 0 additions & 6 deletions src/julia.expmap
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,6 @@
environ;
__progname;

/* compiler run-time intrinsics */
__gnu_h2f_ieee;
__extendhfsf2;
__gnu_f2h_ieee;
__truncdfhf2;

local:
*;
};
14 changes: 12 additions & 2 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1523,8 +1523,18 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT;
#define JL_GC_ASSERT_LIVE(x) (void)(x)
#endif

float __gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
uint16_t __gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT;
//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) JL_NOTSAFEPOINT;

#ifdef __cplusplus
}
Expand Down
268 changes: 214 additions & 54 deletions src/llvm-demote-float16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,169 @@ INST_STATISTIC(FCmp);

namespace {

inline AttributeSet getFnAttrs(const AttributeList &Attrs)
{
#if JL_LLVM_VERSION >= 140000
return Attrs.getFnAttrs();
#else
return Attrs.getFnAttributes();
#endif
}

inline AttributeSet getRetAttrs(const AttributeList &Attrs)
{
#if JL_LLVM_VERSION >= 140000
return Attrs.getRetAttrs();
#else
return Attrs.getRetAttributes();
#endif
}

static Instruction *replaceIntrinsicWith(IntrinsicInst *call, Type *RetType, ArrayRef<Value*> args)
{
Intrinsic::ID ID = call->getIntrinsicID();
assert(ID);
auto oldfType = call->getFunctionType();
auto nargs = oldfType->getNumParams();
assert(args.size() > nargs);
SmallVector<Type*, 8> argTys(nargs);
for (unsigned i = 0; i < nargs; i++)
argTys[i] = args[i]->getType();
auto newfType = FunctionType::get(RetType, argTys, oldfType->isVarArg());

// Accumulate an array of overloaded types for the given intrinsic
// and compute the new name mangling schema
SmallVector<Type*, 4> overloadTys;
{
SmallVector<Intrinsic::IITDescriptor, 8> Table;
getIntrinsicInfoTableEntries(ID, Table);
ArrayRef<Intrinsic::IITDescriptor> TableRef = Table;
auto res = Intrinsic::matchIntrinsicSignature(newfType, TableRef, overloadTys);
assert(res == Intrinsic::MatchIntrinsicTypes_Match);
(void)res;
bool matchvararg = !Intrinsic::matchIntrinsicVarArg(newfType->isVarArg(), TableRef);
assert(matchvararg);
(void)matchvararg;
}
auto newF = Intrinsic::getDeclaration(call->getModule(), ID, overloadTys);
assert(newF->getFunctionType() == newfType);
newF->setCallingConv(call->getCallingConv());
assert(args.back() == call->getCalledFunction());
auto newCall = CallInst::Create(newF, args.drop_back(), "", call);
newCall->setTailCallKind(call->getTailCallKind());
auto old_attrs = call->getAttributes();
newCall->setAttributes(AttributeList::get(call->getContext(), getFnAttrs(old_attrs),
getRetAttrs(old_attrs), {})); // drop parameter attributes
return newCall;
}


static Value* CreateFPCast(Instruction::CastOps opcode, Value *V, Type *DestTy, IRBuilder<> &builder)
{
if (auto *VC = dyn_cast<Constant>(V)) {
// The input IR often has things of the form
// fcmp olt half %0, 0xH7C00
// and we would like to avoid turning that constant into a call here
// if we can simply constant fold it to the new type.
VC = ConstantExpr::getCast(opcode, VC, DestTy, true);
if (VC)
return VC;
}
auto &M = *builder.GetInsertBlock()->getModule();
auto &ctx = M.getContext();
Type *SrcTy = V->getType();
Type *RetTy = DestTy;
// Pick the Function to call in the Julia runtime
StringRef Name;
switch (opcode) {
case Instruction::FPExt:
// this is exact, so we only need one conversion
assert(SrcTy->isHalfTy());
Name = "julia__gnu_h2f_ieee";
RetTy = Type::getFloatTy(ctx);
break;
case Instruction::FPTrunc:
assert(DestTy->isHalfTy());
if (SrcTy->isFloatTy())
Name = "julia__gnu_f2h_ieee";
else if (SrcTy->isDoubleTy())
Name = "julia__truncdfhf2";
break;
// All F16 fit exactly in Int32 (-65504 to 65504)
case Instruction::FPToSI: JL_FALLTHROUGH;
case Instruction::FPToUI:
assert(SrcTy->isHalfTy());
Name = "julia__gnu_h2f_ieee";
RetTy = Type::getFloatTy(ctx);
break;
case Instruction::SIToFP: JL_FALLTHROUGH;
case Instruction::UIToFP:
assert(DestTy->isHalfTy());
Name = "julia__gnu_f2h_ieee";
SrcTy = Type::getFloatTy(ctx);
break;
default:
llvm_unreachable("invalid cast");
}
if (Name.empty())
llvm_unreachable("illegal cast");
// Coerce the source to the required size and type
auto T_int16 = Type::getInt16Ty(ctx);
if (SrcTy->isHalfTy())
SrcTy = T_int16;
if (opcode == Instruction::SIToFP)
V = builder.CreateSIToFP(V, SrcTy);
else if (opcode == Instruction::UIToFP)
V = builder.CreateUIToFP(V, SrcTy);
else
V = builder.CreateBitCast(V, SrcTy);
// Call our intrinsic
if (RetTy->isHalfTy())
RetTy = T_int16;
auto FT = FunctionType::get(RetTy, {SrcTy}, false);
FunctionCallee F = M.getOrInsertFunction(Name, FT);
Value *I = builder.CreateCall(F, {V});
// Coerce the result to the expected type
if (opcode == Instruction::FPToSI)
I = builder.CreateFPToSI(I, DestTy);
else if (opcode == Instruction::FPToUI)
I = builder.CreateFPToUI(I, DestTy);
else if (opcode == Instruction::FPExt)
I = builder.CreateFPCast(I, DestTy);
else
I = builder.CreateBitCast(I, DestTy);
return I;
}

static bool demoteFloat16(Function &F)
{
auto &ctx = F.getContext();
auto T_float16 = Type::getHalfTy(ctx);
auto T_float32 = Type::getFloatTy(ctx);

SmallVector<Instruction *, 0> erase;
for (auto &BB : F) {
for (auto &I : BB) {
// extend Float16 operands to Float32
bool Float16 = I.getType()->isHalfTy();
for (size_t i = 0; !Float16 && i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType()->isHalfTy())
Float16 = true;
}
if (!Float16)
continue;

if (auto CI = dyn_cast<CastInst>(&I)) {
if (CI->getOpcode() != Instruction::BitCast) { // aka !CI->isNoopCast(DL)
++TotalChanged;
IRBuilder<> builder(&I);
Value *NewI = CreateFPCast(CI->getOpcode(), I.getOperand(0), I.getType(), builder);
I.replaceAllUsesWith(NewI);
erase.push_back(&I);
}
continue;
}

switch (I.getOpcode()) {
case Instruction::FNeg:
case Instruction::FAdd:
Expand All @@ -64,6 +218,9 @@ static bool demoteFloat16(Function &F)
case Instruction::FCmp:
break;
default:
if (auto intrinsic = dyn_cast<IntrinsicInst>(&I))
if (intrinsic->getIntrinsicID())
break;
continue;
}

Expand All @@ -75,72 +232,75 @@ static bool demoteFloat16(Function &F)
IRBuilder<> builder(&I);

// extend Float16 operands to Float32
bool OperandsChanged = false;
SmallVector<Value *, 2> Operands(I.getNumOperands());
for (size_t i = 0; i < I.getNumOperands(); i++) {
Value *Op = I.getOperand(i);
if (Op->getType() == T_float16) {
if (Op->getType()->isHalfTy()) {
++TotalExt;
Op = builder.CreateFPExt(Op, T_float32);
OperandsChanged = true;
Op = CreateFPCast(Instruction::FPExt, Op, T_float32, builder);
}
Operands[i] = (Op);
}

// recreate the instruction if any operands changed,
// truncating the result back to Float16
if (OperandsChanged) {
Value *NewI;
++TotalChanged;
switch (I.getOpcode()) {
case Instruction::FNeg:
assert(Operands.size() == 1);
++FNegChanged;
NewI = builder.CreateFNeg(Operands[0]);
break;
case Instruction::FAdd:
assert(Operands.size() == 2);
++FAddChanged;
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
break;
case Instruction::FSub:
assert(Operands.size() == 2);
++FSubChanged;
NewI = builder.CreateFSub(Operands[0], Operands[1]);
break;
case Instruction::FMul:
assert(Operands.size() == 2);
++FMulChanged;
NewI = builder.CreateFMul(Operands[0], Operands[1]);
break;
case Instruction::FDiv:
assert(Operands.size() == 2);
++FDivChanged;
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
break;
case Instruction::FRem:
assert(Operands.size() == 2);
++FRemChanged;
NewI = builder.CreateFRem(Operands[0], Operands[1]);
break;
case Instruction::FCmp:
assert(Operands.size() == 2);
++FCmpChanged;
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
Operands[0], Operands[1]);
Value *NewI;
++TotalChanged;
switch (I.getOpcode()) {
case Instruction::FNeg:
assert(Operands.size() == 1);
++FNegChanged;
NewI = builder.CreateFNeg(Operands[0]);
break;
case Instruction::FAdd:
assert(Operands.size() == 2);
++FAddChanged;
NewI = builder.CreateFAdd(Operands[0], Operands[1]);
break;
case Instruction::FSub:
assert(Operands.size() == 2);
++FSubChanged;
NewI = builder.CreateFSub(Operands[0], Operands[1]);
break;
case Instruction::FMul:
assert(Operands.size() == 2);
++FMulChanged;
NewI = builder.CreateFMul(Operands[0], Operands[1]);
break;
case Instruction::FDiv:
assert(Operands.size() == 2);
++FDivChanged;
NewI = builder.CreateFDiv(Operands[0], Operands[1]);
break;
case Instruction::FRem:
assert(Operands.size() == 2);
++FRemChanged;
NewI = builder.CreateFRem(Operands[0], Operands[1]);
break;
case Instruction::FCmp:
assert(Operands.size() == 2);
++FCmpChanged;
NewI = builder.CreateFCmp(cast<FCmpInst>(&I)->getPredicate(),
Operands[0], Operands[1]);
break;
default:
if (auto intrinsic = dyn_cast<IntrinsicInst>(&I)) {
Type *RetType = I.getType();
if (RetType->isHalfTy())
RetType = T_float32;
NewI = replaceIntrinsicWith(intrinsic, RetType, Operands);
break;
default:
abort();
}
cast<Instruction>(NewI)->copyMetadata(I);
cast<Instruction>(NewI)->copyFastMathFlags(&I);
if (NewI->getType() != I.getType()) {
++TotalTrunc;
NewI = builder.CreateFPTrunc(NewI, I.getType());
}
I.replaceAllUsesWith(NewI);
erase.push_back(&I);
abort();
}
cast<Instruction>(NewI)->copyMetadata(I);
cast<Instruction>(NewI)->copyFastMathFlags(&I);
if (NewI->getType() != I.getType()) {
++TotalTrunc;
NewI = CreateFPCast(Instruction::FPTrunc, NewI, I.getType(), builder);
}
I.replaceAllUsesWith(NewI);
erase.push_back(&I);
}
}

Expand Down
Loading

0 comments on commit f8b2c46

Please sign in to comment.