From 5fc190029e19f800ba038d40048f02539e991807 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 2 Feb 2023 13:08:53 -0500 Subject: [PATCH 1/4] Improve ConcurrentDictionary performance, in particular for strings - By default, string hash codes are randomized. This is an important defense-in-depth security measure, but it also adds some overhead when strings are used as keys in dictionaries. `Dictionary<>` addresses that overhead by starting out with using non-randomized comparers, and then upgrades to randomized comparers only once enough collisions have been detected. This PR updates `ConcurrentDictionary<>` with similar tricks. The comparer is moved from being stored on the dictionary itself to instead be stored on the Tables object that's atomically swapped when the table grows; that way, the comparer always remains in sync with the hashcodes stored in the nodes in that table. When we enumerate a bucket looking for an existing key as part of an add, we count how many items we traverse, and if that resulting number is larger than the allowed threshold and we're currently using a non-randomized comparer, we force a rehash; that rehash will replace the non-randomized comparer with the equivalent randomized one. - The `ConcurrentDictionary<>` ctor is improved to presize based on the size of a collection being passed in; otherwise, it might resize multiple times as it's populating the dictionary. The sizing logic is also changed to use the same prime bucket selection size as does `Dictionary<>`. - The method we were using to compute the bucket for a key wasn't being inlined for reference type keys due to the generic context; that method has been moved to the outer type as a static to avoid the non-inlined call and extra generic dictionary lookup. - For all key types, we were also paying for a non-inlined ldelema helper call when reading the head node of a bucket; that's been addressed via a struct wrapper with a volatile node field, rather than using Volatile.Read to access the array element. - We were inconsistent in whether checked math was used in computing the size of the table. In some situations it would be possible to overflow without it being detected, or for it to be detected and manifest in various ways. This simplifies to just always use checked for computing the counts. - Remove unnecessary try/finally blocks that are leftover from CERs and thread abort protection. - Deduped some code with calls to helper functions. --- .../Concurrent/ConcurrentDictionary.cs | 955 ++++++++++-------- 1 file changed, 518 insertions(+), 437 deletions(-) diff --git a/src/libraries/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs b/src/libraries/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs index 187c7b5f6b91d9..b5dad696f4b0bb 100644 --- a/src/libraries/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs +++ b/src/libraries/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs @@ -22,25 +22,31 @@ namespace System.Collections.Concurrent public class ConcurrentDictionary : IDictionary, IDictionary, IReadOnlyDictionary where TKey : notnull { /// Internal tables of the dictionary. - private volatile Tables _tables; - /// Key equality comparer. - private readonly IEqualityComparer? _comparer; - /// Default comparer for TKey. /// - /// Used to avoid repeatedly accessing the shared default generic static, in particular for reference types where it's - /// currently not devirtualized: https://github.com/dotnet/runtime/issues/10050. + /// When using , we must read the volatile _tables field into a local variable: + /// it is set to a new table on each table resize. Volatile.Reads on array elements then ensure that + /// we have a copy of the reference to tables._buckets[bucketNo]: this protects us from reading fields + /// ('_hashcode', '_key', '_value' and '_next') of different instances. /// - private readonly EqualityComparer _defaultComparer; - /// Whether to dynamically increase the size of the striped lock. - private readonly bool _growLockArray; + private volatile Tables _tables; /// The maximum number of elements per lock before a resize operation is triggered. private int _budget; + /// Whether to dynamically increase the size of the striped lock. + private readonly bool _growLockArray; + /// Whether a non-null comparer in is the default comparer. + /// + /// This is only used for reference types. It lets us use the key's GetHashCode directly rather than going indirectly + /// through the comparer. It can't be used for Equals, as the key might implement IEquatable and employ different + /// equality semantics than the virtual Equals, however unlikely that may be. This field enables us to save an + /// interface dispatch when using the default comparer with a non-string reference type key, at the expense of an + /// extra branch when using a custom comparer with a reference type key. + /// + private readonly bool _comparerIsDefaultForClasses; /// The default capacity, i.e. the initial # of buckets. /// /// When choosing this value, we are making a trade-off between the size of a very small dictionary, - /// and the number of resizes when constructing a large dictionary. Also, the capacity should not be - /// divisible by a small prime. + /// and the number of resizes when constructing a large dictionary. /// private const int DefaultCapacity = 31; @@ -59,7 +65,8 @@ public class ConcurrentDictionary : IDictionary, IDi /// class that is empty, has the default concurrency level, has the default initial capacity, and /// uses the default comparer for the key type. /// - public ConcurrentDictionary() : this(DefaultConcurrencyLevel, DefaultCapacity, growLockArray: true, null) { } + public ConcurrentDictionary() + : this(DefaultConcurrencyLevel, DefaultCapacity, growLockArray: true, null) { } /// /// Initializes a new instance of the @@ -71,7 +78,8 @@ public ConcurrentDictionary() : this(DefaultConcurrencyLevel, DefaultCapacity, g /// The initial number of elements that the can contain. /// is less than 1. /// is less than 0. - public ConcurrentDictionary(int concurrencyLevel, int capacity) : this(concurrencyLevel, capacity, growLockArray: false, null) { } + public ConcurrentDictionary(int concurrencyLevel, int capacity) + : this(concurrencyLevel, capacity, growLockArray: false, null) { } /// /// Initializes a new instance of the @@ -82,7 +90,8 @@ public ConcurrentDictionary(int concurrencyLevel, int capacity) : this(concurren /// cref="IEnumerable{T}"/> whose elements are copied to the new . /// is a null reference (Nothing in Visual Basic). /// contains one or more duplicate keys. - public ConcurrentDictionary(IEnumerable> collection) : this(collection, null) { } + public ConcurrentDictionary(IEnumerable> collection) + : this(DefaultConcurrencyLevel, collection, null) { } /// /// Initializes a new instance of the @@ -90,7 +99,8 @@ public ConcurrentDictionary(IEnumerable> collection) /// . /// /// The implementation to use when comparing keys. - public ConcurrentDictionary(IEqualityComparer? comparer) : this(DefaultConcurrencyLevel, DefaultCapacity, growLockArray: true, comparer) { } + public ConcurrentDictionary(IEqualityComparer? comparer) + : this(DefaultConcurrencyLevel, DefaultCapacity, growLockArray: true, comparer) { } /// /// Initializes a new instance of the @@ -101,7 +111,7 @@ public ConcurrentDictionary(IEqualityComparer? comparer) : this(DefaultCon /// The implementation to use when comparing keys. /// is a null reference (Nothing in Visual Basic). public ConcurrentDictionary(IEnumerable> collection, IEqualityComparer? comparer) - : this(comparer) + : this(DefaultConcurrencyLevel, GetCapacityFromCollection(collection), comparer) { ArgumentNullException.ThrowIfNull(collection); @@ -124,35 +134,13 @@ public ConcurrentDictionary(IEnumerable> collection, /// is less than 1. /// contains one or more duplicate keys. public ConcurrentDictionary(int concurrencyLevel, IEnumerable> collection, IEqualityComparer? comparer) - : this(concurrencyLevel, DefaultCapacity, growLockArray: false, comparer) + : this(concurrencyLevel, GetCapacityFromCollection(collection), growLockArray: false, comparer) { ArgumentNullException.ThrowIfNull(collection); InitializeFromCollection(collection); } - private void InitializeFromCollection(IEnumerable> collection) - { - foreach (KeyValuePair pair in collection) - { - if (pair.Key is null) - { - ThrowHelper.ThrowKeyNullException(); - } - - if (!TryAddInternal(pair.Key, null, pair.Value, updateIfExists: false, acquireLock: false, out _)) - { - throw new ArgumentException(SR.ConcurrentDictionary_SourceContainsDuplicateKeys); - } - } - - if (_budget == 0) - { - Tables tables = _tables; - _budget = tables._buckets.Length / tables._locks.Length; - } - } - /// /// Initializes a new instance of the /// class that is empty, has the specified concurrency level, has the specified initial capacity, and @@ -173,11 +161,12 @@ internal ConcurrentDictionary(int concurrencyLevel, int capacity, bool growLockA ArgumentOutOfRangeException.ThrowIfNegative(capacity); // The capacity should be at least as large as the concurrency level. Otherwise, we would have locks that don't guard - // any buckets. + // any buckets. We also want it to be a prime. if (capacity < concurrencyLevel) { capacity = concurrencyLevel; } + capacity = HashHelpers.GetPrime(capacity); var locks = new object[concurrencyLevel]; locks[0] = locks; // reuse array as the first lock object just to avoid an additional allocation @@ -187,20 +176,121 @@ internal ConcurrentDictionary(int concurrencyLevel, int capacity, bool growLockA } var countPerLock = new int[locks.Length]; - var buckets = new Node[capacity]; - _tables = new Tables(buckets, locks, countPerLock); + var buckets = new VolatileNode[capacity]; - _defaultComparer = EqualityComparer.Default; - if (comparer != null && - !ReferenceEquals(comparer, _defaultComparer) && // if this is the default comparer, take the optimized path - !ReferenceEquals(comparer, StringComparer.Ordinal)) // strings as keys are extremely common, so special-case StringComparer.Ordinal, which is the same as the default comparer + // For reference types, we always want to store a comparer instance, either the one provided, or if + // one wasn't provided, the default (accessing EqualityComparer.Default with shared generics + // on every dictionary access can add measurable overhead). For value types, if no comparer is provided, + // or if the default is provided, we'd prefer to use EqualityComparer.Default.Equals/GetHashCode + // on every use, enabling the JIT to devirtualize and possibly inline the operation. + if (typeof(TKey).IsValueType) { - _comparer = comparer; + if (comparer is not null && // first check for null to avoid forcing default comparer instantiation unnecessarily + ReferenceEquals(comparer, EqualityComparer.Default)) + { + comparer = null; + } + } + else + { + comparer ??= EqualityComparer.Default; + + // Special-case EqualityComparer.Default, StringComparer.Ordinal, and StringComparer.OrdinalIgnoreCase. + // We use a non-randomized comparer for improved perf, falling back to a randomized comparer if the + // hash buckets become unbalanced. + if (typeof(TKey) == typeof(string) && + NonRandomizedStringEqualityComparer.GetStringComparer(comparer) is IEqualityComparer stringComparer) + { + comparer = (IEqualityComparer)stringComparer; + } + else if (ReferenceEquals(comparer, EqualityComparer.Default)) + { + _comparerIsDefaultForClasses = true; + } } + + _tables = new Tables(buckets, locks, countPerLock, comparer); _growLockArray = growLockArray; _budget = buckets.Length / locks.Length; } + /// Computes an initial capacity to use based on an initial seed collection. + /// The collection with which to initially populate this dictionary. + /// The capacity to use. + /// + /// Growing is expensive, and we don't know if the caller plans to add additional items beyond this + /// initial collection, so we use the maximum of the collection's size and the default capacity. That way, + /// the initial capacity selected isn't pessimized by seeding it with a collection that happens to be + /// smaller. + /// + private static int GetCapacityFromCollection(IEnumerable> collection) => + collection is ICollection> c ? Math.Max(DefaultCapacity, c.Count) : + collection is IReadOnlyCollection> rc ? Math.Max(DefaultCapacity, rc.Count) : + DefaultCapacity; + + /// Computes the hash code for the specified key using the dictionary's comparer. + /// + /// The comparer. It's passed in to avoid having to look it up via a volatile read on ; + /// such a comparer could also be incorrect if the table upgraded comparer concurrently. + /// + /// The key for which to compute the hash code. + /// The hash code of the key. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private int GetHashCode(IEqualityComparer? comparer, TKey key) + { + if (typeof(TKey).IsValueType) + { + return comparer is null ? key.GetHashCode() : comparer.GetHashCode(key); + } + + Debug.Assert(comparer is not null); + return _comparerIsDefaultForClasses ? key.GetHashCode() : comparer.GetHashCode(key); + } + + /// Determines whether the specified key and the key stored in the specified node are equal. + /// + /// The comparer. It's passed in to avoid having to look it up via a volatile read on ; + /// such a comparer could also be incorrect if the table upgraded comparer concurrently. + /// + /// The node containing the key to compare. + /// The other key to compare. + /// true if the keys are equal; otherwise, false. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool NodeEqualsKey(IEqualityComparer? comparer, Node node, TKey key) + { + if (typeof(TKey).IsValueType) + { + return comparer is null ? + EqualityComparer.Default.Equals(node._key, key) : + comparer.Equals(node._key, key); + } + + Debug.Assert(comparer is not null); + return comparer.Equals(node._key, key); + } + + private void InitializeFromCollection(IEnumerable> collection) + { + foreach (KeyValuePair pair in collection) + { + if (pair.Key is null) + { + ThrowHelper.ThrowKeyNullException(); + } + + if (!TryAddInternal(_tables, pair.Key, null, pair.Value, updateIfExists: false, acquireLock: false, out _)) + { + throw new ArgumentException(SR.ConcurrentDictionary_SourceContainsDuplicateKeys); + } + } + + if (_budget == 0) + { + Tables tables = _tables; + _budget = tables._buckets.Length / tables._locks.Length; + } + } + /// /// Attempts to add the specified key and value to the . /// @@ -219,7 +309,7 @@ public bool TryAdd(TKey key, TValue value) ThrowHelper.ThrowKeyNullException(); } - return TryAddInternal(key, null, value, updateIfExists: false, acquireLock: true, out _); + return TryAddInternal(_tables, key, null, value, updateIfExists: false, acquireLock: true, out _); } /// @@ -228,15 +318,7 @@ public bool TryAdd(TKey key, TValue value) /// The key to locate in the . /// true if the contains an element with the specified key; otherwise, false. /// is a null reference (Nothing in Visual Basic). - public bool ContainsKey(TKey key) - { - if (key is null) - { - ThrowHelper.ThrowKeyNullException(); - } - - return TryGetValue(key, out _); - } + public bool ContainsKey(TKey key) => TryGetValue(key, out _); /// /// Attempts to remove and return the value with the specified key from the . @@ -295,13 +377,15 @@ public bool TryRemove(KeyValuePair item) /// The conditional value to compare against if is true private bool TryRemoveInternal(TKey key, [MaybeNullWhen(false)] out TValue value, bool matchValue, TValue? oldValue) { - IEqualityComparer? comparer = _comparer; - int hashcode = comparer is null ? key.GetHashCode() : comparer.GetHashCode(key); + Tables tables = _tables; + + IEqualityComparer? comparer = tables._comparer; + int hashcode = GetHashCode(comparer, key); + while (true) { - Tables tables = _tables; object[] locks = tables._locks; - ref Node? bucket = ref tables.GetBucketAndLock(hashcode, out uint lockNo); + ref Node? bucket = ref GetBucketAndLock(tables, hashcode, out uint lockNo); lock (locks[lockNo]) { @@ -309,15 +393,21 @@ private bool TryRemoveInternal(TKey key, [MaybeNullWhen(false)] out TValue value // This should be a rare occurrence. if (tables != _tables) { + tables = _tables; + if (!ReferenceEquals(comparer, tables._comparer)) + { + comparer = tables._comparer; + hashcode = GetHashCode(comparer, key); + } continue; } Node? prev = null; - for (Node? curr = bucket; curr != null; curr = curr._next) + for (Node? curr = bucket; curr is not null; curr = curr._next) { Debug.Assert((prev is null && curr == bucket) || prev!._next == curr); - if (hashcode == curr._hashcode && (comparer is null ? _defaultComparer.Equals(curr._key, key) : comparer.Equals(curr._key, key))) + if (hashcode == curr._hashcode && NodeEqualsKey(comparer, curr, key)) { if (matchValue) { @@ -369,42 +459,27 @@ public bool TryGetValue(TKey key, [MaybeNullWhen(false)] out TValue value) ThrowHelper.ThrowKeyNullException(); } - // We must capture the volatile _tables field into a local variable: it is set to a new table on each table resize. - // The Volatile.Read on the array element then ensures that we have a copy of the reference to tables._buckets[bucketNo]: - // this protects us from reading fields ('_hashcode', '_key', '_value' and '_next') of different instances. Tables tables = _tables; - IEqualityComparer? comparer = _comparer; - if (comparer is null) + IEqualityComparer? comparer = tables._comparer; + if (typeof(TKey).IsValueType && // comparer can only be null for value types; enable JIT to eliminate entire if block for ref types + comparer is null) { int hashcode = key.GetHashCode(); - if (typeof(TKey).IsValueType) + for (Node? n = GetBucket(tables, hashcode); n is not null; n = n._next) { - for (Node? n = Volatile.Read(ref tables.GetBucket(hashcode)); n != null; n = n._next) + if (hashcode == n._hashcode && EqualityComparer.Default.Equals(n._key, key)) { - if (hashcode == n._hashcode && EqualityComparer.Default.Equals(n._key, key)) - { - value = n._value; - return true; - } - } - } - else - { - for (Node? n = Volatile.Read(ref tables.GetBucket(hashcode)); n != null; n = n._next) - { - if (hashcode == n._hashcode && _defaultComparer.Equals(n._key, key)) - { - value = n._value; - return true; - } + value = n._value; + return true; } } } else { - int hashcode = comparer.GetHashCode(key); - for (Node? n = Volatile.Read(ref tables.GetBucket(hashcode)); n != null; n = n._next) + Debug.Assert(comparer is not null); + int hashcode = _comparerIsDefaultForClasses ? key.GetHashCode() : comparer.GetHashCode(key); + for (Node? n = GetBucket(tables, hashcode); n is not null; n = n._next) { if (hashcode == n._hashcode && comparer.Equals(n._key, key)) { @@ -418,45 +493,26 @@ public bool TryGetValue(TKey key, [MaybeNullWhen(false)] out TValue value) return false; } - private bool TryGetValueInternal(TKey key, int hashcode, [MaybeNullWhen(false)] out TValue value) + private static bool TryGetValueInternal(Tables tables, TKey key, int hashcode, [MaybeNullWhen(false)] out TValue value) { - Debug.Assert((_comparer is null ? key.GetHashCode() : _comparer.GetHashCode(key)) == hashcode, - $"Invalid comparer: _comparer {_comparer} key {key} _comparer.GetHashCode(key) {_comparer?.GetHashCode(key)} hashcode {hashcode}"); - - // We must capture the volatile _tables field into a local variable: it is set to a new table on each table resize. - // The Volatile.Read on the array element then ensures that we have a copy of the reference to tables._buckets[bucketNo]: - // this protects us from reading fields ('_hashcode', '_key', '_value' and '_next') of different instances. - Tables tables = _tables; + IEqualityComparer? comparer = tables._comparer; - IEqualityComparer? comparer = _comparer; - if (comparer is null) + if (typeof(TKey).IsValueType && // comparer can only be null for value types; enable JIT to eliminate entire if block for ref types + comparer is null) { - if (typeof(TKey).IsValueType) + for (Node? n = GetBucket(tables, hashcode); n is not null; n = n._next) { - for (Node? n = Volatile.Read(ref tables.GetBucket(hashcode)); n != null; n = n._next) + if (hashcode == n._hashcode && EqualityComparer.Default.Equals(n._key, key)) { - if (hashcode == n._hashcode && EqualityComparer.Default.Equals(n._key, key)) - { - value = n._value; - return true; - } - } - } - else - { - for (Node? n = Volatile.Read(ref tables.GetBucket(hashcode)); n != null; n = n._next) - { - if (hashcode == n._hashcode && _defaultComparer.Equals(n._key, key)) - { - value = n._value; - return true; - } + value = n._value; + return true; } } } else { - for (Node? n = Volatile.Read(ref tables.GetBucket(hashcode)); n != null; n = n._next) + Debug.Assert(comparer is not null); + for (Node? n = GetBucket(tables, hashcode); n is not null; n = n._next) { if (hashcode == n._hashcode && comparer.Equals(n._key, key)) { @@ -492,13 +548,14 @@ public bool TryUpdate(TKey key, TValue newValue, TValue comparisonValue) ThrowHelper.ThrowKeyNullException(); } - return TryUpdateInternal(key, null, newValue, comparisonValue); + return TryUpdateInternal(_tables, key, null, newValue, comparisonValue); } /// /// Updates the value associated with to if the existing value is equal /// to . /// + /// The tables that were used to create the hash code. /// The key whose value is compared with and /// possibly replaced. /// The hashcode computed for . @@ -511,25 +568,19 @@ public bool TryUpdate(TKey key, TValue newValue, TValue comparisonValue) /// replaced with ; otherwise, false. /// /// is a null reference. - private bool TryUpdateInternal(TKey key, int? nullableHashcode, TValue newValue, TValue comparisonValue) + private bool TryUpdateInternal(Tables tables, TKey key, int? nullableHashcode, TValue newValue, TValue comparisonValue) { - IEqualityComparer? comparer = _comparer; - - Debug.Assert( - nullableHashcode is null || - (comparer is null ? key.GetHashCode() : comparer.GetHashCode(key)) == nullableHashcode); + IEqualityComparer? comparer = tables._comparer; - int hashcode = - nullableHashcode ?? - (comparer is null ? key.GetHashCode() : comparer.GetHashCode(key)); + int hashcode = nullableHashcode ?? GetHashCode(comparer, key); + Debug.Assert(nullableHashcode is null || nullableHashcode == hashcode); EqualityComparer valueComparer = EqualityComparer.Default; while (true) { - Tables tables = _tables; object[] locks = tables._locks; - ref Node? bucket = ref tables.GetBucketAndLock(hashcode, out uint lockNo); + ref Node? bucket = ref GetBucketAndLock(tables, hashcode, out uint lockNo); lock (locks[lockNo]) { @@ -537,15 +588,21 @@ nullableHashcode is null || // This should be a rare occurrence. if (tables != _tables) { + tables = _tables; + if (!ReferenceEquals(comparer, tables._comparer)) + { + comparer = tables._comparer; + hashcode = GetHashCode(comparer, key); + } continue; } // Try to find this key in the bucket Node? prev = null; - for (Node? node = bucket; node != null; node = node._next) + for (Node? node = bucket; node is not null; node = node._next) { Debug.Assert((prev is null && node == bucket) || prev!._next == node); - if (hashcode == node._hashcode && (comparer is null ? _defaultComparer.Equals(node._key, key) : comparer.Equals(node._key, key))) + if (hashcode == node._hashcode && NodeEqualsKey(comparer, node, key)) { if (valueComparer.Equals(node._value, comparisonValue)) { @@ -603,13 +660,13 @@ public void Clear() } Tables tables = _tables; - var newTables = new Tables(new Node[DefaultCapacity], tables._locks, new int[tables._countPerLock.Length]); + var newTables = new Tables(new VolatileNode[HashHelpers.GetPrime(DefaultCapacity)], tables._locks, new int[tables._countPerLock.Length], tables._comparer); _tables = newTables; _budget = Math.Max(1, newTables._buckets.Length / newTables._locks.Length); } finally { - ReleaseLocks(0, locksAcquired); + ReleaseLocks(locksAcquired); } } @@ -640,14 +697,8 @@ void ICollection>.CopyTo(KeyValuePair[] { AcquireAllLocks(ref locksAcquired); - int count = 0; - int[] countPerLock = _tables._countPerLock; - for (int i = 0; i < countPerLock.Length && count >= 0; i++) - { - count += countPerLock[i]; - } - - if (array.Length - count < index || count < 0) //"count" itself or "count + index" can overflow + int count = GetCountNoLocks(); + if (array.Length - count < index) { throw new ArgumentException(SR.ConcurrentDictionary_ArrayNotLargeEnough); } @@ -656,7 +707,7 @@ void ICollection>.CopyTo(KeyValuePair[] } finally { - ReleaseLocks(0, locksAcquired); + ReleaseLocks(locksAcquired); } } @@ -673,16 +724,7 @@ public KeyValuePair[] ToArray() { AcquireAllLocks(ref locksAcquired); - int count = 0; - int[] countPerLock = _tables._countPerLock; - for (int i = 0; i < countPerLock.Length; i++) - { - checked - { - count += countPerLock[i]; - } - } - + int count = GetCountNoLocks(); if (count == 0) { return Array.Empty>(); @@ -694,7 +736,7 @@ public KeyValuePair[] ToArray() } finally { - ReleaseLocks(0, locksAcquired); + ReleaseLocks(locksAcquired); } } @@ -702,13 +744,13 @@ public KeyValuePair[] ToArray() /// Important: the caller must hold all locks in _locks before calling CopyToPairs. private void CopyToPairs(KeyValuePair[] array, int index) { - Node?[] buckets = _tables._buckets; - for (int i = 0; i < buckets.Length; i++) + foreach (VolatileNode bucket in _tables._buckets) { - for (Node? current = buckets[i]; current != null; current = current._next) + for (Node? current = bucket._node; current is not null; current = current._next) { array[index] = new KeyValuePair(current._key, current._value); - index++; // this should never overflow, CopyToPairs is only called when there's no overflow risk + Debug.Assert(index < int.MaxValue, "This method should only be called when there's no overflow risk"); + index++; } } } @@ -717,13 +759,13 @@ private void CopyToPairs(KeyValuePair[] array, int index) /// Important: the caller must hold all locks in _locks before calling CopyToPairs. private void CopyToEntries(DictionaryEntry[] array, int index) { - Node?[] buckets = _tables._buckets; - for (int i = 0; i < buckets.Length; i++) + foreach (VolatileNode bucket in _tables._buckets) { - for (Node? current = buckets[i]; current != null; current = current._next) + for (Node? current = bucket._node; current is not null; current = current._next) { array[index] = new DictionaryEntry(current._key, current._value); - index++; //this should never flow, CopyToEntries is only called when there's no overflow risk + Debug.Assert(index < int.MaxValue, "This method should only be called when there's no overflow risk"); + index++; } } } @@ -732,13 +774,13 @@ private void CopyToEntries(DictionaryEntry[] array, int index) /// Important: the caller must hold all locks in _locks before calling CopyToPairs. private void CopyToObjects(object[] array, int index) { - Node?[] buckets = _tables._buckets; - for (int i = 0; i < buckets.Length; i++) + foreach (VolatileNode bucket in _tables._buckets) { - for (Node? current = buckets[i]; current != null; current = current._next) + for (Node? current = bucket._node; current is not null; current = current._next) { array[index] = new KeyValuePair(current._key, current._value); - index++; // this should never overflow, CopyToObjects is only called when there's no overflow risk + Debug.Assert(index < int.MaxValue, "This method should only be called when there's no overflow risk"); + index++; } } } @@ -758,14 +800,14 @@ private void CopyToObjects(object[] array, int index) private sealed class Enumerator : IEnumerator> { // Provides a manually-implemented version of (approximately) this iterator: - // Node?[] buckets = _tables._buckets; + // VolatileNodeWrapper[] buckets = _tables._buckets; // for (int i = 0; i < buckets.Length; i++) - // for (Node? current = Volatile.Read(ref buckets[i]); current != null; current = current._next) + // for (Node? current = buckets[i]._node; current is not null; current = current._next) // yield return new KeyValuePair(current._key, current._value); private readonly ConcurrentDictionary _dictionary; - private ConcurrentDictionary.Node?[]? _buckets; + private ConcurrentDictionary.VolatileNode[]? _buckets; private Node? _node; private int _i; private int _state; @@ -806,23 +848,20 @@ public bool MoveNext() goto case StateOuterloop; case StateOuterloop: - ConcurrentDictionary.Node?[]? buckets = _buckets; - Debug.Assert(buckets != null); + ConcurrentDictionary.VolatileNode[]? buckets = _buckets; + Debug.Assert(buckets is not null); int i = ++_i; if ((uint)i < (uint)buckets.Length) { - // The Volatile.Read ensures that we have a copy of the reference to buckets[i]: - // this protects us from reading fields ('_key', '_value' and '_next') of different instances. - _node = Volatile.Read(ref buckets[i]); + _node = buckets[i]._node; _state = StateInnerLoop; goto case StateInnerLoop; } goto default; case StateInnerLoop: - Node? node = _node; - if (node != null) + if (_node is Node node) { Current = new KeyValuePair(node._key, node._value); _node = node._next; @@ -842,26 +881,20 @@ public bool MoveNext() /// If key exists, we always return false; and if updateIfExists == true we force update with value; /// If key doesn't exist, we always add value and return true; /// - private bool TryAddInternal(TKey key, int? nullableHashcode, TValue value, bool updateIfExists, bool acquireLock, out TValue resultingValue) + private bool TryAddInternal(Tables tables, TKey key, int? nullableHashcode, TValue value, bool updateIfExists, bool acquireLock, out TValue resultingValue) { - IEqualityComparer? comparer = _comparer; - - Debug.Assert( - nullableHashcode is null || - (comparer is null && key.GetHashCode() == nullableHashcode) || - (comparer != null && comparer.GetHashCode(key) == nullableHashcode)); + IEqualityComparer? comparer = tables._comparer; - int hashcode = - nullableHashcode ?? - (comparer is null ? key.GetHashCode() : comparer.GetHashCode(key)); + int hashcode = nullableHashcode ?? GetHashCode(comparer, key); + Debug.Assert(nullableHashcode is null || nullableHashcode == hashcode); while (true) { - Tables tables = _tables; object[] locks = tables._locks; - ref Node? bucket = ref tables.GetBucketAndLock(hashcode, out uint lockNo); + ref Node? bucket = ref GetBucketAndLock(tables, hashcode, out uint lockNo); bool resizeDesired = false; + bool forceRehash = false; bool lockTaken = false; try { @@ -874,15 +907,22 @@ nullableHashcode is null || // This should be a rare occurrence. if (tables != _tables) { + tables = _tables; + if (!ReferenceEquals(comparer, tables._comparer)) + { + comparer = tables._comparer; + hashcode = GetHashCode(comparer, key); + } continue; } // Try to find this key in the bucket + uint collisionCount = 0; Node? prev = null; - for (Node? node = bucket; node != null; node = node._next) + for (Node? node = bucket; node is not null; node = node._next) { Debug.Assert((prev is null && node == bucket) || prev!._next == node); - if (hashcode == node._hashcode && (comparer is null ? _defaultComparer.Equals(node._key, key) : comparer.Equals(node._key, key))) + if (hashcode == node._hashcode && NodeEqualsKey(comparer, node, key)) { // The key was found in the dictionary. If updates are allowed, update the value for that key. // We need to create a new node for the update, in order to support TValue types that cannot @@ -918,6 +958,10 @@ nullableHashcode is null || return false; } prev = node; + if (!typeof(TKey).IsValueType) // this is only relevant to strings, and we can avoid this code for all value types + { + collisionCount++; + } } // The key was not found in the bucket. Insert the key-value pair. @@ -928,15 +972,22 @@ nullableHashcode is null || tables._countPerLock[lockNo]++; } - // // If the number of elements guarded by this lock has exceeded the budget, resize the bucket table. // It is also possible that GrowTable will increase the budget but won't resize the bucket table. // That happens if the bucket table is found to be poorly utilized due to a bad hash function. - // if (tables._countPerLock[lockNo] > _budget) { resizeDesired = true; } + + // We similarly want to invoke redo the tables if we're using a non-randomized comparer + // and need to upgrade to a randomized comparer due to too many collisions. + if (!typeof(TKey).IsValueType && + collisionCount > HashHelpers.HashCollisionThreshold && + comparer is NonRandomizedStringEqualityComparer) + { + forceRehash = true; + } } finally { @@ -946,17 +997,15 @@ nullableHashcode is null || } } - // // The fact that we got here means that we just performed an insertion. If necessary, we will grow the table. // // Concurrency notes: // - Notice that we are not holding any locks at when calling GrowTable. This is necessary to prevent deadlocks. // - As a result, it is possible that GrowTable will be called unnecessarily. But, GrowTable will obtain lock 0 // and then verify that the table we passed to it as the argument is still the current table. - // - if (resizeDesired) + if (resizeDesired | forceRehash) { - GrowTable(tables); + GrowTable(tables, resizeDesired, forceRehash); } resultingValue = value; @@ -993,7 +1042,7 @@ public TValue this[TKey key] ThrowHelper.ThrowKeyNullException(); } - TryAddInternal(key, null, value, updateIfExists: true, acquireLock: true, out _); + TryAddInternal(_tables, key, null, value, updateIfExists: true, acquireLock: true, out _); } } @@ -1018,7 +1067,23 @@ private static void ThrowKeyNotFoundException(TKey key) => /// generic interface by using a constructor that accepts a comparer parameter; /// if you do not specify one, the default generic equality comparer is used. /// - public IEqualityComparer Comparer => _comparer ?? _defaultComparer; + public IEqualityComparer Comparer + { + get + { + IEqualityComparer? comparer = _tables._comparer; + + if (typeof(TKey) == typeof(string)) + { + if ((comparer as NonRandomizedStringEqualityComparer)?.GetUnderlyingEqualityComparer() is IEqualityComparer ec) + { + return (IEqualityComparer)ec; + } + } + + return comparer ?? EqualityComparer.Default; + } + } /// /// Gets the number of key/value pairs contained in the - /// Gets the number of key/value pairs contained in the . Should only be used after all locks - /// have been acquired. - /// - /// The dictionary contains too many - /// elements. - /// The number of key/value pairs contained in the . - /// Count has snapshot semantics and represents the number of items in the - /// at the moment when Count was accessed. - private int GetCountInternal() + /// Gets the number of pairs stored in the dictionary. + /// This assumes all of the dictionary's locks have been taken, or else the result may not be accurate. + private int GetCountNoLocks() { int count = 0; - int[] countPerLocks = _tables._countPerLock; - - // Compute the count, we allow overflow - for (int i = 0; i < countPerLocks.Length; i++) + foreach (int value in _tables._countPerLock) { - count += countPerLocks[i]; + checked { count += value; } } return count; @@ -1104,12 +1154,14 @@ public TValue GetOrAdd(TKey key, Func valueFactory) ThrowHelper.ThrowArgumentNullException(nameof(valueFactory)); } - IEqualityComparer? comparer = _comparer; - int hashcode = comparer is null ? key.GetHashCode() : comparer.GetHashCode(key); + Tables tables = _tables; - if (!TryGetValueInternal(key, hashcode, out TValue? resultingValue)) + IEqualityComparer? comparer = tables._comparer; + int hashcode = GetHashCode(comparer, key); + + if (!TryGetValueInternal(tables, key, hashcode, out TValue? resultingValue)) { - TryAddInternal(key, hashcode, valueFactory(key), updateIfExists: false, acquireLock: true, out resultingValue); + TryAddInternal(tables, key, hashcode, valueFactory(key), updateIfExists: false, acquireLock: true, out resultingValue); } return resultingValue; @@ -1143,12 +1195,14 @@ public TValue GetOrAdd(TKey key, Func valueFactory, TA ThrowHelper.ThrowArgumentNullException(nameof(valueFactory)); } - IEqualityComparer? comparer = _comparer; - int hashcode = comparer is null ? key.GetHashCode() : comparer.GetHashCode(key); + Tables tables = _tables; + + IEqualityComparer? comparer = tables._comparer; + int hashcode = GetHashCode(comparer, key); - if (!TryGetValueInternal(key, hashcode, out TValue? resultingValue)) + if (!TryGetValueInternal(tables, key, hashcode, out TValue? resultingValue)) { - TryAddInternal(key, hashcode, valueFactory(key, factoryArgument), updateIfExists: false, acquireLock: true, out resultingValue); + TryAddInternal(tables, key, hashcode, valueFactory(key, factoryArgument), updateIfExists: false, acquireLock: true, out resultingValue); } return resultingValue; @@ -1173,12 +1227,14 @@ public TValue GetOrAdd(TKey key, TValue value) ThrowHelper.ThrowKeyNullException(); } - IEqualityComparer? comparer = _comparer; - int hashcode = comparer is null ? key.GetHashCode() : comparer.GetHashCode(key); + Tables tables = _tables; + + IEqualityComparer? comparer = tables._comparer; + int hashcode = GetHashCode(comparer, key); - if (!TryGetValueInternal(key, hashcode, out TValue? resultingValue)) + if (!TryGetValueInternal(tables, key, hashcode, out TValue? resultingValue)) { - TryAddInternal(key, hashcode, value, updateIfExists: false, acquireLock: true, out resultingValue); + TryAddInternal(tables, key, hashcode, value, updateIfExists: false, acquireLock: true, out resultingValue); } return resultingValue; @@ -1222,16 +1278,18 @@ public TValue AddOrUpdate( ThrowHelper.ThrowArgumentNullException(nameof(updateValueFactory)); } - IEqualityComparer? comparer = _comparer; - int hashcode = comparer is null ? key.GetHashCode() : comparer.GetHashCode(key); + Tables tables = _tables; + + IEqualityComparer? comparer = tables._comparer; + int hashcode = GetHashCode(comparer, key); while (true) { - if (TryGetValueInternal(key, hashcode, out TValue? oldValue)) + if (TryGetValueInternal(tables, key, hashcode, out TValue? oldValue)) { // key exists, try to update TValue newValue = updateValueFactory(key, oldValue, factoryArgument); - if (TryUpdateInternal(key, hashcode, newValue, oldValue)) + if (TryUpdateInternal(tables, key, hashcode, newValue, oldValue)) { return newValue; } @@ -1239,11 +1297,22 @@ public TValue AddOrUpdate( else { // key doesn't exist, try to add - if (TryAddInternal(key, hashcode, addValueFactory(key, factoryArgument), updateIfExists: false, acquireLock: true, out TValue resultingValue)) + if (TryAddInternal(tables, key, hashcode, addValueFactory(key, factoryArgument), updateIfExists: false, acquireLock: true, out TValue resultingValue)) { return resultingValue; } } + + Tables newTables = _tables; + if (tables != newTables) + { + tables = newTables; + if (!ReferenceEquals(comparer, tables._comparer)) + { + comparer = tables._comparer; + hashcode = GetHashCode(comparer, key); + } + } } } @@ -1283,16 +1352,18 @@ public TValue AddOrUpdate(TKey key, Func addValueFactory, Func? comparer = _comparer; - int hashcode = comparer is null ? key.GetHashCode() : comparer.GetHashCode(key); + Tables tables = _tables; + + IEqualityComparer? comparer = tables._comparer; + int hashcode = GetHashCode(comparer, key); while (true) { - if (TryGetValueInternal(key, hashcode, out TValue? oldValue)) + if (TryGetValueInternal(tables, key, hashcode, out TValue? oldValue)) { // key exists, try to update TValue newValue = updateValueFactory(key, oldValue); - if (TryUpdateInternal(key, hashcode, newValue, oldValue)) + if (TryUpdateInternal(tables, key, hashcode, newValue, oldValue)) { return newValue; } @@ -1300,11 +1371,22 @@ public TValue AddOrUpdate(TKey key, Func addValueFactory, Func ThrowHelper.ThrowArgumentNullException(nameof(updateValueFactory)); } - IEqualityComparer? comparer = _comparer; - int hashcode = comparer is null ? key.GetHashCode() : comparer.GetHashCode(key); + Tables tables = _tables; + + IEqualityComparer? comparer = tables._comparer; + int hashcode = GetHashCode(comparer, key); while (true) { - if (TryGetValueInternal(key, hashcode, out TValue? oldValue)) + if (TryGetValueInternal(tables, key, hashcode, out TValue? oldValue)) { // key exists, try to update TValue newValue = updateValueFactory(key, oldValue); - if (TryUpdateInternal(key, hashcode, newValue, oldValue)) + if (TryUpdateInternal(tables, key, hashcode, newValue, oldValue)) { return newValue; } @@ -1354,11 +1438,22 @@ public TValue AddOrUpdate(TKey key, TValue addValue, Func else { // key doesn't exist, try to add - if (TryAddInternal(key, hashcode, addValue, updateIfExists: false, acquireLock: true, out TValue resultingValue)) + if (TryAddInternal(tables, key, hashcode, addValue, updateIfExists: false, acquireLock: true, out TValue resultingValue)) { return resultingValue; } } + + Tables newTables = _tables; + if (tables != newTables) + { + tables = newTables; + if (!ReferenceEquals(comparer, tables._comparer)) + { + comparer = tables._comparer; + hashcode = GetHashCode(comparer, key); + } + } } } @@ -1382,21 +1477,17 @@ public bool IsEmpty // the collection was actually empty at any point in time as items may have been // added and removed while iterating over the buckets such that we never saw an // empty bucket, but there was always an item present in at least one bucket. - int acquiredLocks = 0; + int locksAcquired = 0; try { - // Acquire all locks - AcquireAllLocks(ref acquiredLocks); + AcquireAllLocks(ref locksAcquired); return AreAllBucketsEmpty(); } finally { - // Release locks that have been acquired earlier - ReleaseLocks(0, acquiredLocks); + ReleaseLocks(locksAcquired); } - - } } @@ -1498,14 +1589,9 @@ void IDictionary.Add(TKey key, TValue value) /// cref="ICollection{TValue}"/>. /// true if the is found in the ; otherwise, false. - bool ICollection>.Contains(KeyValuePair keyValuePair) - { - if (!TryGetValue(keyValuePair.Key, out TValue? value)) - { - return false; - } - return EqualityComparer.Default.Equals(value, keyValuePair.Value); - } + bool ICollection>.Contains(KeyValuePair keyValuePair) => + TryGetValue(keyValuePair.Key, out TValue? value) && + EqualityComparer.Default.Equals(value, keyValuePair.Value); /// /// Gets a value indicating whether the dictionary is read-only. @@ -1720,14 +1806,14 @@ void IDictionary.Remove(object key) [MethodImpl(MethodImplOptions.AggressiveInlining)] private static void ThrowIfInvalidObjectValue(object? value) { - if (value != null) + if (value is not null) { if (!(value is TValue)) { ThrowHelper.ThrowValueNullException(); } } - else if (default(TValue) != null) + else if (default(TValue) is not null) { ThrowHelper.ThrowValueNullException(); } @@ -1764,16 +1850,9 @@ void ICollection.CopyTo(Array array, int index) try { AcquireAllLocks(ref locksAcquired); - Tables tables = _tables; - - int count = 0; - int[] countPerLock = tables._countPerLock; - for (int i = 0; i < countPerLock.Length && count >= 0; i++) - { - count += countPerLock[i]; - } - if (array.Length - count < index || count < 0) //"count" itself or "count + index" can overflow + int count = GetCountNoLocks(); + if (array.Length - count < index) { throw new ArgumentException(SR.ConcurrentDictionary_ArrayNotLargeEnough); } @@ -1806,7 +1885,7 @@ void ICollection.CopyTo(Array array, int index) } finally { - ReleaseLocks(0, locksAcquired); + ReleaseLocks(locksAcquired); } } @@ -1829,7 +1908,6 @@ void ICollection.CopyTo(Array array, int index) #endregion - private bool AreAllBucketsEmpty() => _tables._countPerLock.AsSpan().IndexOfAnyExcept(0) < 0; @@ -1839,13 +1917,13 @@ private bool AreAllBucketsEmpty() => /// small is passed in as an argument to GrowTable(). GrowTable() obtains a lock, and then checks /// the Tables instance has been replaced in the meantime or not. /// - private void GrowTable(Tables tables) + private void GrowTable(Tables tables, bool resizeDesired, bool forceRehashIfNonRandomized) { int locksAcquired = 0; try { // The thread that first obtains _locks[0] will be the one doing the resize operation - AcquireLocks(0, 1, ref locksAcquired); + AcquireFirstLock(ref locksAcquired); // Make sure nobody resized the table while we were waiting for lock 0: if (tables != _tables) @@ -1856,67 +1934,58 @@ private void GrowTable(Tables tables) return; } - // Compute the (approx.) total size. Use an Int64 accumulation variable to avoid an overflow. - long approxCount = 0; - for (int i = 0; i < tables._countPerLock.Length; i++) - { - approxCount += tables._countPerLock[i]; - } + int newLength = tables._buckets.Length; - // - // If the bucket array is too empty, double the budget instead of resizing the table - // - if (approxCount < tables._buckets.Length / 4) + IEqualityComparer? upgradeComparer = null; + if (forceRehashIfNonRandomized && tables._comparer is NonRandomizedStringEqualityComparer nrsec) { - _budget = 2 * _budget; - if (_budget < 0) - { - _budget = int.MaxValue; - } - return; + upgradeComparer = (IEqualityComparer)nrsec.GetUnderlyingEqualityComparer(); } - // Compute the new table size. We find the smallest integer larger than twice the previous table size, and not divisible by - // 2,3,5 or 7. We can consider a different table-sizing policy in the future. - int newLength = 0; - bool maximizeTableSize = false; - try + if (resizeDesired) { - checked + // Compute the (approx.) total size. Use an Int64 accumulation variable to avoid an overflow. + // If the bucket array is too empty, we have an imbalance. + // If we have a string key and we're still using a non-randomized comparer, + // take this as a sign that we need to upgrade to one. + // Otherwise, double the budget instead of resizing the table. + if (upgradeComparer is null && GetCountNoLocks() < tables._buckets.Length / 4) { - // Double the size of the buckets table and add one, so that we have an odd integer. - newLength = tables._buckets.Length * 2 + 1; - - // Now, we only need to check odd integers, and find the first that is not divisible - // by 3, 5 or 7. - while (newLength % 3 == 0 || newLength % 5 == 0 || newLength % 7 == 0) + _budget = 2 * _budget; + if (_budget < 0) { - newLength += 2; + _budget = int.MaxValue; } + return; + } - Debug.Assert(newLength % 2 != 0); - + // Compute the new table size at least twice the previous table size. + bool maximizeTableSize = false; + try + { + // Double the size of the buckets table and choose a prime that's at least as large. + newLength = HashHelpers.GetPrime(checked(tables._buckets.Length * 2)); if (newLength > Array.MaxLength) { maximizeTableSize = true; } } - } - catch (OverflowException) - { - maximizeTableSize = true; - } + catch (OverflowException) + { + maximizeTableSize = true; + } - if (maximizeTableSize) - { - newLength = Array.MaxLength; - - // We want to make sure that GrowTable will not be called again, since table is at the maximum size. - // To achieve that, we set the budget to int.MaxValue. - // - // (There is one special case that would allow GrowTable() to be called in the future: - // calling Clear() on the ConcurrentDictionary will shrink the table and lower the budget.) - _budget = int.MaxValue; + if (maximizeTableSize) + { + newLength = Array.MaxLength; + + // We want to make sure that GrowTable will not be called again, since table is at the maximum size. + // To achieve that, we set the budget to int.MaxValue. + // + // (There is one special case that would allow GrowTable() to be called in the future: + // calling Clear() on the ConcurrentDictionary will shrink the table and lower the budget.) + _budget = int.MaxValue; + } } object[] newLocks = tables._locks; @@ -1932,23 +2001,25 @@ private void GrowTable(Tables tables) } } - var newBuckets = new Node[newLength]; + var newBuckets = new VolatileNode[newLength]; var newCountPerLock = new int[newLocks.Length]; - var newTables = new Tables(newBuckets, newLocks, newCountPerLock); + var newTables = new Tables(newBuckets, newLocks, newCountPerLock, upgradeComparer ?? tables._comparer); // Now acquire all other locks for the table - AcquireLocks(1, tables._locks.Length, ref locksAcquired); + AcquirePostFirstLock(tables, ref locksAcquired); // Copy all data into a new table, creating new nodes for all elements - foreach (Node? bucket in tables._buckets) + foreach (VolatileNode bucket in tables._buckets) { - Node? current = bucket; - while (current != null) + Node? current = bucket._node; + while (current is not null) { + int hashCode = upgradeComparer is null ? current._hashcode : upgradeComparer.GetHashCode(current._key); + Node? next = current._next; - ref Node? newBucket = ref newTables.GetBucketAndLock(current._hashcode, out uint newLockNo); + ref Node? newBucket = ref GetBucketAndLock(newTables, hashCode, out uint newLockNo); - newBucket = new Node(current._key, current._value, current._hashcode, newBucket); + newBucket = new Node(current._key, current._value, hashCode, newBucket); checked { @@ -1967,8 +2038,7 @@ private void GrowTable(Tables tables) } finally { - // Release all locks that we took earlier - ReleaseLocks(0, locksAcquired); + ReleaseLocks(locksAcquired); } } @@ -1987,53 +2057,61 @@ private void AcquireAllLocks(ref int locksAcquired) CDSCollectionETWBCLProvider.Log.ConcurrentDictionary_AcquiringAllLocks(_tables._buckets.Length); } - // First, acquire lock 0 - AcquireLocks(0, 1, ref locksAcquired); - - // Now that we have lock 0, the _locks array will not change (i.e., grow), - // and so we can safely read _locks.Length. - AcquireLocks(1, _tables._locks.Length, ref locksAcquired); + // First, acquire lock 0, then acquire the rest. _tables won't change after acquiring lock 0. + AcquireFirstLock(ref locksAcquired); + AcquirePostFirstLock(_tables, ref locksAcquired); Debug.Assert(locksAcquired == _tables._locks.Length); } - /// - /// Acquires a contiguous range of locks for this hash table, and increments locksAcquired - /// by the number of locks that were successfully acquired. The locks are acquired in an - /// increasing order. - /// - private void AcquireLocks(int fromInclusive, int toExclusive, ref int locksAcquired) + /// Acquires the first lock. + /// The number of locks acquired. It should be 0 on entry and 1 on exit. + /// + /// Once the caller owns the lock on lock 0, _tables._locks will not change (i.e., grow), + /// so a caller can safely snap _tables._locks to read the remaining locks. When the locks array grows, + /// even though the array object itself changes, the locks from the previous array are kept. + /// + private void AcquireFirstLock(ref int locksAcquired) { - Debug.Assert(fromInclusive <= toExclusive); object[] locks = _tables._locks; + Debug.Assert(locksAcquired == 0); + Debug.Assert(!Monitor.IsEntered(locks[0])); + + Monitor.Enter(locks[0]); + locksAcquired = 1; + } - for (int i = fromInclusive; i < toExclusive; i++) + /// Acquires all of the locks after the first, which must already be acquired. + /// The tables snapped after the first lock was acquired. + /// + /// The number of locks acquired, which should be 1 on entry. It's incremented as locks + /// are taken so that the caller can reliably release those locks in a finally in case + /// of exception. + /// + private static void AcquirePostFirstLock(Tables tables, ref int locksAcquired) + { + object[] locks = tables._locks; + Debug.Assert(Monitor.IsEntered(locks[0])); + Debug.Assert(locksAcquired == 1); + + for (int i = 1; i < locks.Length; i++) { - bool lockTaken = false; - try - { - Monitor.Enter(locks[i], ref lockTaken); - } - finally - { - if (lockTaken) - { - locksAcquired++; - } - } + Monitor.Enter(locks[i]); + locksAcquired++; } + + Debug.Assert(locksAcquired == locks.Length); } - /// - /// Releases a contiguous range of locks. - /// - private void ReleaseLocks(int fromInclusive, int toExclusive) + /// Releases all of the locks up to the specified number acquired. + /// The number of locks acquired. All lock numbers in the range [0, locksAcquired) will be released. + private void ReleaseLocks(int locksAcquired) { - Debug.Assert(fromInclusive <= toExclusive); + Debug.Assert(locksAcquired >= 0); - Tables tables = _tables; - for (int i = fromInclusive; i < toExclusive; i++) + object[] locks = _tables._locks; + for (int i = 0; i < locksAcquired; i++) { - Monitor.Exit(tables._locks[i]); + Monitor.Exit(locks[i]); } } @@ -2047,34 +2125,29 @@ private ReadOnlyCollection GetKeys() { AcquireAllLocks(ref locksAcquired); - int count = GetCountInternal(); - if (count < 0) - { - ThrowHelper.ThrowOutOfMemoryException(); - } - + int count = GetCountNoLocks(); if (count == 0) { return ReadOnlyCollection.Empty; } var keys = new TKey[count]; - Node?[] buckets = _tables._buckets; - int n = 0; - for (int i = 0; i < buckets.Length; i++) + int i = 0; + foreach (VolatileNode bucket in _tables._buckets) { - for (Node? current = buckets[i]; current != null; current = current._next) + for (Node? node = bucket._node; node is not null; node = node._next) { - keys[n++] = current._key; + keys[i] = node._key; + i++; } } - Debug.Assert(n == count); + Debug.Assert(i == count); return new ReadOnlyCollection(keys); } finally { - ReleaseLocks(0, locksAcquired); + ReleaseLocks(locksAcquired); } } @@ -2088,37 +2161,41 @@ private ReadOnlyCollection GetValues() { AcquireAllLocks(ref locksAcquired); - int count = GetCountInternal(); - if (count < 0) - { - ThrowHelper.ThrowOutOfMemoryException(); - } - + int count = GetCountNoLocks(); if (count == 0) { return ReadOnlyCollection.Empty; } - var values = new TValue[count]; - Node?[] buckets = _tables._buckets; - int n = 0; - for (int i = 0; i < buckets.Length; i++) + var keys = new TValue[count]; + int i = 0; + foreach (VolatileNode bucket in _tables._buckets) { - for (Node? current = buckets[i]; current != null; current = current._next) + for (Node? node = bucket._node; node is not null; node = node._next) { - values[n++] = current._value; + keys[i] = node._value; + i++; } } - Debug.Assert(n == count); + Debug.Assert(i == count); - return new ReadOnlyCollection(values); + return new ReadOnlyCollection(keys); } finally { - ReleaseLocks(0, locksAcquired); + ReleaseLocks(locksAcquired); } } + private struct VolatileNode + { + // Workaround for https://github.com/dotnet/runtime/issues/65789. + // If we had a Node?[] array, to safely read from the array we'd need to do Volatile.Read(ref array[i]), + // but that triggers an unnecessary ldelema, which in turn results in a call to CastHelpers.LdelemaRef. + // With this wrapper, the non-inlined call disappears. + internal volatile Node? _node; + } + /// /// A node in a singly-linked list representing a particular hash table bucket. /// @@ -2138,65 +2215,69 @@ internal Node(TKey key, TValue value, int hashcode, Node? next) } } + /// Computes a ref to the bucket for a particular key. + /// This reads the bucket with a read acquire barrier. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static Node? GetBucket(Tables tables, int hashcode) + { + VolatileNode[] buckets = tables._buckets; + if (IntPtr.Size == 8) + { + return buckets[HashHelpers.FastMod((uint)hashcode, (uint)buckets.Length, tables._fastModBucketsMultiplier)]._node; + } + else + { + return buckets[(uint)hashcode % (uint)buckets.Length]._node; + } + } + + /// Computes the bucket and lock number for a particular key. + /// This returns a ref to the bucket node; no barriers are employed. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Node? GetBucketAndLock(Tables tables, int hashcode, out uint lockNo) + { + VolatileNode[] buckets = tables._buckets; + uint bucketNo; + if (IntPtr.Size == 8) + { + bucketNo = HashHelpers.FastMod((uint)hashcode, (uint)buckets.Length, tables._fastModBucketsMultiplier); + } + else + { + bucketNo = (uint)hashcode % (uint)buckets.Length; + } + lockNo = bucketNo % (uint)tables._locks.Length; // doesn't use FastMod, as it would require maintaining a different multiplier + return ref buckets[bucketNo]._node; + } + /// Tables that hold the internal state of the ConcurrentDictionary - /// - /// Wrapping the three tables in a single object allows us to atomically - /// replace all tables at once. - /// + /// Wrapping all of the mutable state into a single object allows us to swap in everything atomically. private sealed class Tables { + /// The comparer to use for lookups in the tables. + internal readonly IEqualityComparer? _comparer; /// A singly-linked list for each bucket. - internal readonly Node?[] _buckets; + internal readonly VolatileNode[] _buckets; + /// Pre-computed multiplier for use on 64-bit performing faster modulo operations. + internal readonly ulong _fastModBucketsMultiplier; /// A set of locks, each guarding a section of the table. internal readonly object[] _locks; /// The number of elements guarded by each lock. internal readonly int[] _countPerLock; - /// Pre-computed multiplier for use on 64-bit performing faster modulo operations. - internal readonly ulong _fastModBucketsMultiplier; - internal Tables(Node?[] buckets, object[] locks, int[] countPerLock) + internal Tables(VolatileNode[] buckets, object[] locks, int[] countPerLock, IEqualityComparer? comparer) { + Debug.Assert(typeof(TKey).IsValueType || comparer is not null); + _buckets = buckets; _locks = locks; _countPerLock = countPerLock; + _comparer = comparer; if (IntPtr.Size == 8) { _fastModBucketsMultiplier = HashHelpers.GetFastModMultiplier((uint)buckets.Length); } } - - /// Computes a ref to the bucket for a particular key. - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal ref Node? GetBucket(int hashcode) - { - Node?[] buckets = _buckets; - if (IntPtr.Size == 8) - { - return ref buckets[HashHelpers.FastMod((uint)hashcode, (uint)buckets.Length, _fastModBucketsMultiplier)]; - } - else - { - return ref buckets[(uint)hashcode % (uint)buckets.Length]; - } - } - - /// Computes the bucket and lock number for a particular key. - [MethodImpl(MethodImplOptions.AggressiveInlining)] - internal ref Node? GetBucketAndLock(int hashcode, out uint lockNo) - { - Node?[] buckets = _buckets; - uint bucketNo; - if (IntPtr.Size == 8) - { - bucketNo = HashHelpers.FastMod((uint)hashcode, (uint)buckets.Length, _fastModBucketsMultiplier); - } - else - { - bucketNo = (uint)hashcode % (uint)buckets.Length; - } - lockNo = bucketNo % (uint)_locks.Length; // doesn't use FastMod, as it would require maintaining a different multiplier - return ref buckets[bucketNo]; - } } /// From 21728edc27dcc6570e5c24d90735b08a555369a1 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Tue, 7 Feb 2023 17:57:22 -0500 Subject: [PATCH 2/4] Address PR feedback --- .../Concurrent/ConcurrentDictionary.cs | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/libraries/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs b/src/libraries/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs index b5dad696f4b0bb..a4d60f9576c159 100644 --- a/src/libraries/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs +++ b/src/libraries/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs @@ -478,7 +478,7 @@ public bool TryGetValue(TKey key, [MaybeNullWhen(false)] out TValue value) else { Debug.Assert(comparer is not null); - int hashcode = _comparerIsDefaultForClasses ? key.GetHashCode() : comparer.GetHashCode(key); + int hashcode = GetHashCode(comparer, key); for (Node? n = GetBucket(tables, hashcode); n is not null; n = n._next) { if (hashcode == n._hashcode && comparer.Equals(n._key, key)) @@ -1303,10 +1303,9 @@ public TValue AddOrUpdate( } } - Tables newTables = _tables; - if (tables != newTables) + if (tables != _tables) { - tables = newTables; + tables = _tables; if (!ReferenceEquals(comparer, tables._comparer)) { comparer = tables._comparer; @@ -1377,10 +1376,9 @@ public TValue AddOrUpdate(TKey key, Func addValueFactory, Func } } - Tables newTables = _tables; - if (tables != newTables) + if (tables != _tables) { - tables = newTables; + tables = _tables; if (!ReferenceEquals(comparer, tables._comparer)) { comparer = tables._comparer; From 15463af9e09703d6076c61f71f3ae58e8f53876b Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 8 Feb 2023 11:41:39 -0500 Subject: [PATCH 3/4] Address PR feedback --- .../Concurrent/ConcurrentDictionary.cs | 24 +++-- .../ConcurrentDictionary.Generic.Tests.cs | 91 +++++++++++++++++++ 2 files changed, 105 insertions(+), 10 deletions(-) diff --git a/src/libraries/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs b/src/libraries/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs index a4d60f9576c159..4d1e14feb38cb2 100644 --- a/src/libraries/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs +++ b/src/libraries/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs @@ -224,9 +224,12 @@ internal ConcurrentDictionary(int concurrencyLevel, int capacity, bool growLockA /// smaller. /// private static int GetCapacityFromCollection(IEnumerable> collection) => - collection is ICollection> c ? Math.Max(DefaultCapacity, c.Count) : - collection is IReadOnlyCollection> rc ? Math.Max(DefaultCapacity, rc.Count) : - DefaultCapacity; + collection switch + { + ICollection> c => Math.Max(DefaultCapacity, c.Count), + IReadOnlyCollection> rc => Math.Max(DefaultCapacity, rc.Count), + _ => DefaultCapacity, + }; /// Computes the hash code for the specified key using the dictionary's comparer. /// @@ -1957,20 +1960,21 @@ private void GrowTable(Tables tables, bool resizeDesired, bool forceRehashIfNonR } // Compute the new table size at least twice the previous table size. + // Double the size of the buckets table and choose a prime that's at least as large. bool maximizeTableSize = false; - try + newLength = tables._buckets.Length * 2; + if (newLength < 0) + { + maximizeTableSize = true; + } + else { - // Double the size of the buckets table and choose a prime that's at least as large. - newLength = HashHelpers.GetPrime(checked(tables._buckets.Length * 2)); + newLength = HashHelpers.GetPrime(newLength); if (newLength > Array.MaxLength) { maximizeTableSize = true; } } - catch (OverflowException) - { - maximizeTableSize = true; - } if (maximizeTableSize) { diff --git a/src/libraries/System.Collections.Concurrent/tests/ConcurrentDictionary/ConcurrentDictionary.Generic.Tests.cs b/src/libraries/System.Collections.Concurrent/tests/ConcurrentDictionary/ConcurrentDictionary.Generic.Tests.cs index b03df03b0c062e..ee79cddcfcc010 100644 --- a/src/libraries/System.Collections.Concurrent/tests/ConcurrentDictionary/ConcurrentDictionary.Generic.Tests.cs +++ b/src/libraries/System.Collections.Concurrent/tests/ConcurrentDictionary/ConcurrentDictionary.Generic.Tests.cs @@ -4,6 +4,9 @@ using System.Collections.Tests; using System.Collections.Generic; using System.Linq; +using System.Numerics; +using System.Reflection; +using System.Runtime.InteropServices; using Xunit; namespace System.Collections.Concurrent.Tests @@ -39,6 +42,94 @@ protected override string CreateTKey(int seed) } protected override string CreateTValue(int seed) => CreateTKey(seed); + + [Theory] + [InlineData(false)] + [InlineData(true)] + public void NonRandomizedToRandomizedUpgrade_FunctionsCorrectly(bool ignoreCase) + { + List strings = GenerateCollidingStrings(110); // higher than the collisions threshold + + var cd = new ConcurrentDictionary(ignoreCase ? StringComparer.OrdinalIgnoreCase : StringComparer.Ordinal); + for (int i = 0; i < strings.Count; i++) + { + string s = strings[i]; + + Assert.True(cd.TryAdd(s, s)); + Assert.False(cd.TryAdd(s, s)); + + for (int j = 0; j < strings.Count; j++) + { + Assert.Equal(j <= i, cd.ContainsKey(strings[j])); + } + } + } + + private static List GenerateCollidingStrings(int count) + { + static Func GetHashCodeFunc(ConcurrentDictionary cd) + { + // If the layout of ConcurrentDictionary changes, this will need to change as well. + object tables = cd.GetType().GetField("_tables", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(cd); + Assert.NotNull(tables); + + FieldInfo comparerField = tables.GetType().GetField("_comparer", BindingFlags.Instance | BindingFlags.NonPublic | BindingFlags.Public); + Assert.NotNull(comparerField); + + IEqualityComparer comparer = (IEqualityComparer)comparerField.GetValue(tables); + Assert.NotNull(comparer); + + return comparer.GetHashCode; + } + + Func nonRandomizedOrdinal = GetHashCodeFunc(new ConcurrentDictionary(StringComparer.Ordinal)); + Func nonRandomizedOrdinalIgnoreCase = GetHashCodeFunc(new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase)); + + const int StartOfRange = 0xE020; // use the Unicode Private Use range to avoid accidentally creating strings that really do compare as equal OrdinalIgnoreCase + const int Stride = 0x40; // to ensure we don't accidentally reset the 0x20 bit of the seed, which is used to negate OrdinalIgnoreCase effects + int currentSeed = StartOfRange; + + List collidingStrings = new List(count); + while (collidingStrings.Count < count) + { + Assert.True(currentSeed <= ushort.MaxValue, + $"Couldn't create enough colliding strings? Created {collidingStrings.Count}, needed {count}."); + + // Generates a possible string with a well-known non-randomized hash code: + // - string.GetNonRandomizedHashCode returns 0. + // - string.GetNonRandomizedHashCodeOrdinalIgnoreCase returns 0x24716ca0. + // Provide a different seed to produce a different string. + // Must check OrdinalIgnoreCase hash code to ensure correctness. + string candidate = string.Create(8, currentSeed, static (span, seed) => + { + Span asBytes = MemoryMarshal.AsBytes(span); + + uint hash1 = (5381 << 16) + 5381; + uint hash2 = BitOperations.RotateLeft(hash1, 5) + hash1; + + MemoryMarshal.Write(asBytes, ref seed); + MemoryMarshal.Write(asBytes.Slice(4), ref hash2); // set hash2 := 0 (for Ordinal) + + hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1) ^ (uint)seed; + hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1); + + MemoryMarshal.Write(asBytes.Slice(8), ref hash1); // set hash1 := 0 (for Ordinal) + }); + + int ordinalHashCode = nonRandomizedOrdinal(candidate); + Assert.Equal(0, ordinalHashCode); // ensure has a zero hash code Ordinal + + int ordinalIgnoreCaseHashCode = nonRandomizedOrdinalIgnoreCase(candidate); + if (ordinalIgnoreCaseHashCode == 0x24716ca0) // ensure has a zero hash code OrdinalIgnoreCase (might not have one) + { + collidingStrings.Add(candidate); // success! + } + + currentSeed += Stride; + } + + return collidingStrings; + } } public class ConcurrentDictionary_Generic_Tests_ulong_ulong : ConcurrentDictionary_Generic_Tests From 6ec42c4ca7153633b341b17faacabacf632f0fa8 Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Wed, 8 Feb 2023 18:14:10 -0500 Subject: [PATCH 4/4] Address PR feedback --- .../Concurrent/ConcurrentDictionary.cs | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/src/libraries/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs b/src/libraries/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs index 4d1e14feb38cb2..d102b22a7c7145 100644 --- a/src/libraries/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs +++ b/src/libraries/System.Collections.Concurrent/src/System/Collections/Concurrent/ConcurrentDictionary.cs @@ -1961,22 +1961,8 @@ private void GrowTable(Tables tables, bool resizeDesired, bool forceRehashIfNonR // Compute the new table size at least twice the previous table size. // Double the size of the buckets table and choose a prime that's at least as large. - bool maximizeTableSize = false; - newLength = tables._buckets.Length * 2; - if (newLength < 0) - { - maximizeTableSize = true; - } - else - { - newLength = HashHelpers.GetPrime(newLength); - if (newLength > Array.MaxLength) - { - maximizeTableSize = true; - } - } - - if (maximizeTableSize) + if ((newLength = tables._buckets.Length * 2) < 0 || + (newLength = HashHelpers.GetPrime(newLength)) > Array.MaxLength) { newLength = Array.MaxLength;