From 86b246f18c38a75a9f22323b65d123042ff01503 Mon Sep 17 00:00:00 2001 From: Dmitriy Matrenichev Date: Mon, 20 May 2024 01:00:18 +0300 Subject: [PATCH] chore: add hashtriemap implementation This PR adds a hashtriemap implementation to the project. This is mostly a direct copy from unique package [^1] internals [^2] with some adjustments to make it work with our project. Once `maphash.Comparable` lands we can remove dependency on Go internals [^3]. Michael Knyszek says he plans to use this implementation as a general idea for the future generic sync/v2.Map, but nothing is in stone yet. [^1]: https://github.com/golang/go/issues/62483 [^2]: https://go-review.googlesource.com/c/go/+/573956 [^3]: https://github.com/golang/go/issues/54670 Signed-off-by: Dmitriy Matrenichev --- concurrent/go122.go | 74 +++++ concurrent/go122_bench_test.go | 71 +++++ concurrent/go122_test.go | 35 +++ concurrent/hashtriemap.go | 428 +++++++++++++++++++++++++++ concurrent/hashtriemap_bench_test.go | 69 +++++ concurrent/hashtriemap_test.go | 388 ++++++++++++++++++++++++ 6 files changed, 1065 insertions(+) create mode 100644 concurrent/go122.go create mode 100644 concurrent/go122_bench_test.go create mode 100644 concurrent/go122_test.go create mode 100644 concurrent/hashtriemap.go create mode 100644 concurrent/hashtriemap_bench_test.go create mode 100644 concurrent/hashtriemap_test.go diff --git a/concurrent/go122.go b/concurrent/go122.go new file mode 100644 index 0000000..87cb5da --- /dev/null +++ b/concurrent/go122.go @@ -0,0 +1,74 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.22 && !go1.24 && !nomaptypehash + +//nolint:revive,govet,stylecheck +package concurrent + +import ( + "math/rand/v2" + "unsafe" +) + +// NewHashTrieMap creates a new HashTrieMap for the provided key and value. +func NewHashTrieMap[K, V comparable]() *HashTrieMap[K, V] { + var m map[K]V + + mapType := efaceMapOf(m) + ht := &HashTrieMap[K, V]{ + root: newIndirectNode[K, V](nil), + keyHash: mapType._type.Hasher, + seed: uintptr(rand.Uint64()), + } + return ht +} + +// _MapType is runtime.maptype from runtime/type.go. +type _MapType struct { + _Type + Key *_Type + Elem *_Type + Bucket *_Type // internal type representing a hash bucket + // function for hashing keys (ptr to key, seed) -> hash + Hasher func(unsafe.Pointer, uintptr) uintptr + KeySize uint8 // size of key slot + ValueSize uint8 // size of elem slot + BucketSize uint16 // size of bucket + Flags uint32 +} + +// _Type is runtime._type from runtime/type.go. +type _Type struct { + Size_ uintptr + PtrBytes uintptr // number of (prefix) bytes in the type that can contain pointers + Hash uint32 // hash of type; avoids computation in hash tables + TFlag uint8 // extra type information flags + Align_ uint8 // alignment of variable with this type + FieldAlign_ uint8 // alignment of struct field with this type + Kind_ uint8 // enumeration for C + // function for comparing objects of this type + // (ptr to object A, ptr to object B) -> ==? + Equal func(unsafe.Pointer, unsafe.Pointer) bool + // GCData stores the GC type data for the garbage collector. + // If the KindGCProg bit is set in kind, GCData is a GC program. + // Otherwise it is a ptrmask bitmap. See mbitmap.go for details. + GCData *byte + Str int32 // string form + PtrToThis int32 // type for pointer to this type, may be zero +} + +// efaceMap is runtime.eface from runtime/runtime2.go. +type efaceMap struct { + _type *_MapType + data unsafe.Pointer +} + +func efaceMapOf(ep any) *efaceMap { + return (*efaceMap)(unsafe.Pointer(&ep)) +} diff --git a/concurrent/go122_bench_test.go b/concurrent/go122_bench_test.go new file mode 100644 index 0000000..7c9573b --- /dev/null +++ b/concurrent/go122_bench_test.go @@ -0,0 +1,71 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.22 && !go1.24 && !nomaptypehash + +//nolint:wsl,testpackage +package concurrent + +import ( + "testing" +) + +func BenchmarkHashTrieMapLoadSmall(b *testing.B) { + benchmarkHashTrieMapLoad(b, testDataSmall[:]) +} + +func BenchmarkHashTrieMapLoad(b *testing.B) { + benchmarkHashTrieMapLoad(b, testData[:]) +} + +func BenchmarkHashTrieMapLoadLarge(b *testing.B) { + benchmarkHashTrieMapLoad(b, testDataLarge[:]) +} + +func benchmarkHashTrieMapLoad(b *testing.B, data []string) { + b.ReportAllocs() + m := NewHashTrieMap[string, int]() + for i := range data { + m.LoadOrStore(data[i], i) + } + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + _, _ = m.Load(data[i]) + i++ + if i >= len(data) { + i = 0 + } + } + }) +} + +func BenchmarkHashTrieMapLoadOrStore(b *testing.B) { + benchmarkHashTrieMapLoadOrStore(b, testData[:]) +} + +func BenchmarkHashTrieMapLoadOrStoreLarge(b *testing.B) { + benchmarkHashTrieMapLoadOrStore(b, testDataLarge[:]) +} + +func benchmarkHashTrieMapLoadOrStore(b *testing.B, data []string) { + b.ReportAllocs() + m := NewHashTrieMap[string, int]() + + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + _, _ = m.LoadOrStore(data[i], i) + i++ + if i >= len(data) { + i = 0 + } + } + }) +} diff --git a/concurrent/go122_test.go b/concurrent/go122_test.go new file mode 100644 index 0000000..67df644 --- /dev/null +++ b/concurrent/go122_test.go @@ -0,0 +1,35 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.22 && !go1.24 && !nomaptypehash + +//nolint:revive,nlreturn,wsl,gocyclo,unparam,unused,cyclop,testpackage +package concurrent + +import ( + "testing" + "unsafe" +) + +func TestHashTrieMap(t *testing.T) { + testHashTrieMap(t, func() *HashTrieMap[string, int] { + return NewHashTrieMap[string, int]() + }) +} + +func TestHashTrieMapBadHash(t *testing.T) { + testHashTrieMap(t, func() *HashTrieMap[string, int] { + // Stub out the good hash function with a terrible one. + // Everything should still work as expected. + m := NewHashTrieMap[string, int]() + m.keyHash = func(_ unsafe.Pointer, _ uintptr) uintptr { + return 0 + } + return m + }) +} diff --git a/concurrent/hashtriemap.go b/concurrent/hashtriemap.go new file mode 100644 index 0000000..365daa8 --- /dev/null +++ b/concurrent/hashtriemap.go @@ -0,0 +1,428 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//nolint:revive,govet,nakedret,nlreturn,stylecheck,wsl,staticcheck +package concurrent + +import ( + "hash/maphash" + "math/rand/v2" + "sync" + "sync/atomic" + "unsafe" +) + +// HashTrieMap is an implementation of a concurrent hash-trie. The implementation +// is designed around frequent loads, but offers decent performance for stores +// and deletes as well, especially if the map is larger. It's primary use-case is +// the unique package, but can be used elsewhere as well. +type HashTrieMap[K, V comparable] struct { + root *indirect[K, V] + keyHash hashFunc + seed uintptr +} + +// NewHashTrieMapHasher creates a new HashTrieMap for the provided key and value and uses +// the provided hasher function to hash keys. +func NewHashTrieMapHasher[K, V comparable](hasher func(*K, maphash.Seed) uintptr) *HashTrieMap[K, V] { + ht := &HashTrieMap[K, V]{ + root: newIndirectNode[K, V](nil), + keyHash: func(pointer unsafe.Pointer, u uintptr) uintptr { + return hasher((*K)(pointer), seedToSeed(u)) + }, + seed: uintptr(rand.Uint64()), + } + return ht +} + +type hashFunc func(unsafe.Pointer, uintptr) uintptr + +// Load returns the value stored in the map for a key, or nil if no +// value is present. +// The ok result indicates whether value was found in the map. +func (ht *HashTrieMap[K, V]) Load(key K) (value V, ok bool) { + hash := ht.keyHash(noescape(unsafe.Pointer(&key)), ht.seed) + + i := ht.root + hashShift := 8 * ptrSize + for hashShift != 0 { + hashShift -= nChildrenLog2 + + n := i.children[(hash>>hashShift)&nChildrenMask].Load() + if n == nil { + return *new(V), false + } + if n.isEntry { + return n.entry().lookup(key) + } + i = n.indirect() + } + panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating") +} + +// LoadOrStore returns the existing value for the key if present. +// Otherwise, it stores and returns the given value. +// The loaded result is true if the value was loaded, false if stored. +func (ht *HashTrieMap[K, V]) LoadOrStore(key K, value V) (result V, loaded bool) { + hash := ht.keyHash(noescape(unsafe.Pointer(&key)), ht.seed) + var i *indirect[K, V] + var hashShift uint + var slot *atomic.Pointer[node[K, V]] + var n *node[K, V] + for { + // Find the key or a candidate location for insertion. + i = ht.root + hashShift = 8 * ptrSize + for hashShift != 0 { + hashShift -= nChildrenLog2 + + slot = &i.children[(hash>>hashShift)&nChildrenMask] + n = slot.Load() + if n == nil { + // We found a nil slot which is a candidate for insertion. + break + } + if n.isEntry { + // We found an existing entry, which is as far as we can go. + // If it stays this way, we'll have to replace it with an + // indirect node. + if v, ok := n.entry().lookup(key); ok { + return v, true + } + break + } + i = n.indirect() + } + if hashShift == 0 { + panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating") + } + + // Grab the lock and double-check what we saw. + i.mu.Lock() + n = slot.Load() + if (n == nil || n.isEntry) && !i.dead.Load() { + // What we saw is still true, so we can continue with the insert. + break + } + // We have to start over. + i.mu.Unlock() + } + // N.B. This lock is held from when we broke out of the outer loop above. + // We specifically break this out so that we can use defer here safely. + // One option is to break this out into a new function instead, but + // there's so much local iteration state used below that this turns out + // to be cleaner. + defer i.mu.Unlock() + + var oldEntry *entry[K, V] + if n != nil { + oldEntry = n.entry() + if v, ok := oldEntry.lookup(key); ok { + // Easy case: by loading again, it turns out exactly what we wanted is here! + return v, true + } + } + newEntry := newEntryNode(key, value) + if oldEntry == nil { + // Easy case: create a new entry and store it. + slot.Store(&newEntry.node) + } else { + // We possibly need to expand the entry already there into one or more new nodes. + // + // Publish the node last, which will make both oldEntry and newEntry visible. We + // don't want readers to be able to observe that oldEntry isn't in the tree. + slot.Store(ht.expand(oldEntry, newEntry, hash, hashShift, i)) + } + return value, false +} + +// expand takes oldEntry and newEntry whose hashes conflict from bit 64 down to hashShift and +// produces a subtree of indirect nodes to hold the two new entries. +func (ht *HashTrieMap[K, V]) expand(oldEntry, newEntry *entry[K, V], newHash uintptr, hashShift uint, parent *indirect[K, V]) *node[K, V] { + // Check for a hash collision. + oldHash := ht.keyHash(unsafe.Pointer(&oldEntry.key), ht.seed) + if oldHash == newHash { + // Store the old entry in the new entry's overflow list, then store + // the new entry. + newEntry.overflow.Store(oldEntry) + return &newEntry.node + } + // We have to add an indirect node. Worse still, we may need to add more than one. + newIndirect := newIndirectNode(parent) + top := newIndirect + for { + if hashShift == 0 { + panic("internal/concurrent.HashMapTrie: ran out of hash bits while inserting") + } + hashShift -= nChildrenLog2 // hashShift is for the level parent is at. We need to go deeper. + oi := (oldHash >> hashShift) & nChildrenMask + ni := (newHash >> hashShift) & nChildrenMask + if oi != ni { + newIndirect.children[oi].Store(&oldEntry.node) + newIndirect.children[ni].Store(&newEntry.node) + break + } + nextIndirect := newIndirectNode(newIndirect) + newIndirect.children[oi].Store(&nextIndirect.node) + newIndirect = nextIndirect + } + return &top.node +} + +// CompareAndDelete deletes the entry for key if its value is equal to old. +// +// If there is no current value for key in the map, CompareAndDelete returns false +// (even if the old value is the nil interface value). +func (ht *HashTrieMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) { + hash := ht.keyHash(noescape(unsafe.Pointer(&key)), ht.seed) + var i *indirect[K, V] + var hashShift uint + var slot *atomic.Pointer[node[K, V]] + var n *node[K, V] + for { + // Find the key or return when there's nothing to delete. + i = ht.root + hashShift = 8 * ptrSize + for hashShift != 0 { + hashShift -= nChildrenLog2 + + slot = &i.children[(hash>>hashShift)&nChildrenMask] + n = slot.Load() + if n == nil { + // Nothing to delete. Give up. + return + } + if n.isEntry { + // We found an entry. Check if it matches. + if _, ok := n.entry().lookup(key); !ok { + // No match, nothing to delete. + return + } + // We've got something to delete. + break + } + i = n.indirect() + } + if hashShift == 0 { + panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating") + } + + // Grab the lock and double-check what we saw. + i.mu.Lock() + n = slot.Load() + if !i.dead.Load() { + if n == nil { + // Valid node that doesn't contain what we need. Nothing to delete. + i.mu.Unlock() + return + } + if n.isEntry { + // What we saw is still true, so we can continue with the delete. + break + } + } + // We have to start over. + i.mu.Unlock() + } + // Try to delete the entry. + e, deleted := n.entry().compareAndDelete(key, old) + if !deleted { + // Nothing was actually deleted, which means the node is no longer there. + i.mu.Unlock() + return false + } + if e != nil { + // We didn't actually delete the whole entry, just one entry in the chain. + // Nothing else to do, since the parent is definitely not empty. + slot.Store(&e.node) + i.mu.Unlock() + return true + } + // Delete the entry. + slot.Store(nil) + + // Check if the node is now empty (and isn't the root), and delete it if able. + for i.parent != nil && i.empty() { + if hashShift == 64 { + panic("internal/concurrent.HashMapTrie: ran out of hash bits while iterating") + } + hashShift += nChildrenLog2 + + // Delete the current node in the parent. + parent := i.parent + parent.mu.Lock() + i.dead.Store(true) + parent.children[(hash>>hashShift)&nChildrenMask].Store(nil) + i.mu.Unlock() + i = parent + } + i.mu.Unlock() + return true +} + +// Enumerate produces all key-value pairs in the map. The enumeration does +// not represent any consistent snapshot of the map, but is guaranteed +// to visit each unique key-value pair only once. It is safe to operate +// on the tree during iteration. No particular enumeration order is +// guaranteed. +func (ht *HashTrieMap[K, V]) Enumerate(yield func(key K, value V) bool) { + ht.iter(ht.root, yield) +} + +func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V) bool) bool { + for j := range i.children { + n := i.children[j].Load() + if n == nil { + continue + } + if !n.isEntry { + if !ht.iter(n.indirect(), yield) { + return false + } + continue + } + e := n.entry() + for e != nil { + if !yield(e.key, e.value) { + return false + } + e = e.overflow.Load() + } + } + return true +} + +const ( + // 16 children. This seems to be the sweet spot for + // load performance: any smaller and we lose out on + // 50% or more in CPU performance. Any larger and the + // returns are miniscule (~1% improvement for 32 children). + nChildrenLog2 = 4 + nChildren = 1 << nChildrenLog2 + nChildrenMask = nChildren - 1 +) + +// indirect is an internal node in the hash-trie. +type indirect[K, V comparable] struct { + node[K, V] + dead atomic.Bool + mu sync.Mutex // Protects mutation to children and any children that are entry nodes. + parent *indirect[K, V] + children [nChildren]atomic.Pointer[node[K, V]] +} + +func newIndirectNode[K, V comparable](parent *indirect[K, V]) *indirect[K, V] { + return &indirect[K, V]{node: node[K, V]{isEntry: false}, parent: parent} +} + +func (i *indirect[K, V]) empty() bool { + nc := 0 + for j := range i.children { + if i.children[j].Load() != nil { + nc++ + } + } + return nc == 0 +} + +// entry is a leaf node in the hash-trie. +type entry[K, V comparable] struct { + node[K, V] + overflow atomic.Pointer[entry[K, V]] // Overflow for hash collisions. + key K + value V +} + +func newEntryNode[K, V comparable](key K, value V) *entry[K, V] { + return &entry[K, V]{ + node: node[K, V]{isEntry: true}, + key: key, + value: value, + } +} + +func (e *entry[K, V]) lookup(key K) (V, bool) { + for e != nil { + if e.key == key { + return e.value, true + } + e = e.overflow.Load() + } + return *new(V), false +} + +// compareAndDelete deletes an entry in the overflow chain if both the key and value compare +// equal. Returns the new entry chain and whether or not anything was deleted. +// +// compareAndDelete must be called under the mutex of the indirect node which e is a child of. +func (head *entry[K, V]) compareAndDelete(key K, value V) (*entry[K, V], bool) { + if head.key == key && head.value == value { + // Drop the head of the list. + return head.overflow.Load(), true + } + i := &head.overflow + e := i.Load() + for e != nil { + if e.key == key && e.value == value { + i.Store(e.overflow.Load()) + return head, true + } + i = &e.overflow + e = e.overflow.Load() + } + return head, false +} + +// node is the header for a node. It's polymorphic and +// is actually either an entry or an indirect. +type node[K, V comparable] struct { + isEntry bool +} + +func (n *node[K, V]) entry() *entry[K, V] { + if !n.isEntry { + panic("called entry on non-entry node") + } + return (*entry[K, V])(unsafe.Pointer(n)) +} + +func (n *node[K, V]) indirect() *indirect[K, V] { + if n.isEntry { + panic("called indirect on entry node") + } + return (*indirect[K, V])(unsafe.Pointer(n)) +} + +// noescape hides a pointer from escape analysis. It is the identity function +// but escape analysis doesn't think the output depends on the input. +// noescape is inlined and currently compiles down to zero instructions. +// USE CAREFULLY! +// This was copied from the runtime. +// +//go:nosplit +func noescape(p unsafe.Pointer) unsafe.Pointer { + x := uintptr(p) + return unsafe.Pointer(x ^ 0) +} + +const ptrSize = uint(unsafe.Sizeof(uintptr(0))) + +func init() { + // copied from [maphash.Seed] + type seedType struct { + s uint64 + } + + if unsafe.Sizeof(seedType{}) != unsafe.Sizeof(maphash.Seed{}) { + panic("concurrent: seedType is not maphash.Seed") + } +} + +func seedToSeed(val uintptr) maphash.Seed { + return *(*maphash.Seed)(unsafe.Pointer(&val)) +} diff --git a/concurrent/hashtriemap_bench_test.go b/concurrent/hashtriemap_bench_test.go new file mode 100644 index 0000000..9fd5b31 --- /dev/null +++ b/concurrent/hashtriemap_bench_test.go @@ -0,0 +1,69 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//nolint:wsl,testpackage +package concurrent + +import ( + "testing" +) + +func BenchmarkHashTrieMapHasherLoadSmall(b *testing.B) { + benchmarkHashTrieMapHasherLoad(b, testDataSmall[:]) +} + +func BenchmarkHashTrieMapHasherLoad(b *testing.B) { + benchmarkHashTrieMapHasherLoad(b, testData[:]) +} + +func BenchmarkHashTrieMapHasherLoadLarge(b *testing.B) { + benchmarkHashTrieMapHasherLoad(b, testDataLarge[:]) +} + +func benchmarkHashTrieMapHasherLoad(b *testing.B, data []string) { + b.ReportAllocs() + m := NewHashTrieMapHasher[string, int](stringHasher) + for i := range data { + m.LoadOrStore(data[i], i) + } + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + _, _ = m.Load(data[i]) + i++ + if i >= len(data) { + i = 0 + } + } + }) +} + +func BenchmarkHashTrieMapHasherLoadOrStore(b *testing.B) { + benchmarkHashTrieMapHasherLoadOrStore(b, testData[:]) +} + +func BenchmarkHashTrieMapHasherLoadOrStoreLarge(b *testing.B) { + benchmarkHashTrieMapHasherLoadOrStore(b, testDataLarge[:]) +} + +func benchmarkHashTrieMapHasherLoadOrStore(b *testing.B, data []string) { + b.ReportAllocs() + m := NewHashTrieMapHasher[string, int](stringHasher) + + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + _, _ = m.LoadOrStore(data[i], i) + i++ + if i >= len(data) { + i = 0 + } + } + }) +} diff --git a/concurrent/hashtriemap_test.go b/concurrent/hashtriemap_test.go new file mode 100644 index 0000000..d49f720 --- /dev/null +++ b/concurrent/hashtriemap_test.go @@ -0,0 +1,388 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//nolint:revive,nlreturn,wsl,gocyclo,unparam,unused,cyclop,testpackage +package concurrent + +import ( + "fmt" + "hash/maphash" + "math" + "runtime" + "strconv" + "strings" + "sync" + "testing" + "unsafe" +) + +func stringHasher(val *string, seed maphash.Seed) uintptr { + var hm maphash.Hash + + hm.SetSeed(seed) + + hm.WriteString(*val) + + return uintptr(hm.Sum64()) +} + +func TestHashTrieMapHasher(t *testing.T) { + testHashTrieMap(t, func() *HashTrieMap[string, int] { + return NewHashTrieMapHasher[string, int](stringHasher) + }) +} + +func TestNewHashTrieMapHasherBadHash(t *testing.T) { + testHashTrieMap(t, func() *HashTrieMap[string, int] { + // Stub out the good hash function with a terrible one. + // Everything should still work as expected. + m := NewHashTrieMap[string, int]() + m.keyHash = func(_ unsafe.Pointer, _ uintptr) uintptr { + return 0 + } + return m + }) +} + +//nolint:gocognit +func testHashTrieMap(t *testing.T, newMap func() *HashTrieMap[string, int]) { + t.Run("LoadEmpty", func(t *testing.T) { + m := newMap() + + for _, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + } + }) + t.Run("LoadOrStore", func(t *testing.T) { + m := newMap() + + for i, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + expectStored(t, s, i)(m.LoadOrStore(s, i)) + expectPresent(t, s, i)(m.Load(s)) + expectLoaded(t, s, i)(m.LoadOrStore(s, 0)) + } + for i, s := range testData { + expectPresent(t, s, i)(m.Load(s)) + expectLoaded(t, s, i)(m.LoadOrStore(s, 0)) + } + }) + t.Run("CompareAndDeleteAll", func(t *testing.T) { + m := newMap() + + for range 3 { + for i, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + expectStored(t, s, i)(m.LoadOrStore(s, i)) + expectPresent(t, s, i)(m.Load(s)) + expectLoaded(t, s, i)(m.LoadOrStore(s, 0)) + } + for i, s := range testData { + expectPresent(t, s, i)(m.Load(s)) + expectNotDeleted(t, s, math.MaxInt)(m.CompareAndDelete(s, math.MaxInt)) + expectDeleted(t, s, i)(m.CompareAndDelete(s, i)) + expectNotDeleted(t, s, i)(m.CompareAndDelete(s, i)) + expectMissing(t, s, 0)(m.Load(s)) + } + for _, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + } + } + }) + t.Run("CompareAndDeleteOne", func(t *testing.T) { + m := newMap() + + for i, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + expectStored(t, s, i)(m.LoadOrStore(s, i)) + expectPresent(t, s, i)(m.Load(s)) + expectLoaded(t, s, i)(m.LoadOrStore(s, 0)) + } + expectNotDeleted(t, testData[15], math.MaxInt)(m.CompareAndDelete(testData[15], math.MaxInt)) + expectDeleted(t, testData[15], 15)(m.CompareAndDelete(testData[15], 15)) + expectNotDeleted(t, testData[15], 15)(m.CompareAndDelete(testData[15], 15)) + for i, s := range testData { + if i == 15 { + expectMissing(t, s, 0)(m.Load(s)) + } else { + expectPresent(t, s, i)(m.Load(s)) + } + } + }) + t.Run("DeleteMultiple", func(t *testing.T) { + m := newMap() + + for i, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + expectStored(t, s, i)(m.LoadOrStore(s, i)) + expectPresent(t, s, i)(m.Load(s)) + expectLoaded(t, s, i)(m.LoadOrStore(s, 0)) + } + for _, i := range []int{1, 105, 6, 85} { + expectNotDeleted(t, testData[i], math.MaxInt)(m.CompareAndDelete(testData[i], math.MaxInt)) + expectDeleted(t, testData[i], i)(m.CompareAndDelete(testData[i], i)) + expectNotDeleted(t, testData[i], i)(m.CompareAndDelete(testData[i], i)) + } + for i, s := range testData { + if i == 1 || i == 105 || i == 6 || i == 85 { + expectMissing(t, s, 0)(m.Load(s)) + } else { + expectPresent(t, s, i)(m.Load(s)) + } + } + }) + t.Run("Enumerate", func(t *testing.T) { + m := newMap() + + testEnumerate(t, m, testDataMap(testData[:]), func(_ string, _ int) bool { + return true + }) + }) + t.Run("EnumerateDelete", func(t *testing.T) { + m := newMap() + + testEnumerate(t, m, testDataMap(testData[:]), func(s string, i int) bool { + expectDeleted(t, s, i)(m.CompareAndDelete(s, i)) + return true + }) + for _, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + } + }) + t.Run("ConcurrentLifecycleUnsharedKeys", func(t *testing.T) { + m := newMap() + + gmp := runtime.GOMAXPROCS(-1) + var wg sync.WaitGroup + for i := range gmp { + wg.Add(1) + go func(id int) { + defer wg.Done() + + makeKey := func(s string) string { + return s + "-" + strconv.Itoa(id) + } + for _, s := range testData { + key := makeKey(s) + expectMissing(t, key, 0)(m.Load(key)) + expectStored(t, key, id)(m.LoadOrStore(key, id)) + expectPresent(t, key, id)(m.Load(key)) + expectLoaded(t, key, id)(m.LoadOrStore(key, 0)) + } + for _, s := range testData { + key := makeKey(s) + expectPresent(t, key, id)(m.Load(key)) + expectDeleted(t, key, id)(m.CompareAndDelete(key, id)) + expectMissing(t, key, 0)(m.Load(key)) + } + for _, s := range testData { + key := makeKey(s) + expectMissing(t, key, 0)(m.Load(key)) + } + }(i) + } + wg.Wait() + }) + t.Run("ConcurrentDeleteSharedKeys", func(t *testing.T) { + m := newMap() + + // Load up the map. + for i, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + expectStored(t, s, i)(m.LoadOrStore(s, i)) + } + gmp := runtime.GOMAXPROCS(-1) + var wg sync.WaitGroup + for i := range gmp { + wg.Add(1) + go func(id int) { + defer wg.Done() + + for i, s := range testData { + expectNotDeleted(t, s, math.MaxInt)(m.CompareAndDelete(s, math.MaxInt)) + m.CompareAndDelete(s, i) + expectMissing(t, s, 0)(m.Load(s)) + } + for _, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + } + }(i) + } + wg.Wait() + }) +} + +func testEnumerate[K, V comparable](t *testing.T, m *HashTrieMap[K, V], testData map[K]V, yield func(K, V) bool) { + for k, v := range testData { + expectStored(t, k, v)(m.LoadOrStore(k, v)) + } + visited := make(map[K]int) + m.Enumerate(func(key K, got V) bool { + want, ok := testData[key] + if !ok { + t.Errorf("unexpected key %v in map", key) + return false + } + if got != want { + t.Errorf("expected key %v to have value %v, got %v", key, want, got) + return false + } + visited[key]++ + return yield(key, got) + }) + for key, n := range visited { + if n > 1 { + t.Errorf("visited key %v more than once", key) + } + } +} + +func expectPresent[K, V comparable](t *testing.T, key K, want V) func(got V, ok bool) { + t.Helper() + return func(got V, ok bool) { + t.Helper() + + if !ok { + t.Errorf("expected key %v to be present in map", key) + } + if ok && got != want { + t.Errorf("expected key %v to have value %v, got %v", key, want, got) + } + } +} + +func expectMissing[K, V comparable](t *testing.T, key K, want V) func(got V, ok bool) { + t.Helper() + if want != *new(V) { + // This is awkward, but the want argument is necessary to smooth over type inference. + // Just make sure the want argument always looks the same. + panic("expectMissing must always have a zero value variable") + } + return func(got V, ok bool) { + t.Helper() + + if ok { + t.Errorf("expected key %v to be missing from map, got value %v", key, got) + } + if !ok && got != want { + t.Errorf("expected missing key %v to be paired with the zero value; got %v", key, got) + } + } +} + +func expectLoaded[K, V comparable](t *testing.T, key K, want V) func(got V, loaded bool) { + t.Helper() + return func(got V, loaded bool) { + t.Helper() + + if !loaded { + t.Errorf("expected key %v to have been loaded, not stored", key) + } + if got != want { + t.Errorf("expected key %v to have value %v, got %v", key, want, got) + } + } +} + +func expectStored[K, V comparable](t *testing.T, key K, want V) func(got V, loaded bool) { + t.Helper() + return func(got V, loaded bool) { + t.Helper() + + if loaded { + t.Errorf("expected inserted key %v to have been stored, not loaded", key) + } + if got != want { + t.Errorf("expected inserted key %v to have value %v, got %v", key, want, got) + } + } +} + +func expectDeleted[K, V comparable](t *testing.T, key K, old V) func(deleted bool) { + t.Helper() + return func(deleted bool) { + t.Helper() + + if !deleted { + t.Errorf("expected key %v with value %v to be in map and deleted", key, old) + } + } +} + +func expectNotDeleted[K, V comparable](t *testing.T, key K, old V) func(deleted bool) { + t.Helper() + return func(deleted bool) { + t.Helper() + + if deleted { + t.Errorf("expected key %v with value %v to not be in map and thus not deleted", key, old) + } + } +} + +func testDataMap(data []string) map[string]int { + m := make(map[string]int) + for i, s := range data { + m[s] = i + } + return m +} + +var ( + testDataSmall [8]string + testData [128]string + testDataLarge [128 << 10]string +) + +func init() { + for i := range testDataSmall { + testDataSmall[i] = fmt.Sprintf("%b", i) + } + for i := range testData { + testData[i] = fmt.Sprintf("%b", i) + } + for i := range testDataLarge { + testDataLarge[i] = fmt.Sprintf("%b", i) + } +} + +func dumpMap[K, V comparable](ht *HashTrieMap[K, V]) { + dumpNode(ht, &ht.root.node, 0) +} + +func dumpNode[K, V comparable](ht *HashTrieMap[K, V], n *node[K, V], depth int) { + var sb strings.Builder + for range depth { + fmt.Fprintf(&sb, "\t") + } + prefix := sb.String() + if n.isEntry { + e := n.entry() + for e != nil { + fmt.Printf("%s%p [Entry Key=%v Value=%v Overflow=%p, Hash=%016x]\n", prefix, e, e.key, e.value, e.overflow.Load(), ht.keyHash(unsafe.Pointer(&e.key), ht.seed)) + e = e.overflow.Load() + } + return + } + i := n.indirect() + fmt.Printf("%s%p [Indirect Parent=%p Dead=%t Children=[", prefix, i, i.parent, i.dead.Load()) + for j := range i.children { + c := i.children[j].Load() + fmt.Printf("%p", c) + if j != len(i.children)-1 { + fmt.Printf(", ") + } + } + fmt.Printf("]]\n") + for j := range i.children { + c := i.children[j].Load() + if c != nil { + dumpNode(ht, c, depth+1) + } + } +}