Skip to content

Commit

Permalink
Move GetValueRefOrAddDefault impl to separate type
Browse files Browse the repository at this point in the history
This avoids the additional overhead when loading Dictionary<TKey, TValue> instances, especially in AOT scenarios, and it makes the new API pay for play.
  • Loading branch information
Sergio0694 committed Jul 14, 2021
1 parent ffeef77 commit 1d7e386
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -681,67 +681,105 @@ private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior)
return true;
}

internal ref TValue? GetValueRefOrAddDefault(TKey key, out bool exists)
/// <summary>
/// A helper class containing APIs exposed through <see cref="Runtime.InteropServices.CollectionsMarshal"/>.
/// These methods are relatively niche and only used in specific scenarios, so adding them in a separate type avoids
/// the additional overhead on each <see cref="Dictionary{TKey, TValue}"/> instantiation, especially in AOT scenarios.
/// </summary>
internal static class CollectionsMarshalHelper
{
if (key == null)
/// <inheritdoc cref="Runtime.InteropServices.CollectionsMarshal.GetValueRefOrAddDefault{TKey, TValue}(Dictionary{TKey, TValue}, TKey, out bool)"/>
public static ref TValue? GetValueRefOrAddDefault(Dictionary<TKey, TValue> dictionary, TKey key, out bool exists)
{
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
}
if (key == null)
{
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key);
}

if (_buckets == null)
{
Initialize(0);
}
Debug.Assert(_buckets != null);
if (dictionary._buckets == null)
{
dictionary.Initialize(0);
}
Debug.Assert(dictionary._buckets != null);

Entry[]? entries = _entries;
Debug.Assert(entries != null, "expected entries to be non-null");
Entry[]? entries = dictionary._entries;
Debug.Assert(entries != null, "expected entries to be non-null");

IEqualityComparer<TKey>? comparer = _comparer;
uint hashCode = (uint)((comparer == null) ? key.GetHashCode() : comparer.GetHashCode(key));
IEqualityComparer<TKey>? comparer = dictionary._comparer;
uint hashCode = (uint)((comparer == null) ? key.GetHashCode() : comparer.GetHashCode(key));

uint collisionCount = 0;
ref int bucket = ref GetBucket(hashCode);
int i = bucket - 1; // Value in _buckets is 1-based
uint collisionCount = 0;
ref int bucket = ref dictionary.GetBucket(hashCode);
int i = bucket - 1; // Value in _buckets is 1-based

if (comparer == null)
{
if (typeof(TKey).IsValueType)
if (comparer == null)
{
// ValueType: Devirtualize with EqualityComparer<TValue>.Default intrinsic
while (true)
if (typeof(TKey).IsValueType)
{
// Should be a while loop https://github.com/dotnet/runtime/issues/9422
// Test uint in if rather than loop condition to drop range check for following array access
if ((uint)i >= (uint)entries.Length)
// ValueType: Devirtualize with EqualityComparer<TValue>.Default intrinsic
while (true)
{
break;
}
// Should be a while loop https://github.com/dotnet/runtime/issues/9422
// Test uint in if rather than loop condition to drop range check for following array access
if ((uint)i >= (uint)entries.Length)
{
break;
}

if (entries[i].hashCode == hashCode && EqualityComparer<TKey>.Default.Equals(entries[i].key, key))
{
exists = true;
if (entries[i].hashCode == hashCode && EqualityComparer<TKey>.Default.Equals(entries[i].key, key))
{
exists = true;

return ref entries[i].value!;
}
return ref entries[i].value!;
}

i = entries[i].next;
i = entries[i].next;

collisionCount++;
if (collisionCount > (uint)entries.Length)
collisionCount++;
if (collisionCount > (uint)entries.Length)
{
// The chain of entries forms a loop; which means a concurrent update has happened.
// Break out of the loop and throw, rather than looping forever.
ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
}
}
}
else
{
// Object type: Shared Generic, EqualityComparer<TValue>.Default won't devirtualize
// https://github.com/dotnet/runtime/issues/10050
// So cache in a local rather than get EqualityComparer per loop iteration
EqualityComparer<TKey> defaultComparer = EqualityComparer<TKey>.Default;
while (true)
{
// The chain of entries forms a loop; which means a concurrent update has happened.
// Break out of the loop and throw, rather than looping forever.
ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
// Should be a while loop https://github.com/dotnet/runtime/issues/9422
// Test uint in if rather than loop condition to drop range check for following array access
if ((uint)i >= (uint)entries.Length)
{
break;
}

if (entries[i].hashCode == hashCode && defaultComparer.Equals(entries[i].key, key))
{
exists = true;

return ref entries[i].value!;
}

i = entries[i].next;

collisionCount++;
if (collisionCount > (uint)entries.Length)
{
// The chain of entries forms a loop; which means a concurrent update has happened.
// Break out of the loop and throw, rather than looping forever.
ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
}
}
}
}
else
{
// Object type: Shared Generic, EqualityComparer<TValue>.Default won't devirtualize
// https://github.com/dotnet/runtime/issues/10050
// So cache in a local rather than get EqualityComparer per loop iteration
EqualityComparer<TKey> defaultComparer = EqualityComparer<TKey>.Default;
while (true)
{
// Should be a while loop https://github.com/dotnet/runtime/issues/9422
Expand All @@ -751,7 +789,7 @@ private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior)
break;
}

