Skip to content

Commit

Permalink
fix the merging algo
Browse files Browse the repository at this point in the history
  • Loading branch information
krynju committed Sep 5, 2021
1 parent 31d709f commit 79993d7
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 78 deletions.
2 changes: 1 addition & 1 deletion src/table/dtable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import TableOperations

import Base: fetch, show

export DTable, tabletype, tabletype!, trim, trim!
export DTable, tabletype, tabletype!, trim, trim!, groupby

const VTYPE = Vector{Union{Dagger.Chunk,Dagger.EagerThunk}}

Expand Down
122 changes: 57 additions & 65 deletions src/table/groupby.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@

function groupby(d::DTable, col; merge=true, chunksize=0)
function groupby(d::DTable, col::Symbol; merge=true, chunksize=0)
distinct_values = (_chunk, _col) -> unique(Tables.getcolumn(_chunk, _col))

filter_wrap = (_chunk, _f) -> begin
Expand All @@ -9,6 +8,7 @@ function groupby(d::DTable, col; merge=true, chunksize=0)

chunk_wrap = (_chunk, _col) -> begin
vals = distinct_values(_chunk, _col)
sort!(vals)
if length(vals) > 1
[v => Dagger.@spawn filter_wrap(_chunk, x -> Tables.getcolumn(x, _col) .== v) for v in vals]
else
Expand All @@ -18,76 +18,68 @@ function groupby(d::DTable, col; merge=true, chunksize=0)

v = [Dagger.@spawn chunk_wrap(c, col) for c in d.chunks]

build_index = (merge, chunksize, vs...) -> begin
v = vcat(vs...)
ks = unique(map(x-> x[1], v))
chunks = Vector{Union{EagerThunk, Nothing}}(map(x-> x[2], v))

idx = Dict([k => Vector{Int}() for k in ks])
for (i, k) in enumerate(map(x-> x[1], v))
push!(idx[k], i)
#ret = _build_groupby_index(merge, chunksize, tabletype(d), fetch.(v)...)
ret = fetch(Dagger.@spawn _build_groupby_index(merge, chunksize, tabletype(d), v...))
DTable(VTYPE(ret[2]), d.tabletype, Dict(col => ret[1]))
end

function _build_groupby_index(merge::Bool, chunksize::Int, tabletype, vs...)
v = vcat(vs...)
ks = unique(map(x-> x[1], v))
chunks = Vector{Union{EagerThunk, Nothing}}(map(x-> x[2], v))

idx = Dict([k => Vector{Int}() for k in ks])
for (i, k) in enumerate(map(x-> x[1], v))
push!(idx[k], i)
end

if merge && chunksize <= 0 # merge all partitions into one
sink = Tables.materializer(tabletype())
v2 = Vector{EagerThunk}()
sizehint!(v2, length(keys(idx)))
for (i, k) in enumerate(keys(idx))
c = getindex.(Ref(chunks), idx[k])
push!(v2, Dagger.@spawn merge_chunks(sink, c...))
idx[k] = [i]
end

if merge && chunksize <= 0 # merge all partitions into one
sink = Tables.materializer(tabletype(d)())
v2 = Vector{EagerThunk}()
sizehint!(v2, length(keys(idx)))
for (i, k) in enumerate(keys(idx))
c = getindex.(Ref(chunks), idx[k])
push!(v2, Dagger.@spawn merge_chunks(sink, c...))
idx[k] = [i]
end
idx, v2
elseif merge && chunksize > 0 # merge all but keep the chunking approximately at chunksize with minimal merges
sink = Tables.materializer(tabletype(d)())
for (i, k) in enumerate(keys(idx))
_indices = idx[k]
_chunks = getindex.(Ref(chunks), _indices)
_lengths = fetch.(Dagger.spawn.(rowcount, _chunks))
c = collect.(collect(zip(_indices, _chunks, _lengths)))
sort!(c, by=(x->x[3]), rev=true)
idx, v2
elseif merge && chunksize > 0 # merge all but try to merge all the small chunks into chunks of chunksize
sink = Tables.materializer(tabletype())
for k in keys(idx)
_indices = idx[k]
_chunks = getindex.(Ref(chunks), _indices)
_lengths = fetch.(Dagger.spawn.(rowcount, _chunks))

c = collect.(collect(zip(_indices, _lengths, _chunks)))
index = 1; len = 2; chunk = 3

