diff --git a/base/ctypes.jl b/base/ctypes.jl index 26640ed82bef5..45f01b684902f 100644 --- a/base/ctypes.jl +++ b/base/ctypes.jl @@ -113,3 +113,7 @@ const Cfloat = Float32 Equivalent to the native `double` c-type ([`Float64`](@ref)). """ const Cdouble = Float64 + + +# we have no `Float16` alias, because C does not define a standard fp16 type. Julia follows +# the _Float16 C ABI; if that becomes standard, we can add an appropriate alias here. diff --git a/src/APInt-C.cpp b/src/APInt-C.cpp index f06d4362bf958..7ff68edb0868c 100644 --- a/src/APInt-C.cpp +++ b/src/APInt-C.cpp @@ -313,10 +313,13 @@ void LLVMByteSwap(unsigned numbits, integerPart *pa, integerPart *pr) { ASSIGN(r, a) } +extern "C" float julia_half_to_float(uint16_t ival) JL_NOTSAFEPOINT; +extern "C" uint16_t julia_float_to_half(float param) JL_NOTSAFEPOINT; + void LLVMFPtoInt(unsigned numbits, void *pa, unsigned onumbits, integerPart *pr, bool isSigned, bool *isExact) { double Val; if (numbits == 16) - Val = julia__gnu_h2f_ieee(*(uint16_t*)pa); + Val = julia_half_to_float(*(uint16_t*)pa); else if (numbits == 32) Val = *(float*)pa; else if (numbits == 64) @@ -391,7 +394,7 @@ void LLVMSItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar val = a.roundToDouble(true); } if (onumbits == 16) - *(uint16_t*)pr = julia__gnu_f2h_ieee(val); + *(uint16_t*)pr = julia_float_to_half(val); else if (onumbits == 32) *(float*)pr = val; else if (onumbits == 64) @@ -408,7 +411,7 @@ void LLVMUItoFP(unsigned numbits, integerPart *pa, unsigned onumbits, integerPar val = a.roundToDouble(false); } if (onumbits == 16) - *(uint16_t*)pr = julia__gnu_f2h_ieee(val); + *(uint16_t*)pr = julia_float_to_half(val); else if (onumbits == 32) *(float*)pr = val; else if (onumbits == 64) diff --git a/src/abi_ppc64le.cpp b/src/abi_ppc64le.cpp index 2e18acdbd4f4b..44d110422a099 100644 --- a/src/abi_ppc64le.cpp +++ b/src/abi_ppc64le.cpp @@ -118,7 +118,12 @@ bool needPassByRef(jl_datatype_t *dt, AttrBuilder &ab, LLVMContext &ctx, Type *T Type *preferred_llvm_type(jl_datatype_t *dt, bool isret, LLVMContext &ctx) const override { // Arguments are either scalar or passed by value - size_t size = jl_datatype_size(dt); + + // LLVM passes Float16 in floating-point registers, but this doesn't match the ABI. + // No C compiler seems to support _Float16 yet, so in the meantime, pass as i16 + if (dt == jl_float16_type || dt == jl_bfloat16_type) + return Type::getInt16Ty(ctx); + // don't need to change bitstypes if (!jl_datatype_nfields(dt)) return NULL; @@ -143,6 +148,7 @@ Type *preferred_llvm_type(jl_datatype_t *dt, bool isret, LLVMContext &ctx) const } // rewrite integer-sized (non-HFA) struct to an array // the bitsize of the integer gives the desired alignment + size_t size = jl_datatype_size(dt); if (size > 8) { if (jl_datatype_align(dt) <= 8) { Type *T_int64 = Type::getInt64Ty(ctx); diff --git a/src/abi_x86_64.cpp b/src/abi_x86_64.cpp index 7800c44b4d3ae..5938e1e5778a2 100644 --- a/src/abi_x86_64.cpp +++ b/src/abi_x86_64.cpp @@ -118,7 +118,8 @@ struct Classification { void classifyType(Classification& accum, jl_datatype_t *dt, uint64_t offset) const { // Floating point types - if (dt == jl_float64_type || dt == jl_float32_type || dt == jl_bfloat16_type) { + if (dt == jl_float64_type || dt == jl_float32_type || dt == jl_float16_type || + dt == jl_bfloat16_type) { accum.addField(offset, Sse); } // Misc types diff --git a/src/aotcompile.cpp b/src/aotcompile.cpp index fab53fa4de14c..2e1a9d2418eaa 100644 --- a/src/aotcompile.cpp +++ b/src/aotcompile.cpp @@ -986,8 +986,6 @@ struct ShardTimers { } }; -void emitFloat16Wrappers(Module &M, bool external); - struct AOTOutputs { SmallVector unopt, opt, obj, asm_; }; @@ -1047,11 +1045,12 @@ static AOTOutputs add_output_impl(Module &M, TargetMachine &SourceTM, ShardTimer // no need to inject aliases if we have no functions if (inject_aliases) { -#if JULIA_FLOAT16_ABI == 1 // We would like to emit an alias or an weakref alias to redirect these symbols // but LLVM doesn't let us emit a GlobalAlias to a declaration... // So for now we inject a definition of these functions that calls our runtime // functions. We do so after optimization to avoid cloning these functions. + + // Float16 conversion routines injectCRTAlias(M, "__gnu_h2f_ieee", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(M.getContext()), { Type::getHalfTy(M.getContext()) }, false)); injectCRTAlias(M, "__extendhfsf2", "julia__gnu_h2f_ieee", @@ -1062,10 +1061,8 @@ static AOTOutputs add_output_impl(Module &M, TargetMachine &SourceTM, ShardTimer FunctionType::get(Type::getHalfTy(M.getContext()), { Type::getFloatTy(M.getContext()) }, false)); injectCRTAlias(M, "__truncdfhf2", "julia__truncdfhf2", FunctionType::get(Type::getHalfTy(M.getContext()), { Type::getDoubleTy(M.getContext()) }, false)); -#else - emitFloat16Wrappers(M, false); -#endif + // BFloat16 conversion routines injectCRTAlias(M, "__truncsfbf2", "julia__truncsfbf2", FunctionType::get(Type::getBFloatTy(M.getContext()), { Type::getFloatTy(M.getContext()) }, false)); injectCRTAlias(M, "__truncsdbf2", "julia__truncdfbf2", diff --git a/src/codegen.cpp b/src/codegen.cpp index e37de75b7c5bc..dbb7d5aa498dc 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -9144,58 +9144,6 @@ static JuliaVariable *julia_const_gv(jl_value_t *val) return nullptr; } -// Handle FLOAT16 ABI v2 -#if JULIA_FLOAT16_ABI == 2 -static void makeCastCall(Module &M, StringRef wrapperName, StringRef calledName, FunctionType *FTwrapper, FunctionType *FTcalled, bool external) -{ - Function *calledFun = M.getFunction(calledName); - if (!calledFun) { - calledFun = Function::Create(FTcalled, Function::ExternalLinkage, calledName, M); - } - auto linkage = external ? Function::ExternalLinkage : Function::InternalLinkage; - auto wrapperFun = Function::Create(FTwrapper, linkage, wrapperName, M); - wrapperFun->addFnAttr(Attribute::AlwaysInline); - llvm::IRBuilder<> builder(BasicBlock::Create(M.getContext(), "top", wrapperFun)); - SmallVector CallArgs; - if (wrapperFun->arg_size() != calledFun->arg_size()){ - llvm::errs() << "FATAL ERROR: Can't match wrapper to called function"; - abort(); - } - for (auto wrapperArg = wrapperFun->arg_begin(), calledArg = calledFun->arg_begin(); - wrapperArg != wrapperFun->arg_end() && calledArg != calledFun->arg_end(); ++wrapperArg, ++calledArg) - { - CallArgs.push_back(builder.CreateBitCast(wrapperArg, calledArg->getType())); - } - auto val = builder.CreateCall(calledFun, CallArgs); - auto retval = builder.CreateBitCast(val,wrapperFun->getReturnType()); - builder.CreateRet(retval); -} - -void emitFloat16Wrappers(Module &M, bool external) -{ - auto &ctx = M.getContext(); - makeCastCall(M, "__gnu_h2f_ieee", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false), - FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external); - makeCastCall(M, "__extendhfsf2", "julia__gnu_h2f_ieee", FunctionType::get(Type::getFloatTy(ctx), { Type::getHalfTy(ctx) }, false), - FunctionType::get(Type::getFloatTy(ctx), { Type::getInt16Ty(ctx) }, false), external); - makeCastCall(M, "__gnu_f2h_ieee", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false), - FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external); - makeCastCall(M, "__truncsfhf2", "julia__gnu_f2h_ieee", FunctionType::get(Type::getHalfTy(ctx), { Type::getFloatTy(ctx) }, false), - FunctionType::get(Type::getInt16Ty(ctx), { Type::getFloatTy(ctx) }, false), external); - makeCastCall(M, "__truncdfhf2", "julia__truncdfhf2", FunctionType::get(Type::getHalfTy(ctx), { Type::getDoubleTy(ctx) }, false), - FunctionType::get(Type::getInt16Ty(ctx), { Type::getDoubleTy(ctx) }, false), external); -} - -static void init_f16_funcs(void) -{ - auto ctx = jl_ExecutionEngine->acquireContext(); - auto TSM = jl_create_ts_module("F16Wrappers", ctx); - auto aliasM = TSM.getModuleUnlocked(); - emitFloat16Wrappers(*aliasM, true); - jl_ExecutionEngine->addModule(std::move(TSM)); -} -#endif - static void init_jit_functions(void) { add_named_global(jl_small_typeof_var, &jl_small_typeof); @@ -9438,9 +9386,6 @@ extern "C" JL_DLLEXPORT_CODEGEN void jl_init_codegen_impl(void) jl_init_llvm(); // Now that the execution engine exists, initialize all modules init_jit_functions(); -#if JULIA_FLOAT16_ABI == 2 - init_f16_funcs(); -#endif } extern "C" JL_DLLEXPORT_CODEGEN void jl_teardown_codegen_impl() JL_NOTSAFEPOINT diff --git a/src/jitlayers.cpp b/src/jitlayers.cpp index e3f30f7d22d58..7eed240529e20 100644 --- a/src/jitlayers.cpp +++ b/src/jitlayers.cpp @@ -1729,13 +1729,14 @@ JuliaOJIT::JuliaOJIT() ExternalJD.addToLinkOrder(JD, orc::JITDylibLookupFlags::MatchExportedSymbolsOnly); orc::SymbolAliasMap jl_crt = { -#if JULIA_FLOAT16_ABI == 1 + // Float16 conversion routines { mangle("__gnu_h2f_ieee"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } }, { mangle("__extendhfsf2"), { mangle("julia__gnu_h2f_ieee"), JITSymbolFlags::Exported } }, { mangle("__gnu_f2h_ieee"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } }, { mangle("__truncsfhf2"), { mangle("julia__gnu_f2h_ieee"), JITSymbolFlags::Exported } }, { mangle("__truncdfhf2"), { mangle("julia__truncdfhf2"), JITSymbolFlags::Exported } }, -#endif + + // BFloat16 conversion routines { mangle("__truncsfbf2"), { mangle("julia__truncsfbf2"), JITSymbolFlags::Exported } }, { mangle("__truncdfbf2"), { mangle("julia__truncdfbf2"), JITSymbolFlags::Exported } }, }; diff --git a/src/julia_internal.h b/src/julia_internal.h index 204674c6d495a..4f326216d8daf 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -1661,20 +1661,12 @@ jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT; #define JL_WEAK_SYMBOL_DEFAULT(sym) NULL #endif -JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) JL_NOTSAFEPOINT; -JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT; -JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) JL_NOTSAFEPOINT; -JL_DLLEXPORT float julia__truncsfbf2(float param) JL_NOTSAFEPOINT; -JL_DLLEXPORT float julia__truncdfbf2(double param) JL_NOTSAFEPOINT; -//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) JL_NOTSAFEPOINT; -//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) JL_NOTSAFEPOINT; -//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) JL_NOTSAFEPOINT; -//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) JL_NOTSAFEPOINT; -//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) JL_NOTSAFEPOINT; -//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) JL_NOTSAFEPOINT; -//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) JL_NOTSAFEPOINT; -//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) JL_NOTSAFEPOINT; -//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) JL_NOTSAFEPOINT; +//JL_DLLEXPORT float julia__gnu_h2f_ieee(half param) JL_NOTSAFEPOINT; +//JL_DLLEXPORT half julia__gnu_f2h_ieee(float param) JL_NOTSAFEPOINT; +//JL_DLLEXPORT half julia__truncdfhf2(double param) JL_NOTSAFEPOINT; +//JL_DLLEXPORT float julia__truncsfbf2(float param) JL_NOTSAFEPOINT; +//JL_DLLEXPORT float julia__truncdfbf2(double param) JL_NOTSAFEPOINT; +//JL_DLLEXPORT double julia__extendhfdf2(half n) JL_NOTSAFEPOINT; JL_DLLEXPORT uint32_t jl_crc32c(uint32_t crc, const char *buf, size_t len); diff --git a/src/llvm-version.h b/src/llvm-version.h index 01638b8d44a6e..7b8dfbbae92d6 100644 --- a/src/llvm-version.h +++ b/src/llvm-version.h @@ -18,15 +18,6 @@ #define JL_LLVM_OPAQUE_POINTERS 1 #endif -// Pre GCC 12 libgcc defined the ABI for Float16->Float32 -// to take an i16. GCC 12 silently changed the ABI to now pass -// Float16 in Float32 registers. -#if JL_LLVM_VERSION < 150000 || defined(_CPU_PPC64_) || defined(_CPU_PPC_) -#define JULIA_FLOAT16_ABI 1 -#else -#define JULIA_FLOAT16_ABI 2 -#endif - #ifdef __cplusplus #if defined(__GNUC__) && (__GNUC__ >= 9) // Added in GCC 9, this warning is annoying diff --git a/src/runtime_intrinsics.c b/src/runtime_intrinsics.c index b42b7d9832383..588c0359f70be 100644 --- a/src/runtime_intrinsics.c +++ b/src/runtime_intrinsics.c @@ -5,8 +5,6 @@ // // this file assumes a little-endian processor, although that isn't too hard to fix // it also assumes two's complement negative numbers, which might be a bit harder to fix -// -// TODO: add half-float support #include "APInt-C.h" #include "julia.h" @@ -14,7 +12,7 @@ const unsigned int host_char_bit = 8; -// float16 intrinsics +// float16 conversion helpers static inline float half_to_float(uint16_t ival) JL_NOTSAFEPOINT { @@ -185,94 +183,189 @@ static inline uint16_t float_to_half(float param) JL_NOTSAFEPOINT return h; } -JL_DLLEXPORT float julia__gnu_h2f_ieee(uint16_t param) +static inline uint16_t double_to_half(double param) JL_NOTSAFEPOINT { + float temp = (float)param; + uint32_t tempi; + memcpy(&tempi, &temp, sizeof(temp)); + + // if Float16(res) is subnormal + if ((tempi&0x7fffffffu) < 0x38800000u) { + // shift so that the mantissa lines up where it would for normal Float16 + uint32_t shift = 113u-((tempi & 0x7f800000u)>>23u); + if (shift<23u) { + tempi |= 0x00800000; // set implicit bit + tempi >>= shift; + } + } + + // if we are halfway between 2 Float16 values + if ((tempi & 0x1fffu) == 0x1000u) { + memcpy(&tempi, &temp, sizeof(temp)); + // adjust the value by 1 ULP in the direction that will make Float16(temp) give the right answer + tempi += (fabs(temp) < fabs(param)) - (fabs(param) < fabs(temp)); + memcpy(&temp, &tempi, sizeof(temp)); + } + + return float_to_half(temp); +} + +// x86-specific helpers for emulating the (B)Float16 ABI +#if defined(_CPU_X86_) || defined(_CPU_X86_64_) +#include +static inline __m128 return_in_xmm(uint16_t input) JL_NOTSAFEPOINT { + __m128 xmm_output; + asm ( + "movd %[input], %%xmm0\n\t" + "movss %%xmm0, %[xmm_output]\n\t" + : [xmm_output] "=x" (xmm_output) + : [input] "r" ((uint32_t)input) + : "xmm0" + ); + return xmm_output; +} +static inline uint16_t take_from_xmm(__m128 xmm_input) JL_NOTSAFEPOINT { + uint32_t output; + asm ( + "movss %[xmm_input], %%xmm0\n\t" + "movd %%xmm0, %[output]\n\t" + : [output] "=r" (output) + : [xmm_input] "x" (xmm_input) + : "xmm0" + ); + return (uint16_t)output; +} +#endif + +// float16 conversion API + +// for use in APInt (without the ABI shenanigans from below) +uint16_t julia_float_to_half(float param) { + return float_to_half(param); +} +float julia_half_to_float(uint16_t param) { return half_to_float(param); } -JL_DLLEXPORT uint16_t julia__gnu_f2h_ieee(float param) +// starting with GCC 12 and Clang 15, we have _Float16 on most platforms +// (but not on Windows; this may be a bug in the MSYS2 GCC compilers) +#if ((defined(__GNUC__) && __GNUC__ > 11) || \ + (defined(__clang__) && __clang_major__ > 14)) && \ + !defined(_CPU_PPC64_) && !defined(_CPU_PPC_) && \ + !defined(_OS_WINDOWS_) + #define FLOAT16_TYPE _Float16 + #define FLOAT16_TO_UINT16(x) (*(uint16_t*)&(x)) + #define FLOAT16_FROM_UINT16(x) (*(_Float16*)&(x)) +// on older compilers, we need to emulate the platform-specific ABI +#elif defined(_CPU_X86_) || (defined(_CPU_X86_64_) && !defined(_OS_WINDOWS_)) + // on x86, we can use __m128; except on Windows where x64 calling + // conventions expect to pass __m128 by reference. + #define FLOAT16_TYPE __m128 + #define FLOAT16_TO_UINT16(x) take_from_xmm(x) + #define FLOAT16_FROM_UINT16(x) return_in_xmm(x) +#elif defined(_CPU_PPC64_) || defined(_CPU_PPC_) + // on PPC, pass Float16 as if it were an integer, similar to the old x86 ABI + // before _Float16 + #define FLOAT16_TYPE uint16_t + #define FLOAT16_TO_UINT16(x) (x) + #define FLOAT16_FROM_UINT16(x) (x) +#else + // otherwise, pass using floating-point calling conventions + #define FLOAT16_TYPE float + #define FLOAT16_TO_UINT16(x) ((uint16_t)*(uint32_t*)&(x)) + #define FLOAT16_FROM_UINT16(x) ({ uint32_t tmp = (uint32_t)(x); *(float*)&tmp; }) +#endif + +JL_DLLEXPORT float julia__gnu_h2f_ieee(FLOAT16_TYPE param) { - return float_to_half(param); + uint16_t param16 = FLOAT16_TO_UINT16(param); + return half_to_float(param16); } -JL_DLLEXPORT uint16_t julia__truncdfhf2(double param) +JL_DLLEXPORT FLOAT16_TYPE julia__gnu_f2h_ieee(float param) { - float res = (float)param; - uint32_t resi; - memcpy(&resi, &res, sizeof(res)); - if ((resi&0x7fffffffu) < 0x38800000u){ // if Float16(res) is subnormal - // shift so that the mantissa lines up where it would for normal Float16 - uint32_t shift = 113u-((resi & 0x7f800000u)>>23u); - if (shift<23u) { - resi |= 0x00800000; // set implicit bit - resi >>= shift; - } - } - if ((resi & 0x1fffu) == 0x1000u) { // if we are halfway between 2 Float16 values - memcpy(&resi, &res, sizeof(res)); - // adjust the value by 1 ULP in the direction that will make Float16(res) give the right answer - resi += (fabs(res) < fabs(param)) - (fabs(param) < fabs(res)); - memcpy(&res, &resi, sizeof(res)); - } - return float_to_half(res); + uint16_t res = float_to_half(param); + return FLOAT16_FROM_UINT16(res); } -JL_DLLEXPORT float julia__truncsfbf2(float param) JL_NOTSAFEPOINT +JL_DLLEXPORT FLOAT16_TYPE julia__truncdfhf2(double param) { - uint16_t result; + uint16_t res = double_to_half(param); + return FLOAT16_FROM_UINT16(res); +} + +// bfloat16 conversion helpers + +static inline uint16_t float_to_bfloat(float param) JL_NOTSAFEPOINT +{ if (isnan(param)) - result = 0x7fc0; - else { - uint32_t bits = *((uint32_t*) ¶m); + return 0x7fc0; - // round to nearest even - bits += 0x7fff + ((bits >> 16) & 1); - result = (uint16_t)(bits >> 16); - } + uint32_t bits = *((uint32_t*) ¶m); - // on x86, bfloat16 needs to be returned in XMM. only GCC 13 provides the necessary ABI - // support in the form of the __bf16 type; older versions only provide __bfloat16 which - // is simply a typedef for short (i16). so use float, which is passed in XMM too. - uint32_t result_32bit = (uint32_t)result; - return *(float*)&result_32bit; + // round to nearest even + bits += 0x7fff + ((bits >> 16) & 1); + return (uint16_t)(bits >> 16); } -JL_DLLEXPORT float julia__truncdfbf2(double param) JL_NOTSAFEPOINT +static inline uint16_t double_to_bfloat(double param) JL_NOTSAFEPOINT { - float res = (float)param; - uint32_t resi; - memcpy(&resi, &res, sizeof(res)); + float temp = (float)param; + uint32_t tempi; + memcpy(&tempi, &temp, sizeof(temp)); // bfloat16 uses the same exponent as float32, so we don't need special handling // for subnormals when truncating float64 to bfloat16. - if ((resi & 0x1ffu) == 0x100u) { // if we are halfway between 2 bfloat16 values - // adjust the value by 1 ULP in the direction that will make bfloat16(res) give the right answer - resi += (fabs(res) < fabs(param)) - (fabs(param) < fabs(res)); - memcpy(&res, &resi, sizeof(res)); + // if we are halfway between 2 bfloat16 values + if ((tempi & 0x1ffu) == 0x100u) { + // adjust the value by 1 ULP in the direction that will make bfloat16(temp) give the right answer + tempi += (fabs(temp) < fabs(param)) - (fabs(param) < fabs(temp)); + memcpy(&temp, &tempi, sizeof(temp)); } - return julia__truncsfbf2(res); -} - -//JL_DLLEXPORT double julia__extendhfdf2(uint16_t n) { return (double)julia__gnu_h2f_ieee(n); } -//JL_DLLEXPORT int32_t julia__fixhfsi(uint16_t n) { return (int32_t)julia__gnu_h2f_ieee(n); } -//JL_DLLEXPORT int64_t julia__fixhfdi(uint16_t n) { return (int64_t)julia__gnu_h2f_ieee(n); } -//JL_DLLEXPORT uint32_t julia__fixunshfsi(uint16_t n) { return (uint32_t)julia__gnu_h2f_ieee(n); } -//JL_DLLEXPORT uint64_t julia__fixunshfdi(uint16_t n) { return (uint64_t)julia__gnu_h2f_ieee(n); } -//JL_DLLEXPORT uint16_t julia__floatsihf(int32_t n) { return julia__gnu_f2h_ieee((float)n); } -//JL_DLLEXPORT uint16_t julia__floatdihf(int64_t n) { return julia__gnu_f2h_ieee((float)n); } -//JL_DLLEXPORT uint16_t julia__floatunsihf(uint32_t n) { return julia__gnu_f2h_ieee((float)n); } -//JL_DLLEXPORT uint16_t julia__floatundihf(uint64_t n) { return julia__gnu_f2h_ieee((float)n); } -//HANDLE_LIBCALL(F16, F128, __extendhftf2) -//HANDLE_LIBCALL(F16, F80, __extendhfxf2) -//HANDLE_LIBCALL(F80, F16, __truncxfhf2) -//HANDLE_LIBCALL(F128, F16, __trunctfhf2) -//HANDLE_LIBCALL(PPCF128, F16, __trunctfhf2) -//HANDLE_LIBCALL(F16, I128, __fixhfti) -//HANDLE_LIBCALL(F16, I128, __fixunshfti) -//HANDLE_LIBCALL(I128, F16, __floattihf) -//HANDLE_LIBCALL(I128, F16, __floatuntihf) + + return float_to_bfloat(temp); +} + +// bfloat16 conversion API + +// starting with GCC 13 and Clang 17, we have __bf16 on most platforms +// (but not on Windows; this may be a bug in the MSYS2 GCC compilers) +#if ((defined(__GNUC__) && __GNUC__ > 12) || \ + (defined(__clang__) && __clang_major__ > 16)) && \ + !defined(_CPU_PPC64_) && !defined(_CPU_PPC_) && \ + !defined(_OS_WINDOWS_) + #define BFLOAT16_TYPE __bf16 + #define BFLOAT16_TO_UINT16(x) (*(uint16_t*)&(x)) + #define BFLOAT16_FROM_UINT16(x) (*(__bf16*)&(x)) +// on older compilers, we need to emulate the platform-specific ABI. +// for more details, see similar code above that deals with Float16. +#elif defined(_CPU_X86_) || (defined(_CPU_X86_64_) && !defined(_OS_WINDOWS_)) + #define BFLOAT16_TYPE __m128 + #define BFLOAT16_TO_UINT16(x) take_from_xmm(x) + #define BFLOAT16_FROM_UINT16(x) return_in_xmm(x) +#elif defined(_CPU_PPC64_) || defined(_CPU_PPC_) + #define BFLOAT16_TYPE uint16_t + #define BFLOAT16_TO_UINT16(x) (x) + #define BFLOAT16_FROM_UINT16(x) (x) +#else + #define BFLOAT16_TYPE float + #define BFLOAT16_TO_UINT16(x) ((uint16_t)*(uint32_t*)&(x)) + #define BFLOAT16_FROM_UINT16(x) ({ uint32_t tmp = (uint32_t)(x); *(float*)&tmp; }) +#endif + +JL_DLLEXPORT BFLOAT16_TYPE julia__truncsfbf2(float param) JL_NOTSAFEPOINT +{ + uint16_t res = float_to_bfloat(param); + return BFLOAT16_FROM_UINT16(res); +} + +JL_DLLEXPORT BFLOAT16_TYPE julia__truncdfbf2(double param) JL_NOTSAFEPOINT +{ + uint16_t res = double_to_bfloat(param); + return BFLOAT16_FROM_UINT16(res); +} // run time version of bitcast intrinsic @@ -643,11 +736,11 @@ static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \ static inline void name(unsigned osize, void *pa, void *pr) JL_NOTSAFEPOINT \ { \ uint16_t a = *(uint16_t*)pa; \ - float A = julia__gnu_h2f_ieee(a); \ + float A = half_to_float(a); \ if (osize == 16) { \ float R; \ OP(&R, A); \ - *(uint16_t*)pr = julia__gnu_f2h_ieee(R); \ + *(uint16_t*)pr = float_to_half(R); \ } else { \ OP((uint16_t*)pr, A); \ } \ @@ -671,11 +764,11 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pr) { \ uint16_t a = *(uint16_t*)pa; \ uint16_t b = *(uint16_t*)pb; \ - float A = julia__gnu_h2f_ieee(a); \ - float B = julia__gnu_h2f_ieee(b); \ + float A = half_to_float(a); \ + float B = half_to_float(b); \ runtime_nbits = 16; \ float R = OP(A, B); \ - *(uint16_t*)pr = julia__gnu_f2h_ieee(R); \ + *(uint16_t*)pr = float_to_half(R); \ } // float or integer inputs, bool output @@ -696,8 +789,8 @@ static int jl_##name##16(unsigned runtime_nbits, void *pa, void *pb) JL_NOTSAFEP { \ uint16_t a = *(uint16_t*)pa; \ uint16_t b = *(uint16_t*)pb; \ - float A = julia__gnu_h2f_ieee(a); \ - float B = julia__gnu_h2f_ieee(b); \ + float A = half_to_float(a); \ + float B = half_to_float(b); \ runtime_nbits = 16; \ return OP(A, B); \ } @@ -737,12 +830,12 @@ static void jl_##name##16(unsigned runtime_nbits, void *pa, void *pb, void *pc, uint16_t a = *(uint16_t*)pa; \ uint16_t b = *(uint16_t*)pb; \ uint16_t c = *(uint16_t*)pc; \ - float A = julia__gnu_h2f_ieee(a); \ - float B = julia__gnu_h2f_ieee(b); \ - float C = julia__gnu_h2f_ieee(c); \ + float A = half_to_float(a); \ + float B = half_to_float(b); \ + float C = half_to_float(c); \ runtime_nbits = 16; \ float R = OP(A, B, C); \ - *(uint16_t*)pr = julia__gnu_f2h_ieee(R); \ + *(uint16_t*)pr = float_to_half(R); \ } @@ -1412,7 +1505,7 @@ cvt_iintrinsic(LLVMFPtoUI, fptoui) if (!(osize < 8 * sizeof(a))) \ jl_error("fptrunc: output bitsize must be < input bitsize"); \ else if (osize == 16) \ - *(uint16_t*)pr = julia__gnu_f2h_ieee(a); \ + *(uint16_t*)pr = float_to_half(a); \ else if (osize == 32) \ *(float*)pr = a; \ else if (osize == 64) \ diff --git a/test/intrinsics.jl b/test/intrinsics.jl index d67dad33e60cc..8e4ab932f5eb6 100644 --- a/test/intrinsics.jl +++ b/test/intrinsics.jl @@ -180,28 +180,12 @@ end @test_intrinsic Core.Intrinsics.fptoui UInt Float16(3.3) UInt(3) end -if Sys.ARCH == :aarch64 || Sys.ARCH === :powerpc64le || Sys.ARCH === :ppc64le - # On AArch64 we are following the `_Float16` ABI. Buthe these functions expect `Int16`. - # TODO: Should we have `Chalf == Int16` and `Cfloat16 == Float16`? - extendhfsf2(x::Float16) = ccall("extern __extendhfsf2", llvmcall, Float32, (UInt16,), reinterpret(UInt16, x)) - gnu_h2f_ieee(x::Float16) = ccall("extern __gnu_h2f_ieee", llvmcall, Float32, (UInt16,), reinterpret(UInt16, x)) - truncsfhf2(x::Float32) = reinterpret(Float16, ccall("extern __truncsfhf2", llvmcall, UInt16, (Float32,), x)) - gnu_f2h_ieee(x::Float32) = reinterpret(Float16, ccall("extern __gnu_f2h_ieee", llvmcall, UInt16, (Float32,), x)) - truncdfhf2(x::Float64) = reinterpret(Float16, ccall("extern __truncdfhf2", llvmcall, UInt16, (Float64,), x)) -else - extendhfsf2(x::Float16) = ccall("extern __extendhfsf2", llvmcall, Float32, (Float16,), x) - gnu_h2f_ieee(x::Float16) = ccall("extern __gnu_h2f_ieee", llvmcall, Float32, (Float16,), x) - truncsfhf2(x::Float32) = ccall("extern __truncsfhf2", llvmcall, Float16, (Float32,), x) - gnu_f2h_ieee(x::Float32) = ccall("extern __gnu_f2h_ieee", llvmcall, Float16, (Float32,), x) - truncdfhf2(x::Float64) = ccall("extern __truncdfhf2", llvmcall, Float16, (Float64,), x) -end - @testset "Float16 intrinsics (crt)" begin - @test extendhfsf2(Float16(3.3)) == 3.3007812f0 + gnu_h2f_ieee(x::Float16) = ccall("julia__gnu_h2f_ieee", Float32, (Float16,), x) + gnu_f2h_ieee(x::Float32) = ccall("julia__gnu_f2h_ieee", Float16, (Float32,), x) + @test gnu_h2f_ieee(Float16(3.3)) == 3.3007812f0 - @test truncsfhf2(3.3f0) == Float16(3.3) @test gnu_f2h_ieee(3.3f0) == Float16(3.3) - @test truncdfhf2(3.3) == Float16(3.3) end using Base.Experimental: @force_compile