Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sort for NTuples, make it extensible by packages, and make DefaultStable and DefaultUnstable dispatchable #54494

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 75 additions & 21 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module Sort

using Base.Order

using Base: copymutable, midpoint, require_one_based_indexing, uinttype,
using Base: copymutable, midpoint, require_one_based_indexing, uinttype, tail,
sub_with_overflow, add_with_overflow, OneTo, BitSigned, BitIntegerType, top_set_bit

import Base:
Expand Down Expand Up @@ -1475,21 +1475,16 @@ InitialOptimizations(next) = SubArrayOptimization(
Small{10}(
IEEEFloatOptimization(
next)))))
"""
DEFAULT_STABLE

The default sorting algorithm.

This algorithm is guaranteed to be stable (i.e. it will not reorder elements that compare
equal). It makes an effort to be fast for most inputs.

The algorithms used by `DEFAULT_STABLE` are an implementation detail. See extended help
for the current dispatch system.
"""
struct DefaultStable <: Algorithm end

# Extended Help
`DefaultStable` is an algorithm which indicates that a fast, general purpose sorting
algorithm should be used, but does not specify exactly which algorithm.

`DEFAULT_STABLE` is composed of two parts: the [`InitialOptimizations`](@ref) and a hybrid
of Radix, Insertion, Counting, Quick sorts.
Currently, when sorting short NTuples, this is an unrolled mergesort, and otherwise it is
composed of two parts: the [`InitialOptimizations`](@ref) and a hybrid of Radix, Insertion,
Counting, Quick sorts.

We begin with MissingOptimization because it has no runtime cost when it is not
triggered and can enable other optimizations to be applied later. For example,
Expand Down Expand Up @@ -1549,7 +1544,39 @@ stage.
Finally, if the input has length less than 80, we dispatch to [`InsertionSort`](@ref) and
otherwise we dispatch to [`ScratchQuickSort`](@ref).
"""
const DEFAULT_STABLE = InitialOptimizations(
struct DefaultStable <: Algorithm end

"""
DEFAULT_STABLE

The default sorting algorithm.

This algorithm is guaranteed to be stable (i.e. it will not reorder elements that compare
equal). It makes an effort to be fast for most inputs.

The algorithms used by `DEFAULT_STABLE` are an implementation detail. See the extended help
of `Base.Sort.DefaultStable` for the current dispatch system.
"""
const DEFAULT_STABLE = DefaultStable()

"""
DefaultUnstable <: Algorithm

Like [`DefaultStable`](@ref), but does not guarantee stability.
"""
struct DefaultUnstable <: Algorithm end

"""
DEFAULT_UNSTABLE

An efficient sorting algorithm which may or may not be stable.

The algorithms used by `DEFAULT_UNSTABLE` are an implementation detail. They are currently
the same as those used by [`DEFAULT_STABLE`](@ref), but this is subject to change in future.
"""
const DEFAULT_UNSTABLE = DefaultUnstable()

const _DEFAULT_ALGORITHMS_FOR_VECTORS = InitialOptimizations(
IsUIntMappable(
Small{40}(
CheckSorted(
Expand All @@ -1560,15 +1587,10 @@ const DEFAULT_STABLE = InitialOptimizations(
ScratchQuickSort())))))),
StableCheckSorted(
ScratchQuickSort())))
"""
DEFAULT_UNSTABLE

An efficient sorting algorithm.
_sort!(v::AbstractVector, ::Union{DefaultStable, DefaultUnstable}, o::Ordering, kw) =
_sort!(v, _DEFAULT_ALGORITHMS_FOR_VECTORS, o, kw)

The algorithms used by `DEFAULT_UNSTABLE` are an implementation detail. They are currently
the same as those used by [`DEFAULT_STABLE`](@ref), but this is subject to change in future.
"""
const DEFAULT_UNSTABLE = DEFAULT_STABLE
const SMALL_THRESHOLD = 20

function Base.show(io::IO, alg::Algorithm)
Expand Down Expand Up @@ -1598,6 +1620,7 @@ defalg(v::AbstractArray) = DEFAULT_STABLE
defalg(v::AbstractArray{<:Union{Number, Missing}}) = DEFAULT_UNSTABLE
defalg(v::AbstractArray{Missing}) = DEFAULT_UNSTABLE # for method disambiguation
defalg(v::AbstractArray{Union{}}) = DEFAULT_UNSTABLE # for method disambiguation
defalg(v::NTuple) = DEFAULT_STABLE

"""
sort!(v; alg::Base.Sort.Algorithm=Base.Sort.defalg(v), lt=isless, by=identity, rev::Bool=false, order::Base.Order.Ordering=Base.Order.Forward)
Expand Down Expand Up @@ -1736,6 +1759,37 @@ julia> v
"""
sort(v::AbstractVector; kws...) = sort!(copymutable(v); kws...)

function sort(x::NTuple{N,T};
alg::Algorithm=defalg(x),
lt=isless,
by=identity,
rev::Union{Bool,Nothing}=nothing,
order::Ordering=Forward,
scratch::Union{Vector{T}, Nothing}=nothing) where {N,T}
_sort(x, alg, ord(lt,by,rev,order), (;scratch))
end
# Folks who want to hack internals can define a new _sort(x::NTuple, ::TheirAlg, o::Ordering)
# or _sort(x::NTuple{N, TheirType}, ::DefaultStable, o::Ordering) where N
function _sort(x::NTuple, a::Union{DefaultStable, DefaultUnstable}, o::Ordering, kw)
if length(x) > 9
LilithHafner marked this conversation as resolved.
Show resolved Hide resolved
v = copymutable(x)
_sort!(v, a, o, kw)
typeof(x)(v)
else
_mergesort(x, o)
end
end
_mergesort(x::Union{NTuple{0}, NTuple{1}}, o::Ordering) = x
function _mergesort(x::NTuple, o::Ordering)
a, b = Base.IteratorsMD.split(x, Val(length(x)>>1))
merge(_mergesort(a, o), _mergesort(b, o), o)
end
merge(x::NTuple, y::NTuple{0}, o::Ordering) = x
merge(x::NTuple{0}, y::NTuple, o::Ordering) = y
merge(x::NTuple{0}, y::NTuple{0}, o::Ordering) = x # Method ambiguity
merge(x::NTuple, y::NTuple, o::Ordering) =
(lt(o, y[1], x[1]) ? (y[1], merge(x, tail(y), o)...) : (x[1], merge(tail(x), y, o)...))

## partialsortperm: the permutation to sort the first k elements of an array ##

"""
Expand Down
43 changes: 37 additions & 6 deletions test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,20 @@ end
vcat(2000, (x:x+99 for x in 1900:-100:100)..., 1:99)
end

