Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generate floats using lexographically ordered encoding #49

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Supposition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ using StyledStrings
include("types.jl")
include("testcase.jl")
include("util.jl")
include("float.jl")
include("data.jl")
include("teststate.jl")
include("shrink.jl")
Expand Down
13 changes: 12 additions & 1 deletion src/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ module Data

using Supposition
using Supposition: smootherstep, lerp, TestCase, choice!, weighted!, forced_choice!, reject
using Supposition.FloatEncoding: lex_to_float
using RequiredInterfaces: @required
using StyledStrings: @styled_str
using Printf: format, @format_str
Expand Down Expand Up @@ -1449,9 +1450,19 @@ function Base.show(io::IO, ::MIME"text/plain", f::Floats)
E.g. {code:$obj}; {code:isinf}: $inf, {code:isnan}: $nan""")
end


function produce!(tc::TestCase, f::Floats{T}) where {T}
iT = Supposition.uint(T)
res = reinterpret(T, produce!(tc, Integers{iT}()))

bits = produce!(tc, Integers{iT}())

is_negative = produce!(tc, Booleans())

res = lex_to_float(T, bits)
if is_negative
res = -res
end

# early rejections
!f.infs && isinf(res) && reject(tc)
!f.nans && isnan(res) && reject(tc)
Expand Down
147 changes: 147 additions & 0 deletions src/float.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
module FloatEncoding
using Supposition: uint, tear, bias, fracsize, exposize, max_exponent, assemble

"""
exponent_key(T, e)

A lexographical ordering for floating point exponents. The encoding is taken
from hypothesis.
The ordering is
- non-negative exponents in increasing order
- negative exponents in decreasing order
- the maximum exponent
"""
raineszm marked this conversation as resolved.
Show resolved Hide resolved
function exponent_key(::Type{T}, e::iT) where {T<:Base.IEEEFloat,iT<:Unsigned}
if e == max_exponent(T)
return Inf
end
unbiased = float(e) - bias(T)
if unbiased < 0
10000 - unbiased
Seelengrab marked this conversation as resolved.
Show resolved Hide resolved
else
unbiased
end
end

_make_encoding_table(T) = sort(
zero(uint(T)):max_exponent(T),
by=Base.Fix1(exponent_key, T))
raineszm marked this conversation as resolved.
Show resolved Hide resolved
const ENCODING_TABLE = Dict(
UInt16 => _make_encoding_table(Float16),
UInt32 => _make_encoding_table(Float32),
UInt64 => _make_encoding_table(Float64))

encode_exponent(e::T) where {T<:Unsigned} = ENCODING_TABLE[T][e+1]

function _make_decoding_table(T)
decoding_table = zeros(uint(T), max_exponent(T) + 1)
for (i, e) in enumerate(ENCODING_TABLE[uint(T)])
decoding_table[e+1] = i - 1
end
decoding_table
end
const DECODING_TABLE = Dict(
UInt16 => _make_decoding_table(Float16),
UInt32 => _make_decoding_table(Float32),
UInt64 => _make_decoding_table(Float64))
decode_exponent(e::T) where {T<:Unsigned} = DECODING_TABLE[T][e+1]
Seelengrab marked this conversation as resolved.
Show resolved Hide resolved


"""
update_mantissa(exponent, mantissa)

Encode the mantissa of a floating point number using an encoding with better shrinking.
"""
function update_mantissa(::Type{T}, exponent::iT, mantissa::iT)::iT where {T<:Base.IEEEFloat,iT<:Unsigned}
@assert uint(T) == iT
Seelengrab marked this conversation as resolved.
Show resolved Hide resolved
# The unbiased exponent is <= 0
if exponent <= bias(T)
# reverse the bits of the mantissa in place
bitreverse(mantissa) >> (exposize(T) + 1)
elseif exponent >= fracsize(T) + bias(T)
mantissa
else
# reverse the low bits of the fractional part
# as determined by the exponent
n_reverse_bits = fracsize(T) + bias(T) - exponent
# isolate the bits to be reversed
to_reverse = mantissa & iT((1 << n_reverse_bits) - 1)
# zero them out
mantissa = mantissa ⊻ to_reverse
# reverse them and put them back in place
mantissa |= bitreverse(to_reverse) >> (8 * sizeof(T) - n_reverse_bits)
end
end


"""
lex_to_float(T, bits)

