Skip to content

Commit

Permalink
Use a single Float16 ABI. (#51666)
Browse files Browse the repository at this point in the history
Currently, Julia uses 2 different Float16 ABIs, depending on the host compiler used to
compile Julia: either pass as integer, or pass as LLVM's native `half`. Since the runtime
intrinsics are implemented in C using `uint16`, this necessitated conversions around the
runtime functions (`gnu_h2f_ieee`, `truncdfhf2`, etc) that the compiler may emit calls to.

This PR switches to always using the 'native' ABIs that platforms have for Float16,
by removing the conversions around runtime calls, and defining our runtime intrinsics
using the native `_Float16` type. Availability of this type depends on the platform, and
the compiler version, so we also define fallbacks that mimick the platform-specific
calling convention.
  • Loading branch information
maleadt authored Oct 20, 2023
1 parent f8f573d commit ef3bf66
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 190 deletions.
4 changes: 4 additions & 0 deletions base/ctypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,7 @@ const Cfloat = Float32
Equivalent to the native `double` c-type ([`Float64`](@ref)).
"""
const Cdouble = Float64


# we have no `Float16` alias, because C does not define a standard fp16 type. Julia follows
# the _Float16 C ABI; if that becomes standard, we can add an appropriate alias here.
9 changes: 6 additions & 3 deletions src/APInt-C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,13 @@ void LLVMByteSwap(unsigned numbits, integerPart *pa, integerPart *pr) {
ASSIGN(r, a)
}

extern "C" float julia_half_to_float(uint16_t ival) JL_NOTSAFEPOINT;
extern "C" uint16_t julia_float_to_half(float param) JL_NOTSAFEPOINT;

void LLVMFPtoInt(unsigned numbits, void *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) {
double Val;
if (numbits == 16)
Val = julia__gnu_h2f_ieee(*(uint16_t*)pa);
Val = julia_half_to_float(*(uint16_t*)pa);
else if (numbits == 32)
Val = *(float*)pa;
else if (numbits == 64)
Expand Down Expand Up @@ -391,7 +394,7 @@ void LLVMSItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
val = a.roundToDouble(true);
}
if (onumbits == 16)
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
*(uint16_t*)pr = julia_float_to_half(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
Expand All @@ -408,7 +411,7 @@ void LLVMUItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar
val = a.roundToDouble(false);
}
if (onumbits == 16)
*(uint16_t*)pr = julia__gnu_f2h_ieee(val);
*(uint16_t*)pr = julia_float_to_half(val);
else if (onumbits == 32)
*(float*)pr = val;
else if (onumbits == 64)
Expand Down
8 changes: 7 additions & 1 deletion src/abi_ppc64le.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,12 @@ bool needPassByRef(jl_datatype_t *dt, AttrBuilder &ab, LLVMContext &ctx, Type *T
Type *preferred_llvm_type(jl_datatype_t *dt, bool isret, LLVMContext &ctx) const override
{
// Arguments are either scalar or passed by value
size_t size = jl_datatype_size(dt);

// LLVM passes Float16 in floating-point registers, but this doesn't match the ABI.
// No C compiler seems to support _Float16 yet, so in the meantime, pass as i16
if (dt == jl_float16_type || dt == jl_bfloat16_type)
return Type::getInt16Ty(ctx);

// don't need to change bitstypes
if (!jl_datatype_nfields(dt))
return NULL;
Expand All @@ -143,6 +148,7 @@ Type *preferred_llvm_type(jl_datatype_t *dt, bool isret, LLVMContext &ctx) const
}
// rewrite integer-sized (non-HFA) struct to an array
// the bitsize of the integer gives the desired alignment
size_t size = jl_datatype_size(dt);
if (size > 8) {
if (jl_datatype_align(dt) <= 8) {
Type *T_int64 = Type::getInt64Ty(ctx);
Expand Down
3 changes: 2 additions & 1 deletion src/abi_x86_64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ struct Classification {
void classifyType(Classification& accum, jl_datatype_t *dt, uint64_t offset) const
{
// Floating point types
if (dt == jl_float64_type || dt == jl_float32_type || dt == jl_bfloat16_type) {
if (dt == jl_float64_type || dt == jl_float32_type || dt == jl_float16_type ||
dt == jl_bfloat16_type) {
accum.addField(offset, Sse);
}
// Misc types
Expand Down
9 changes: 3 additions & 6 deletions src/aotcompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -986,8 +986,6 @@ struct ShardTimers {
}
};

void emitFloat16Wrappers(Module &M, bool external);

struct AOTOutputs {
SmallVector<char, 0> unopt, opt, obj, asm_;
};
Expand Down Expand Up @@ -1047,11 +1045,12 @@ static AOTOutputs add_output_impl(Module &M, TargetMachine &SourceTM, ShardTimer
// no need to inject aliases if we have no functions

if (inject_aliases) {
#if JULIA_FLOAT16_ABI == 1
// We would like to emit an alias or an weakref alias to redirect these symbols
// but LLVM doesn't let us emit a GlobalAlias to a declaration...
// So for now we inject a definition of these functions that calls our runtime
// functions. We do so after optimization to avoid cloning these functions.

// Float16 conversion routines
injectCRTAlias(M, "__gnu_h2f_ieee", "julia__gnu_h2f_ieee",
FunctionType::get(Type::getFloatTy(M.getContext()), { Type::getHalfTy(M.getContext()) }, false));
injectCRTAlias(M, "__extendhfsf2", "julia__gnu_h2f_ieee",
Expand All @@ -1062,10 +1061,8 @@ static AOTOutputs add_output_impl(Module &M, TargetMachine &SourceTM, ShardTimer
FunctionType::get(Type::getHalfTy(M.getContext()), { Type::getFloatTy(M.getContext()) }, false));
injectCRTAlias(M, "__truncdfhf2", "julia__truncdfhf2",
FunctionType::get(Type::getHalfTy(M.getContext()), { Type::getDoubleTy(M.getContext()) }, false));
#else
emitFloat16Wrappers(M, false);
#endif

// BFloat16 conversion routines
injectCRTAlias(M, "__truncsfbf2", "julia__truncsfbf2",
FunctionType::get(Type::getBFloatTy(M.getContext()), { Type::getFloatTy(M.getContext()) }, false));
injectCRTAlias(M, "__truncsdbf2", "julia__truncdfbf2",
Expand Down
55 changes: 0 additions & 55 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9144,58 +9144,6 @@ static JuliaVariable *julia_const_gv(jl_value_t *val)
return nullptr;
}

// Handle FLOAT16 ABI v2
#if JULIA_FLOAT16_ABI == 2
static void makeCastCall(Module &M, StringRef wrapperName, StringRef calledName, FunctionType *FTwrapper, FunctionType *FTcalled, bool external)
{
Function *calledFun = M.getFunction(calledName);
if (!calledFun) {
calledFun = Function::Create(FTcalled, Function::ExternalLinkage, calledName, M);
}
auto linkage = external ? Function::ExternalLinkage : Function::InternalLinkage;
auto wrapperFun = Function::Create(FTwrapper, linkage, wrapperName, M);
wrapperFun->addFnAttr(Attribute::AlwaysInline);
llvm::IRBuilder<> builder(BasicBlock::Create(M.getContext(), "top", wrapperFun));
SmallVector<Value *, 4> CallArgs;
if (wrapperFun->arg_size() != calledFun->arg_size()){
llvm::errs() << "FATAL ERROR: Can't match wrapper to called function";
abort();
}
for (auto wrapperArg = wrapperFun->arg_begin(), calledArg = calledFun->arg_begin();
wrapperArg != wrapperFun->arg_end() && calledArg != calledFun->arg_end(); ++wrapperArg, ++calledArg)
{
CallArgs.push_back(builder.CreateBitCast(wrapperArg, calledArg->getType()));
}
auto val = builder.CreateCall(calledFun, CallArgs);
auto retval = builder.CreateBitCast(val,wrapperFun->getReturnType());
builder.CreateRet(retval);
}

void emitFloat16Wrappers(Module &M, bool external)
{
auto &ctx = M.getContext();
makeCastCall(M, "__gnu_h2f_ieee", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false),
FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external);
makeCastCall(M, "__extendhfsf2", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false),
FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external);
makeCastCall(M, "__gnu_f2h_ieee", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false),
FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external);
makeCastCall(M, "__truncsfhf2", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false),
FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external);
makeCastCall(M, "__truncdfhf2", "julia__truncdfhf2", FunctionType::get(Type::getHalfTy(ctx), { Type::getDoubleTy(ctx) }, false),
FunctionType::get(Type::getInt16Ty(ctx), { Type::getDoubleTy(ctx) }, false), external);
}

static void init_f16_funcs(void)
{
auto ctx = jl_ExecutionEngine->acquireContext();
auto TSM = jl_create_ts_module("F16Wrappers", ctx);
auto aliasM = TSM.getModuleUnlocked();
emitFloat16Wrappers(*aliasM, true);
jl_ExecutionEngine->addModule(std::move(TSM));
}
#endif

static void init_jit_functions(void)
{
add_named_global(jl_small_typeof_var, &jl_small_typeof);
Expand Down Expand Up @@ -9438,9 +9386,6 @@ extern "C" JL_DLLEXPORT_CODEGEN void jl_init_codegen_impl(void)
jl_init_llvm();
// Now that the execution engine exists, initialize all modules
init_jit_functions();
#if JULIA_FLOAT16_ABI == 2
init_f16_funcs();
#endif
}

extern "C" JL_DLLEXPORT_CODEGEN void jl_teardown_codegen_impl() JL_NOTSAFEPOINT
Expand Down
5 changes: 3 additions & 2 deletions src/jitlayers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1729,13 +1729,14 @@ JuliaOJIT::JuliaOJIT()
ExternalJD.addToLinkOrder(JD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly);

orc::SymbolAliasMap jl_crt = {
#if JULIA_FLOAT16_ABI == 1
// Float16 conversion routines
{ mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
{ mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } },
{ mangle("__gnu_f2h_ieee"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
{ mangle("__truncsfhf2"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } },
{ mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } },
#endif

// BFloat16 conversion routines
{ mangle("__truncsfbf2"), { mangle("julia__truncsfbf2"), JITSymbolFlags::Exported } },
{ mangle("__truncdfbf2"), { mangle("julia__truncdfbf2"), JITSymbolFlags::Exported } },
};
Expand Down
20 changes: 6 additions & 14 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1661,20 +1661,12 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT;
#define JL_WEAK_SYMBOL_DEFAULT(sym) NULL
#endif

JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT;
JL_DLLEXPORT float julia__truncsfbf2(float param) JL_NOTSAFEPOINT;
JL_DLLEXPORT float julia__truncdfbf2(double param) JL_NOTSAFEPOINT;
//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) JL_NOTSAFEPOINT;
//JL_DLLEXPORT float julia__gnu_h2f_ieee(half param) JL_NOTSAFEPOINT;
//JL_DLLEXPORT half julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT;
//JL_DLLEXPORT half julia__truncdfhf2(double param) JL_NOTSAFEPOINT;
//JL_DLLEXPORT float julia__truncsfbf2(float param) JL_NOTSAFEPOINT;
//JL_DLLEXPORT float julia__truncdfbf2(double param) JL_NOTSAFEPOINT;
//JL_DLLEXPORT double julia__extendhfdf2(half n) JL_NOTSAFEPOINT;

JL_DLLEXPORT uint32_t jl_crc32c(uint32_t crc, const char *buf, size_t len);

Expand Down
9 changes: 0 additions & 9 deletions src/llvm-version.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,6 @@
#define JL_LLVM_OPAQUE_POINTERS 1
#endif

// Pre GCC 12 libgcc defined the ABI for Float16->Float32
// to take an i16. GCC 12 silently changed the ABI to now pass
// Float16 in Float32 registers.
#if JL_LLVM_VERSION < 150000 || defined(_CPU_PPC64_) || defined(_CPU_PPC_)
#define JULIA_FLOAT16_ABI 1
#else
#define JULIA_FLOAT16_ABI 2
#endif

#ifdef __cplusplus
#if defined(__GNUC__) && (__GNUC__ >= 9)
// Added in GCC 9, this warning is annoying
Expand Down
Loading

0 comments on commit ef3bf66

Please sign in to comment.