Skip to content

Commit

Permalink
Ensure serializing Arrow.DictEncoded writes dictionary messages (#149)
Browse files Browse the repository at this point in the history
Fixes #126. The issue here was when `Arrow.write` was faced with the
task of serializing an `Arrow.DictEncoded`. For most arrow array types,
if the input array is already an arrow array type, it's a no-op (e.g. if
you're writing out an `Arrow.Table`). The problem comes from
`Arrow.DictEncoded`, where there is still no conversion required, but we
do need to make a note of the dict encoded column to ensure a dictionary
message is written before the record batch. In addition, we also add
some code for handling delta dictionary messages if required from
multiple record batches that contain `Arrow.DictEncoded`s, which is a
valid use-case where you may have multiple arrow files, with the same
schema, that you wish to serialize as a single arrow file w/ each file
as a separate record batch.

Slightly unrelated, but there's also a fix here in our use of Lockable.
We actually had a race condition I ran into once where the locking was
on the Lockable object, but inside the locked region, we replaced the
entire Lockable instead of the _contents_ of the Lockable. This meant
anyone who started waiting on the Lockable's lock didn't see updates
when unlocked because the entire Lockable had been updated.
  • Loading branch information
quinnj authored Mar 12, 2021
1 parent d7a1e32 commit 4cc34a3
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 8 deletions.
33 changes: 31 additions & 2 deletions src/arraytypes/dictencoding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,36 @@ dictencodeid(colidx, nestedlevel, fieldid) = (Int64(nestedlevel) << 48) | (Int64
getid(d::DictEncoded) = d.encoding.id
getid(c::Compressed{Z, A}) where {Z, A <: DictEncoded} = c.data.encoding.id

arrowvector(::DictEncodedType, x::DictEncoded, i, nl, fi, de, ded, meta; kw...) = x
function arrowvector(::DictEncodedType, x::DictEncoded, i, nl, fi, de, ded, meta; dictencode::Bool=false, dictencodenested::Bool=false, kw...)
id = x.encoding.id
if !haskey(de, id)
de[id] = Lockable(x.encoding)
else
encodinglockable = de[id]
@lock encodinglockable begin
encoding = encodinglockable.x
# in this case, we just need to check if any values in our local pool need to be delta dicationary serialized
deltas = setdiff(x.encoding, encoding)
if !isempty(deltas)
@show deltas
ET = indextype(encoding)
if length(deltas) + length(encoding) > typemax(ET)
error("fatal error serializing dict encoded column with ref index type of $ET; subsequent record batch unique values resulted in $(length(deltas) + length(encoding)) unique values, which exceeds possible index values in $ET")
end
data = arrowvector(deltas, i, nl, fi, de, ded, nothing; dictencode=dictencodenested, dictencodenested=dictencodenested, dictencoding=true, kw...)
push!(ded, DictEncoding{eltype(data), ET, typeof(data)}(id, data, false, getmetadata(data)))
if typeof(encoding.data) <: ChainedVector
append!(encoding.data, data)
else
data2 = ChainedVector([encoding.data, data])
encoding = DictEncoding{eltype(data2), ET, typeof(data2)}(id, data2, false, getmetadata(encoding))
de[id].x = encoding
end
end
end
end
return x
end

function arrowvector(::DictEncodedType, x, i, nl, fi, de, ded, meta; dictencode::Bool=false, dictencodenested::Bool=false, kw...)
@assert x isa DictEncode
Expand Down Expand Up @@ -195,7 +224,7 @@ function arrowvector(::DictEncodedType, x, i, nl, fi, de, ded, meta; dictencode:
else
data2 = ChainedVector([encoding.data, data])
encoding = DictEncoding{eltype(data2), ET, typeof(data2)}(id, data2, false, getmetadata(encoding))
de[id] = Lockable(encoding)
de[id].x = encoding
end
end
end
Expand Down
6 changes: 3 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,12 @@ function Base.close(ch::OrderedChannel)
return
end

struct Lockable{T}
x::T
mutable struct Lockable
x
lock::ReentrantLock
end

Lockable(x::T) where {T} = Lockable{T}(x, ReentrantLock())
Lockable(x) = Lockable(x, ReentrantLock())

Base.lock(x::Lockable) = lock(x.lock)
Base.unlock(x::Lockable) = unlock(x.lock)
7 changes: 4 additions & 3 deletions src/write.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ function write(io, source, writetofile, largelists, compress, denseunions, dicte
wait(tsk)
# write empty message
if !writetofile
Base.write(io, Message(UInt8[], nothing, 0, true, false), blocks, sch, alignment)
Base.write(io, Message(UInt8[], nothing, 0, true, false, Meta.Schema), blocks, sch, alignment)
end
if writetofile
b = FlatBuffers.Builder(1024)
Expand Down Expand Up @@ -223,6 +223,7 @@ struct Message
bodylen
isrecordbatch::Bool
blockmsg::Bool
headerType
end

struct Block
Expand All @@ -233,7 +234,7 @@ end

function Base.write(io::IO, msg::Message, blocks, sch, alignment)
metalen = padding(length(msg.msgflatbuf), alignment)
@debug 1 "writing message: metalen = $metalen, bodylen = $(msg.bodylen), isrecordbatch = $(msg.isrecordbatch)"
@debug -1 "writing message: metalen = $metalen, bodylen = $(msg.bodylen), isrecordbatch = $(msg.isrecordbatch), headerType = $(msg.headerType)"
if msg.blockmsg
push!(blocks[msg.isrecordbatch ? 1 : 2], Block(position(io), metalen + 8, msg.bodylen))
end
Expand Down Expand Up @@ -266,7 +267,7 @@ function makemessage(b, headerType, header, columns=nothing, bodylen=0)
# Meta.messageStartCustomMetadataVector(b, num_meta_elems)
msg = Meta.messageEnd(b)
FlatBuffers.finish!(b, msg)
return Message(FlatBuffers.finishedbytes(b), columns, bodylen, headerType == Meta.RecordBatch, headerType == Meta.RecordBatch || headerType == Meta.DictionaryBatch)
return Message(FlatBuffers.finishedbytes(b), columns, bodylen, headerType == Meta.RecordBatch, headerType == Meta.RecordBatch || headerType == Meta.DictionaryBatch, headerType)
end

function makeschema(b, sch::Tables.Schema{names}, columns) where {names}
Expand Down
24 changes: 24 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,30 @@ seekstart(io)
tt = Arrow.Table(io)
@test length(tt.a) == 132

# 126
t = Tables.partitioner(
(
(a=Arrow.toarrowvector(PooledArray([1,2,3 ])),),
(a=Arrow.toarrowvector(PooledArray([1,2,3,4])),),
(a=Arrow.toarrowvector(PooledArray([1,2,3,4,5])),),
)
)
io = IOBuffer()
Arrow.write(io, t)
seekstart(io)
tt = Arrow.Table(io)
@test length(tt.a) == 12
@test tt.a == [1, 2, 3, 1, 2, 3, 4, 1, 2, 3, 4, 5]

t = Tables.partitioner(
(
(a=Arrow.toarrowvector(PooledArray([1,2,3 ], signed=true, compress=true)),),
(a=Arrow.toarrowvector(PooledArray(collect(1:129))),),
)
)
io = IOBuffer()
@test_throws CompositeException Arrow.write(io, t)

end # @testset "misc"

end

0 comments on commit 4cc34a3

Please sign in to comment.