Skip to content

Commit

Permalink
Implement IAlternateEqualityComparer<ReadOnlySpan<char>, string> on…
Browse files Browse the repository at this point in the history
… `EqualityComparer<string>.Default`
  • Loading branch information
stephentoub committed Jun 30, 2024
1 parent b77fef7 commit 69a192b
Show file tree
Hide file tree
Showing 17 changed files with 103 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ internal static object CreateDefaultEqualityComparer(Type type)

if (type == typeof(string))
{
return new GenericEqualityComparer<string>();
return new StringEqualityComparer();
}
else if (type.IsAssignableTo(typeof(IEquatable<>).MakeGenericType(type)))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ internal static unsafe bool IsEnum(RuntimeTypeHandle t)
return t.ToMethodTable()->IsEnum;
}

internal static unsafe bool IsString(RuntimeTypeHandle t)
{
return t.ToMethodTable()->IsString;
}

// this function utilizes the template type loader to generate new
// EqualityComparer types on the fly
internal static object GetComparer(RuntimeTypeHandle t)
Expand All @@ -59,6 +64,11 @@ internal static object GetComparer(RuntimeTypeHandle t)
RuntimeTypeHandle openComparerType = default(RuntimeTypeHandle);
RuntimeTypeHandle comparerTypeArgument = default(RuntimeTypeHandle);

if (IsString(t))
{
return new StringEqualityComparer();
}

