Skip to content

Commit

Permalink
exact BigFloat to IEEE FP conversion in pure Julia
Browse files Browse the repository at this point in the history
There's lots of code, but most of it seems like it will be useful in
general. For example, I think I'll use the changes in float.jl and
rounding.jl to improve the JuliaLang#49749 PR. The changes in float.jl could
also be used to refactor float.jl to remove many magic constants.

Benchmarking script:
```julia
using BenchmarkTools
f(::Type{T} = BigFloat, n::Int = 2000) where {T} = rand(T, n)
g!(u, v) = map!(eltype(u), u, v)
@Btime g!(u, v) setup=(u = f(Float16); v = f();)
@Btime g!(u, v) setup=(u = f(Float32); v = f();)
@Btime g!(u, v) setup=(u = f(Float64); v = f();)
```

On master (dc06468):
```
  46.116 μs (0 allocations: 0 bytes)
  38.842 μs (0 allocations: 0 bytes)
  37.039 μs (0 allocations: 0 bytes)
```

With both this commit and JuliaLang#50674 applied:
```
  42.310 μs (0 allocations: 0 bytes)
  42.661 μs (0 allocations: 0 bytes)
  41.608 μs (0 allocations: 0 bytes)
```

So, with this benchmark at least, on an AMD Zen 2 laptop, conversion
to `Float16` is faster, but there's a slowdown for `Float32` and
`Float64`.

Fixes JuliaLang#50642 (exact conversion to `Float16`)
  • Loading branch information
nsajko committed Jul 28, 2023
1 parent dc06468 commit 1ad2b27
Show file tree
Hide file tree
Showing 7 changed files with 363 additions and 29 deletions.
1 change: 1 addition & 0 deletions base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ include("hashing.jl")
include("rounding.jl")
using .Rounding
include("div.jl")
include("rawbigints.jl")
include("float.jl")
include("twiceprecision.jl")
include("complex.jl")
Expand Down
62 changes: 62 additions & 0 deletions base/float.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,68 @@ i.e. the maximum integer value representable by [`exponent_bits(T)`](@ref) bits.
"""
function exponent_raw_max end

"""
IEEE 754 definition of the minimum exponent.
"""
ieee754_exponent_min(::Type{T}) where {T<:IEEEFloat} = Int(1 - exponent_max(T))::Int

exponent_min(::Type{Float16}) = ieee754_exponent_min(Float16)
exponent_min(::Type{Float32}) = ieee754_exponent_min(Float32)
exponent_min(::Type{Float64}) = ieee754_exponent_min(Float64)

function ieee754_representation(
::Type{F}, sign_bit::Bool, exponent_field::Integer, significand_field::Integer
) where {F<:IEEEFloat}
T = uinttype(F)
ret::T = sign_bit
ret <<= exponent_bits(F)
ret |= exponent_field
ret <<= significand_bits(F)
ret |= significand_field
end

# ±floatmax(T)
function ieee754_representation(
::Type{F}, sign_bit::Bool, ::Val{:omega}
) where {F<:IEEEFloat}
ieee754_representation(F, sign_bit, exponent_raw_max(F) - 1, significand_mask(F))
end

# NaN or an infinity
function ieee754_representation(
::Type{F}, sign_bit::Bool, significand_field::Integer, ::Val{:nan}
) where {F<:IEEEFloat}
ieee754_representation(F, sign_bit, exponent_raw_max(F), significand_field)
end

# NaN with default payload
function ieee754_representation(
::Type{F}, sign_bit::Bool, ::Val{:nan}
) where {F<:IEEEFloat}
ieee754_representation(F, sign_bit, one(uinttype(F)) << (significand_bits(F) - 1), Val(:nan))
end

# Infinity
function ieee754_representation(
::Type{F}, sign_bit::Bool, ::Val{:inf}
) where {F<:IEEEFloat}
ieee754_representation(F, sign_bit, false, Val(:nan))
end

# Subnormal or zero
function ieee754_representation(
::Type{F}, sign_bit::Bool, significand_field::Integer, ::Val{:subnormal}
) where {F<:IEEEFloat}
ieee754_representation(F, sign_bit, false, significand_field)
end

# Zero
function ieee754_representation(
::Type{F}, sign_bit::Bool, ::Val{:zero}
) where {F<:IEEEFloat}
ieee754_representation(F, sign_bit, false, Val(:subnormal))
end

"""
uabs(x::Integer)
Expand Down
112 changes: 84 additions & 28 deletions base/mpfr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@ import
cbrt, typemax, typemin, unsafe_trunc, floatmin, floatmax, rounding,
setrounding, maxintfloat, widen, significand, frexp, tryparse, iszero,
isone, big, _string_n, decompose, minmax,
sinpi, cospi, sincospi, tanpi, sind, cosd, tand, asind, acosd, atand
sinpi, cospi, sincospi, tanpi, sind, cosd, tand, asind, acosd, atand,
uinttype, exponent_max, exponent_min, ieee754_representation, significand_mask,
RawBigIntRoundingIncrementHelper, truncated, RawBigInt


