diff --git a/src/Iterators.jl b/src/Iterators.jl index cc907aa..9216072 100644 --- a/src/Iterators.jl +++ b/src/Iterators.jl @@ -17,6 +17,7 @@ export subsets, iterate, takenth, + mergesorted, @itr # iteratorsize is new in 0.5, declare it here for older versions. However, @@ -714,6 +715,129 @@ start(it::Iterate) = it.seed next(it::Iterate, state) = (state, it.f(state)) @compat done(it::Iterate, state) = (state==Union{}) +# PrefetchIterator: helper class for preliminary fetching one element in advance + +immutable PrefetchIterator + inner # underlying iterator +end + +immutable PrefetchState + hd # current head of iterator (next to emit) + hd_prev # previous head (just emitted) + st # curent state of iterator (next to emit) + st_prev # previous state (just emitted) +end + +function head(state::PrefetchState) + return state.hd +end + +function start(it::PrefetchIterator) + st0 = start(it.inner) + hd, st = next(it.inner, st0) + return PrefetchState(hd, hd, st, st0) +end + +function next(it::PrefetchIterator, state::PrefetchState) + if done(it.inner, state.st) + # can't read any more elements, repeating current one forever + return state.hd, PrefetchState(state.hd, state.hd, state.st, state.st) + else + new_hd, new_st = next(it.inner, state.st) + return state.hd, PrefetchState(new_hd, state.hd, new_st, state.st) + end +end + +function done(it::PrefetchIterator, state::PrefetchState) + return done(it.inner, state.st_prev) +end + + +# mergesorted: merge sorted iterators + +type MergeIter{T} + iterators::Vector{PeekIter} + lt::Function +end + +function show{T}(io::IO, merged::MergeIter{T}) + print(io, "MergeIter{$T}($(length(merged.iterators)))") +end + +length(it::MergeIter) = sum(map(iteratorsize, it.iterators)) +size(it::MergeIter) = (length(it),) + +function start(merged::MergeIter) + states = Array(Tuple, length(merged.iterators)) + for i in eachindex(merged.iterators) + states[i] = start(merged.iterators[i]) + end + return states +end + +function smaller_iter_state(merged::MergeIter, states::Vector{Tuple}, + i::Int, j::Int) + lt = merged.lt + hd1 = get(peek(merged.iterators[i], states[i])) + hd2 = get(peek(merged.iterators[j], states[j])) + return lt(hd1, hd2) ? i : j +end + +function move_iterators{T}(merged::MergeIter{T}, states::Vector{Tuple}, hd::T) + iters = merged.iterators + new_states = copy(states) + for i in eachindex(states) + while (!done(iters[i], new_states[i]) && + get(peek(iters[i], new_states[i])) == hd) + _, new_states[i] = next(iters[i], new_states[i]) + end + end + return new_states +end + +function next{T}(merged::MergeIter{T}, states::Vector{Tuple}) + active_idxs::Vector{Int} = filter(i -> !done(merged.iterators[i], states[i]), + 1:length(merged.iterators)) + @assert length(active_idxs) != 0 + min_idx::Int = reduce((i, j) -> smaller_iter_state(merged, states, i, j), + active_idxs) + min_it = merged.iterators[min_idx] + min_s = states[min_idx] + hd::T = convert(T, get(peek(min_it, min_s))) + new_states = move_iterators(merged, states, hd) + return hd, new_states +end + +function done(merged::MergeIter, states::Vector{Tuple}) + return all([done(it, s) for (it, s) in zip(merged.iterators, states)]) +end + +function peek{T}(merged::MergeIter{T}, states::Vector{Tuple}) + active_idxs::Vector{Int} = filter(i -> !done(merged.iterators[i], states[i]), + 1:length(merged.iterators)) + if length(active_idxs) != 0 + min_idx::Int = reduce((i, j) -> smaller_iter_state(merged, states, i, j), + active_idxs) + min_it = merged.iterators[min_idx] + min_s = states[min_idx] + hd::T = convert(T, get(peek(min_it, min_s))) + return Nullable{T}(hd) + else + return Nullable{T}() + end +end + +""" +Merge sorted iterators, removing dublicates. Comparison is done using `lt` option +that defaults to `isless` function. +""" +function mergesorted(iterators...; lt = isless) + T = promote_type(map(eltype, iterators)...) + peek_iterators = [PeekIter(it) for it in iterators] + return MergeIter{T}(peek_iterators, lt) +end + + using Base.Meta ## @itr macro for auto-inlining in for loops diff --git a/test/runtests.jl b/test/runtests.jl index 6df1b09..8fee66e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -224,6 +224,25 @@ test_groupby( @test collect(takenth(10:20, 1)) == collect(10:20) +# mergesorted + +macro test_mergesorted(expected, iterators...) + x = gensym() + w = :(mergesorted($(iterators...))) + quote + actual = Any[] + for $x in $w + push!(actual, $x) + end + @test actual == $expected + end +end + +@test_mergesorted [0, 1, 2, 3, 4, 5] 1:3 0:1 [2, 3, 4, 5] +@test_mergesorted [] [] [] +@test_mergesorted [:a, :b, :c] [:b, :c] [:a] [] + + ## @itr ## ====