Skip to content

Commit

Permalink
Add CollectionsMarshal.GetValueRef API
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergio0694 committed Jun 25, 2021
1 parent e9f101c commit ab061a3
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using Internal.Runtime.CompilerServices;

namespace System.Runtime.InteropServices
{
Expand All @@ -18,6 +19,25 @@ public static class CollectionsMarshal
public static Span<T> AsSpan<T>(List<T>? list)
=> list is null ? default : new Span<T>(list._items, 0, list._size);

/// <summary>
/// Gets a ref to a <typeparamref name="TValue"/> in the <see cref="Dictionary{TKey, TValue}"/>.
/// </summary>
/// <param name="dictionary">The dictionary to get the ref to <typeparamref name="TValue"/> from.</param>
/// <param name="key">The key used for lookup.</param>
/// <remarks>Items should not be added or removed from the <see cref="Dictionary{TKey, TValue}"/> while the ref <typeparamref name="TValue"/> is in use.</remarks>
/// <exception cref="KeyNotFoundException">Thrown when <paramref name="key"/> does not exist in the <paramref name="dictionary"/>.</exception>
public static ref TValue GetValueRef<TKey, TValue>(Dictionary<TKey, TValue> dictionary, TKey key) where TKey : notnull
{
ref TValue valueRef = ref dictionary.FindValue(key);

if (Unsafe.IsNullRef(ref valueRef))
{
ThrowHelper.ThrowKeyNotFoundException(key);
}

return ref valueRef;
}

/// <summary>
/// Gets either a ref to a <typeparamref name="TValue"/> in the <see cref="Dictionary{TKey, TValue}"/> or a ref null if it does not exist in the <paramref name="dictionary"/>.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ public CoClassAttribute(System.Type coClass) { }
public static partial class CollectionsMarshal
{
public static System.Span<T> AsSpan<T>(System.Collections.Generic.List<T>? list) { throw null; }
public static ref TValue GetValueRef<TKey, TValue>(System.Collections.Generic.Dictionary<TKey, TValue> dictionary, TKey key) where TKey : notnull { throw null; }
public static ref TValue GetValueRefOrNullRef<TKey, TValue>(System.Collections.Generic.Dictionary<TKey, TValue> dictionary, TKey key) where TKey : notnull { throw null; }
}
[System.AttributeUsageAttribute(System.AttributeTargets.Field | System.AttributeTargets.Parameter | System.AttributeTargets.Property | System.AttributeTargets.ReturnValue, Inherited=false)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,159 @@ public void ListAsSpanLinkBreaksOnResize()
}
}

[Fact]
public void GetValueRefValueType()
{
var dict = new Dictionary<int, Struct>
{
{ 1, default },
{ 2, default }
};

Assert.Equal(2, dict.Count);

Assert.Equal(0, dict[1].Value);
Assert.Equal(0, dict[1].Property);

var itemVal = dict[1];
itemVal.Value = 1;
itemVal.Property = 2;

// Does not change values in dictionary
Assert.Equal(0, dict[1].Value);
Assert.Equal(0, dict[1].Property);

CollectionsMarshal.GetValueRef(dict, 1).Value = 3;
CollectionsMarshal.GetValueRef(dict, 1).Property = 4;

Assert.Equal(3, dict[1].Value);
Assert.Equal(4, dict[1].Property);

ref var itemRef = ref CollectionsMarshal.GetValueRef(dict, 2);

Assert.Equal(0, itemRef.Value);
Assert.Equal(0, itemRef.Property);

itemRef.Value = 5;
itemRef.Property = 6;

Assert.Equal(5, itemRef.Value);
Assert.Equal(6, itemRef.Property);
Assert.Equal(dict[2].Value, itemRef.Value);
Assert.Equal(dict[2].Property, itemRef.Property);

itemRef = new() { Value = 7, Property = 8 };

Assert.Equal(7, itemRef.Value);
Assert.Equal(8, itemRef.Property);
Assert.Equal(dict[2].Value, itemRef.Value);
Assert.Equal(dict[2].Property, itemRef.Property);

// Check for exception

Assert.Throws<KeyNotFoundException>(() => CollectionsMarshal.GetValueRef(dict, 3));

Assert.Equal(2, dict.Count);
}

[Fact]
public void GetValueRefClass()
{
var dict = new Dictionary<int, IntAsObject>
{
{ 1, new() },
{ 2, new() }
};

Assert.Equal(2, dict.Count);

Assert.Equal(0, dict[1].Value);
Assert.Equal(0, dict[1].Property);

var itemVal = dict[1];
itemVal.Value = 1;
itemVal.Property = 2;

// Does change values in dictionary
Assert.Equal(1, dict[1].Value);
Assert.Equal(2, dict[1].Property);

CollectionsMarshal.GetValueRef(dict, 1).Value = 3;
CollectionsMarshal.GetValueRef(dict, 1).Property = 4;

Assert.Equal(3, dict[1].Value);
Assert.Equal(4, dict[1].Property);

ref var itemRef = ref CollectionsMarshal.GetValueRef(dict, 2);

Assert.Equal(0, itemRef.Value);
Assert.Equal(0, itemRef.Property);

itemRef.Value = 5;
itemRef.Property = 6;

Assert.Equal(5, itemRef.Value);
Assert.Equal(6, itemRef.Property);
Assert.Equal(dict[2].Value, itemRef.Value);
Assert.Equal(dict[2].Property, itemRef.Property);

itemRef = new() { Value = 7, Property = 8 };

Assert.Equal(7, itemRef.Value);
Assert.Equal(8, itemRef.Property);
Assert.Equal(dict[2].Value, itemRef.Value);
Assert.Equal(dict[2].Property, itemRef.Property);

// Check for exception

Assert.Throws<KeyNotFoundException>(() => CollectionsMarshal.GetValueRef(dict, 3));

Assert.Equal(2, dict.Count);
}

[Fact]
public void GetValueRefLinkBreaksOnResize()
{
var dict = new Dictionary<int, Struct>
{
{ 1, new() }
};

Assert.Equal(1, dict.Count);

ref var itemRef = ref CollectionsMarshal.GetValueRef(dict, 1);

Assert.Equal(0, itemRef.Value);
Assert.Equal(0, itemRef.Property);

itemRef.Value = 1;
itemRef.Property = 2;

Assert.Equal(1, itemRef.Value);
Assert.Equal(2, itemRef.Property);
Assert.Equal(dict[1].Value, itemRef.Value);
Assert.Equal(dict[1].Property, itemRef.Property);

// Resize
dict.EnsureCapacity(100);
for (int i = 2; i <= 50; i++)
{
dict.Add(i, new());
}

itemRef.Value = 3;
itemRef.Property = 4;

Assert.Equal(3, itemRef.Value);
Assert.Equal(4, itemRef.Property);

// Check connection broken
Assert.NotEqual(dict[1].Value, itemRef.Value);
Assert.NotEqual(dict[1].Property, itemRef.Property);

Assert.Equal(50, dict.Count);
}

[Fact]
public void GetValueRefOrNullRefValueType()
{
Expand Down

0 comments on commit ab061a3

Please sign in to comment.