Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make PersistentDict behave like an IdDict #52193

Merged
merged 3 commits into from
Nov 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions base/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,10 @@ returns a new dictionary separate from the previous one, but the underlying
implementation is space-efficient and may share storage across multiple
separate dictionaries.

!!!note
It behaves like an IdDict.


PersistentDict(KV::Pair)

# Examples
Expand All @@ -922,24 +926,31 @@ Base.PersistentDict{Symbol, Int64} with 1 entry:
PersistentDict

PersistentDict{K,V}() where {K,V} = PersistentDict(HAMT.HAMT{K,V}())
PersistentDict{K,V}(KV::Pair) where {K,V} = PersistentDict(HAMT.HAMT{K,V}(KV...))
PersistentDict(KV::Pair{K,V}) where {K,V} = PersistentDict(HAMT.HAMT{K,V}(KV...))
PersistentDict{K,V}(KV::Pair) where {K,V} = PersistentDict(HAMT.HAMT{K,V}(KV))
PersistentDict(KV::Pair{K,V}) where {K,V} = PersistentDict(HAMT.HAMT{K,V}(KV))
PersistentDict(dict::PersistentDict, pair::Pair) = PersistentDict(dict, pair...)
PersistentDict{K,V}(dict::PersistentDict{K,V}, pair::Pair) where {K,V} = PersistentDict(dict, pair...)
function PersistentDict(dict::PersistentDict{K,V}, key, val) where {K,V}
key = convert(K, key)
val = convert(V, val)
trie = dict.trie
h = hash(key)
h = HAMT.HashState(key)
found, present, trie, i, bi, top, hs = HAMT.path(trie, key, h, #=persistent=# true)
HAMT.insert!(found, present, trie, i, bi, hs, val)
return PersistentDict(top)
end

function PersistentDict{K,V}(KV::Pair, rest::Pair...) where {K,V}
dict = PersistentDict{K,V}(KV)
for (key, value) in rest
dict = PersistentDict(dict, key, value)
end
return dict
end

function PersistentDict(kv::Pair, rest::Pair...)
dict = PersistentDict(kv)
for kv in rest
key, value = kv
for (key, value) in rest
dict = PersistentDict(dict, key, value)
end
return dict
Expand All @@ -955,7 +966,7 @@ function in(key_val::Pair{K,V}, dict::PersistentDict{K,V}, valcmp=(==)) where {K

key, val = key_val

h = hash(key)
h = HAMT.HashState(key)
found, present, trie, i, _, _, _ = HAMT.path(trie, key, h)
if found && present
leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V}
Expand All @@ -966,7 +977,7 @@ end

function haskey(dict::PersistentDict{K}, key::K) where K
trie = dict.trie
h = hash(key)
h = HAMT.HashState(key)
found, present, _, _, _, _, _ = HAMT.path(trie, key, h)
return found && present
end
Expand All @@ -976,7 +987,7 @@ function getindex(dict::PersistentDict{K,V}, key::K) where {K,V}
if HAMT.islevel_empty(trie)
throw(KeyError(key))
end
h = hash(key)
h = HAMT.HashState(key)
found, present, trie, i, _, _, _ = HAMT.path(trie, key, h)
if found && present
leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V}
Expand All @@ -990,7 +1001,7 @@ function get(dict::PersistentDict{K,V}, key::K, default) where {K,V}
if HAMT.islevel_empty(trie)
return default
end
h = hash(key)
h = HAMT.HashState(key)
found, present, trie, i, _, _, _ = HAMT.path(trie, key, h)
if found && present
leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V}
Expand All @@ -1004,7 +1015,7 @@ function get(default::Callable, dict::PersistentDict{K,V}, key::K) where {K,V}
if HAMT.islevel_empty(trie)
return default
end
h = hash(key)
h = HAMT.HashState(key)
found, present, trie, i, _, _, _ = HAMT.path(trie, key, h)
if found && present
leaf = @inbounds trie.data[i]::HAMT.Leaf{K,V}
Expand All @@ -1017,7 +1028,7 @@ iterate(dict::PersistentDict, state=nothing) = HAMT.iterate(dict.trie, state)