function tuple_sort_test(x)
@test issorted(sort(x))
length(x) > 9 && return # length > 9 uses a vector fallback
@test 0 == @allocated sort(x)
end
@testset "sort(::NTuple)" begin
@test sort((9,8,3,3,6,2,0,8)) == (0,2,3,3,6,8,8,9)
@test sort((9,8,3,3,6,2,0,8), by=x->x÷3) == (2,0,3,3,8,6,8,9)
for i in 1:40
tuple_sort_test(rand(NTuple{i, Float64}))
end
@test_throws MethodError sort((1,2,3.0))
end

@testset "partialsort" begin
@test partialsort([3,6,30,1,9],3) == 6
@test partialsort([3,6,30,1,9],3:4) == [6,9]
Expand Down Expand Up @@ -799,9 +813,9 @@ end
let
requires_uint_mappable = Union{Base.Sort.RadixSort, Base.Sort.ConsiderRadixSort,
Base.Sort.CountingSort, Base.Sort.ConsiderCountingSort,
typeof(Base.Sort.DEFAULT_STABLE.next.next.next.big.next.yes),
typeof(Base.Sort.DEFAULT_STABLE.next.next.next.big.next.yes.big),
typeof(Base.Sort.DEFAULT_STABLE.next.next.next.big.next.yes.big.next)}
typeof(Base.Sort._DEFAULT_ALGORITHMS_FOR_VECTORS.next.next.next.big.next.yes),
typeof(Base.Sort._DEFAULT_ALGORITHMS_FOR_VECTORS.next.next.next.big.next.yes.big),
typeof(Base.Sort._DEFAULT_ALGORITHMS_FOR_VECTORS.next.next.next.big.next.yes.big.next)}

function test_alg(kw, alg, float=true)
for order in [Base.Forward, Base.Reverse, Base.By(x -> x^2)]
Expand Down Expand Up @@ -841,15 +855,18 @@ end
end
end

test_alg_rec(Base.DEFAULT_STABLE)
test_alg_rec(Base.Sort._DEFAULT_ALGORITHMS_FOR_VECTORS)
end
end

@testset "show(::Algorithm)" begin
@test eval(Meta.parse(string(Base.DEFAULT_STABLE))) === Base.DEFAULT_STABLE
lines = split(string(Base.DEFAULT_STABLE), '\n')
@test eval(Meta.parse(string(Base.Sort._DEFAULT_ALGORITHMS_FOR_VECTORS))) === Base.Sort._DEFAULT_ALGORITHMS_FOR_VECTORS
lines = split(string(Base.Sort._DEFAULT_ALGORITHMS_FOR_VECTORS), '\n')
@test 10 < maximum(length, lines) < 100
@test 1 < length(lines) < 30

@test eval(Meta.parse(string(Base.DEFAULT_STABLE))) === Base.DEFAULT_STABLE
@test string(Base.DEFAULT_STABLE) == "Base.Sort.DefaultStable()"
end

@testset "Extensibility" begin
Expand Down Expand Up @@ -890,6 +907,20 @@ end
end
@test sort([1,2,3], alg=MySecondAlg()) == [9,9,9]
@test all(sort(v, alg=Base.Sort.InitialOptimizations(MySecondAlg())) .=== vcat(fill(9, 100), fill(missing, 10)))

# Tuple extensions (custom alg)
@test_throws MethodError sort((1,2,3), alg=MyFirstAlg())
Base.Sort._sort(v::NTuple, ::MyFirstAlg, o::Base.Order.Ordering, kw) = "hi!"
@test sort((1,2,3), alg=MyFirstAlg()) == "hi!"

struct TupleFoo
x::Int
end

# Tuple extensions (custom type)
@test_throws MethodError sort(TupleFoo.((3,1,2)))
Base.Sort._sort(v::NTuple{N, TupleFoo}, ::Base.Sort.DefaultStable, o::Base.Order.Ordering, kw) where N = v
@test sort(TupleFoo.((3,1,2))) === TupleFoo.((3,1,2))
end

@testset "sort!(v, lo, hi, alg, order)" begin
Expand Down