Skip to content

Commit

Permalink
Loosen searchsorted* index type (fixes JuliaLang#30763)
Browse files Browse the repository at this point in the history
* Might also address JuliaLang#31618
* Types of start and stop indicies are still restricted to
  Union{Int32,Int64} and must be the same
  • Loading branch information
kmsquire committed Apr 6, 2019
1 parent b4e1d0e commit b0165eb
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 12 deletions.
24 changes: 12 additions & 12 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,10 @@ partialsort(v::AbstractVector, k::Union{Int,OrdinalRange}; kws...) =

# index of the first value of vector a that is greater than or equal to x;
# returns length(v)+1 if x is greater than all values in v.
function searchsortedfirst(v::AbstractVector, x, lo::Int, hi::Int, o::Ordering)
lo = lo-1
hi = hi+1
@inbounds while lo < hi-1
function searchsortedfirst(v::AbstractVector, x, lo::INT, hi::INT, o::Ordering) where INT<:Union{Int32,Int64}
lo = lo-INT(1)
hi = hi+INT(1)
@inbounds while lo < hi-INT(1)
m = (lo+hi)>>>1
if lt(o, v[m], x)
lo = m
Expand All @@ -186,10 +186,10 @@ end

# index of the last value of vector a that is less than or equal to x;
# returns 0 if x is less than all values of v.
function searchsortedlast(v::AbstractVector, x, lo::Int, hi::Int, o::Ordering)
lo = lo-1
hi = hi+1
@inbounds while lo < hi-1
function searchsortedlast(v::AbstractVector, x, lo::INT, hi::INT, o::Ordering) where INT<:Union{Int32,Int64}
lo = lo-INT(1)
hi = hi+INT(1)
@inbounds while lo < hi-INT(1)
m = (lo+hi)>>>1
if lt(o, x, v[m])
hi = m
Expand All @@ -203,10 +203,10 @@ end
# returns the range of indices of v equal to x
# if v does not contain x, returns a 0-length range
# indicating the insertion point of x
function searchsorted(v::AbstractVector, x, ilo::Int, ihi::Int, o::Ordering)
lo = ilo-1
hi = ihi+1
@inbounds while lo < hi-1
function searchsorted(v::AbstractVector, x, ilo::INT, ihi::INT, o::Ordering) where INT<:Union{Int32,Int64}
lo = ilo-INT(1)
hi = ihi+INT(1)
@inbounds while lo < hi-INT(1)
m = (lo+hi)>>>1
if lt(o, v[m], x)
lo = m
Expand Down
29 changes: 29 additions & 0 deletions test/sorting.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module SortingTests

using Base.Order: Forward
using Random
using Test

@test sort([2,3,1]) == [1,2,3]
@test sort([2,3,1], rev=true) == [3,2,1]
Expand Down Expand Up @@ -371,3 +374,29 @@ end
end
# https://discourse.julialang.org/t/sorting-big-int-with-v-0-6/1241
@test sort([big(3), big(2)]) == [big(2), big(3)]

@testset "issue #30763" begin
for INT in [:Int32, :Int64]
@eval begin
struct T_30763{T}
n::T
end

Base.zero(::T_30763{$INT}) = T_30763{$INT}(0)
Base.convert(::Type{T_30763{$INT}}, n::Integer) = T_30763{$INT}($INT(n))
Base.isless(a::T_30763{$INT}, b::T_30763{$INT}) = isless(a.n, b.n)
Base.:(-)(a::T_30763{$INT}, b::T_30763{$INT}) = T_30763{$INT}(a.n - b.n)
Base.:(+)(a::T_30763{$INT}, b::T_30763{$INT}) = T_30763{$INT}(a.n + b.n)
Base.:(*)(n::Integer, a::T_30763{$INT}) = T_30763{$INT}(n * a.n)
Base.rem(a::T_30763{$INT}, b::T_30763{$INT}) = T_30763{$INT}(rem(a.n, b.n))

# The important part of this test is that the return type of length might be different from Int
Base.length(r::StepRange{T_30763{$INT},T_30763{$INT}}) = $INT((last(r).n - first(r).n) ÷ step(r).n)
end
end

@test searchsorted(T_30763{Int32}(1):T_30763{Int32}(3), T_30763{Int32}(2)) == 2:2
@test searchsorted(T_30763{Int64}(1):T_30763{Int64}(3), T_30763{Int64}(2)) == 2:2
end

end

0 comments on commit b0165eb

Please sign in to comment.