function delete(dict::PersistentDict{K}, key::K) where K
trie = dict.trie
h = hash(key)
h = HAMT.HashState(key)
found, present, trie, i, bi, top, _ = HAMT.path(trie, key, h, #=persistent=# true)
if found && present
deleteat!(trie.data, i)
Expand Down
32 changes: 21 additions & 11 deletions base/hamt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,36 +62,47 @@ A HashArrayMappedTrie that optionally supports persistence.
mutable struct HAMT{K, V}
const data::Vector{Union{Leaf{K, V}, HAMT{K, V}}}
bitmap::BITMAP
HAMT{K,V}(data, bitmap) where {K,V} = new{K,V}(data, bitmap)
HAMT{K, V}() where {K, V} = new{K,V}(Vector{Union{Leaf{K, V}, HAMT{K, V}}}(undef, 0), zero(BITMAP))
end
HAMT{K, V}() where {K, V} = HAMT(Vector{Union{Leaf{K, V}, HAMT{K, V}}}(undef, 0), zero(BITMAP))
function HAMT{K,V}(k::K, v) where {K, V}
function HAMT{K,V}((k,v)::Pair) where {K, V}
k = convert(K, k)
v = convert(V, v)
# For a single element we can't have a hash-collision
trie = HAMT(Vector{Union{Leaf{K, V}, HAMT{K, V}}}(undef, 1), zero(BITMAP))
trie = HAMT{K,V}(Vector{Union{Leaf{K, V}, HAMT{K, V}}}(undef, 1), zero(BITMAP))
trie.data[1] = Leaf{K,V}(k,v)
bi = BitmapIndex(HashState(k))
set!(trie, bi)
return trie
end
HAMT(k::K, v::V) where {K, V} = HAMT{K,V}(K, V)
HAMT(pair::Pair{K,V}) where {K, V} = HAMT{K,V}(pair)

# TODO: Parameterize by hash function
struct HashState{K}
key::K
hash::UInt
depth::Int
shift::Int
end
HashState(key)= HashState(key, hash(key), 0, 0)
HashState(key) = HashState(key, objectid(key), 0, 0)
# Reconstruct
HashState(key, depth, shift) = HashState(key, hash(key, UInt(depth ÷ BITS_PER_LEVEL)), depth, shift)
function HashState(other::HashState, key)
h = HashState(key)
while h.depth !== other.depth
h = next(h)
end
return h
end

function next(h::HashState)
depth = h.depth + 1
shift = h.shift + BITS_PER_LEVEL
@assert h.shift <= MAX_SHIFT
if shift > MAX_SHIFT
# Note we use `UInt(depth ÷ BITS_PER_LEVEL)` to seed the hash function
# the hash docs, do we need to hash `UInt(depth ÷ BITS_PER_LEVEL)` first?
h_hash = hash(h.key, UInt(depth ÷ BITS_PER_LEVEL))
h_hash = hash(objectid(h.key), UInt(depth ÷ BITS_PER_LEVEL))
shift = 0
else
h_hash = h.hash
end
Expand Down Expand Up @@ -137,8 +148,7 @@ as the current `level`.
If a copy function is provided `copyf` use the return `top` for the
new persistent tree.
"""
@inline function path(trie::HAMT{K,V}, key, _h, copy=false) where {K, V}
h = HashState(key, _h, 0, 0)
@inline function path(trie::HAMT{K,V}, key, h::HashState, copy=false) where {K, V}
if copy
trie = top = HAMT{K,V}(Base.copy(trie.data), trie.bitmap)
else
Expand All @@ -151,7 +161,7 @@ new persistent tree.
next = @inbounds trie.data[i]
if next isa Leaf{K,V}
# Check if key match if not we will need to grow.
found = (next.key === h.key || isequal(next.key, h.key))
found = next.key === h.key
return found, true, trie, i, bi, top, h
end
if copy
Expand Down Expand Up @@ -184,7 +194,7 @@ or grows the HAMT by inserting a new trie instead.
@assert present
# collision -> grow
leaf = @inbounds trie.data[i]::Leaf{K,V}
leaf_h = HashState(leaf.key, h.depth, h.shift)
leaf_h = HashState(h, leaf.key)
if leaf_h.hash == h.hash
error("Perfect hash collision")
end
Expand Down
93 changes: 37 additions & 56 deletions test/dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1084,17 +1084,30 @@ Dict(1 => rand(2,3), 'c' => "asdf") # just make sure this does not trigger a dep
GC.@preserve A B C D nothing
end

mutable struct CollidingHash
end
Base.hash(::CollidingHash, h::UInt) = hash(UInt(0), h)

struct PredictableHash
x::UInt
end
Base.hash(x::PredictableHash, h::UInt) = x.x

import Base.PersistentDict
@testset "PersistentDict" begin
@testset "HAMT HashState" begin
key = :key
h = Base.HAMT.HashState(key)
h1 = Base.HAMT.HashState(key, objectid(key), 0, 0)
h2 = Base.HAMT.HashState(h, key) # reconstruct
@test h.hash == h1.hash
@test h.hash == h2.hash

hs = Base.HAMT.next(h1)
@test hs.depth == 1
recompute_depth = (Base.HAMT.MAX_SHIFT ÷ Base.HAMT.BITS_PER_LEVEL) + 1
for i in 2:recompute_depth
hs = Base.HAMT.next(hs)
@test hs.depth == i
end
@test hs.depth == recompute_depth
@test hs.shift == 0
hsr = Base.HAMT.HashState(hs, key)
@test hs.hash == hsr.hash
@test hs.depth == hsr.depth
@test hs.shift == hsr.shift
end
@testset "basics" begin
dict = PersistentDict{Int, Int}()
@test_throws KeyError dict[1]
Expand Down Expand Up @@ -1145,6 +1158,21 @@ import Base.PersistentDict
@test dict[4] == 1
end

@testset "objectid" begin
c = [0]
dict = PersistentDict{Any, Int}(c => 1, [1] => 2)
@test dict[c] == 1
c[1] = 1
@test dict[c] == 1

c[1] = 0
dict = PersistentDict{Any, Int}((c,) => 1, ([1],) => 2)
@test dict[(c,)] == 1

c[1] = 1
@test dict[(c,)] == 1
end

@testset "stress" begin
N = 2^14
dict = PersistentDict{Int, Int}()
Expand All @@ -1164,53 +1192,6 @@ import Base.PersistentDict
end
@test isempty(dict)
end

@testset "CollidingHash" begin
dict = PersistentDict{CollidingHash, Nothing}()
dict = PersistentDict(dict, CollidingHash(), nothing)
@test_throws ErrorException PersistentDict(dict, CollidingHash(), nothing)
end

# Test the internal implementation
@testset "PredictableHash" begin
dict = PersistentDict{PredictableHash, Nothing}()
for i in 1:Base.HashArrayMappedTries.ENTRY_COUNT
key = PredictableHash(UInt(i-1)) # Level 0
dict = PersistentDict(dict, key, nothing)
end
@test length(dict.trie.data) == Base.HashArrayMappedTries.ENTRY_COUNT
@test dict.trie.bitmap == typemax(Base.HashArrayMappedTries.BITMAP)
for entry in dict.trie.data
@test entry isa Base.HashArrayMappedTries.Leaf
end

dict = PersistentDict{PredictableHash, Nothing}()
for i in 1:Base.HashArrayMappedTries.ENTRY_COUNT
key = PredictableHash(UInt(i-1) << Base.HashArrayMappedTries.BITS_PER_LEVEL) # Level 1
dict = PersistentDict(dict, key, nothing)
end
@test length(dict.trie.data) == 1
@test length(dict.trie.data[1].data) == 32

max_level = (Base.HashArrayMappedTries.NBITS ÷ Base.HashArrayMappedTries.BITS_PER_LEVEL)
dict = PersistentDict{PredictableHash, Nothing}()
for i in 1:Base.Base.HashArrayMappedTries.ENTRY_COUNT
key = PredictableHash(UInt(i-1) << (max_level * Base.HashArrayMappedTries.BITS_PER_LEVEL)) # Level 12
dict = PersistentDict(dict, key, nothing)
end
data = dict.trie.data
for level in 1:max_level
@test length(data) == 1
data = only(data).data
end
last_level_nbits = Base.HashArrayMappedTries.NBITS - (max_level * Base.HashArrayMappedTries.BITS_PER_LEVEL)
if Base.HashArrayMappedTries.NBITS == 64
@test last_level_nbits == 4
elseif Base.HashArrayMappedTries.NBITS == 32
@test last_level_nbits == 2
end
@test length(data) == 2^last_level_nbits
end
end

@testset "issue #19995, hash of dicts" begin
Expand Down