if (RuntimeAugments.IsNullable(t))
{
RuntimeTypeHandle nullableType = RuntimeAugments.GetNullableType(t);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ private static EqualityComparer<T> Create()
// This body serves as a fallback when instantiation-specific implementation is unavailable.
// If that happens, the compiler ensures we generate data structures to make the fallback work
// when this method is compiled.

if (typeof(T) == typeof(string))
{
return Unsafe.As<EqualityComparer<T>>(new StringEqualityComparer());
}

if (SupportsGenericIEquatableInterfaces)
{
return Unsafe.As<EqualityComparer<T>>(EqualityComparerHelpers.GetComparer(typeof(T).TypeHandle));
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/vm/corelib.h
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,7 @@ DEFINE_METHOD(UTF8BUFFERMARSHALER, CONVERT_TO_MANAGED, ConvertToManaged, NoSig)

// Classes referenced in EqualityComparer<T>.Default optimization

DEFINE_CLASS(STRING_EQUALITYCOMPARER, CollectionsGeneric, StringEqualityComparer)
DEFINE_CLASS(ENUM_EQUALITYCOMPARER, CollectionsGeneric, EnumEqualityComparer`1)
DEFINE_CLASS(NULLABLE_EQUALITYCOMPARER, CollectionsGeneric, NullableEqualityComparer`1)
DEFINE_CLASS(GENERIC_EQUALITYCOMPARER, CollectionsGeneric, GenericEqualityComparer`1)
Expand Down
18 changes: 10 additions & 8 deletions src/coreclr/vm/jitinterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8848,8 +8848,7 @@ CORINFO_CLASS_HANDLE CEEInfo::getDefaultComparerClassHelper(CORINFO_CLASS_HANDLE
TypeHandle elemTypeHnd(elemType);

// Mirrors the logic in BCL's CompareHelpers.CreateDefaultComparer
// And in compile.cpp's SpecializeComparer
//

// We need to find the appropriate instantiation
Instantiation inst(&elemTypeHnd, 1);

Expand Down Expand Up @@ -8914,16 +8913,19 @@ CORINFO_CLASS_HANDLE CEEInfo::getDefaultEqualityComparerClassHelper(CORINFO_CLAS
MODE_PREEMPTIVE;
} CONTRACTL_END;

// Mirrors the logic in BCL's CompareHelpers.CreateDefaultEqualityComparer
// And in compile.cpp's SpecializeEqualityComparer
// Mirrors the logic in BCL's CompareHelpers.CreateDefaultEqualityComparer.
TypeHandle elemTypeHnd(elemType);

// Mirrors the logic in BCL's CompareHelpers.CreateDefaultComparer
// And in compile.cpp's SpecializeComparer
//
// We need to find the appropriate instantiation
// We need to find the appropriate instantiation.
Instantiation inst(&elemTypeHnd, 1);

// string
if (elemTypeHnd.IsString())
{
TypeHandle resultTh = ((TypeHandle)CoreLibBinder::GetClass(CLASS__STRING_EQUALITYCOMPARER)).Instantiate(inst);
return CORINFO_CLASS_HANDLE(resultTh.GetMethodTable());
}

// Nullable<T>
if (Nullable::IsNullableType(elemTypeHnd))
{
Expand Down
7 changes: 7 additions & 0 deletions src/coreclr/vm/typehandle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ BOOL TypeHandle::IsArray() const {
return !IsTypeDesc() && AsMethodTable()->IsArray();
}

BOOL TypeHandle::IsString() const
{
LIMITED_METHOD_CONTRACT;

return !IsTypeDesc() && AsMethodTable()->IsString();
}

BOOL TypeHandle::IsGenericVariable() const {
LIMITED_METHOD_DAC_CONTRACT;

Expand Down
3 changes: 3 additions & 0 deletions src/coreclr/vm/typehandle.h
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,9 @@ class TypeHandle
// PTR
BOOL IsPointer() const;

// String
BOOL IsString() const;

// True if this type *is* a formal generic type parameter or any component of it is a formal generic type parameter
BOOL ContainsGenericVariables(BOOL methodOnly=FALSE) const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1088,13 +1088,6 @@ public void ConcurrentWriteRead_NoTornValues()
}));
}

// TODO: Revise this test when EqualityComparer<string>.Default implements IAlternateEqualityComparer<ReadOnlySpan<char>, string>
[Fact]
public void GetAlternateLookup_FailsForDefaultComparer()
{
Assert.False(new ConcurrentDictionary<string, string>().TryGetAlternateLookup<ReadOnlySpan<char>>(out _));
}

[Fact]
public void GetAlternateLookup_FailsWhenIncompatible()
{
Expand All @@ -1119,17 +1112,19 @@ public void GetAlternateLookup_FailsWhenIncompatible()
[InlineData(3)]
[InlineData(4)]
[InlineData(5)]
[InlineData(6)]
public void GetAlternateLookup_OperationsMatchUnderlyingDictionary(int mode)
{
// Test with a variety of comparers to ensure that the alternate lookup is consistent with the underlying dictionary
ConcurrentDictionary<string, int> dictionary = new(mode switch
{
0 => StringComparer.Ordinal,
1 => StringComparer.OrdinalIgnoreCase,
2 => StringComparer.InvariantCulture,
3 => StringComparer.InvariantCultureIgnoreCase,
4 => StringComparer.CurrentCulture,
5 => StringComparer.CurrentCultureIgnoreCase,
0 => EqualityComparer<string>.Default,
1 => StringComparer.Ordinal,
2 => StringComparer.OrdinalIgnoreCase,
3 => StringComparer.InvariantCulture,
4 => StringComparer.InvariantCultureIgnoreCase,
5 => StringComparer.CurrentCulture,
6 => StringComparer.CurrentCultureIgnoreCase,
_ => throw new ArgumentOutOfRangeException(nameof(mode))
});
ConcurrentDictionary<string, int>.AlternateLookup<ReadOnlySpan<char>> lookup = dictionary.GetAlternateLookup<ReadOnlySpan<char>>();
Expand Down Expand Up @@ -1165,7 +1160,8 @@ public void GetAlternateLookup_OperationsMatchUnderlyingDictionary(int mode)

// Ensure that case-sensitivity of the comparer is respected
lookup["a".AsSpan()] = 42;
if (dictionary.Comparer.Equals(StringComparer.Ordinal) ||
if (dictionary.Comparer.Equals(EqualityComparer<string>.Default) ||
dictionary.Comparer.Equals(StringComparer.Ordinal) ||
dictionary.Comparer.Equals(StringComparer.InvariantCulture) ||
dictionary.Comparer.Equals(StringComparer.CurrentCulture))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,10 @@ public class FrozenDictionaryAlternateLookupTests
[Fact]
public void AlternateLookup_Empty()
{
FrozenDictionary<string, string>[] unsupported =
FrozenDictionary<string, string>[] supported =
[
FrozenDictionary<string, string>.Empty,
FrozenDictionary.ToFrozenDictionary<string, string>([]),
FrozenDictionary.ToFrozenDictionary<string, string>([], EqualityComparer<string>.Default),
];
foreach (FrozenDictionary<string, string> frozen in unsupported)
{
Assert.Throws<InvalidOperationException>(() => frozen.GetAlternateLookup<ReadOnlySpan<char>>());
Assert.False(frozen.TryGetAlternateLookup<ReadOnlySpan<char>>(out _));
}

FrozenDictionary<string, string>[] supported =
[
FrozenDictionary.ToFrozenDictionary<string, string>([], StringComparer.Ordinal),
FrozenDictionary.ToFrozenDictionary<string, string>([], StringComparer.OrdinalIgnoreCase),
];
Expand All @@ -39,7 +29,7 @@ public void AlternateLookup_Empty()
[Fact]
public void UnsupportedComparer_ThrowsOrReturnsFalse()
{
FrozenDictionary<string, int> frozen = new Dictionary<string, int> { ["a"] = 1, ["b"] = 2 }.ToFrozenDictionary();
FrozenDictionary<char, int> frozen = new Dictionary<char, int> { ['a'] = 1, ['b'] = 2 }.ToFrozenDictionary();
Assert.Throws<InvalidOperationException>(() => frozen.GetAlternateLookup<ReadOnlySpan<char>>());
Assert.False(frozen.TryGetAlternateLookup<ReadOnlySpan<char>>(out _));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ public class FrozenSetAlternateLookupTests
[Fact]
public void AlternateLookup_Empty()
{
Assert.False(FrozenSet<string>.Empty.TryGetAlternateLookup<ReadOnlySpan<char>>(out _));
Assert.True(FrozenSet<string>.Empty.TryGetAlternateLookup<ReadOnlySpan<char>>(out _));

foreach (StringComparer comparer in new[] { StringComparer.Ordinal, StringComparer.OrdinalIgnoreCase })
foreach (IEqualityComparer<string> comparer in new IEqualityComparer<string>[] { null, EqualityComparer<string>.Default, StringComparer.Ordinal, StringComparer.OrdinalIgnoreCase })
{
FrozenSet<string>.AlternateLookup<ReadOnlySpan<char>> lookup = FrozenSet.ToFrozenSet([], comparer).GetAlternateLookup<ReadOnlySpan<char>>();
Assert.False(lookup.Contains("anything".AsSpan()));
Expand All @@ -24,7 +24,7 @@ public void AlternateLookup_Empty()
[Fact]
public void UnsupportedComparer()
{
FrozenSet<string> frozen = FrozenSet.ToFrozenSet(["a", "b"]);
FrozenSet<char> frozen = FrozenSet.ToFrozenSet(['a', 'b']);
Assert.Throws<InvalidOperationException>(() => frozen.GetAlternateLookup<ReadOnlySpan<char>>());
Assert.False(frozen.TryGetAlternateLookup<ReadOnlySpan<char>>(out _));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,6 @@ public void GetAlternateLookup_ThrowsWhenNull()
AssertExtensions.Throws<ArgumentNullException>("set", () => CollectionExtensions.TryGetAlternateLookup<int, long>((HashSet<int>)null, out _));
}

// TODO https://github.com/dotnet/runtime/issues/102906:
// Revise this test when EqualityComparer<string>.Default implements IAlternateEqualityComparer<ReadOnlySpan<char>, string>
[Fact]
public void GetAlternateLookup_FailsForDefaultComparer()
{
Assert.False(new Dictionary<string, string>().TryGetAlternateLookup<string, string, ReadOnlySpan<char>>(out _));
Assert.False(new HashSet<string>().TryGetAlternateLookup<string, ReadOnlySpan<char>>(out _));
}

[Fact]
public void GetAlternateLookup_FailsWhenIncompatible()
{
Expand Down Expand Up @@ -196,17 +187,19 @@ public void GetAlternateLookup_FailsWhenIncompatible()
[InlineData(3)]
[InlineData(4)]
[InlineData(5)]
[InlineData(6)]
public void Dictionary_GetAlternateLookup_OperationsMatchUnderlyingDictionary(int mode)
{
// Test with a variety of comparers to ensure that the alternate lookup is consistent with the underlying dictionary
Dictionary<string, int> dictionary = new(mode switch
{
0 => StringComparer.Ordinal,
1 => StringComparer.OrdinalIgnoreCase,
2 => StringComparer.InvariantCulture,
3 => StringComparer.InvariantCultureIgnoreCase,
4 => StringComparer.CurrentCulture,
5 => StringComparer.CurrentCultureIgnoreCase,
0 => EqualityComparer<string>.Default,
1 => StringComparer.Ordinal,
2 => StringComparer.OrdinalIgnoreCase,
3 => StringComparer.InvariantCulture,
4 => StringComparer.InvariantCultureIgnoreCase,
5 => StringComparer.CurrentCulture,
6 => StringComparer.CurrentCultureIgnoreCase,
_ => throw new ArgumentOutOfRangeException(nameof(mode))
});
Dictionary<string, int>.AlternateLookup<ReadOnlySpan<char>> lookup = dictionary.GetAlternateLookup<string, int, ReadOnlySpan<char>>();
Expand Down Expand Up @@ -241,7 +234,8 @@ public void Dictionary_GetAlternateLookup_OperationsMatchUnderlyingDictionary(in

// Ensure that case-sensitivity of the comparer is respected
lookup["a".AsSpan()] = 42;
if (dictionary.Comparer.Equals(StringComparer.Ordinal) ||
if (dictionary.Comparer.Equals(EqualityComparer<string>.Default) ||
dictionary.Comparer.Equals(StringComparer.Ordinal) ||
dictionary.Comparer.Equals(StringComparer.InvariantCulture) ||
dictionary.Comparer.Equals(StringComparer.CurrentCulture))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class InternalHashCodeTests_Dictionary_NullComparer : InternalHashCodeTes
protected override bool ContainsKey(Dictionary<string, string> collection, string key) => collection.ContainsKey(key);
protected override IEqualityComparer<string> GetComparer(Dictionary<string, string> collection) => collection.Comparer;

protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedDefaultComparerType;
protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType;
protected override IEqualityComparer<string> ExpectedPublicComparerBeforeCollisionThreshold => EqualityComparer<string>.Default;
protected override Type ExpectedInternalComparerTypeAfterCollisionThreshold => randomizedOrdinalComparerType;

Expand Down Expand Up @@ -60,7 +60,7 @@ public class InternalHashCodeTests_Dictionary_DefaultComparer : InternalHashCode
protected override bool ContainsKey(Dictionary<string, string> collection, string key) => collection.ContainsKey(key);
protected override IEqualityComparer<string> GetComparer(Dictionary<string, string> collection) => collection.Comparer;

protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedDefaultComparerType;
protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType;
protected override IEqualityComparer<string> ExpectedPublicComparerBeforeCollisionThreshold => EqualityComparer<string>.Default;
protected override Type ExpectedInternalComparerTypeAfterCollisionThreshold => randomizedOrdinalComparerType;
}
Expand Down Expand Up @@ -124,7 +124,7 @@ public class InternalHashCodeTests_HashSet_NullComparer : InternalHashCodeTests<
protected override bool ContainsKey(HashSet<string> collection, string key) => collection.Contains(key);
protected override IEqualityComparer<string> GetComparer(HashSet<string> collection) => collection.Comparer;

protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedDefaultComparerType;
protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType;
protected override IEqualityComparer<string> ExpectedPublicComparerBeforeCollisionThreshold => EqualityComparer<string>.Default;
protected override Type ExpectedInternalComparerTypeAfterCollisionThreshold => randomizedOrdinalComparerType;
}
Expand All @@ -136,7 +136,7 @@ public class InternalHashCodeTests_HashSet_DefaultComparer : InternalHashCodeTes
protected override bool ContainsKey(HashSet<string> collection, string key) => collection.Contains(key);
protected override IEqualityComparer<string> GetComparer(HashSet<string> collection) => collection.Comparer;

protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedDefaultComparerType;
protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType;
protected override IEqualityComparer<string> ExpectedPublicComparerBeforeCollisionThreshold => EqualityComparer<string>.Default;
protected override Type ExpectedInternalComparerTypeAfterCollisionThreshold => randomizedOrdinalComparerType;
}
Expand Down Expand Up @@ -186,7 +186,7 @@ public class InternalHashCodeTests_OrderedDictionary_NullComparer : InternalHash
protected override bool ContainsKey(OrderedDictionary<string, string> collection, string key) => collection.ContainsKey(key);
protected override IEqualityComparer<string> GetComparer(OrderedDictionary<string, string> collection) => collection.Comparer;

protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedDefaultComparerType;
protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType;
protected override IEqualityComparer<string> ExpectedPublicComparerBeforeCollisionThreshold => EqualityComparer<string>.Default;
protected override Type ExpectedInternalComparerTypeAfterCollisionThreshold => EqualityComparer<string>.Default.GetType();
}
Expand All @@ -198,7 +198,7 @@ public class InternalHashCodeTests_OrderedDictionary_DefaultComparer : InternalH
protected override bool ContainsKey(OrderedDictionary<string, string> collection, string key) => collection.ContainsKey(key);
protected override IEqualityComparer<string> GetComparer(OrderedDictionary<string, string> collection) => collection.Comparer;

protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedDefaultComparerType;
protected override Type ExpectedInternalComparerTypeBeforeCollisionThreshold => nonRandomizedOrdinalComparerType;
protected override IEqualityComparer<string> ExpectedPublicComparerBeforeCollisionThreshold => EqualityComparer<string>.Default;
protected override Type ExpectedInternalComparerTypeAfterCollisionThreshold => EqualityComparer<string>.Default.GetType();
}
Expand Down Expand Up @@ -242,7 +242,6 @@ public class InternalHashCodeTests_OrderedDictionary_LinguisticComparer : Intern

public abstract class InternalHashCodeTests<TCollection>
{
protected static Type nonRandomizedDefaultComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.NonRandomizedStringEqualityComparer+DefaultComparer", throwOnError: true);
protected static Type nonRandomizedOrdinalComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.NonRandomizedStringEqualityComparer+OrdinalComparer", throwOnError: true);
protected static Type nonRandomizedOrdinalIgnoreCaseComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.NonRandomizedStringEqualityComparer+OrdinalIgnoreCaseComparer", throwOnError: true);
protected static Type randomizedOrdinalComparerType = typeof(object).Assembly.GetType("System.Collections.Generic.RandomizedStringEqualityComparer+OrdinalComparer", throwOnError: true);
Expand Down
Loading

0 comments on commit 69a192b

Please sign in to comment.