if (entries[i].hashCode == hashCode && defaultComparer.Equals(entries[i].key, key))
if (entries[i].hashCode == hashCode && comparer.Equals(entries[i].key, key))
{
exists = true;

Expand All @@ -769,88 +807,59 @@ private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior)
}
}
}
}
else
{
while (true)
{
// Should be a while loop https://github.com/dotnet/runtime/issues/9422
// Test uint in if rather than loop condition to drop range check for following array access
if ((uint)i >= (uint)entries.Length)
{
break;
}

if (entries[i].hashCode == hashCode && comparer.Equals(entries[i].key, key))
{
exists = true;

return ref entries[i].value!;
}

i = entries[i].next;

collisionCount++;
if (collisionCount > (uint)entries.Length)
int index;
if (dictionary._freeCount > 0)
{
index = dictionary._freeList;
Debug.Assert((StartOfFreeList - entries[dictionary._freeList].next) >= -1, "shouldn't overflow because `next` cannot underflow");
dictionary._freeList = StartOfFreeList - entries[dictionary._freeList].next;
dictionary._freeCount--;
}
else
{
int count = dictionary._count;
if (count == entries.Length)
{
// The chain of entries forms a loop; which means a concurrent update has happened.
// Break out of the loop and throw, rather than looping forever.
ThrowHelper.ThrowInvalidOperationException_ConcurrentOperationsNotSupported();
dictionary.Resize();
bucket = ref dictionary.GetBucket(hashCode);
}
index = count;
dictionary._count = count + 1;
entries = dictionary._entries;
}
}

int index;
if (_freeCount > 0)
{
index = _freeList;
Debug.Assert((StartOfFreeList - entries[_freeList].next) >= -1, "shouldn't overflow because `next` cannot underflow");
_freeList = StartOfFreeList - entries[_freeList].next;
_freeCount--;
}
else
{
int count = _count;
if (count == entries.Length)
ref Entry entry = ref entries![index];
entry.hashCode = hashCode;
entry.next = bucket - 1; // Value in _buckets is 1-based
entry.key = key;
entry.value = default!;
bucket = index + 1; // Value in _buckets is 1-based
dictionary._version++;

// Value types never rehash
if (!typeof(TKey).IsValueType && collisionCount > HashHelpers.HashCollisionThreshold && comparer is NonRandomizedStringEqualityComparer)
{
Resize();
bucket = ref GetBucket(hashCode);
}
index = count;
_count = count + 1;
entries = _entries;
}
// If we hit the collision threshold we'll need to switch to the comparer which is using randomized string hashing
// i.e. EqualityComparer<string>.Default.
dictionary.Resize(entries.Length, true);

ref Entry entry = ref entries![index];
entry.hashCode = hashCode;
entry.next = bucket - 1; // Value in _buckets is 1-based
entry.key = key;
entry.value = default!;
bucket = index + 1; // Value in _buckets is 1-based
_version++;
exists = false;

// Value types never rehash
if (!typeof(TKey).IsValueType && collisionCount > HashHelpers.HashCollisionThreshold && comparer is NonRandomizedStringEqualityComparer)
{
// If we hit the collision threshold we'll need to switch to the comparer which is using randomized string hashing
// i.e. EqualityComparer<string>.Default.
Resize(entries.Length, true);
// At this point the entries array has been resized, so the current reference we have is no longer valid.
// We're forced to do a new lookup and return an updated reference to the new entry instance. This new
// lookup is guaranteed to always find a value though and it will never return a null reference here.
ref TValue? value = ref dictionary.FindValue(key)!;

exists = false;
Debug.Assert(!Unsafe.IsNullRef(ref value), "the lookup result cannot be a null ref here");

// At this point the entries array has been resized, so the current reference we have is no longer valid.
// We're forced to do a new lookup and return an updated reference to the new entry instance. This new
// lookup is guaranteed to always find a value though and it will never return a null reference here.
ref TValue? value = ref FindValue(key)!;
return ref value;
}

Debug.Assert(!Unsafe.IsNullRef(ref value), "the lookup result cannot be a null ref here");
exists = false;

return ref value;
return ref entry.value!;
}

exists = false;

return ref entry.value!;
}

public virtual void OnDeserialization(object? sender)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@ public static ref TValue GetValueRefOrNullRef<TKey, TValue>(Dictionary<TKey, TVa
/// <param name="exists">Whether or not a new entry for the given key was added to the dictionary.</param>
/// <remarks>Items should not be added to or removed from the <see cref="Dictionary{TKey, TValue}"/> while the ref <typeparamref name="TValue"/> is in use.</remarks>
public static ref TValue? GetValueRefOrAddDefault<TKey, TValue>(Dictionary<TKey, TValue> dictionary, TKey key, out bool exists) where TKey : notnull
=> ref dictionary.GetValueRefOrAddDefault(key, out exists);
=> ref Dictionary<TKey, TValue>.CollectionsMarshalHelper.GetValueRefOrAddDefault(dictionary, key, out exists);
}
}

0 comments on commit 1d7e386

Please sign in to comment.