Skip to content

Commit

Permalink
Add subsets(collection,Val{k}) (#13)
Browse files Browse the repository at this point in the history
* Add subsets(collection,Val{k})

* Update REQUIRE

* Prettify

* Fix iteratoreltype for

* Change ::Type{Val{K}} -> ::Val{K}

* Add additional tests

* Fix

* Make changes

* Fix things

* Update

* Replace

* Update documentation

* Fix iteratorsize(StaticSizeBinomial)

* Clean up

* Clean up

* Remove unnecessary inlines
  • Loading branch information
ettersi authored and iamed2 committed Mar 7, 2018
1 parent fd47db9 commit 3ef3ca4
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 7 deletions.
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
julia 0.6
Compat 0.31.0
72 changes: 65 additions & 7 deletions src/IterTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import Base: IteratorSize, IteratorEltype
import Base: SizeUnknown, IsInfinite, HasLength, HasShape
import Base: HasEltype, EltypeUnknown

import Compat # for ntuple with ::Val{N} arg

@static if VERSION < v"0.7.0-DEV.3309"
import Base: iteratorsize, iteratoreltype
else
Expand Down Expand Up @@ -607,10 +609,16 @@ length(it::Subsets) = 1 << length(it.xs)
"""
subsets(xs)
subsets(xs, k)
subsets(xs, Val{k}())
Iterate over every subset of the collection `xs`. You can restrict the subsets to a specific
size `k`.
Giving the subset size in the form `Val{k}()` allows the compiler to produce code optimized
for the particular size requested. This leads to performance comparable to hand-written
loops if `k` is small and known at compile time, but may or may not improve performance
otherwise.
```jldoctest
julia> for i in subsets([1, 2, 3])
@show i
Expand All @@ -633,6 +641,16 @@ i = [1, 4]
i = [2, 3]
i = [2, 4]
i = [3, 4]
julia> for i in subsets(1:4, Val{2}())
@show i
end
i = (1, 2)
i = (1, 3)
i = (1, 4)
i = (2, 3)
i = (2, 4)
i = (3, 4)
```
"""
function subsets(xs)
Expand Down Expand Up @@ -667,19 +685,20 @@ end

# Iterate over all subsets of a collection with a given size

struct Binomial{T}
xs::Vector{T}
struct Binomial{Collection}
xs::Collection
n::Int64
k::Int64
end
Binomial(xs::AbstractVector{T}, n::Integer, k::Integer) where {T} = Binomial{T}(xs, n, k)
Binomial(xs::C, n::Integer, k::Integer) where {C} = Binomial{C}(xs, n, k)

iteratorsize(::Type{<:Binomial}) = HasLength()
iteratoreltype(::Type{Binomial{C}}) where {C} = iteratoreltype(C)

eltype(::Type{Binomial{T}}) where {T} = Vector{T}
eltype(::Type{Binomial{C}}) where {C} = Vector{eltype(C)}
length(it::Binomial) = binomial(it.n,it.k)

subsets(xs,k) = Binomial(xs,length(xs),k)
subsets(xs, k) = Binomial(xs, length(xs), k)

mutable struct BinomialIterState
idx::Vector{Int64}
Expand All @@ -694,7 +713,7 @@ function next(it::Binomial, state::BinomialIterState)
idx = state.idx
set = it.xs[idx]
i = it.k
while(i>0)
while i > 0
if idx[i] < it.n - it.k + i
idx[i] += 1

Expand All @@ -708,14 +727,53 @@ function next(it::Binomial, state::BinomialIterState)
end
end

state.done = i==0
state.done = i == 0

return set, state
end

done(it::Binomial, state::BinomialIterState) = state.done


# Iterate over all subsets of a collection with a given *statically* known size

struct StaticSizeBinomial{K, Container}
xs::Container
end

iteratorsize(::Type{StaticSizeBinomial{K, C}}) where {K, C} = HasLength()
iteratoreltype(::Type{StaticSizeBinomial{K, C}}) where {K, C} = iteratoreltype(C)

eltype(::Type{StaticSizeBinomial{K, C}}) where {K, C} = NTuple{K, eltype(C)}
length(it::StaticSizeBinomial{K}) where {K} = binomial(length(it.xs), K)

subsets(xs::C, ::Val{K}) where {K, C} = StaticSizeBinomial{K, C}(xs)

# Special cases for K == 0
start(it::StaticSizeBinomial{0}) = false
next(it::StaticSizeBinomial{0}, _) = (), true
done(it::StaticSizeBinomial{0}, d) = d

# Generic case K >= 1
pop(t::NTuple) = reverse(Base.tail(reverse(t))), t[end]

function advance(it::StaticSizeBinomial{K}, idx) where {K}
xs = it.xs
lidx, i = pop(idx)
i += 1
if i > length(xs) - K + length(idx)
lidx = advance(it, lidx)
i = lidx[end] + 1
end
return (lidx..., i)
end
advance(it::StaticSizeBinomial, idx::NTuple{1}) = (idx[end]+1,)

start(it::StaticSizeBinomial{K}) where {K} = ntuple(identity, Val{K}())
next(it::StaticSizeBinomial, idx) = map(i -> it.xs[i], idx), advance(it, idx)
done(it::StaticSizeBinomial, state) = state[end] > length(it.xs)


# nth : return the nth element in a collection

"""
Expand Down
46 changes: 46 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,52 @@ end
@test length(collect(sk11)) == binomial(3, 2)
end
end

@testset "specific static length" begin
sk0 = subsets([:a, :b, :c], Val{0}())
@test collect(sk0) == [()]

sk1 = subsets([:a, :b, :c], Val{1}())
@test eltype(eltype(sk1)) == Symbol
@test collect(sk1) == [(:a,), (:b,), (:c,)]

sk2 = subsets([:a, :b, :c], Val{2}())
@test eltype(eltype(sk2)) == Symbol
@test collect(sk2) == [(:a, :b), (:a, :c), (:b, :c)]

sk3 = subsets([:a, :b, :c], Val{3}())
@test eltype(eltype(sk3)) == Symbol
@test collect(sk3) == [(:a, :b, :c)]

sk4 = subsets([:a, :b, :c], Val{4}())
@test eltype(eltype(sk4)) == Symbol
@test collect(sk4) == []

sk5 = subsets([:a, :b, :c], Val{5}())
@test eltype(eltype(sk5)) == Symbol
@test collect(sk5) == []

@testset for i in 1:6
sk5 = subsets(collect(1:4), Val{i}())
@test eltype(eltype(sk5)) == Int
@test length(collect(sk5)) == binomial(4, i)
end

function collect_pairs(x)
p = Vector{NTuple{2, eltype(x)}}(binomial(length(x), 2))
idx = 1
for i = 1:length(x)
for j = i+1:length(x)
p[idx] = (x[i], x[j])
idx += 1
end
end
return p
end
@testset for n = 1:10
@test collect(subsets(1:n, Val{2}())) == collect_pairs(1:n)
end
end
end

@testset "nth" begin
Expand Down

0 comments on commit 3ef3ca4

Please sign in to comment.