diff --git a/src/MemPool.jl b/src/MemPool.jl index c64fc13..0290bca 100644 --- a/src/MemPool.jl +++ b/src/MemPool.jl @@ -40,6 +40,9 @@ end unwrap_payload(f::FileRef) = unwrap_payload(open(deserialize, f.file)) +include("io.jl") +include("datastore.jl") + """ `approx_size(d)` @@ -50,10 +53,13 @@ function approx_size(d) end function approx_size{T}(d::Array{T}) - if isbits(T) - sizeof(d) + isbits(T) && return sizeof(d) + + fl = fixedlength(T) + if fl > 0 + return length(d) * fl else - Base.summarysize(d) + return Base.summarysize(d) end end @@ -63,9 +69,6 @@ function approx_size(xs::Array{String}) sum(map(sizeof, xs)) + 4 * length(xs) end -include("io.jl") -include("datastore.jl") - __init__() = global session = "sess-" * randstring(5) end # module diff --git a/src/io.jl b/src/io.jl index 61c5434..efdbf15 100644 --- a/src/io.jl +++ b/src/io.jl @@ -19,8 +19,18 @@ function mmwrite(io::AbstractSerializer, arr::A) where A<:Union{Array,BitArray} if isbits(T) serialize(io, size(arr)) write(io.io, arr) + return elseif T<:Union{} || T<:Nullable{Union{}} serialize(io, size(arr)) + return + end + + fl = fixedlength(T) + if fl > 0 + serialize(io, size(arr)) + for x in arr + fast_write(io.io, x) + end else serialize(io, arr) end @@ -43,6 +53,16 @@ function mmread(::Type{A}, io, mmap) where A<:Union{Array,BitArray} elseif T<:Union{} || T<:Nullable{Union{}} sz = deserialize(io) return Array{T}(sz) + end + + fl = fixedlength(T) + if fl > 0 + sz = deserialize(io) + arr = A(sz...) + @inbounds for i in eachindex(arr) + arr[i] = fast_read(io.io, T)::T + end + return arr else return deserialize(io) # slow!! end @@ -85,3 +105,68 @@ function mmread{N}(::Type{Array{String,N}}, io, mmap) end ys end + + +## Optimized fixed length IO +## E.g. this is very good for `StaticArrays.MVector`s + +function fixedlength(t::Type, cycles=ObjectIdDict()) + if isbits(t) + return sizeof(t) + elseif isa(t, UnionAll) + return -1 + end + + if haskey(cycles, t) + return -1 + end + cycles[t] = nothing + lens = ntuple(i->fixedlength(fieldtype(t, i), copy(cycles)), nfields(t)) + if isempty(lens) + # e.g. abstract type / array type + return -1 + elseif any(x->x<0, lens) + return -1 + else + return sum(lens) + end +end + +fixedlength(t::Type{<:String}) = -1 +fixedlength(t::Type{<:Ptr}) = -1 + +function gen_writer{T}(::Type{T}, expr) + @assert fixedlength(T) >= 0 "gen_writer must be called for fixed length eltypes" + if T<:Tuple + :(write(io, Ref{$T}($expr))) + elseif length(T.types) > 0 + :(begin + $([gen_writer(fieldtype(T, i), :(getfield($expr, $i))) for i=1:nfields(T)]...) + end) + elseif isbits(T) + return :(write(io, $expr)) + else + error("Don't know how to serialize $T") + end +end + +function gen_reader{T}(::Type{T}) + @assert fixedlength(T) >= 0 "gen_reader must be called for fixed length eltypes" + if T<:Tuple + :(read(io, Ref{$T}())[]) + elseif length(T.types) > 0 + return :(ccall(:jl_new_struct, Any, (Any,Any...), $T, $([gen_reader(fieldtype(T, i)) for i=1:nfields(T)]...))) + elseif isbits(T) + return :(read(io, $T)) + else + error("Don't know how to deserialize $T") + end +end + +@generated function fast_write(io, x) + gen_writer(x, :x) +end + +@generated function fast_read{T}(io, ::Type{T}) + gen_reader(T) +end diff --git a/test/runtests.jl b/test/runtests.jl index 1e5647d..b0d0865 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,6 +25,17 @@ end @test length(y) == 10 end +using StaticArrays +@testset "StaticArrays" begin + x = [@MVector(rand(75)) for i=1:100] + io = IOBuffer() + mmwrite(SerializationState(io), x) + alloc = @allocated mmwrite(SerializationState(seekstart(io)), x) + + @test deserialize(seekstart(io)) == x + @test MemPool.approx_size(x) == 75*100*8 +end + @testset "Array{String}" begin roundtrip([randstring(rand(1:10)) for i=4]) end