using .Base.Libc
import ..Rounding: rounding_raw, setrounding_raw
import ..Rounding:
rounding_raw, setrounding_raw, rounds_to_nearest, rounds_away_from_zero,
tie_breaker_is_to_even, correct_rounding_requires_increment

import ..GMP: ClongMax, CulongMax, CdoubleMax, Limb, libgmp

Expand Down Expand Up @@ -89,6 +93,21 @@ function convert(::Type{RoundingMode}, r::MPFRRoundingMode)
end
end

rounds_to_nearest(m::MPFRRoundingMode) = m == MPFRRoundNearest
function rounds_away_from_zero(m::MPFRRoundingMode, sign_bit::Bool)
if m == MPFRRoundToZero
false
elseif m == MPFRRoundUp
!sign_bit
elseif m == MPFRRoundDown
sign_bit
else
# Assuming `m == MPFRRoundFromZero`
true
end
end
tie_breaker_is_to_even(::MPFRRoundingMode) = true

const ROUNDING_MODE = Ref{MPFRRoundingMode}(MPFRRoundNearest)
const DEFAULT_PRECISION = Ref{Clong}(256)

Expand Down Expand Up @@ -130,6 +149,9 @@ mutable struct BigFloat <: AbstractFloat
end
end

# The rounding mode here shouldn't matter.
significand_limb_count(x::BigFloat) = div(sizeof(x._d), sizeof(Limb), RoundToZero)

rounding_raw(::Type{BigFloat}) = ROUNDING_MODE[]
setrounding_raw(::Type{BigFloat}, r::MPFRRoundingMode) = ROUNDING_MODE[]=r

Expand Down Expand Up @@ -380,35 +402,69 @@ function (::Type{T})(x::BigFloat) where T<:Integer
trunc(T,x)
end

## BigFloat -> AbstractFloat
_cpynansgn(x::AbstractFloat, y::BigFloat) = isnan(x) && signbit(x) != signbit(y) ? -x : x

Float64(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) =
_cpynansgn(ccall((:mpfr_get_d,libmpfr), Float64, (Ref{BigFloat}, MPFRRoundingMode), x, r), x)
Float64(x::BigFloat, r::RoundingMode) = Float64(x, convert(MPFRRoundingMode, r))

Float32(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) =
_cpynansgn(ccall((:mpfr_get_flt,libmpfr), Float32, (Ref{BigFloat}, MPFRRoundingMode), x, r), x)
Float32(x::BigFloat, r::RoundingMode) = Float32(x, convert(MPFRRoundingMode, r))

