diff --git a/Project.toml b/Project.toml index 1c7af6f..d04f82c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,12 @@ name = "IntervalSets" uuid = "8197267c-284f-5f27-9208-e0e47529a953" -version = "0.7.10" +version = "0.7.11" [deps] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" [weakdeps] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -24,8 +25,10 @@ OffsetArrays = "1" Plots = "1" Random = "1" RecipesBase = "1" +Scanf = "0.5" Statistics = "1" Test = "1" +TupleTools = "1" Unitful = "1" julia = "1.6" @@ -36,9 +39,10 @@ OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +Scanf = "6ef1bc8b-493b-44e1-8d40-549aa65c4b41" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" [targets] -test = ["Aqua", "Dates", "Test", "Plots", "Random", "RecipesBase", "OffsetArrays", "Statistics", "Unitful"] +test = ["Aqua", "Dates", "Test", "Plots", "Random", "RecipesBase", "OffsetArrays", "Scanf", "Statistics", "Unitful"] diff --git a/src/interval.jl b/src/interval.jl index a35bf39..2cadebf 100644 --- a/src/interval.jl +++ b/src/interval.jl @@ -161,11 +161,18 @@ show(io::IO, I::OpenInterval) = print(io, leftendpoint(I), " .. ", rightendpoint show(io::IO, I::Interval{:open,:closed}) = print(io, leftendpoint(I), " .. ", rightendpoint(I), " (open-closed)") show(io::IO, I::Interval{:closed,:open}) = print(io, leftendpoint(I), " .. ", rightendpoint(I), " (closed-open)") +leftendpointtype(::TypedEndpointsInterval{L,R}) where {L,R} = L +rightendpointtype(::TypedEndpointsInterval{L,R}) where {L,R} = R + # The following are not typestable for mixed endpoint types _left_intersect_type(::Type{Val{:open}}, ::Type{Val{L2}}, a1, a2) where L2 = a1 < a2 ? (a2,L2) : (a1,:open) _left_intersect_type(::Type{Val{:closed}}, ::Type{Val{L2}}, a1, a2) where L2 = a1 ≤ a2 ? (a2,L2) : (a1,:closed) _right_intersect_type(::Type{Val{:open}}, ::Type{Val{R2}}, b1, b2) where R2 = b1 > b2 ? (b2,R2) : (b1,:open) _right_intersect_type(::Type{Val{:closed}}, ::Type{Val{R2}}, b1, b2) where R2 = b1 ≥ b2 ? (b2,R2) : (b1,:closed) +_left_union_type(::Type{Val{:open}}, ::Type{Val{L2}}, a1, a2) where L2 = a1 < a2 ? (a1,:open) : (a2,L2) +_left_union_type(::Type{Val{:closed}}, ::Type{Val{L2}}, a1, a2) where L2 = a1 ≤ a2 ? (a1,:closed) : (a2,L2) +_right_union_type(::Type{Val{:open}}, ::Type{Val{R2}}, b1, b2) where R2 = b1 > b2 ? (b1,:open) : (b2,R2) +_right_union_type(::Type{Val{:closed}}, ::Type{Val{R2}}, b1, b2) where R2 = b1 ≥ b2 ? (b1,:closed) : (b2,R2) function intersect(d1::TypedEndpointsInterval{L1,R1}, d2::TypedEndpointsInterval{L2,R2}) where {L1,R1,L2,R2} a1, b1 = endpoints(d1); a2, b2 = endpoints(d2) @@ -181,44 +188,12 @@ end intersect(d1::AbstractInterval, d2::AbstractInterval) = intersect(Interval(d1), Interval(d2)) +include("unionalgorithms.jl") -function union(d1::TypedEndpointsInterval{L1,R1,T1}, d2::TypedEndpointsInterval{L2,R2,T2}) where {L1,R1,T1,L2,R2,T2} - T = promote_type(T1,T2) - isempty(d1) && return Interval{L2,R2,T}(d2) - isempty(d2) && return Interval{L1,R1,T}(d1) - any(∈(d1), endpoints(d2)) || any(∈(d2), endpoints(d1)) || - throw(ArgumentError("Cannot construct union of disjoint sets.")) - _union(d1, d2) -end - -# these assume overlap -function _union(A::TypedEndpointsInterval{L,R}, B::TypedEndpointsInterval{L,R}) where {L,R} - left = min(leftendpoint(A), leftendpoint(B)) - right = max(rightendpoint(A), rightendpoint(B)) - Interval{L,R}(left, right) -end - -# this is not typestable -function _union(A::TypedEndpointsInterval{L1,R1}, B::TypedEndpointsInterval{L2,R2}) where {L1,R1,L2,R2} - if leftendpoint(A) == leftendpoint(B) - L = L1 == :closed ? :closed : L2 - elseif leftendpoint(A) < leftendpoint(B) - L = L1 - else - L = L2 - end - if rightendpoint(A) == rightendpoint(B) - R = R1 == :closed ? :closed : R2 - elseif rightendpoint(A) > rightendpoint(B) - R = R1 - else - R = R2 - end - left = min(leftendpoint(A), leftendpoint(B)) - right = max(rightendpoint(A), rightendpoint(B)) - - Interval{L,R}(left, right) -end +union(d::TypedEndpointsInterval) = d # 1 interval +union(d1::TypedEndpointsInterval, d2::TypedEndpointsInterval) = union2(d1, d2) # 2 intervals +Base.@nexprs(18,N -> union(I::Vararg{TypedEndpointsInterval,N+2}) = iterunion(TupleTools.sort(I; lt = leftof))) # 3 to 20 intervals +union(I::TypedEndpointsInterval...) = iterunion(sort!(collect(I); lt = leftof)) # ≥21 intervals ClosedInterval{T}(i::AbstractUnitRange{I}) where {T,I<:Integer} = ClosedInterval{T}(minimum(i), maximum(i)) ClosedInterval(i::AbstractUnitRange{I}) where {I<:Integer} = ClosedInterval{I}(minimum(i), maximum(i)) diff --git a/src/unionalgorithms.jl b/src/unionalgorithms.jl new file mode 100644 index 0000000..8994653 --- /dev/null +++ b/src/unionalgorithms.jl @@ -0,0 +1,73 @@ +import TupleTools + +""" + leftof(I1::TypedEndpointsInterval, I2::TypedEndpointsInterval) + +Returns if `I1` has a part to the left of `I2`. +""" +function leftof(I1::TypedEndpointsInterval{L1,R1}, I2::TypedEndpointsInterval{L2,R2}) where {L1,R1,L2,R2} + if leftendpoint(I1) < leftendpoint(I2) + true + elseif leftendpoint(I1) > leftendpoint(I2) + false + elseif L1 == :closed && L2 == :open + true + else + false + end +end + +""" + canunion(d1, d2) + +Returns if `d1 ∪ d2` is a single interval. `d1` and `d2` have to be non-empty. Note that `canunion` is not always `!isdisjoint`. For example, ``[0,1)`` and ``[1,2]`` are disjoint, but the union of them is a single interval. +""" +@inline canunion(d1, d2) = any(∈(d1), endpoints(d2)) || any(∈(d2), endpoints(d1)) + +function iterunion(iter) + T = promote_type(map(eltype, iter)...) + next = iterate(iter) + while !isnothing(next) + (item, state) = next + # find the first non-empty interval + if isempty(item) + next = iterate(iter, state) + continue + end + L = leftendpointtype(item) + R = rightendpointtype(item) + l = leftendpoint(item) + r = rightendpoint(item) + next = iterate(iter, state) + while !isnothing(next) + (item, state) = next + if isempty(item) + elseif leftendpoint(item) > r + throw(ArgumentError("IntervalSets doesn't support union of disjoint intervals, while the interval $r..$(leftendpoint(item)) (open) is not covered. Try using DomainSets.UnionDomain for disjoint intervals or ∪(a,b,c...) if the intervals are not sorted.")) + elseif R==:open && leftendpoint(item)==r && leftendpointtype(item)==:open + throw(ArgumentError("IntervalSets doesn't support union of disjoint intervals, while the point $r is not covered. Try using DomainSets.UnionDomain for disjoint intervals or ∪(a,b,c...) if the intervals are not sorted.")) + else + (r,R) = _right_union_type(Val{R}, Val{rightendpointtype(item)}, r, rightendpoint(item)) + end + next = iterate(iter, state) + end + return Interval{L,R,T}(l,r) + end + return first(iter) # can't find the first non-empty interval. return the first one. +end + +# good old union +function union2(d1::TypedEndpointsInterval{L1,R1,T1}, d2::TypedEndpointsInterval{L2,R2,T2}) where {L1,R1,T1,L2,R2,T2} + T = promote_type(T1,T2) + isempty(d1) && return Interval{L2,R2,T}(d2) + isempty(d2) && return Interval{L1,R1,T}(d1) + canunion(d1, d2) && return _union(d1, d2) + throw(ArgumentError("Cannot construct union of disjoint sets.")) +end + +# this is not typestable +function _union(A::TypedEndpointsInterval{L1,R1}, B::TypedEndpointsInterval{L2,R2}) where {L1,R1,L2,R2} + l, L = _left_union_type(Val{L1}, Val{L2}, leftendpoint(A), leftendpoint(B)) + r, R = _right_union_type(Val{R1}, Val{R2}, rightendpoint(A), rightendpoint(B)) + Interval{L,R}(l, r) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index d9dc0ea..189a88b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,7 @@ import Statistics: mean using Random using Unitful using Plots +using Scanf import IntervalSets: Domain, endpoints, closedendpoints, TypedEndpointsInterval diff --git a/test/setoperations.jl b/test/setoperations.jl index 551300c..226c565 100644 --- a/test/setoperations.jl +++ b/test/setoperations.jl @@ -230,6 +230,67 @@ # - different interval types @test (1..2) ∩ OpenInterval(0.5, 1.5) ≡ Interval{:closed, :open}(1, 1.5) @test (1..2) ∪ OpenInterval(0.5, 1.5) ≡ Interval{:open, :closed}(0.5, 2) + + intervals = [i1, i2, i3, i4, i5, i_empty] + for _ in 1:10 + @test ∪(shuffle!(intervals)...) == 0..3 + end + intervals = [i1, i2, i4, i5, i_empty] + for _ in 1:10 + @test_throws ArgumentError union(shuffle!(intervals)...) + end +end + +randinterval(s) = Interval{rand([:closed,:open]),rand([:closed,:open])}(rand(s), rand(s)) +function test_multipleunion(intervals) + if all(isempty, intervals) + @test isempty(∪(intervals...)) + else + u = nothing + try + u = ∪(intervals...) + catch e + @test e isa ArgumentError + s = e.msg + ind = findfirst("while the ", s)[end] + 1 + if startswith(s[ind:end], "interval") + u = OpenInterval(@scanf(s[ind+9:end], "%d..%d", Int, Int)[2:3]...) + @test all(v -> isempty(u ∩ v), intervals) + elseif startswith(s[ind:end], "point") + x = @scanf(s[ind+6:end], "%d", Int)[2] + @test all(v -> x ∉ v, intervals) + else + error("have you touched the error message?") + end + return + end + @test all(v -> v ⊆ u, intervals) + # The following codes rigorously tests the correctness of u. + # However, the `setdiff` is only implemented in `DomainSets.jl` + # which introduces piracies. See https://github.com/JuliaMath/IntervalSets.jl/pull/156#discussion_r1497829695 + # as a result, the correctness of interval union is not thoroughly tested. + #= v = u + for i in intervals + try + v = setdiff(v, i) + catch + println("setdiff($v, $i) was not successful. The union was $u and the components are $intervals.") + @test false + end + end + @test isempty(v) =# + end +end + +@testset "general union" begin + for _ in 1:500 + intervals = [randinterval(0:10) for _ in 1:5] + test_multipleunion(intervals) + end + for _ in 1:50 + intervals = [randinterval(0:100) for _ in 1:100] + test_multipleunion(intervals) + end end @testset "in" begin