l = 1
r = length(c)
prev_r = r
while l < r
if c[l][3] >= chunksize
if r < prev_r
c[l][2] = Dagger.@spawn merge_chunks(sink, c[l][2], getindex.(c[r+1:prev_r], 2)...)
prev_r = r
end
l += 1
elseif c[l][3] + c[r][3] > chunksize
if r < prev_r
c[l][2] = Dagger.@spawn merge_chunks(sink, c[l][2], getindex.(c[r+1:prev_r], 2)...)
prev_r = r
end
l += 1

elseif c[l][3] + c[r][3] <= chunksize # merge
c[l][3] = c[l][3] + c[r][3]
r -= 1
end
end
@assert l == r
for i in 1:length(c)
if i <= l
chunks[c[i][1]] = c[i][2]
else
chunks[c[i][1]] = nothing
sort!(c, by=(x->x[len]), rev=true)

l = 1
r = length(c)
prev_r = r

while l <= r
if c[l][len] >= chunksize || c[l][len] + c[r][len] > chunksize || l == r
if r < prev_r
c[l][chunk] = Dagger.@spawn merge_chunks(sink, c[l][chunk], getindex.(c[r+1:prev_r], chunk)...)
prev_r = r
end
l += 1
elseif c[l][len] + c[r][len] <= chunksize # merge
c[l][len] += c[r][len]
r -= 1
end
idx[k] = map(x-> x[1], c[1:l])
end
idx, filter(x-> !isnothing(x), chunks)
else
idx, chunks
for i in 1:length(c)
chunks[c[i][index]] = i <= r ? c[i][chunk] : nothing
end
idx[k] = map(x-> x[index], c[1:r])
end
idx, filter(x-> x !== nothing, chunks)
else
idx, chunks
end

res = Dagger.@spawn build_index(merge, chunksize, v...)
r = fetch(res)
DTable(VTYPE(r[2]), d.tabletype, Dict(col => r[1]))
end

merge_chunks(sink, chunks...) = sink(TableOperations.joinpartitions(Tables.partitioner(x -> x, chunks)))
Expand Down
37 changes: 25 additions & 12 deletions test/table.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,17 +213,30 @@ using CSV

@testset "groupby" begin
d = DTable((a=repeat(['a','b','c','d'], 6),), 4)
g = Dagger.groupby(d, :a) # merge=true, chunksize=0
@test length(g.chunks) == 4
g = Dagger.groupby(d, :a, chunksize=1)
@test length(g.chunks) == 24
g = Dagger.groupby(d, :a, merge=false)
@test length(g.chunks) == 24
g = Dagger.groupby(d, :a, chunksize=2)
@test length(g.chunks) == 12
g = Dagger.groupby(d, :a, chunksize=3)
@test length(g.chunks) == 8
g = Dagger.groupby(d, :a, chunksize=6)
@test length(g.chunks) == 4

@test length(groupby(d, :a).chunks) == 4
@test length(groupby(d, :a, chunksize=1).chunks) == 24
@test length(groupby(d, :a, merge=false).chunks) == 24
@test length(groupby(d, :a, chunksize=2).chunks) == 12
@test length(groupby(d, :a, chunksize=3).chunks) == 8
@test length(groupby(d, :a, chunksize=6).chunks) == 4

@test sort(collect(fetch(d).a)) == sort(collect(fetch(groupby(d, :a)).a))
@test sort(collect(fetch(d).a)) == sort(collect(fetch(groupby(d, :a, chunksize=3)).a))
@test sort(collect(fetch(d).a)) == sort(collect(fetch(groupby(d, :a, chunksize=1)).a))

d = DTable((a=repeat(['a','a', 'b', 'b'], 6),), 2)
@test length(groupby(d, :a).chunks) == 2
@test length(groupby(d, :a, chunksize=1).chunks) == 12
@test length(groupby(d, :a, merge=false).chunks) == 12
@test length(groupby(d, :a, chunksize=2).chunks) == 12
@test length(groupby(d, :a, chunksize=3).chunks) == 12 # grouping doesn't split chunks, so two 2-long chunks won't merge on chunksize 3
@test length(groupby(d, :a, chunksize=4).chunks) == 6
@test length(groupby(d, :a, chunksize=6).chunks) == 4
@test length(groupby(d, :a, chunksize=12).chunks) == 2
@test length(groupby(d, :a, chunksize=24).chunks) == 2

@test sort(collect(fetch(d).a)) == sort(collect(fetch(groupby(d, :a)).a))
@test sort(collect(fetch(d).a)) == sort(collect(fetch(groupby(d, :a, chunksize=5)).a))
end
end

0 comments on commit 79993d7

Please sign in to comment.