function Float16(x::BigFloat) :: Float16
res = Float32(x)
resi = reinterpret(UInt32, res)
if (resi&0x7fffffff) < 0x38800000 # if Float16(res) is subnormal
#shift so that the mantissa lines up where it would for normal Float16
shift = 113-((resi & 0x7f800000)>>23)
if shift<23
resi |= 0x0080_0000 # set implicit bit
resi >>= shift
function to_ieee754(::Type{T}, x::BigFloat, rm) where {T<:AbstractFloat}
sb = signbit(x)
is_zero = iszero(x)
is_inf = isinf(x)
is_nan = isnan(x)
is_regular = !is_zero & !is_inf & !is_nan
ieee_exp = Int(x.exp) - 1
ieee_precision = precision(T)
ieee_exp_max = exponent_max(T)
ieee_exp_min = exponent_min(T)
exp_diff = ieee_exp - ieee_exp_min
is_normal = 0 exp_diff
(rm_is_to_zero, rm_is_from_zero) = if rounds_to_nearest(rm)
(false, false)
else
let from = rounds_away_from_zero(rm, sb)
(!from, from)
end
end
if (resi & 0x1fff == 0x1000) # if we are halfway between 2 Float16 values
# adjust the value by 1 ULP in the direction that will make Float16(res) give the right answer
res = nextfloat(res, cmp(x, res))
end
return res
end::NTuple{2,Bool}
exp_is_huge_p = ieee_exp_max < ieee_exp
exp_is_huge_n = signbit(exp_diff + ieee_precision)
rounds_to_inf = is_regular & exp_is_huge_p & !rm_is_to_zero
rounds_to_zero = is_regular & exp_is_huge_n & !rm_is_from_zero
U = uinttype(T)

ret_u = if is_regular & !rounds_to_inf & !rounds_to_zero
if !exp_is_huge_p
# significand
v = RawBigInt(x.d, significand_limb_count(x))
len = max(ieee_precision + min(exp_diff, 0), 0)::Int
signif = truncated(U, v, len) & significand_mask(T)

# round up if necessary
rh = RawBigIntRoundingIncrementHelper(v, len)
incr = correct_rounding_requires_increment(rh, rm, sb)

# exponent
exp_field = max(exp_diff, 0) + is_normal

ieee754_representation(T, sb, exp_field, signif) + incr
else
ieee754_representation(T, sb, Val(:omega))
end
else
if is_zero | rounds_to_zero
ieee754_representation(T, sb, Val(:zero))
elseif is_inf | rounds_to_inf
ieee754_representation(T, sb, Val(:inf))
else
ieee754_representation(T, sb, Val(:nan))
end
end::U

reinterpret(T, ret_u)
end

Float16(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) = to_ieee754(Float16, x, r)
Float32(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) = to_ieee754(Float32, x, r)
Float64(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) = to_ieee754(Float64, x, r)
Float16(x::BigFloat, r::RoundingMode) = to_ieee754(Float16, x, r)
Float32(x::BigFloat, r::RoundingMode) = to_ieee754(Float32, x, r)
Float64(x::BigFloat, r::RoundingMode) = to_ieee754(Float64, x, r)

promote_rule(::Type{BigFloat}, ::Type{<:Real}) = BigFloat
promote_rule(::Type{BigInt}, ::Type{<:AbstractFloat}) = BigFloat
promote_rule(::Type{BigFloat}, ::Type{<:AbstractFloat}) = BigFloat
Expand Down
149 changes: 149 additions & 0 deletions base/rawbigints.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

"""
Segment of raw words of bits interpreted as a big integer. Less
significant words come first. Each word is in machine-native bit-order.
"""
struct RawBigInt{T<:Unsigned}
d::Ptr{T}
word_count::Int

function RawBigInt{T}(d::Ptr{T}, word_count::Int) where {T<:Unsigned}
new{T}(d, word_count)
end
end

RawBigInt(d::Ptr{T}, word_count::Int) where {T<:Unsigned} = RawBigInt{T}(d, word_count)
elem_count(x::RawBigInt, ::Val{:words}) = x.word_count
elem_count(x::Unsigned, ::Val{:bits}) = sizeof(x) * 8
word_length(::RawBigInt{T}) where {T} = elem_count(zero(T), Val(:bits))
elem_count(x::RawBigInt{T}, ::Val{:bits}) where {T} = word_length(x) * elem_count(x, Val(:words))
reversed_index(n::Int, i::Int) = n - i - 1
reversed_index(x, i::Int, v::Val) = reversed_index(elem_count(x, v), i)::Int
split_bit_index(x::RawBigInt, i::Int) = divrem(i, word_length(x), RoundToZero)

"""
`i` is the zero-based index of the wanted word in `x`, starting from
the less significant words.
"""
function get_elem(x::RawBigInt, i::Int, ::Val{:words}, ::Val{:ascending})
unsafe_load(x.d, i + 1)
end

