Skip to content

Commit

Permalink
Implement lifting infrastructure
Browse files Browse the repository at this point in the history
  • Loading branch information
davidagold committed Oct 1, 2016
1 parent 35b2583 commit 47b0a97
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 0 deletions.
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,7 @@ export
# nullable types
isnull,
unsafe_get,
Lifted,

# Macros
# parser internal
Expand Down
124 changes: 124 additions & 0 deletions base/nullable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,127 @@ function hash(x::Nullable, h::UInt)
return hash(x.value, h + nullablehash_seed)
end
end

"""
Lifted{F}
A type used to represent the lifted version of a function `f::F`.
Calling an `_f::Lifted{F}` on arguments `xs...` lowers to
`lift(_f.f, U, xs...)`, where the return type parameter `U` is chosen with the
help of type inference.
"""
immutable Lifted{F}
f::F
cache::Dict{Tuple{Vararg{DataType}}, DataType}

(::Type{Lifted}){F}(f::F) = new{F}(
f, Dict{Tuple{Vararg{DataType}}, DataType}()
)
end

function (_f::Lifted{F}){F}(xs...)
f, cache = _f.f, _f.cache
signature = map(eltype, xs)
U = Base.@get!(
cache,
signature,
Core.Inference.return_type(f, Tuple{signature...})
)
return lift(f, U, xs...)
end

"""
lift(f::F)::Lifted{F}
Return a lifted version of `f`.
"""
lift(f) = Lifted(f)

"""
lift(f, U, xs...)
Return an empty `Nullable{U}` if any of the `xs` is null; otherwise, return the
(`Nullable`-wrapped) value of `f` applied to the values of the `xs`.
NOTE: There are two exceptions to the above: `lift(|, Bool, x, y)` and
`lift(&, Bool, x, y)`. These methods both follow three-valued logic semantics.
"""
function lift(f, U::DataType, x)
if isnull(x)
return Nullable{U}()
else
return Nullable{U}(f(unsafe_get(x)))
end
end

function lift(f, U::DataType, x1, x2)
if isnull(x1) | isnull(x2)
return Nullable{U}()
else
return Nullable{U}(f(unsafe_get(x1), unsafe_get(x2)))
end
end

function lift(f, U::DataType, xs...)
if mapreduce(isnull, |, false, xs)
return Nullable{U}()
else
return Nullable{U}(f(map(unsafe_get, xs)...))
end
end

# Three-valued logic

(::Lifted{&})(x::Union{Bool, Nullable{Bool}}, y::Union{Bool, Nullable{Bool}}) =
lift(&, Bool, x, y)
(::Lifted{|})(x::Union{Bool, Nullable{Bool}}, y::Union{Bool, Nullable{Bool}}) =
lift(|, Bool, x, y)

function lift(f::typeof(&), ::Type{Bool}, x, y)::Nullable{Bool}
return ifelse(
isnull(x),
ifelse(
isnull(y),
Nullable{Bool}(),
ifelse(
unsafe_get(y),
Nullable{Bool}(),
Nullable(false)
)
),
ifelse(
isnull(y),
ifelse(
unsafe_get(x),
Nullable{Bool}(),
Nullable(false)
),
Nullable(unsafe_get(x) & unsafe_get(y))
)
)
end

function lift(f::typeof(|), ::Type{Bool}, x, y)::Nullable{Bool}
return ifelse(
isnull(x),
ifelse(
isnull(y),
Nullable{Bool}(),
ifelse(
unsafe_get(y),
Nullable(true),
Nullable{Bool}()
)
),
ifelse(
isnull(y),
ifelse(
unsafe_get(x),
Nullable(true),
Nullable{Bool}()
),
Nullable(unsafe_get(x) | unsafe_get(y))
)
)
end
69 changes: 69 additions & 0 deletions test/nullable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,72 @@ end

# issue #11675
@test repr(Nullable()) == "Nullable{Union{}}()"

# lifting

f(x::Number) = 5 * x
f(x::Number, y::Number) = x + y
f(x::Number, y::Number, z::Number) = x + y * z
_f = lift(f)

for T in setdiff(types, [Bool])
a = one(T)
x = Nullable{T}(a)
y = Nullable{T}()

U1 = Core.Inference.return_type(f, Tuple{T})
@test isequal(_f(x), Nullable(f(a)))
@test isequal(_f(y), Nullable{U1}())

U2 = Core.Inference.return_type(f, Tuple{T, T})
@test isequal(_f(x, x), Nullable(f(a, a)))
@test isequal(_f(x, y), Nullable{U2}())

U3 = Core.Inference.return_type(f, Tuple{T, T, T})
@test isequal(_f(x, x, x), Nullable(f(a, a, a)))
@test isequal(_f(x, y, x), Nullable{U3}())
end

# three-valued logic

# & truth table
v1 = lift(&, Bool, Nullable(true), Nullable(true))
v2 = lift(&, Bool, Nullable(true), Nullable(false))
v3 = lift(&, Bool, Nullable(true), Nullable{Bool}())
v4 = lift(&, Bool, Nullable(false), Nullable(true))
v5 = lift(&, Bool, Nullable(false), Nullable(false))
v6 = lift(&, Bool, Nullable(false), Nullable{Bool}())
v7 = lift(&, Bool, Nullable{Bool}(), Nullable(true))
v8 = lift(&, Bool, Nullable{Bool}(), Nullable(false))
v9 = lift(&, Bool, Nullable{Bool}(), Nullable{Bool}())

@test isequal(v1, Nullable(true))
@test isequal(v2, Nullable(false))
@test isequal(v3, Nullable{Bool}())
@test isequal(v4, Nullable(false))
@test isequal(v5, Nullable(false))
@test isequal(v6, Nullable(false))
@test isequal(v7, Nullable{Bool}())
@test isequal(v8, Nullable(false))
@test isequal(v9, Nullable{Bool}())

# | truth table
u1 = lift(|, Bool, Nullable(true), Nullable(true))
u2 = lift(|, Bool, Nullable(true), Nullable(false))
u3 = lift(|, Bool, Nullable(true), Nullable{Bool}())
u4 = lift(|, Bool, Nullable(false), Nullable(true))
u5 = lift(|, Bool, Nullable(false), Nullable(false))
u6 = lift(|, Bool, Nullable(false), Nullable{Bool}())
u7 = lift(|, Bool, Nullable{Bool}(), Nullable(true))
u8 = lift(|, Bool, Nullable{Bool}(), Nullable(false))
u9 = lift(|, Bool, Nullable{Bool}(), Nullable{Bool}())

@test isequal(u1, Nullable(true))
@test isequal(u2, Nullable(true))
@test isequal(u3, Nullable(true))
@test isequal(u4, Nullable(true))
@test isequal(u5, Nullable(false))
@test isequal(u6, Nullable{Bool}())
@test isequal(u7, Nullable(true))
@test isequal(u8, Nullable{Bool}())
@test isequal(u9, Nullable{Bool}())

0 comments on commit 47b0a97

Please sign in to comment.