diff --git a/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs b/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs index e131d4bb25ee5..8736cef946905 100644 --- a/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs +++ b/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs @@ -317,7 +317,7 @@ public void ComparerImplementations_Dictionary_WithWellKnownStringComparers() // Now exceed the collision threshold, which should rebucket entries. // Continue adding a few more entries to ensure we didn't corrupt internal state. - for (int i = 100; i < 110; i++) + for (int i = 100; i < _collidingStrings.Count; i++) { string newKey = _collidingStrings[i]; Assert.Equal(0, _lazyGetNonRandomizedHashCodeDel.Value(newKey)); // ensure has a zero hash code Ordinal @@ -364,7 +364,8 @@ public void ComparerImplementations_Dictionary_WithWellKnownStringComparers() () => GetStringHashCodeOpenDelegate("GetNonRandomizedHashCodeOrdinalIgnoreCase")); // n.b., must be initialized *after* delegate fields above - private static readonly List _collidingStrings = GenerateCollidingStrings(110); + // FIXME: This can't generate more than 124 colliding strings + private static readonly List _collidingStrings = GenerateCollidingStrings(124); private static List GenerateCollidingStrings(int count) { diff --git a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Dictionary.cs b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Dictionary.cs index c890bbed83c7e..7f81d87569d50 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Dictionary.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/Dictionary.cs @@ -1,12 +1,16 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Collections.ObjectModel; using System.ComponentModel; using System.Diagnostics; using System.Diagnostics.CodeAnalysis; +using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.Arm; +using System.Runtime.Intrinsics.Wasm; +using System.Runtime.Intrinsics.X86; using System.Runtime.Serialization; namespace System.Collections.Generic @@ -17,17 +21,269 @@ namespace System.Collections.Generic [TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")] public class Dictionary : IDictionary, IDictionary, IReadOnlyDictionary, ISerializable, IDeserializationCallback where TKey : notnull { + // A comparison protocol is an adapter to allow us to have a single implementation of all our + // search/insertion/removal algorithms that generalizes efficiently over different comparers at JIT + // or AOT time without duplicated code + private interface IComparisonProtocol + where TActualKey : allows ref struct + { + int GetHashCode(TActualKey key); + bool Equals(TActualKey lhs, TKey rhs); + TKey GetKey(TActualKey input); + } + + private readonly struct DefaultValueTypeComparerComparisonProtocol : IComparisonProtocol + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool Equals(TKey lhs, TKey rhs) => EqualityComparer.Default.Equals(lhs, rhs); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int GetHashCode(TKey key) => EqualityComparer.Default.GetHashCode(key); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public TKey GetKey(TKey input) => input; + } + + private readonly struct ComparerComparisonProtocol + : IComparisonProtocol + { + public readonly IEqualityComparer comparer; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ComparerComparisonProtocol (IEqualityComparer comparer) + { + Debug.Assert(comparer != null); + this.comparer = comparer; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool Equals(TKey lhs, TKey rhs) => comparer.Equals(lhs, rhs); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int GetHashCode(TKey key) => comparer.GetHashCode(key); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public TKey GetKey(TKey input) => input; + } + + [method: MethodImpl(MethodImplOptions.AggressiveInlining)] + private readonly struct AlternateComparerComparisonProtocol(IAlternateEqualityComparer comparer) + : IComparisonProtocol + where TAlternateKey : allows ref struct + { + public readonly IAlternateEqualityComparer comparer = comparer; + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool Equals(TAlternateKey lhs, TKey rhs) => comparer.Equals(lhs, rhs); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int GetHashCode(TAlternateKey key) => comparer.GetHashCode(key); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public TKey GetKey(TAlternateKey input) + { + TKey result = comparer.Create(input); + if (result == null) + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); + return result; + } + } + + /* + static Dictionary () { + while (!Debugger.IsAttached) + Debugger.Launch(); + } + */ + + private ref struct LoopingBucketEnumerator + { + // The size of this struct is REALLY important! Adding even a single field to this will add stack spills to critical loops. + // FIXME: This span being a field puts pressure on the JIT to do recursive struct decomposition; I'm not sure it always does + private readonly Span _buckets; + private readonly int _initialIndex; + private int _index; + + [Obsolete("Use LoopingBucketEnumerator.New")] + public LoopingBucketEnumerator() + { + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private LoopingBucketEnumerator(Span buckets, uint hashCode, ulong fastModMultiplier) + { + _buckets = buckets; + _initialIndex = GetBucketIndexForHashCode(buckets, hashCode, fastModMultiplier); + _index = _initialIndex; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static ref Bucket New(Span buckets, uint hashCode, ulong fastModMultiplier, out LoopingBucketEnumerator enumerator) + { + // FIXME: Optimize this out with EmptyBuckets array like SimdDictionary + if (buckets.IsEmpty) + { + enumerator = default; + return ref Unsafe.NullRef(); + } + else + { + enumerator = new LoopingBucketEnumerator(buckets, hashCode, fastModMultiplier); + // FIXME: Optimize out the memory load of _initialIndex somehow. + return ref enumerator._buckets[enumerator._initialIndex]; + } + } + + /// + /// Walks forward through buckets, wrapping around at the end of the container. + /// Never visits a bucket twice. + /// + /// The next bucket, or NullRef if you have visited every bucket exactly once. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ref Bucket Advance() + { + // Operating on the index field directly is harmless as long as the enumerator struct got decomposed, which it seems to + // Caching index into a local and then doing a writeback at the end increases generated code size so it's not worth it + if (++_index >= _buckets.Length) + _index = 0; + + if (_index == _initialIndex) + return ref Unsafe.NullRef(); + else + return ref _buckets[_index]; + } + + /// + /// Walks back through the buckets you have previously visited. + /// + /// Each bucket you previously visited, exactly once, in reverse order, then NullRef. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public ref Bucket Retreat() + { + if (_index == _initialIndex) + return ref Unsafe.NullRef(); + + if (--_index < 0) + _index = _buckets.Length - 1; + return ref _buckets[_index]; + } + } + + [InlineArray(12)] + [StructLayout(LayoutKind.Sequential)] + private struct InlineEntryIndexArray + { + public int Index0; + } + + private struct Bucket + { + public const int Capacity = 12, + CountSlot = 13, + CascadeSlot = 14, + DegradedCascadeCount = 0xFF; + + public Vector128 Suffixes; + public InlineEntryIndexArray Indices; + + // This analysis is incorrect +#pragma warning disable IDE0251 + public ref byte Count + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => ref Unsafe.AddByteOffset(ref Unsafe.As, byte>(ref Unsafe.AsRef(in Suffixes)), CountSlot); + } + + public ref ushort CascadeCount + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => ref Unsafe.AddByteOffset(ref Unsafe.As, ushort>(ref Unsafe.AsRef(in Suffixes)), CascadeSlot); + } +#pragma warning restore IDE0251 + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public readonly byte GetSlot(int index) + { + Debug.Assert(index < Vector128.Count); + // the extract-lane opcode this generates is slower than doing a byte load from memory, + // even if we already have the bucket in a register. Not sure why, but my guess based on agner's + // instruction tables is that it's because lane extract generates more uops than a byte move. + // the two operations have the same latency on icelake, and the byte move's latency is lower on zen4 + // return self[index]; + // index &= 15; + return Unsafe.AddByteOffset(ref Unsafe.As, byte>(ref Unsafe.AsRef(in Suffixes)), index); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void SetSlot(nuint index, byte value) + { + Debug.Assert(index < (nuint)Vector128.Count); + // index &= 15; + Unsafe.AddByteOffset(ref Unsafe.As, byte>(ref Suffixes), index) = value; + } + + public readonly int FindSuffix(int bucketCount, byte suffix, Vector128 searchVector) + { + if (Sse2.IsSupported) + { + return BitOperations.TrailingZeroCount(Sse2.MoveMask(Sse2.CompareEqual(searchVector, Suffixes))); + } + else if (AdvSimd.Arm64.IsSupported) + { + // Completely untested + var laneBits = AdvSimd.And( + AdvSimd.CompareEqual(searchVector, Suffixes), + Vector128.Create(1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128) + ); + var moveMask = AdvSimd.Arm64.AddAcross(laneBits.GetLower()).ToScalar() | + (AdvSimd.Arm64.AddAcross(laneBits.GetUpper()).ToScalar() << 8); + return BitOperations.TrailingZeroCount(moveMask); + } + else if (PackedSimd.IsSupported) + { + // Completely untested + return BitOperations.TrailingZeroCount(PackedSimd.Bitmask(PackedSimd.CompareEqual(searchVector, Suffixes))); + } + else + { + return FindSuffixScalar(bucketCount, suffix); + } + } + + public readonly unsafe int FindSuffixScalar(int bucketCount, byte suffix) + { + // Hand-unrolling the search into four comparisons per loop iteration is a significant performance improvement + // for a moderate code size penalty (733b -> 826b; 399usec -> 321usec) + var haystack = (byte*)Unsafe.AsPointer(ref Unsafe.AsRef(in Suffixes)); + for (int i = 0; i < bucketCount; i += 4, haystack += 4) + { + // FIXME: It's not possible to use cmovs here due to a JIT limitation (can't do cmovs in loops) + // A chain of cmovs would be much faster. + if (haystack[0] == suffix) + return i; + if (haystack[1] == suffix) + return i + 1; + if (haystack[2] == suffix) + return i + 2; + if (haystack[3] == suffix) + return i + 3; + } + + return 32; + } + } + // constants for serialization private const string VersionName = "Version"; // Do not rename (binary serialization) private const string HashSizeName = "HashSize"; // Do not rename (binary serialization). Must save buckets.Length private const string KeyValuePairsName = "KeyValuePairs"; // Do not rename (binary serialization) private const string ComparerName = "Comparer"; // Do not rename (binary serialization) - private int[]? _buckets; + private Bucket[]? _buckets; private Entry[]? _entries; -#if TARGET_64BIT + //#if TARGET_64BIT private ulong _fastModMultiplier; -#endif + //#endif private int _count; private int _freeList; private int _freeCount; @@ -132,16 +388,19 @@ private void AddRange(IEnumerable> enumerable) Debug.Assert(_count == 0); Entry[] oldEntries = source._entries; + // FIXME + /* if (source._comparer == _comparer) { // If comparers are the same, we can copy _entries without rehashing. CopyEntries(oldEntries, source._count); return; } + */ // Comparers differ need to rehash all the entries via Add - int count = source._count; - for (int i = 0; i < count; i++) + int allocatedEntryCount = source._count; + for (int i = 0; i < allocatedEntryCount; i++) { // Only copy if an entry if (oldEntries[i].next >= -1) @@ -240,15 +499,15 @@ public TValue this[TKey key] } set { - bool modified = TryInsert(key, value, InsertionBehavior.OverwriteExisting); - Debug.Assert(modified); + ref Entry result = ref TryInsert(key, value, InsertionBehavior.OverwriteExisting, out _); + Debug.Assert(!Unsafe.IsNullRef(ref result)); } } public void Add(TKey key, TValue value) { - bool modified = TryInsert(key, value, InsertionBehavior.ThrowOnExisting); - Debug.Assert(modified); // If there was an existing key and the Add failed, an exception will already have been thrown. + ref Entry result = ref TryInsert(key, value, InsertionBehavior.ThrowOnExisting, out _); + Debug.Assert(!Unsafe.IsNullRef(ref result)); } void ICollection>.Add(KeyValuePair keyValuePair) => @@ -285,6 +544,7 @@ public void Clear() Debug.Assert(_buckets != null, "_buckets should be non-null"); Debug.Assert(_entries != null, "_entries should be non-null"); + // TODO: Optimized clear that only touches buckets where count is nonzero Array.Clear(_buckets); _count = 0; @@ -384,9 +644,9 @@ public virtual void GetObjectData(SerializationInfo info, StreamingContext conte info.AddValue(VersionName, _version); info.AddValue(ComparerName, Comparer, typeof(IEqualityComparer)); - info.AddValue(HashSizeName, _buckets == null ? 0 : _buckets.Length); // This is the length of the bucket array + info.AddValue(HashSizeName, _entries == null ? 0 : _entries.Length); // This is the length of the bucket array - if (_buckets != null) + if (_entries != null) { var array = new KeyValuePair[Count]; CopyTo(array, 0); @@ -394,207 +654,316 @@ public virtual void GetObjectData(SerializationInfo info, StreamingContext conte } } - internal ref TValue FindValue(TKey key) + private ref Entry FindEntry(TProtocol protocol, TActualKey key) + where TProtocol : struct, IComparisonProtocol + where TActualKey : allows ref struct { - if (key == null) - { - ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); + uint hashCode = (uint)protocol.GetHashCode(key); + var suffix = GetHashSuffix(hashCode); + var vectorized = Sse2.IsSupported || AdvSimd.Arm64.IsSupported || PackedSimd.IsSupported; + Vector128 searchVector = vectorized ? Vector128.Create(suffix) : default; + ref var bucket = ref LoopingBucketEnumerator.New(_buckets, hashCode, _fastModMultiplier, out var enumerator); + Span entries = _entries!; + // FIXME: Change to do { } while () by introducing EmptyBuckets optimization from SimdDictionary + while (!Unsafe.IsNullRef(ref bucket)) + { + // Pipelining + int bucketCount = bucket.Count; + // Determine start index for key search + int startIndex = vectorized + ? bucket.FindSuffix(bucketCount, suffix, searchVector) + : bucket.FindSuffixScalar(bucketCount, suffix); + ref var entry = ref FindEntryInBucket(protocol, ref bucket, entries, startIndex, bucketCount, key, out _, out _); + if (Unsafe.IsNullRef(ref entry)) + { + if (bucket.CascadeCount == 0) + return ref Unsafe.NullRef(); + } + else + return ref entry; + bucket = ref enumerator.Advance(); } - ref Entry entry = ref Unsafe.NullRef(); - if (_buckets != null) - { - Debug.Assert(_entries != null, "expected entries to be != null"); - IEqualityComparer? comparer = _comparer; - if (typeof(TKey).IsValueType && // comparer can only be null for value types; enable JIT to eliminate entire if block for ref types - comparer == null) - { - uint hashCode = (uint)key.GetHashCode(); - int i = GetBucket(hashCode); - Entry[]? entries = _entries; - uint collisionCount = 0; - - // ValueType: Devirtualize with EqualityComparer.Default intrinsic - i--; // Value in _buckets is 1-based; subtract 1 from i. We do it here so it fuses with the following conditional. - do - { - // Test in if to drop range check for following array access - if ((uint)i >= (uint)entries.Length) - { - goto ReturnNotFound; - } - - entry = ref entries[i]; - if (entry.hashCode == hashCode && EqualityComparer.Default.Equals(entry.key, key)) - { - goto ReturnFound; - } + return ref Unsafe.NullRef(); + } - i = entry.next; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Entry FindEntryInBucket( + TProtocol protocol, ref Bucket bucket, Span entries, + int startIndex, int bucketCount, TActualKey key, + // These out-params are annoying but inlining seems to optimize them away + out int entryIndex, out int matchIndexInBucket + ) + where TProtocol : struct, IComparisonProtocol + where TActualKey : allows ref struct + { + Unsafe.SkipInit(out matchIndexInBucket); + Unsafe.SkipInit(out entryIndex); + Debug.Assert(startIndex >= 0); - collisionCount++; - } while (collisionCount <= (uint)entries.Length); + int count = bucketCount - startIndex; + if (count <= 0) + return ref Unsafe.NullRef(); - // The chain of entries forms a loop; which means a concurrent update has happened. - // Break out of the loop and throw, rather than looping forever. - goto ConcurrentOperation; + ref int indexSlot = ref bucket.Indices[startIndex]; + while (true) + { + ref var entry = ref entries[indexSlot]; + if (protocol.Equals(key, entry.key)) + { + // We could optimize out the bucketCount local to prevent a stack spill in some cases by doing + // Unsafe.ByteOffset(...) / sizeof(Pair), but the potential idiv is extremely painful + entryIndex = indexSlot; + matchIndexInBucket = bucketCount - count; + return ref entry; } + + // NOTE: --count <= 0 produces an extra 'test' opcode + if (--count == 0) + return ref Unsafe.NullRef(); else - { - Debug.Assert(comparer is not null); - uint hashCode = (uint)comparer.GetHashCode(key); - int i = GetBucket(hashCode); - Entry[]? entries = _entries; - uint collisionCount = 0; - i--; // Value in _buckets is 1-based; subtract 1 from i. We do it here so it fuses with the following conditional. - do - { - // Test in if to drop range check for following array access - if ((uint)i >= (uint)entries.Length) - { - goto ReturnNotFound; - } + indexSlot = ref Unsafe.Add(ref indexSlot, 1); + } + } - entry = ref entries[i]; - if (entry.hashCode == hashCode && comparer.Equals(entry.key, key)) - { - goto ReturnFound; - } + internal ref TValue FindValue(TKey key) + { + if (key == null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); + } - i = entry.next; + var comparer = _comparer; + ref Entry entry = ref (typeof(TKey).IsValueType && (comparer == null)) + ? ref FindEntry(default(DefaultValueTypeComparerComparisonProtocol), key) + : ref FindEntry(new ComparerComparisonProtocol(comparer!), key); - collisionCount++; - } while (collisionCount <= (uint)entries.Length); + if (Unsafe.IsNullRef(ref entry)) + return ref Unsafe.NullRef(); + else + return ref entry.value; + } - // The chain of entries forms a loop; which means a concurrent update has happened. - // Break out of the loop and throw, rather than looping forever. - goto ConcurrentOperation; - } + private static void FillNewBucketsForResizeOrRehash( + Span newBuckets, ulong fastModMultiplier, + Span entries, int allocatedEntryCount, IEqualityComparer comparer + ) + { + for (int index = 0; index < allocatedEntryCount; index++) + { + // FIXME: Use Unsafe.Add to optimize out the imul per element + ref var entry = ref entries[index]; + if (entry.next >= -1) + InsertExistingEntryIntoNewBucket(newBuckets, fastModMultiplier, comparer, ref entry, index); } - goto ReturnNotFound; + static void InsertExistingEntryIntoNewBucket( + Span newBuckets, ulong fastModMultiplier, + IEqualityComparer comparer, ref Entry entry, int entryIndex + ) + { + Debug.Assert(comparer is not null || typeof(TKey).IsValueType); + uint hashCode = (uint)((typeof(TKey).IsValueType && comparer == null) ? entry.key.GetHashCode() : comparer!.GetHashCode(entry.key)); + var suffix = GetHashSuffix(hashCode); + // FIXME: Skip this on non-vectorized targets + var searchVector = Vector128.Create(suffix); + + ref var bucket = ref LoopingBucketEnumerator.New(newBuckets, hashCode, fastModMultiplier, out var enumerator); + // FIXME: Change to do { } while () by introducing EmptyBuckets optimization from SimdDictionary + while (!Unsafe.IsNullRef(ref bucket)) + { + // Pipelining + int bucketCount = bucket.Count; + if (bucketCount < Bucket.Capacity) + { + InsertIntoBucket(ref bucket, suffix, bucketCount, entryIndex); + // We can ignore the return value of this, we're in the middle of rehashing/resizing so we wouldn't ever + // do a comparer swap in this scenario + AdjustCascadeCounts(enumerator, true); + return; + } + + bucket = ref enumerator.Advance(); + } - ConcurrentOperation: - ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported(); - ReturnFound: - ref TValue value = ref entry.value; - Return: - return ref value; - ReturnNotFound: - value = ref Unsafe.NullRef(); - goto Return; + ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported(); + } } private int Initialize(int capacity) { int size = HashHelpers.GetPrime(capacity); - int[] buckets = new int[size]; + int bucketCount = GetBucketCountForEntryCount(size); + Bucket[] buckets = new Bucket[bucketCount]; Entry[] entries = new Entry[size]; // Assign member variables after both arrays allocated to guard against corruption from OOM if second fails _freeList = -1; + _freeCount = 0; #if TARGET_64BIT - _fastModMultiplier = HashHelpers.GetFastModMultiplier((uint)size); + _fastModMultiplier = HashHelpers.GetFastModMultiplier((uint)bucketCount); #endif _buckets = buckets; _entries = entries; + _count = 0; return size; } - private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior) + // TODO: Figure out if we can outline this (reduces code size) without regressing performance for all inserts/removes + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool AdjustCascadeCounts(LoopingBucketEnumerator enumerator, bool increase) { - // NOTE: this method is mirrored in CollectionsMarshal.GetValueRefOrAddDefault below. - // If you make any changes here, make sure to keep that version in sync as well. + bool needRehash = false; + // We may have cascaded out of a previous bucket; if so, scan backwards and update + // the cascade count for every bucket we previously scanned. + ref Bucket bucket = ref enumerator.Retreat(); + while (!Unsafe.IsNullRef(ref bucket)) + { + // FIXME: Track number of times we cascade out of a bucket for string rehashing anti-DoS mitigation! + var cascadeCount = bucket.CascadeCount; + if (increase) + { + // Never overflow (wrap around) the counter + if (cascadeCount < Bucket.DegradedCascadeCount) + { + int newCascadeCount = bucket.CascadeCount = (ushort)(cascadeCount + 1); + if (!typeof(TKey).IsValueType && (newCascadeCount >= HashHelpers.HashCollisionThreshold)) + needRehash = true; + } + } + else + { + if (cascadeCount == 0) + ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported(); - if (key == null) - { - ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); - } + // If the cascade counter hit the maximum, it's possible the actual cascade count through here is higher, + // so it's no longer safe to decrement. This is a very rare scenario, but it permanently degrades the table. + // TODO: Track this and trigger a rehash once too many buckets are in this state + dict is mostly empty. + else if (cascadeCount < Bucket.DegradedCascadeCount) + bucket.CascadeCount = (ushort)(cascadeCount - 1); + } - if (_buckets == null) - { - Initialize(0); + bucket = ref enumerator.Retreat(); } - Debug.Assert(_buckets != null); - Entry[]? entries = _entries; - Debug.Assert(entries != null, "expected entries to be non-null"); - - IEqualityComparer? comparer = _comparer; - Debug.Assert(comparer is not null || typeof(TKey).IsValueType); - uint hashCode = (uint)((typeof(TKey).IsValueType && comparer == null) ? key.GetHashCode() : comparer!.GetHashCode(key)); + return needRehash; + } - uint collisionCount = 0; - ref int bucket = ref GetBucket(hashCode); - int i = bucket - 1; // Value in _buckets is 1-based + private ref Entry TryInsert(TKey key, TValue value, InsertionBehavior behavior, out bool exists) + { + var comparer = _comparer; + return ref (typeof(TKey).IsValueType && (comparer == null)) + ? ref TryInsert(default(DefaultValueTypeComparerComparisonProtocol), key, value, behavior, out exists) + : ref TryInsert(new ComparerComparisonProtocol(comparer!), key, value, behavior, out exists); + } - if (typeof(TKey).IsValueType && // comparer can only be null for value types; enable JIT to eliminate entire if block for ref types - comparer == null) + private ref Entry TryInsert(TProtocol protocol, TActualKey key, TValue value, InsertionBehavior behavior, out bool exists) + where TProtocol : struct, IComparisonProtocol + where TActualKey : allows ref struct + { + if (key == null) { - // ValueType: Devirtualize with EqualityComparer.Default intrinsic - while ((uint)i < (uint)entries.Length) - { - if (entries[i].hashCode == hashCode && EqualityComparer.Default.Equals(entries[i].key, key)) - { - if (behavior == InsertionBehavior.OverwriteExisting) - { - entries[i].value = value; - return true; - } - - if (behavior == InsertionBehavior.ThrowOnExisting) - { - ThrowHelper.ThrowAddingDuplicateWithKeyArgumentException(key); - } - - return false; - } + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); + } - i = entries[i].next; + uint hashCode = (uint)protocol.GetHashCode(key); + var suffix = GetHashSuffix(hashCode); + var vectorized = Sse2.IsSupported || AdvSimd.Arm64.IsSupported || PackedSimd.IsSupported; + Vector128 searchVector = vectorized ? Vector128.Create(suffix) : default; - collisionCount++; - if (collisionCount > (uint)entries.Length) + // We need to retry when we grow the buckets array since the correct destination bucket will have changed and might not + // be the same as the destination bucket before resizing (it probably isn't, in fact) + retry: + Span buckets = _buckets; + if (buckets.IsEmpty) + { + Initialize(0); + buckets = _buckets; + } + Debug.Assert(!buckets.IsEmpty); + Span entries = _entries!; + Debug.Assert(!entries.IsEmpty, "expected entries to be non-null"); + + ref var bucket = ref LoopingBucketEnumerator.New(buckets, hashCode, _fastModMultiplier, out var enumerator); + // FIXME: Change to do { } while () by introducing EmptyBuckets optimization from SimdDictionary + while (!Unsafe.IsNullRef(ref bucket)) + { + // Pipelining + int bucketCount = bucket.Count; + // Determine start index for key search + int startIndex = vectorized + ? bucket.FindSuffix(bucketCount, suffix, searchVector) + : bucket.FindSuffixScalar(bucketCount, suffix); + ref var entry = ref FindEntryInBucket(protocol, ref bucket, entries, startIndex, bucketCount, key, out _, out _); + if (!Unsafe.IsNullRef(ref entry)) + { + exists = true; + switch (behavior) { - // The chain of entries forms a loop; which means a concurrent update has happened. - // Break out of the loop and throw, rather than looping forever. - ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported(); + case InsertionBehavior.InsertNewOnly: + return ref entry; + case InsertionBehavior.OverwriteExisting: + entry.value = value; + return ref entry; + case InsertionBehavior.ThrowOnExisting: + ThrowHelper.ThrowAddingDuplicateWithKeyArgumentException(protocol.GetKey(key)); + return ref entry; + default: + ThrowHelper.ThrowArgumentOutOfRangeException(); + return ref entry; } } - } - else - { - Debug.Assert(comparer is not null); - while ((uint)i < (uint)entries.Length) + else if (startIndex < Bucket.Capacity) { - if (entries[i].hashCode == hashCode && comparer.Equals(entries[i].key, key)) - { - if (behavior == InsertionBehavior.OverwriteExisting) - { - entries[i].value = value; - return true; - } + // FIXME: Suffix collision. Track these for rehashing anti-DoS mitigation. + } - if (behavior == InsertionBehavior.ThrowOnExisting) - { - ThrowHelper.ThrowAddingDuplicateWithKeyArgumentException(key); - } + if (bucketCount < Bucket.Capacity) + { + // NOTE: Compute this before creating the entry, otherwise a comparer that throws could corrupt us. + var actualKey = protocol.GetKey(key); - return false; + int newEntryIndex = TryCreateNewEntry(entries); + if (newEntryIndex < 0) + { + // We can't reuse the existing target bucket once we resized, so start over. This is very rare. + Resize(); + goto retry; } - i = entries[i].next; - - collisionCount++; - if (collisionCount > (uint)entries.Length) + ref var newEntry = ref entries[newEntryIndex]; + PopulateEntry(ref newEntry, actualKey, value); + InsertIntoBucket(ref bucket, suffix, bucketCount, newEntryIndex); + _version++; + exists = false; + if (AdjustCascadeCounts(enumerator, true) && (_comparer is NonRandomizedStringEqualityComparer)) { - // The chain of entries forms a loop; which means a concurrent update has happened. - // Break out of the loop and throw, rather than looping forever. - ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported(); + // if AdjustCascadeCounts returned true, we need to change comparers (if possible) to one with better collision + // resistance. + ChangeToRandomizedStringEqualityComparer(); + // This will have invalidated our buckets but not our entries, so it's safe to return newEntry. } + return ref newEntry; } + + bucket = ref enumerator.Advance(); } + // We failed to find any bucket with room and hit the end of the loop, so we should be full. This is very rare. + if (_count >= entries.Length) + { + Resize(); + goto retry; + } + + ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported(); + exists = false; + return ref Unsafe.NullRef(); + } + + private int TryCreateNewEntry(Span entries) + { int index; if (_freeCount > 0) { @@ -605,34 +974,36 @@ private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior) } else { - int count = _count; - if (count == entries.Length) - { - Resize(); - bucket = ref GetBucket(hashCode); - } - index = count; - _count = count + 1; - entries = _entries; + index = _count; + // Resize needed + if (_count >= entries.Length) + return -1; + _count = index + 1; } + return index; + } - ref Entry entry = ref entries![index]; - entry.hashCode = hashCode; - entry.next = bucket - 1; // Value in _buckets is 1-based + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void PopulateEntry(ref Entry entry, TKey key, TValue value) + { entry.key = key; entry.value = value; - bucket = index + 1; // Value in _buckets is 1-based - _version++; + entry.next = 0; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool InsertIntoBucket(ref Bucket bucket, byte suffix, int bucketCount, int entryIndex) + { + Debug.Assert(bucketCount < Bucket.Capacity); - // Value types never rehash - if (!typeof(TKey).IsValueType && collisionCount > HashHelpers.HashCollisionThreshold && comparer is NonRandomizedStringEqualityComparer) + unchecked { - // If we hit the collision threshold we'll need to switch to the comparer which is using randomized string hashing - // i.e. EqualityComparer.Default. - Resize(entries.Length, true); + ref var destination = ref bucket.Indices[bucketCount]; + bucket.Count = (byte)(bucketCount + 1); + bucket.SetSlot((nuint)bucketCount, suffix); + destination = entryIndex; + return true; } - - return true; } /// @@ -800,64 +1171,31 @@ public bool ContainsKey(TAlternateKey key) => internal ref TValue FindValue(TAlternateKey key, [MaybeNullWhen(false)] out TKey actualKey) { Dictionary dictionary = Dictionary; - IAlternateEqualityComparer comparer = GetAlternateComparer(dictionary); + AlternateComparerComparisonProtocol protocol = new(GetAlternateComparer(dictionary)); - ref Entry entry = ref Unsafe.NullRef(); - if (dictionary._buckets != null) + ref var entry = ref dictionary.FindEntry(protocol, key); + if (Unsafe.IsNullRef(ref entry)) { - Debug.Assert(dictionary._entries != null, "expected entries to be != null"); - - uint hashCode = (uint)comparer.GetHashCode(key); - int i = dictionary.GetBucket(hashCode); - Entry[]? entries = dictionary._entries; - uint collisionCount = 0; - i--; // Value in _buckets is 1-based; subtract 1 from i. We do it here so it fuses with the following conditional. - do - { - // Should be a while loop https://github.com/dotnet/runtime/issues/9422 - // Test in if to drop range check for following array access - if ((uint)i >= (uint)entries.Length) - { - goto ReturnNotFound; - } - - entry = ref entries[i]; - if (entry.hashCode == hashCode && comparer.Equals(key, entry.key)) - { - goto ReturnFound; - } - - i = entry.next; - - collisionCount++; - } while (collisionCount <= (uint)entries.Length); - - // The chain of entries forms a loop; which means a concurrent update has happened. - // Break out of the loop and throw, rather than looping forever. - goto ConcurrentOperation; + actualKey = default!; + return ref Unsafe.NullRef(); + } + else + { + actualKey = entry.key; + return ref entry.value; } - - goto ReturnNotFound; - - ConcurrentOperation: - ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported(); - ReturnFound: - ref TValue value = ref entry.value; - actualKey = entry.key; - Return: - return ref value; - ReturnNotFound: - value = ref Unsafe.NullRef(); - actualKey = default!; - goto Return; } /// Removes the value with the specified alternate key from the . /// The alternate key of the element to remove. /// true if the element is successfully found and removed; otherwise, false. /// is . - public bool Remove(TAlternateKey key) => - Remove(key, out _, out _); + public bool Remove(TAlternateKey key) + { + Dictionary dictionary = Dictionary; + AlternateComparerComparisonProtocol protocol = new(GetAlternateComparer(dictionary)); + return dictionary.Remove(protocol, key, out Unsafe.NullRef()!, out Unsafe.NullRef()!); + } /// /// Removes the value with the specified alternate key from the , @@ -871,71 +1209,8 @@ public bool Remove(TAlternateKey key) => public bool Remove(TAlternateKey key, [MaybeNullWhen(false)] out TKey actualKey, [MaybeNullWhen(false)] out TValue value) { Dictionary dictionary = Dictionary; - IAlternateEqualityComparer comparer = GetAlternateComparer(dictionary); - - if (dictionary._buckets != null) - { - Debug.Assert(dictionary._entries != null, "entries should be non-null"); - uint collisionCount = 0; - - uint hashCode = (uint)comparer.GetHashCode(key); - - ref int bucket = ref dictionary.GetBucket(hashCode); - Entry[]? entries = dictionary._entries; - int last = -1; - int i = bucket - 1; // Value in buckets is 1-based - while (i >= 0) - { - ref Entry entry = ref entries[i]; - - if (entry.hashCode == hashCode && comparer.Equals(key, entry.key)) - { - if (last < 0) - { - bucket = entry.next + 1; // Value in buckets is 1-based - } - else - { - entries[last].next = entry.next; - } - - actualKey = entry.key; - value = entry.value; - - Debug.Assert((StartOfFreeList - dictionary._freeList) < 0, "shouldn't underflow because max hashtable length is MaxPrimeArrayLength = 0x7FEFFFFD(2146435069) _freelist underflow threshold 2147483646"); - entry.next = StartOfFreeList - dictionary._freeList; - - if (RuntimeHelpers.IsReferenceOrContainsReferences()) - { - entry.key = default!; - } - - if (RuntimeHelpers.IsReferenceOrContainsReferences()) - { - entry.value = default!; - } - - dictionary._freeList = i; - dictionary._freeCount++; - return true; - } - - last = i; - i = entry.next; - - collisionCount++; - if (collisionCount > (uint)entries.Length) - { - // The chain of entries forms a loop; which means a concurrent update has happened. - // Break out of the loop and throw, rather than looping forever. - ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported(); - } - } - } - - actualKey = default; - value = default; - return false; + AlternateComparerComparisonProtocol protocol = new(GetAlternateComparer(dictionary)); + return dictionary.Remove(protocol, key, out actualKey, out value); } /// Attempts to add the specified key and value to the dictionary. @@ -956,106 +1231,13 @@ public bool TryAdd(TAlternateKey key, TValue value) } /// +#pragma warning disable IDE0060 internal ref TValue? GetValueRefOrAddDefault(TAlternateKey key, out bool exists) +#pragma warning restore IDE0060 { - // NOTE: this method is a mirror of GetValueRefOrAddDefault above. Keep it in sync. - Dictionary dictionary = Dictionary; - IAlternateEqualityComparer comparer = GetAlternateComparer(dictionary); - - if (dictionary._buckets == null) - { - dictionary.Initialize(0); - } - Debug.Assert(dictionary._buckets != null); - - Entry[]? entries = dictionary._entries; - Debug.Assert(entries != null, "expected entries to be non-null"); - - uint hashCode = (uint)comparer.GetHashCode(key); - - uint collisionCount = 0; - ref int bucket = ref dictionary.GetBucket(hashCode); - int i = bucket - 1; // Value in _buckets is 1-based - - Debug.Assert(comparer is not null); - while ((uint)i < (uint)entries.Length) - { - if (entries[i].hashCode == hashCode && comparer.Equals(key, entries[i].key)) - { - exists = true; - - return ref entries[i].value!; - } - - i = entries[i].next; - - collisionCount++; - if (collisionCount > (uint)entries.Length) - { - // The chain of entries forms a loop; which means a concurrent update has happened. - // Break out of the loop and throw, rather than looping forever. - ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported(); - } - } - - TKey actualKey = comparer.Create(key); - if (actualKey is null) - { - ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); - } - - int index; - if (dictionary._freeCount > 0) - { - index = dictionary._freeList; - Debug.Assert((StartOfFreeList - entries[dictionary._freeList].next) >= -1, "shouldn't overflow because `next` cannot underflow"); - dictionary._freeList = StartOfFreeList - entries[dictionary._freeList].next; - dictionary._freeCount--; - } - else - { - int count = dictionary._count; - if (count == entries.Length) - { - dictionary.Resize(); - bucket = ref dictionary.GetBucket(hashCode); - } - index = count; - dictionary._count = count + 1; - entries = dictionary._entries; - } - - ref Entry entry = ref entries![index]; - entry.hashCode = hashCode; - entry.next = bucket - 1; // Value in _buckets is 1-based - entry.key = actualKey; - entry.value = default!; - bucket = index + 1; // Value in _buckets is 1-based - dictionary._version++; - - // Value types never rehash - if (!typeof(TKey).IsValueType && collisionCount > HashHelpers.HashCollisionThreshold && comparer is NonRandomizedStringEqualityComparer) - { - // If we hit the collision threshold we'll need to switch to the comparer which is using randomized string hashing - // i.e. EqualityComparer.Default. - dictionary.Resize(entries.Length, true); - - exists = false; - - // At this point the entries array has been resized, so the current reference we have is no longer valid. - // We're forced to do a new lookup and return an updated reference to the new entry instance. This new - // lookup is guaranteed to always find a value though and it will never return a null reference here. - ref TValue? value = ref dictionary.FindValue(actualKey)!; - - Debug.Assert(!Unsafe.IsNullRef(ref value), "the lookup result cannot be a null ref here"); - - return ref value; - } - - exists = false; - - return ref entry.value!; + AlternateComparerComparisonProtocol protocol = new(GetAlternateComparer(dictionary)); + return ref dictionary.TryInsert(protocol, key, default!, InsertionBehavior.InsertNewOnly, out exists).value!; } } @@ -1077,122 +1259,11 @@ internal static class CollectionsMarshalHelper ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); } - if (dictionary._buckets == null) - { - dictionary.Initialize(0); - } - Debug.Assert(dictionary._buckets != null); - - Entry[]? entries = dictionary._entries; - Debug.Assert(entries != null, "expected entries to be non-null"); - - IEqualityComparer? comparer = dictionary._comparer; - Debug.Assert(comparer is not null || typeof(TKey).IsValueType); - uint hashCode = (uint)((typeof(TKey).IsValueType && comparer == null) ? key.GetHashCode() : comparer!.GetHashCode(key)); - - uint collisionCount = 0; - ref int bucket = ref dictionary.GetBucket(hashCode); - int i = bucket - 1; // Value in _buckets is 1-based - - if (typeof(TKey).IsValueType && // comparer can only be null for value types; enable JIT to eliminate entire if block for ref types - comparer == null) - { - // ValueType: Devirtualize with EqualityComparer.Default intrinsic - while ((uint)i < (uint)entries.Length) - { - if (entries[i].hashCode == hashCode && EqualityComparer.Default.Equals(entries[i].key, key)) - { - exists = true; - - return ref entries[i].value!; - } - - i = entries[i].next; - - collisionCount++; - if (collisionCount > (uint)entries.Length) - { - // The chain of entries forms a loop; which means a concurrent update has happened. - // Break out of the loop and throw, rather than looping forever. - ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported(); - } - } - } + ref var entry = ref dictionary.TryInsert(key, default!, InsertionBehavior.InsertNewOnly, out exists); + if (!Unsafe.IsNullRef(ref entry)) + return ref entry.value!; else - { - Debug.Assert(comparer is not null); - while ((uint)i < (uint)entries.Length) - { - if (entries[i].hashCode == hashCode && comparer.Equals(entries[i].key, key)) - { - exists = true; - - return ref entries[i].value!; - } - - i = entries[i].next; - - collisionCount++; - if (collisionCount > (uint)entries.Length) - { - // The chain of entries forms a loop; which means a concurrent update has happened. - // Break out of the loop and throw, rather than looping forever. - ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported(); - } - } - } - - int index; - if (dictionary._freeCount > 0) - { - index = dictionary._freeList; - Debug.Assert((StartOfFreeList - entries[dictionary._freeList].next) >= -1, "shouldn't overflow because `next` cannot underflow"); - dictionary._freeList = StartOfFreeList - entries[dictionary._freeList].next; - dictionary._freeCount--; - } - else - { - int count = dictionary._count; - if (count == entries.Length) - { - dictionary.Resize(); - bucket = ref dictionary.GetBucket(hashCode); - } - index = count; - dictionary._count = count + 1; - entries = dictionary._entries; - } - - ref Entry entry = ref entries![index]; - entry.hashCode = hashCode; - entry.next = bucket - 1; // Value in _buckets is 1-based - entry.key = key; - entry.value = default!; - bucket = index + 1; // Value in _buckets is 1-based - dictionary._version++; - - // Value types never rehash - if (!typeof(TKey).IsValueType && collisionCount > HashHelpers.HashCollisionThreshold && comparer is NonRandomizedStringEqualityComparer) - { - // If we hit the collision threshold we'll need to switch to the comparer which is using randomized string hashing - // i.e. EqualityComparer.Default. - dictionary.Resize(entries.Length, true); - - exists = false; - - // At this point the entries array has been resized, so the current reference we have is no longer valid. - // We're forced to do a new lookup and return an updated reference to the new entry instance. This new - // lookup is guaranteed to always find a value though and it will never return a null reference here. - ref TValue? value = ref dictionary.FindValue(key)!; - - Debug.Assert(!Unsafe.IsNullRef(ref value), "the lookup result cannot be a null ref here"); - - return ref value; - } - - exists = false; - - return ref entry.value!; + return ref Unsafe.NullRef(); } } @@ -1242,12 +1313,10 @@ public virtual void OnDeserialization(object? sender) HashHelpers.SerializationInfoTable.Remove(this); } - private void Resize() => Resize(HashHelpers.ExpandPrime(_count), false); + private void Resize() => Resize(HashHelpers.ExpandPrime(_count)); - private void Resize(int newSize, bool forceNewHashCodes) + private void Resize(int newSize) { - // Value types never rehash - Debug.Assert(!forceNewHashCodes || !typeof(TKey).IsValueType); Debug.Assert(_entries != null, "_entries should be non-null"); Debug.Assert(newSize >= _entries.Length); @@ -1256,185 +1325,175 @@ private void Resize(int newSize, bool forceNewHashCodes) int count = _count; Array.Copy(_entries, entries, count); - if (!typeof(TKey).IsValueType && forceNewHashCodes) - { - Debug.Assert(_comparer is NonRandomizedStringEqualityComparer); - IEqualityComparer comparer = _comparer = (IEqualityComparer)((NonRandomizedStringEqualityComparer)_comparer).GetRandomizedEqualityComparer(); + int newBucketCount = GetBucketCountForEntryCount(newSize); + ulong fastModMultiplier = HashHelpers.GetFastModMultiplier((uint)newBucketCount); - for (int i = 0; i < count; i++) - { - if (entries[i].next >= -1) - { - entries[i].hashCode = (uint)comparer.GetHashCode(entries[i].key); - } - } + if (newBucketCount != _buckets?.Length) + { + Bucket[] newBuckets = new Bucket[newBucketCount]; + FillNewBucketsForResizeOrRehash( + newBuckets, fastModMultiplier, entries, _count, + // FIXME + _comparer ?? EqualityComparer.Default + ); + _buckets = newBuckets; + _fastModMultiplier = fastModMultiplier; } - // Assign member variables after both arrays allocated to guard against corruption from OOM if second fails - _buckets = new int[newSize]; -#if TARGET_64BIT - _fastModMultiplier = HashHelpers.GetFastModMultiplier((uint)newSize); -#endif - for (int i = 0; i < count; i++) - { - if (entries[i].next >= -1) + _entries = entries; + _version++; + } + + private void ChangeToRandomizedStringEqualityComparer() + { + Debug.Assert(_comparer is NonRandomizedStringEqualityComparer); + _comparer = (IEqualityComparer)((NonRandomizedStringEqualityComparer)_comparer).GetRandomizedEqualityComparer(); + Debug.Assert(_buckets != null); + Array.Clear(_buckets!); + FillNewBucketsForResizeOrRehash(_buckets, _fastModMultiplier, _entries, _count, _comparer); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void RemoveIndexFromBucket(ref Bucket bucket, int indexInBucket, int bucketCount) + { + Debug.Assert(bucketCount > 0); + unchecked + { + int replacementIndexInBucket = bucketCount - 1; + bucket.Count = (byte)replacementIndexInBucket; + ref var toRemove = ref bucket.Indices[indexInBucket]; + ref var replacement = ref bucket.Indices[replacementIndexInBucket]; + // This rotate-back algorithm makes removes more expensive than if we were to just always zero the slot. + // But then other algorithms like insertion get more expensive, since we have to search for a zero to replace... + if (!Unsafe.AreSame(ref toRemove, ref replacement)) + { + // TODO: This is the only place in the find/insert/remove algorithms that actually needs indexInBucket. + // Can we refactor it away? The good news is RyuJIT optimizes it out entirely in find/insert. + bucket.SetSlot((uint)indexInBucket, bucket.GetSlot(replacementIndexInBucket)); + bucket.SetSlot((uint)replacementIndexInBucket, 0); + toRemove = replacement; + } + else { - ref int bucket = ref GetBucket(entries[i].hashCode); - entries[i].next = bucket - 1; // Value in _buckets is 1-based - bucket = i + 1; + bucket.SetSlot((uint)indexInBucket, 0); + toRemove = default!; } } - - _entries = entries; } - public bool Remove(TKey key) + private void RemoveEntry(ref Entry entry, int entryIndex) { - // The overload Remove(TKey key, out TValue value) is a copy of this method with one additional - // statement to copy the value for entry being removed into the output parameter. - // Code has been intentionally duplicated for performance reasons. + Debug.Assert((StartOfFreeList - _freeList) < 0, "shouldn't underflow because max hashtable length is MaxPrimeArrayLength = 0x7FEFFFFD(2146435069) _freelist underflow threshold 2147483646"); + entry.next = StartOfFreeList - _freeList; - if (key == null) + if (RuntimeHelpers.IsReferenceOrContainsReferences()) { - ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); + entry.key = default!; } - if (_buckets != null) + if (RuntimeHelpers.IsReferenceOrContainsReferences()) { - Debug.Assert(_entries != null, "entries should be non-null"); - uint collisionCount = 0; - - IEqualityComparer? comparer = _comparer; - Debug.Assert(typeof(TKey).IsValueType || comparer is not null); - uint hashCode = (uint)(typeof(TKey).IsValueType && comparer == null ? key.GetHashCode() : comparer!.GetHashCode(key)); - - ref int bucket = ref GetBucket(hashCode); - Entry[]? entries = _entries; - int last = -1; - int i = bucket - 1; // Value in buckets is 1-based - while (i >= 0) - { - ref Entry entry = ref entries[i]; - - if (entry.hashCode == hashCode && - (typeof(TKey).IsValueType && comparer == null ? EqualityComparer.Default.Equals(entry.key, key) : comparer!.Equals(entry.key, key))) - { - if (last < 0) - { - bucket = entry.next + 1; // Value in buckets is 1-based - } - else - { - entries[last].next = entry.next; - } - - Debug.Assert((StartOfFreeList - _freeList) < 0, "shouldn't underflow because max hashtable length is MaxPrimeArrayLength = 0x7FEFFFFD(2146435069) _freelist underflow threshold 2147483646"); - entry.next = StartOfFreeList - _freeList; - - if (RuntimeHelpers.IsReferenceOrContainsReferences()) - { - entry.key = default!; - } - - if (RuntimeHelpers.IsReferenceOrContainsReferences()) - { - entry.value = default!; - } - - _freeList = i; - _freeCount++; - return true; - } + entry.value = default!; + } - last = i; - i = entry.next; + _freeList = entryIndex; + _freeCount++; + } - collisionCount++; - if (collisionCount > (uint)entries.Length) - { - // The chain of entries forms a loop; which means a concurrent update has happened. - // Break out of the loop and throw, rather than looping forever. - ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported(); - } - } - } - return false; + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public bool Remove(TKey key) + { + return Remove(key, out Unsafe.NullRef()!); } public bool Remove(TKey key, [MaybeNullWhen(false)] out TValue value) { - // This overload is a copy of the overload Remove(TKey key) with one additional - // statement to copy the value for entry being removed into the output parameter. - // Code has been intentionally duplicated for performance reasons. + var comparer = _comparer; + return (typeof(TKey).IsValueType && (comparer == null)) + ? Remove(default(DefaultValueTypeComparerComparisonProtocol), key, out Unsafe.NullRef()!, out value) + : Remove(new ComparerComparisonProtocol(comparer!), key, out Unsafe.NullRef()!, out value); + } + + private bool Remove( + TProtocol protocol, TActualKey key, + [MaybeNullWhen(false)] out TKey actualKey, + [MaybeNullWhen(false)] out TValue value + ) + where TProtocol : struct, IComparisonProtocol + where TActualKey : allows ref struct + { + // This allows using Remove(key, out value) to implement Remove(key) efficiently, + // as long as we check whether value is a null reference before writing to it. + Unsafe.SkipInit(out actualKey); + Unsafe.SkipInit(out value); if (key == null) { ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); } - if (_buckets != null) + if (_buckets == null) { - Debug.Assert(_entries != null, "entries should be non-null"); - uint collisionCount = 0; + if (!Unsafe.IsNullRef(ref actualKey)) + actualKey = default!; + if (!Unsafe.IsNullRef(ref value)) + value = default!; + return false; + } - IEqualityComparer? comparer = _comparer; - Debug.Assert(typeof(TKey).IsValueType || comparer is not null); - uint hashCode = (uint)(typeof(TKey).IsValueType && comparer == null ? key.GetHashCode() : comparer!.GetHashCode(key)); + uint hashCode = (uint)protocol.GetHashCode(key); + var suffix = GetHashSuffix(hashCode); + var vectorized = Sse2.IsSupported || AdvSimd.Arm64.IsSupported || PackedSimd.IsSupported; + Vector128 searchVector = vectorized ? Vector128.Create(suffix) : default; + Span entries = _entries!; + Debug.Assert(!entries.IsEmpty, "expected entries to be non-null"); - ref int bucket = ref GetBucket(hashCode); - Entry[]? entries = _entries; - int last = -1; - int i = bucket - 1; // Value in buckets is 1-based - while (i >= 0) - { - ref Entry entry = ref entries[i]; + ref var bucket = ref LoopingBucketEnumerator.New(_buckets, hashCode, _fastModMultiplier, out var enumerator); - if (entry.hashCode == hashCode && - (typeof(TKey).IsValueType && comparer == null ? EqualityComparer.Default.Equals(entry.key, key) : comparer!.Equals(entry.key, key))) - { - if (last < 0) - { - bucket = entry.next + 1; // Value in buckets is 1-based - } - else - { - entries[last].next = entry.next; - } + // FIXME: Change to do { } while () by introducing EmptyBuckets optimization from SimdDictionary + while (!Unsafe.IsNullRef(ref bucket)) + { + // Pipelining + int bucketCount = bucket.Count; + // Determine start index for key search + int startIndex = vectorized + ? bucket.FindSuffix(bucketCount, suffix, searchVector) + : bucket.FindSuffixScalar(bucketCount, suffix); + ref var entry = ref FindEntryInBucket( + protocol, ref bucket, entries, startIndex, bucketCount, key, + out int entryIndex, out int indexInBucket + ); + if (!Unsafe.IsNullRef(ref entry)) + { + if (!Unsafe.IsNullRef(ref actualKey)) + actualKey = entry.key; + if (!Unsafe.IsNullRef(ref value)) value = entry.value; + // NOTE: We don't increment version because it's documented that removal during enumeration works. + RemoveEntry(ref entry, entryIndex); + RemoveIndexFromBucket(ref bucket, indexInBucket, bucketCount); + AdjustCascadeCounts(enumerator, false); + return true; + } - Debug.Assert((StartOfFreeList - _freeList) < 0, "shouldn't underflow because max hashtable length is MaxPrimeArrayLength = 0x7FEFFFFD(2146435069) _freelist underflow threshold 2147483646"); - entry.next = StartOfFreeList - _freeList; - - if (RuntimeHelpers.IsReferenceOrContainsReferences()) - { - entry.key = default!; - } - - if (RuntimeHelpers.IsReferenceOrContainsReferences()) - { - entry.value = default!; - } - - _freeList = i; - _freeCount++; - return true; - } - - last = i; - i = entry.next; - - collisionCount++; - if (collisionCount > (uint)entries.Length) - { - // The chain of entries forms a loop; which means a concurrent update has happened. - // Break out of the loop and throw, rather than looping forever. - ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported(); - } + if (bucket.CascadeCount == 0) + { + if (!Unsafe.IsNullRef(ref actualKey)) + actualKey = default!; + if (!Unsafe.IsNullRef(ref value)) + value = default!; + return false; } + + bucket = ref enumerator.Advance(); } - value = default; + if (!Unsafe.IsNullRef(ref actualKey)) + actualKey = default!; + if (!Unsafe.IsNullRef(ref value)) + value = default!; + ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported(); return false; } @@ -1451,8 +1510,11 @@ public bool TryGetValue(TKey key, [MaybeNullWhen(false)] out TValue value) return false; } - public bool TryAdd(TKey key, TValue value) => - TryInsert(key, value, InsertionBehavior.None); + public bool TryAdd(TKey key, TValue value) + { + TryInsert(key, value, InsertionBehavior.InsertNewOnly, out bool exists); + return !exists; + } bool ICollection>.IsReadOnly => false; @@ -1554,7 +1616,7 @@ public int EnsureCapacity(int capacity) } int newSize = HashHelpers.GetPrime(capacity); - Resize(newSize, forceNewHashCodes: false); + Resize(newSize); return newSize; } @@ -1582,50 +1644,36 @@ public int EnsureCapacity(int capacity) /// Passed capacity is lower than entries count. public void TrimExcess(int capacity) { + int allocatedEntryCount = _count; if (capacity < Count) { ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.capacity); } int newSize = HashHelpers.GetPrime(capacity); - Entry[]? oldEntries = _entries; - int currentCapacity = oldEntries == null ? 0 : oldEntries.Length; + Span oldEntries = _entries; + int currentCapacity = oldEntries.IsEmpty ? 0 : oldEntries.Length; if (newSize >= currentCapacity) { return; } - int oldCount = _count; _version++; Initialize(newSize); - Debug.Assert(oldEntries is not null); - - CopyEntries(oldEntries, oldCount); - } - - private void CopyEntries(Entry[] entries, int count) - { - Debug.Assert(_entries is not null); - - Entry[] newEntries = _entries; - int newCount = 0; - for (int i = 0; i < count; i++) + // FIXME: Write a dedicated special-case implementation of this loop maybe? + // Not sure how much faster it could actually be. + if (!oldEntries.IsEmpty) { - uint hashCode = entries[i].hashCode; - if (entries[i].next >= -1) + for (int i = 0; i < allocatedEntryCount; i++) { - ref Entry entry = ref newEntries[newCount]; - entry = entries[i]; - ref int bucket = ref GetBucket(hashCode); - entry.next = bucket - 1; // Value in _buckets is 1-based - bucket = newCount + 1; - newCount++; + ref var entry = ref oldEntries[i]; + // Initialize zeroed our count and created new bucket/entry arrays so we can use the regular insert operation + // to repopulate our new backing stores + if (entry.next >= -1) + TryInsert(entry.key, entry.value, InsertionBehavior.ThrowOnExisting, out _); } } - - _count = newCount; - _freeCount = 0; } bool ICollection.IsSynchronized => false; @@ -1738,23 +1786,47 @@ void IDictionary.Remove(object key) } } + // The hash suffix is selected from 8 bits of the hash, and then modified to ensure + // it is never zero (because a zero suffix indicates an empty slot.) [MethodImpl(MethodImplOptions.AggressiveInlining)] - private ref int GetBucket(uint hashCode) + private static byte GetHashSuffix(uint hashCode) { - int[] buckets = _buckets!; + // We could shift by 24 bits to take the other end of the value, but taking the low 8 + // bits produces better results for the common scenario where you're using sequential + // integers as keys (since their default hash is the identity function). + var result = unchecked((byte)hashCode); + // Assuming the JIT turns this into a cmov, this should be better than a bitwise or + // since it nearly doubles the number of possible suffixes, improving collision + // resistance and reducing the odds of having to check multiple keys. + return result == 0 ? (byte)255 : result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int GetBucketCountForEntryCount(int count) + { + int result = checked((count + Bucket.Capacity - 1) / Bucket.Capacity); + return (result > 1) + ? HashHelpers.GetPrime(result) + : result; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int GetBucketIndexForHashCode(Span buckets, uint hashCode, ulong fastModMultiplier) + { + unchecked + { #if TARGET_64BIT - return ref buckets[HashHelpers.FastMod(hashCode, (uint)buckets.Length, _fastModMultiplier)]; + return (int)HashHelpers.FastMod(hashCode, (uint)buckets.Length, fastModMultiplier); #else - return ref buckets[(uint)hashCode % buckets.Length]; + return (int)(hashCode % buckets.Length); #endif + } } private struct Entry { - public uint hashCode; /// - /// 0-based index of next entry in chain: -1 means end of chain - /// also encodes whether this entry _itself_ is part of the free list by changing sign and subtracting 3, + /// encodes whether this entry _itself_ is part of the free list by changing sign and subtracting 3, /// so -2 means end of free list, -3 means index 0 but on free list, -4 means index 1 but on free list, etc. /// public int next; diff --git a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/InsertionBehavior.cs b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/InsertionBehavior.cs index 378283a2419d0..b097153a56687 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/InsertionBehavior.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/InsertionBehavior.cs @@ -10,8 +10,9 @@ internal enum InsertionBehavior : byte { /// /// The default insertion behavior. + /// Specifies that if an existing entry with the same key is encountered, the insertion operation should fail without throwing. /// - None = 0, + InsertNewOnly = 0, /// /// Specifies that an existing entry with the same key should be overwritten if encountered.