From cf123a8e402b2eb83e3d532b42f07415ee0803d7 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Mon, 4 May 2020 21:13:56 -0700 Subject: [PATCH] Optimize copy!(::Dict, ::Dict) (#34101) Co-authored-by: Jameson Nash --- base/dict.jl | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++ test/dict.jl | 32 ++++++++++++++++++++++++++----- 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/base/dict.jl b/base/dict.jl index 0dc8d4936719a..072cd6e39d18c 100644 --- a/base/dict.jl +++ b/base/dict.jl @@ -274,6 +274,60 @@ function empty!(h::Dict{K,V}) where V where K return h end +# Fast pass for exactly same `Dict` type: +function copy!(dst::D, src::D) where {D <: Dict} + copy!(dst.vals, src.vals) + copy!(dst.keys, src.keys) + copy!(dst.slots, src.slots) + dst.ndel = src.ndel + dst.count = src.count + dst.age = src.age + dst.idxfloor = src.idxfloor + dst.maxprobe = src.maxprobe + return dst +end + +# Fast pass when not changing key types (hence not changing hashes). +# It is unsafe to call this function in the sense it may leave `dst` +# in a state that is unsafe to use after `unsafe_copy!` failed. +function unsafe_copy!(dst::Dict, src::Dict) + resize!(dst.vals, length(src.vals)) + resize!(dst.keys, length(src.keys)) + + svals = src.vals + skeys = src.keys + dvals = dst.vals + dkeys = dst.keys + @inbounds for i = src.idxfloor:lastindex(svals) + if isslotfilled(src, i) + dvals[i] = svals[i] + dkeys[i] = skeys[i] + end + end + + copy!(dst.slots, src.slots) + + dst.ndel = src.ndel + dst.count = src.count + dst.age = src.age + dst.idxfloor = src.idxfloor + dst.maxprobe = src.maxprobe + return dst +end + +function copy!(dst::Dict{Kd}, src::Dict{Ks}) where {Kd, Ks<:Kd} + try + return unsafe_copy!(dst, src) + catch + empty!(dst) # avoid leaving `dst` in a corrupt state + rethrow() + end +end + +# It's safe to call `unsafe_copy!` if `Vd>:Vs`: +copy!(dst::Dict{Kd,Vd}, src::Dict{Ks,Vs}) where {Kd, Ks<:Kd, Vd, Vs<:Vd} = + unsafe_copy!(dst, src) + # get the index where a key is stored, or -1 if not present function ht_keyindex(h::Dict{K,V}, key) where V where K sz = length(h.keys) diff --git a/test/dict.jl b/test/dict.jl index 541bc3b50f84d..705d637756952 100644 --- a/test/dict.jl +++ b/test/dict.jl @@ -1071,13 +1071,35 @@ end end @testset "copy!" begin - s = Dict(1=>2, 2=>3) - for a = ([3=>4], [0x3=>0x4], [3=>4, 5=>6, 7=>8], Pair{UInt,UInt}[3=>4, 5=>6, 7=>8]) - @test s === copy!(s, Dict(a)) == Dict(a) - if length(a) == 1 # current limitation of Base.ImmutableDict - @test s === copy!(s, Base.ImmutableDict(a[])) == Dict(a[]) + @testset "copy!(::$(typeof(s)), _)" for s in Any[ + Dict(1 => 2, 2 => 3), # concrete key type + Dict{Union{Int,UInt},Int}(1 => 2, 2 => 3), # union key type + Dict{Any,Int}(1 => 2, 2 => 3), # abstract key type + Dict{Int,Float64}(1 => 2, 2 => 3), # values are converted + ] + @testset "copy!(_, ::$(typeof(Dict(a))))" for a in Any[ + [3 => 4], + [0x3 => 0x4], + [3 => 4, 5 => 6, 7 => 8], + Pair{UInt,UInt}[3=>4, 5=>6, 7=>8], + ] + if s isa Dict{Union{Int,UInt},Int} && a isa Vector{Pair{UInt8,UInt8}} + @test_broken s === copy!(s, Dict(a)) == Dict(a) + continue + end + @test s === copy!(s, Dict(a)) == Dict(a) + if length(a) == 1 # current limitation of Base.ImmutableDict + @test s === copy!(s, Base.ImmutableDict(a[])) == Dict(a[]) + end end end + + @testset "no corruption on failed copy!" begin + dst = Dict{Int,Int}(1 => 2) + # Fails while trying `convert`: + @test_throws MethodError copy!(dst, Dict(1 => "2")) + @test dst == Dict() + end end @testset "map!(f, values(dict))" begin