diff --git a/base/dict.jl b/base/dict.jl index 26896c78a7008..83180f5c0ee1b 100644 --- a/base/dict.jl +++ b/base/dict.jl @@ -902,6 +902,10 @@ returns a new dictionary separate from the previous one, but the underlying implementation is space-efficient and may share storage across multiple separate dictionaries. +!!!note + It behaves like an IdDict. + + PersistentDict(KV::Pair) # Examples @@ -922,24 +926,31 @@ Base.PersistentDict{Symbol, Int64} with 1 entry: PersistentDict PersistentDict{K,V}() where {K,V} = PersistentDict(HAMT.HAMT{K,V}()) -PersistentDict{K,V}(KV::Pair) where {K,V} = PersistentDict(HAMT.HAMT{K,V}(KV...)) -PersistentDict(KV::Pair{K,V}) where {K,V} = PersistentDict(HAMT.HAMT{K,V}(KV...)) +PersistentDict{K,V}(KV::Pair) where {K,V} = PersistentDict(HAMT.HAMT{K,V}(KV)) +PersistentDict(KV::Pair{K,V}) where {K,V} = PersistentDict(HAMT.HAMT{K,V}(KV)) PersistentDict(dict::PersistentDict, pair::Pair) = PersistentDict(dict, pair...) PersistentDict{K,V}(dict::PersistentDict{K,V}, pair::Pair) where {K,V} = PersistentDict(dict, pair...) function PersistentDict(dict::PersistentDict{K,V}, key, val) where {K,V} key = convert(K, key) val = convert(V, val) trie = dict.trie - h = hash(key) + h = HAMT.HashState(key) found, present, trie, i, bi, top, hs = HAMT.path(trie, key, h, #=persistent=# true) HAMT.insert!(found, present, trie, i, bi, hs, val) return PersistentDict(top) end +function PersistentDict{K,V}(KV::Pair, rest::Pair...) where {K,V} + dict = PersistentDict{K,V}(KV) + for (key, value) in rest + dict = PersistentDict(dict, key, value) + end + return dict +end + function PersistentDict(kv::Pair, rest::Pair...) dict = PersistentDict(kv) - for kv in rest - key, value = kv + for (key, value) in rest dict = PersistentDict(dict, key, value) end return dict @@ -955,7 +966,7 @@ function in(key_val::Pair{K,V}, dict::PersistentDict{K,V}, valcmp=(==)) where {K key, val = key_val - h = hash(key) + h = HAMT.HashState(key) found, present, trie, i, _, _, _ = HAMT.path(trie, key, h) if found && present leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V} @@ -966,7 +977,7 @@ end function haskey(dict::PersistentDict{K}, key::K) where K trie = dict.trie - h = hash(key) + h = HAMT.HashState(key) found, present, _, _, _, _, _ = HAMT.path(trie, key, h) return found && present end @@ -976,7 +987,7 @@ function getindex(dict::PersistentDict{K,V}, key::K) where {K,V} if HAMT.islevel_empty(trie) throw(KeyError(key)) end - h = hash(key) + h = HAMT.HashState(key) found, present, trie, i, _, _, _ = HAMT.path(trie, key, h) if found && present leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V} @@ -990,7 +1001,7 @@ function get(dict::PersistentDict{K,V}, key::K, default) where {K,V} if HAMT.islevel_empty(trie) return default end - h = hash(key) + h = HAMT.HashState(key) found, present, trie, i, _, _, _ = HAMT.path(trie, key, h) if found && present leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V} @@ -1004,7 +1015,7 @@ function get(default::Callable, dict::PersistentDict{K,V}, key::K) where {K,V} if HAMT.islevel_empty(trie) return default end - h = hash(key) + h = HAMT.HashState(key) found, present, trie, i, _, _, _ = HAMT.path(trie, key, h) if found && present leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V} @@ -1017,7 +1028,7 @@ iterate(dict::PersistentDict, state=nothing) = HAMT.iterate(dict.trie, state) function delete(dict::PersistentDict{K}, key::K) where K trie = dict.trie - h = hash(key) + h = HAMT.HashState(key) found, present, trie, i, bi, top, _ = HAMT.path(trie, key, h, #=persistent=# true) if found && present deleteat!(trie.data, i) diff --git a/base/hamt.jl b/base/hamt.jl index d801352fce6c5..e940f4e00b1d5 100644 --- a/base/hamt.jl +++ b/base/hamt.jl @@ -62,36 +62,47 @@ A HashArrayMappedTrie that optionally supports persistence. mutable struct HAMT{K, V} const data::Vector{Union{Leaf{K, V}, HAMT{K, V}}} bitmap::BITMAP + HAMT{K,V}(data, bitmap) where {K,V} = new{K,V}(data, bitmap) + HAMT{K, V}() where {K, V} = new{K,V}(Vector{Union{Leaf{K, V}, HAMT{K, V}}}(undef, 0), zero(BITMAP)) end -HAMT{K, V}() where {K, V} = HAMT(Vector{Union{Leaf{K, V}, HAMT{K, V}}}(undef, 0), zero(BITMAP)) -function HAMT{K,V}(k::K, v) where {K, V} +function HAMT{K,V}((k,v)::Pair) where {K, V} + k = convert(K, k) v = convert(V, v) # For a single element we can't have a hash-collision - trie = HAMT(Vector{Union{Leaf{K, V}, HAMT{K, V}}}(undef, 1), zero(BITMAP)) + trie = HAMT{K,V}(Vector{Union{Leaf{K, V}, HAMT{K, V}}}(undef, 1), zero(BITMAP)) trie.data[1] = Leaf{K,V}(k,v) bi = BitmapIndex(HashState(k)) set!(trie, bi) return trie end -HAMT(k::K, v::V) where {K, V} = HAMT{K,V}(K, V) +HAMT(pair::Pair{K,V}) where {K, V} = HAMT{K,V}(pair) +# TODO: Parameterize by hash function struct HashState{K} key::K hash::UInt depth::Int shift::Int end -HashState(key)= HashState(key, hash(key), 0, 0) +HashState(key) = HashState(key, objectid(key), 0, 0) # Reconstruct -HashState(key, depth, shift) = HashState(key, hash(key, UInt(depth ÷ BITS_PER_LEVEL)), depth, shift) +function HashState(other::HashState, key) + h = HashState(key) + while h.depth !== other.depth + h = next(h) + end + return h +end function next(h::HashState) depth = h.depth + 1 shift = h.shift + BITS_PER_LEVEL + @assert h.shift <= MAX_SHIFT if shift > MAX_SHIFT # Note we use `UInt(depth ÷ BITS_PER_LEVEL)` to seed the hash function # the hash docs, do we need to hash `UInt(depth ÷ BITS_PER_LEVEL)` first? - h_hash = hash(h.key, UInt(depth ÷ BITS_PER_LEVEL)) + h_hash = hash(objectid(h.key), UInt(depth ÷ BITS_PER_LEVEL)) + shift = 0 else h_hash = h.hash end @@ -137,8 +148,7 @@ as the current `level`. If a copy function is provided `copyf` use the return `top` for the new persistent tree. """ -@inline function path(trie::HAMT{K,V}, key, _h, copy=false) where {K, V} - h = HashState(key, _h, 0, 0) +@inline function path(trie::HAMT{K,V}, key, h::HashState, copy=false) where {K, V} if copy trie = top = HAMT{K,V}(Base.copy(trie.data), trie.bitmap) else @@ -151,7 +161,7 @@ new persistent tree. next = @inbounds trie.data[i] if next isa Leaf{K,V} # Check if key match if not we will need to grow. - found = (next.key === h.key || isequal(next.key, h.key)) + found = next.key === h.key return found, true, trie, i, bi, top, h end if copy @@ -184,7 +194,7 @@ or grows the HAMT by inserting a new trie instead. @assert present # collision -> grow leaf = @inbounds trie.data[i]::Leaf{K,V} - leaf_h = HashState(leaf.key, h.depth, h.shift) + leaf_h = HashState(h, leaf.key) if leaf_h.hash == h.hash error("Perfect hash collision") end diff --git a/test/dict.jl b/test/dict.jl index d3d1c3732eee2..d5ae88735dd43 100644 --- a/test/dict.jl +++ b/test/dict.jl @@ -1084,17 +1084,30 @@ Dict(1 => rand(2,3), 'c' => "asdf") # just make sure this does not trigger a dep GC.@preserve A B C D nothing end -mutable struct CollidingHash -end -Base.hash(::CollidingHash, h::UInt) = hash(UInt(0), h) - -struct PredictableHash - x::UInt -end -Base.hash(x::PredictableHash, h::UInt) = x.x - import Base.PersistentDict @testset "PersistentDict" begin + @testset "HAMT HashState" begin + key = :key + h = Base.HAMT.HashState(key) + h1 = Base.HAMT.HashState(key, objectid(key), 0, 0) + h2 = Base.HAMT.HashState(h, key) # reconstruct + @test h.hash == h1.hash + @test h.hash == h2.hash + + hs = Base.HAMT.next(h1) + @test hs.depth == 1 + recompute_depth = (Base.HAMT.MAX_SHIFT ÷ Base.HAMT.BITS_PER_LEVEL) + 1 + for i in 2:recompute_depth + hs = Base.HAMT.next(hs) + @test hs.depth == i + end + @test hs.depth == recompute_depth + @test hs.shift == 0 + hsr = Base.HAMT.HashState(hs, key) + @test hs.hash == hsr.hash + @test hs.depth == hsr.depth + @test hs.shift == hsr.shift + end @testset "basics" begin dict = PersistentDict{Int, Int}() @test_throws KeyError dict[1] @@ -1145,6 +1158,21 @@ import Base.PersistentDict @test dict[4] == 1 end + @testset "objectid" begin + c = [0] + dict = PersistentDict{Any, Int}(c => 1, [1] => 2) + @test dict[c] == 1 + c[1] = 1 + @test dict[c] == 1 + + c[1] = 0 + dict = PersistentDict{Any, Int}((c,) => 1, ([1],) => 2) + @test dict[(c,)] == 1 + + c[1] = 1 + @test dict[(c,)] == 1 + end + @testset "stress" begin N = 2^14 dict = PersistentDict{Int, Int}() @@ -1164,53 +1192,6 @@ import Base.PersistentDict end @test isempty(dict) end - - @testset "CollidingHash" begin - dict = PersistentDict{CollidingHash, Nothing}() - dict = PersistentDict(dict, CollidingHash(), nothing) - @test_throws ErrorException PersistentDict(dict, CollidingHash(), nothing) - end - - # Test the internal implementation - @testset "PredictableHash" begin - dict = PersistentDict{PredictableHash, Nothing}() - for i in 1:Base.HashArrayMappedTries.ENTRY_COUNT - key = PredictableHash(UInt(i-1)) # Level 0 - dict = PersistentDict(dict, key, nothing) - end - @test length(dict.trie.data) == Base.HashArrayMappedTries.ENTRY_COUNT - @test dict.trie.bitmap == typemax(Base.HashArrayMappedTries.BITMAP) - for entry in dict.trie.data - @test entry isa Base.HashArrayMappedTries.Leaf - end - - dict = PersistentDict{PredictableHash, Nothing}() - for i in 1:Base.HashArrayMappedTries.ENTRY_COUNT - key = PredictableHash(UInt(i-1) << Base.HashArrayMappedTries.BITS_PER_LEVEL) # Level 1 - dict = PersistentDict(dict, key, nothing) - end - @test length(dict.trie.data) == 1 - @test length(dict.trie.data[1].data) == 32 - - max_level = (Base.HashArrayMappedTries.NBITS ÷ Base.HashArrayMappedTries.BITS_PER_LEVEL) - dict = PersistentDict{PredictableHash, Nothing}() - for i in 1:Base.Base.HashArrayMappedTries.ENTRY_COUNT - key = PredictableHash(UInt(i-1) << (max_level * Base.HashArrayMappedTries.BITS_PER_LEVEL)) # Level 12 - dict = PersistentDict(dict, key, nothing) - end - data = dict.trie.data - for level in 1:max_level - @test length(data) == 1 - data = only(data).data - end - last_level_nbits = Base.HashArrayMappedTries.NBITS - (max_level * Base.HashArrayMappedTries.BITS_PER_LEVEL) - if Base.HashArrayMappedTries.NBITS == 64 - @test last_level_nbits == 4 - elseif Base.HashArrayMappedTries.NBITS == 32 - @test last_level_nbits == 2 - end - @test length(data) == 2^last_level_nbits - end end @testset "issue #19995, hash of dicts" begin