Reinterpret the bits of a floating point number using an encoding with better shrinking
properties.
raineszm marked this conversation as resolved.
Show resolved Hide resolved
This produces a non-negative floating point number, possibly including NaN or Inf.
raineszm marked this conversation as resolved.
Show resolved Hide resolved

The encoding is taken from hypothesis, and has the property that lexicographically smaller
raineszm marked this conversation as resolved.
Show resolved Hide resolved
bit patterns corespond to 'simpler' floats.
raineszm marked this conversation as resolved.
Show resolved Hide resolved

# Encoding

The encoding used is as follows:

If the sign bit is set:

- the remainder of the first byte is ignored
- the remaining bytes are interpreted as an integer and converted to a float

If the sign bit is not set:

- the exponent is decoded using `decode_exponent`
- the mantissa is updated using `update_mantissa`
- the float is reassembled using `assemble`

raineszm marked this conversation as resolved.
Show resolved Hide resolved
"""
function lex_to_float(::Type{T}, bits::I)::T where {I,T<:Base.IEEEFloat}
sizeof(T) == sizeof(I) || throw(ArgumentError("The bitwidth of `$T` needs to match the bidwidth of `I`!"))
raineszm marked this conversation as resolved.
Show resolved Hide resolved
iT = uint(T)
sign, exponent, mantissa = tear(reinterpret(T, bits))
if sign == 1
raineszm marked this conversation as resolved.
Show resolved Hide resolved
exponent = encode_exponent(exponent)
mantissa = update_mantissa(T, exponent, mantissa)
assemble(T, zero(iT), exponent, mantissa)
else
integral_mask = iT((1 << (8 * (sizeof(T) - 1))) - 1)
raineszm marked this conversation as resolved.
Show resolved Hide resolved
integral_part = bits & integral_mask
T(integral_part)
end
end

function float_to_lex(f::T) where {T<:Base.IEEEFloat}
if is_simple_float(f)
uint(T)(f)
else
base_float_to_lex(f)
end
end

function is_simple_float(f::T) where {T<:Base.IEEEFloat}
try
if trunc(f) != f
return false
end
ndigits(reinterpret(uint(T), f), base=2) <= 8 * (sizeof(T) - 1)
raineszm marked this conversation as resolved.
Show resolved Hide resolved
catch e
Seelengrab marked this conversation as resolved.
Show resolved Hide resolved
if isa(e, InexactError)
return false
end
rethrow(e)
end
end

function base_float_to_lex(f::T) where {T<:Base.IEEEFloat}
Seelengrab marked this conversation as resolved.
Show resolved Hide resolved
_, exponent, mantissa = tear(f)
mantissa = update_mantissa(T, exponent, mantissa)
exponent = decode_exponent(exponent)

reinterpret(uint(T), assemble(T, one(uint(T)), exponent, mantissa))
end
end
3 changes: 3 additions & 0 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ exposize(::Type{Float16}) = 5
exposize(::Type{Float32}) = 8
exposize(::Type{Float64}) = 11

max_exponent(::Type{T}) where {T<:Base.IEEEFloat} = uint(T)(1 << exposize(T) - 1)
raineszm marked this conversation as resolved.
Show resolved Hide resolved
bias(::Type{T}) where {T<:Base.IEEEFloat} = uint(T)(1 << (exposize(T) - 1) - 1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is bias?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bias is the IEEE754 bias for the exponent. I'll add a docstring.


function masks(::Type{T}) where T <: Base.IEEEFloat
ui = uint(T)
signbitmask = one(ui) << (8*sizeof(ui)-1)
Expand Down
77 changes: 76 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Supposition
using Supposition: Data, test_function, shrink_remove, shrink_redistribute,
using Supposition: Data, FloatEncoding, test_function, shrink_remove, shrink_redistribute,
NoRecordDB, UnsetDB, Attempt, DEFAULT_CONFIG, TestCase, TestState, choice!, weighted!
using Test
using Aqua
Expand Down Expand Up @@ -513,6 +513,81 @@ const verb = VERSION.major == 1 && VERSION.minor < 11
@test_throws ArgumentError Data.Floats(;minimum=2.0, maximum=1.0)
end

# Tests the properties of the enocding used to represent floating point numbers
@testset "Floating point encoding" begin
@testset for T in (Float16, Float32, Float64)

iT = Supposition.uint(T)
# These invariants are ported from Hypothesis
@testset "Exponent encoding" begin
exponents = zero(iT):Supposition.max_exponent(T)
raineszm marked this conversation as resolved.
Show resolved Hide resolved

# Round tripping
@test all(exponents) do e
FloatEncoding.decode_exponent(FloatEncoding.encode_exponent(e)) == e
end

@test all(exponents) do e
FloatEncoding.encode_exponent(FloatEncoding.decode_exponent(e)) == e
end
end

function roundtrip_encoding(f)
assume!(!signbit(f))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We always check numbers >= 0.0 below, is this assume! necessary? What's the corresponding test in hypothesis?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have to double check but I believe it's excluding -0 which makes things a bit screwy. This isn't present in the hypothesis test.

encoded = FloatEncoding.float_to_lex(f)
decoded = FloatEncoding.lex_to_float(T, encoded)
reinterpret(iT, decoded) == reinterpret(iT, f)
Seelengrab marked this conversation as resolved.
Show resolved Hide resolved
end

roundtrip_examples = map(Data.Just,
T[
0.0,
2.5,
8.000000000000007,
3.0,
2.0,
1.9999999999999998,
1.0
])
@check roundtrip_encoding(Data.OneOf(roundtrip_examples...))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't splat Vectors, it leads to unnecessary specialization of the called function since it needs to compile a new version per length. Also requires dynamic dispatch, since the length of the Vector is not known at compile time (it's nice to keep CI times lower if we can).

One problem with this approach though is that it's not guaranteed that all of these examples are actually run 🤔 Maybe looping over the array and doing @check max_examples=1 roundtrip_encoding(Data.Just(x)) would be better? If the assume! is removed, this could also just be a @testset for.

@check roundtrip_encoding(Data.Floats{T}(; minimum=zero(T)))

@testset "Ordering" begin
function order_integral_part(n, g)
f = n + g
assume!(trunc(f) != f)
assume!(trunc(f) != 0)
i = FloatEncoding.float_to_lex(f)
g = trunc(f)
FloatEncoding.float_to_lex(g) < i
end

@check order_integral_part(Data.Just(1.0), Data.Just(0.5))
@check order_integral_part(
Data.Floats{T}(;
minimum=one(T),
maximum=T(2^(Supposition.fracsize(T) + 1)),
nans=false),
filter(x -> !(x in T[0, 1]),
Data.Floats{T}(; minimum=zero(T), maximum=one(T), nans=false)))
raineszm marked this conversation as resolved.
Show resolved Hide resolved

integral_float_gen = map(abs ∘ trunc,
Data.Floats{T}(; minimum=zero(T), infs=false, nans=false))

@check function integral_floats_order_as_integers(x=integral_float_gen,
y=integral_float_gen)
(x < y) == (FloatEncoding.float_to_lex(x) < FloatEncoding.float_to_lex(y))
end

@check function fractional_floats_greater_than_1(
f=Data.Floats{T}(; minimum=zero(T), maximum=one(T), nans=false))
assume!(0 < f < 1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do nextfloat(zero(T)) and prevfloat(one(T)) here too and skip the assume!, I think.

FloatEncoding.float_to_lex(f) > FloatEncoding.float_to_lex(one(T))
end
end
end
end

@testset "@check API" begin
# These tests are for accepted syntax, not functionality, so only one example is fine
API_conf = Supposition.merge(DEFAULT_CONFIG[]; verbose=verb, max_examples=1)
Expand Down
Loading