function get_elem(x, i::Int, v::Val, ::Val{:descending})
j = reversed_index(x, i, v)
get_elem(x, j, v, Val(:ascending))
end

word_is_nonzero(x::RawBigInt, i::Int, v::Val) = !iszero(get_elem(x, i, Val(:words), v))

word_is_nonzero(x::RawBigInt, v::Val) = let x = x
i -> word_is_nonzero(x, i, v)
end

"""
Returns a `Bool` indicating whether the `len` least significant words
of `x` are nonzero.
"""
function tail_is_nonzero(x::RawBigInt, len::Int, ::Val{:words})
any(word_is_nonzero(x, Val(:ascending)), 0:(len - 1))
end

"""
Returns a `Bool` indicating whether the `len` least significant bits of
the `i`-th (zero-based index) word of `x` are nonzero.
"""
function tail_is_nonzero(x::RawBigInt, len::Int, i::Int, ::Val{:word})
!iszero(len) &&
!iszero(get_elem(x, i, Val(:words), Val(:ascending)) << (word_length(x) - len))
end

"""
Returns a `Bool` indicating whether the `len` least significant bits of
`x` are nonzero.
"""
function tail_is_nonzero(x::RawBigInt, len::Int, ::Val{:bits})
if 0 < len
word_count, bit_count_in_word = split_bit_index(x, len)
tail_is_nonzero(x, bit_count_in_word, word_count, Val(:word)) ||
tail_is_nonzero(x, word_count, Val(:words))
else
false
end::Bool
end

"""
Returns a `Bool` that is the `i`-th (zero-based index) bit of `x`.
"""
function get_elem(x::Unsigned, i::Int, ::Val{:bits}, ::Val{:ascending})
(x >>> i) % Bool
end

"""
Returns a `Bool` that is the `i`-th (zero-based index) bit of `x`.
"""
function get_elem(x::RawBigInt, i::Int, ::Val{:bits}, v::Val{:ascending})
vb = Val(:bits)
if 0 i < elem_count(x, vb)
word_index, bit_index_in_word = split_bit_index(x, i)
word = get_elem(x, word_index, Val(:words), v)
get_elem(word, bit_index_in_word, vb, v)
else
false
end::Bool
end

"""
Returns an integer of type `R`, consisting of the `len` most
significant bits of `x`.
"""
function truncated(::Type{R}, x::RawBigInt, len::Int) where {R<:Integer}
ret = zero(R)
if 0 < len
word_count, bit_count_in_word = split_bit_index(x, len)
k = word_length(x)
vals = (Val(:words), Val(:descending))

for w 0:(word_count - 1)
ret <<= k
word = get_elem(x, w, vals...)
ret |= R(word)
end

if !iszero(bit_count_in_word)
ret <<= bit_count_in_word
wrd = get_elem(x, word_count, vals...)
ret |= R(wrd >>> (k - bit_count_in_word))
end
end
ret::R
end

struct RawBigIntRoundingIncrementHelper{T<:Unsigned}
n::RawBigInt{T}
trunc_len::Int

final_bit::Bool
round_bit::Bool

function RawBigIntRoundingIncrementHelper{T}(n::RawBigInt{T}, len::Int) where {T<:Unsigned}
vals = (Val(:bits), Val(:descending))
f = get_elem(n, len - 1, vals...)
r = get_elem(n, len , vals...)
new{T}(n, len, f, r)
end
end

function RawBigIntRoundingIncrementHelper(n::RawBigInt{T}, len::Int) where {T<:Unsigned}
RawBigIntRoundingIncrementHelper{T}(n, len)
end

(h::RawBigIntRoundingIncrementHelper)(::Rounding.FinalBit) = h.final_bit

(h::RawBigIntRoundingIncrementHelper)(::Rounding.RoundBit) = h.round_bit

function (h::RawBigIntRoundingIncrementHelper)(::Rounding.StickyBit)
v = Val(:bits)
n = h.n
tail_is_nonzero(n, elem_count(n, v) - h.trunc_len - 1, v)
end
Loading

0 comments on commit 1ad2b27

Please sign in to comment.