Skip to content

Commit

Permalink
Implement IAlternateEqualityComparer<ReadOnlySpan<char>, string> on…
Browse files Browse the repository at this point in the history
… `EqualityComparer<string>.Default` (#104202)

* Implement `IAlternateEqualityComparer<ReadOnlySpan<char>, string>` on `EqualityComparer<string>.Default`

* Fixes and feedback

* Fix another test

* Fix another test

* Address test feedback
  • Loading branch information
stephentoub authored Jul 2, 2024
1 parent 5044e93 commit aa280ad
Show file tree
Hide file tree
Showing 24 changed files with 160 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ internal static object CreateDefaultEqualityComparer(Type type)

if (type == typeof(string))
{
return new GenericEqualityComparer<string>();
return new StringEqualityComparer();
}
else if (type.IsAssignableTo(typeof(IEquatable<>).MakeGenericType(type)))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ internal static unsafe bool IsEnum(RuntimeTypeHandle t)
return t.ToMethodTable()->IsEnum;
}

internal static unsafe bool IsString(RuntimeTypeHandle t)
{
return t.ToMethodTable()->IsString;
}

// this function utilizes the template type loader to generate new
// EqualityComparer types on the fly
internal static object GetComparer(RuntimeTypeHandle t)
Expand All @@ -59,6 +64,11 @@ internal static object GetComparer(RuntimeTypeHandle t)
RuntimeTypeHandle openComparerType = default(RuntimeTypeHandle);
RuntimeTypeHandle comparerTypeArgument = default(RuntimeTypeHandle);

if (IsString(t))
{
return new StringEqualityComparer();
}

if (RuntimeAugments.IsNullable(t))
{
RuntimeTypeHandle nullableType = RuntimeAugments.GetNullableType(t);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ private static EqualityComparer<T> Create()
// This body serves as a fallback when instantiation-specific implementation is unavailable.
// If that happens, the compiler ensures we generate data structures to make the fallback work
// when this method is compiled.

if (typeof(T) == typeof(string))
{
return Unsafe.As<EqualityComparer<T>>(new StringEqualityComparer());
}

if (SupportsGenericIEquatableInterfaces)
{
return Unsafe.As<EqualityComparer<T>>(EqualityComparerHelpers.GetComparer(typeof(T).TypeHandle));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ private static MethodIL EmitComparerAndEqualityComparerCreateCommon(MethodDesc m
/// Gets the comparer type that is suitable to compare instances of <paramref name="type"/>
/// or null if such comparer cannot be determined at compile time.
/// </summary>
private static InstantiatedType GetComparerForType(TypeDesc type, string flavor, string interfaceName)
private static TypeDesc GetComparerForType(TypeDesc type, string flavor, string interfaceName)
{
TypeSystemContext context = type.Context;

Expand All @@ -92,6 +92,11 @@ private static InstantiatedType GetComparerForType(TypeDesc type, string flavor,
.MakeInstantiatedType(type.Instantiation[0]);
}

if (type.IsString && flavor == "EqualityComparer")
{
return context.SystemModule.GetKnownType("System.Collections.Generic", "StringEqualityComparer");
}

if (type.IsEnum)
{
// Enums have a specialized comparer that avoids boxing
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/vm/corelib.h
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,7 @@ DEFINE_METHOD(UTF8BUFFERMARSHALER, CONVERT_TO_MANAGED, ConvertToManaged, NoSig)

// Classes referenced in EqualityComparer<T>.Default optimization

DEFINE_CLASS(STRING_EQUALITYCOMPARER, CollectionsGeneric, StringEqualityComparer)
DEFINE_CLASS(ENUM_EQUALITYCOMPARER, CollectionsGeneric, EnumEqualityComparer`1)
DEFINE_CLASS(NULLABLE_EQUALITYCOMPARER, CollectionsGeneric, NullableEqualityComparer`1)
DEFINE_CLASS(GENERIC_EQUALITYCOMPARER, CollectionsGeneric, GenericEqualityComparer`1)
Expand Down
17 changes: 9 additions & 8 deletions src/coreclr/vm/jitinterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8903,8 +8903,7 @@ CORINFO_CLASS_HANDLE CEEInfo::getDefaultComparerClassHelper(CORINFO_CLASS_HANDLE
TypeHandle elemTypeHnd(elemType);

// Mirrors the logic in BCL's CompareHelpers.CreateDefaultComparer
// And in compile.cpp's SpecializeComparer
//

// We need to find the appropriate instantiation
Instantiation inst(&elemTypeHnd, 1);

Expand Down Expand Up @@ -8969,16 +8968,18 @@ CORINFO_CLASS_HANDLE CEEInfo::getDefaultEqualityComparerClassHelper(CORINFO_CLAS
MODE_PREEMPTIVE;
} CONTRACTL_END;

// Mirrors the logic in BCL's CompareHelpers.CreateDefaultEqualityComparer
// And in compile.cpp's SpecializeEqualityComparer
// Mirrors the logic in BCL's CompareHelpers.CreateDefaultEqualityComparer.
TypeHandle elemTypeHnd(elemType);

// Mirrors the logic in BCL's CompareHelpers.CreateDefaultComparer
// And in compile.cpp's SpecializeComparer
//
// We need to find the appropriate instantiation
// We need to find the appropriate instantiation.
Instantiation inst(&elemTypeHnd, 1);

// string
if (elemTypeHnd.IsString())
{
return CORINFO_CLASS_HANDLE(CoreLibBinder::GetClass(CLASS__STRING_EQUALITYCOMPARER));
}

// Nullable<T>
if (Nullable::IsNullableType(elemTypeHnd))
{
Expand Down
7 changes: 7 additions & 0 deletions src/coreclr/vm/typehandle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ BOOL TypeHandle::IsArray() const {
return !IsTypeDesc() && AsMethodTable()->IsArray();
}

BOOL TypeHandle::IsString() const
{
LIMITED_METHOD_CONTRACT;

return !IsTypeDesc() && AsMethodTable()->IsString();
}

BOOL TypeHandle::IsGenericVariable() const {
LIMITED_METHOD_DAC_CONTRACT;

Expand Down
3 changes: 3 additions & 0 deletions src/coreclr/vm/typehandle.h
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,9 @@ class TypeHandle
// PTR
BOOL IsPointer() const;

// String
BOOL IsString() const;

// True if this type *is* a formal generic type parameter or any component of it is a formal generic type parameter
BOOL ContainsGenericVariables(BOOL methodOnly=FALSE) const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1088,13 +1088,6 @@ public void ConcurrentWriteRead_NoTornValues()
}));
}

// TODO: Revise this test when EqualityComparer<string>.Default implements IAlternateEqualityComparer<ReadOnlySpan<char>, string>
[Fact]
public void GetAlternateLookup_FailsForDefaultComparer()
{
Assert.False(new ConcurrentDictionary<string, string>().TryGetAlternateLookup<ReadOnlySpan<char>>(out _));
}

[Fact]
public void GetAlternateLookup_FailsWhenIncompatible()
{
Expand All @@ -1112,26 +1105,23 @@ public void GetAlternateLookup_FailsWhenIncompatible()
Assert.False(dictionary.TryGetAlternateLookup<int>(out _));
}

public static IEnumerable<object[]> Dictionary_GetAlternateLookup_OperationsMatchUnderlyingDictionary_MemberData()
{
yield return new object[] { EqualityComparer<string>.Default };
yield return new object[] { StringComparer.Ordinal };
yield return new object[] { StringComparer.OrdinalIgnoreCase };
yield return new object[] { StringComparer.InvariantCulture };
yield return new object[] { StringComparer.InvariantCultureIgnoreCase };
yield return new object[] { StringComparer.CurrentCulture };
yield return new object[] { StringComparer.CurrentCultureIgnoreCase };
}

[Theory]
[InlineData(0)]
[InlineData(1)]
[InlineData(2)]
[InlineData(3)]
[InlineData(4)]
[InlineData(5)]
public void GetAlternateLookup_OperationsMatchUnderlyingDictionary(int mode)
[MemberData(nameof(Dictionary_GetAlternateLookup_OperationsMatchUnderlyingDictionary_MemberData))]
public void GetAlternateLookup_OperationsMatchUnderlyingDictionary(IEqualityComparer<string> comparer)
{
// Test with a variety of comparers to ensure that the alternate lookup is consistent with the underlying dictionary
ConcurrentDictionary<string, int> dictionary = new(mode switch
{
0 => StringComparer.Ordinal,
1 => StringComparer.OrdinalIgnoreCase,
2 => StringComparer.InvariantCulture,
3 => StringComparer.InvariantCultureIgnoreCase,
4 => StringComparer.CurrentCulture,
5 => StringComparer.CurrentCultureIgnoreCase,
_ => throw new ArgumentOutOfRangeException(nameof(mode))
});
ConcurrentDictionary<string, int> dictionary = new(comparer);
ConcurrentDictionary<string, int>.AlternateLookup<ReadOnlySpan<char>> lookup = dictionary.GetAlternateLookup<ReadOnlySpan<char>>();
Assert.Same(dictionary, lookup.Dictionary);
Assert.Same(lookup.Dictionary, lookup.Dictionary);
Expand Down Expand Up @@ -1165,7 +1155,8 @@ public void GetAlternateLookup_OperationsMatchUnderlyingDictionary(int mode)

// Ensure that case-sensitivity of the comparer is respected
lookup["a".AsSpan()] = 42;
if (dictionary.Comparer.Equals(StringComparer.Ordinal) ||
if (dictionary.Comparer.Equals(EqualityComparer<string>.Default) ||
dictionary.Comparer.Equals(StringComparer.Ordinal) ||
dictionary.Comparer.Equals(StringComparer.InvariantCulture) ||
dictionary.Comparer.Equals(StringComparer.CurrentCulture))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,10 @@ public class FrozenDictionaryAlternateLookupTests
[Fact]
public void AlternateLookup_Empty()
{
FrozenDictionary<string, string>[] unsupported =
FrozenDictionary<string, string>[] supported =
[
FrozenDictionary<string, string>.Empty,
FrozenDictionary.ToFrozenDictionary<string, string>([]),
FrozenDictionary.ToFrozenDictionary<string, string>([], EqualityComparer<string>.Default),
];
foreach (FrozenDictionary<string, string> frozen in unsupported)
{
Assert.Throws<InvalidOperationException>(() => frozen.GetAlternateLookup<ReadOnlySpan<char>>());
Assert.False(frozen.TryGetAlternateLookup<ReadOnlySpan<char>>(out _));
}

FrozenDictionary<string, string>[] supported =
[
FrozenDictionary.ToFrozenDictionary<string, string>([], StringComparer.Ordinal),
FrozenDictionary.ToFrozenDictionary<string, string>([], StringComparer.OrdinalIgnoreCase),
];
Expand All @@ -39,7 +29,7 @@ public void AlternateLookup_Empty()
[Fact]
public void UnsupportedComparer_ThrowsOrReturnsFalse()
{
FrozenDictionary<string, int> frozen = new Dictionary<string, int> { ["a"] = 1, ["b"] = 2 }.ToFrozenDictionary();
FrozenDictionary<char, int> frozen = new Dictionary<char, int> { ['a'] = 1, ['b'] = 2 }.ToFrozenDictionary();
Assert.Throws<InvalidOperationException>(() => frozen.GetAlternateLookup<ReadOnlySpan<char>>());
Assert.False(frozen.TryGetAlternateLookup<ReadOnlySpan<char>>(out _));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ public class FrozenSetAlternateLookupTests
[Fact]
public void AlternateLookup_Empty()
{
Assert.False(FrozenSet<string>.Empty.TryGetAlternateLookup<ReadOnlySpan<char>>(out _));
Assert.True(FrozenSet<string>.Empty.TryGetAlternateLookup<ReadOnlySpan<char>>(out _));

foreach (StringComparer comparer in new[] { StringComparer.Ordinal, StringComparer.OrdinalIgnoreCase })
foreach (IEqualityComparer<string> comparer in new IEqualityComparer<string>[] { null, EqualityComparer<string>.Default, StringComparer.Ordinal, StringComparer.OrdinalIgnoreCase })
{
FrozenSet<string>.AlternateLookup<ReadOnlySpan<char>> lookup = FrozenSet.ToFrozenSet([], comparer).GetAlternateLookup<ReadOnlySpan<char>>();
Assert.False(lookup.Contains("anything".AsSpan()));
Expand All @@ -24,7 +24,7 @@ public void AlternateLookup_Empty()
[Fact]
public void UnsupportedComparer()
{
FrozenSet<string> frozen = FrozenSet.ToFrozenSet(["a", "b"]);
FrozenSet<char> frozen = FrozenSet.ToFrozenSet(['a', 'b']);
Assert.Throws<InvalidOperationException>(() => frozen.GetAlternateLookup<ReadOnlySpan<char>>());
Assert.False(frozen.TryGetAlternateLookup<ReadOnlySpan<char>>(out _));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,6 @@ public void GetAlternateLookup_ThrowsWhenNull()
AssertExtensions.Throws<ArgumentNullException>("set", () => CollectionExtensions.TryGetAlternateLookup<int, long>((HashSet<int>)null, out _));
}

// TODO https://github.com/dotnet/runtime/issues/102906:
// Revise this test when EqualityComparer<string>.Default implements IAlternateEqualityComparer<ReadOnlySpan<char>, string>
[Fact]
public void GetAlternateLookup_FailsForDefaultComparer()
{
Assert.False(new Dictionary<string, string>().TryGetAlternateLookup<string, string, ReadOnlySpan<char>>(out _));
Assert.False(new HashSet<string>().TryGetAlternateLookup<string, ReadOnlySpan<char>>(out _));
}

[Fact]
public void GetAlternateLookup_FailsWhenIncompatible()
{
Expand Down Expand Up @@ -189,26 +180,23 @@ public void GetAlternateLookup_FailsWhenIncompatible()
Assert.False(hashSet.TryGetAlternateLookup<string, int>(out _));
}

public static IEnumerable<object[]> Dictionary_GetAlternateLookup_OperationsMatchUnderlyingDictionary_MemberData()
{
yield return new object[] { EqualityComparer<string>.Default };
yield return new object[] { StringComparer.Ordinal };
yield return new object[] { StringComparer.OrdinalIgnoreCase };
yield return new object[] { StringComparer.InvariantCulture };
yield return new object[] { StringComparer.InvariantCultureIgnoreCase };
yield return new object[] { StringComparer.CurrentCulture };
yield return new object[] { StringComparer.CurrentCultureIgnoreCase };
}

[Theory]
[InlineData(0)]
[InlineData(1)]
[InlineData(2)]
[InlineData(3)]
[InlineData(4)]
[InlineData(5)]
public void Dictionary_GetAlternateLookup_OperationsMatchUnderlyingDictionary(int mode)
[MemberData(nameof(Dictionary_GetAlternateLookup_OperationsMatchUnderlyingDictionary_MemberData))]
public void Dictionary_GetAlternateLookup_OperationsMatchUnderlyingDictionary(IEqualityComparer<string> comparer)
{
// Test with a variety of comparers to ensure that the alternate lookup is consistent with the underlying dictionary
Dictionary<string, int> dictionary = new(mode switch
{
0 => StringComparer.Ordinal,
1 => StringComparer.OrdinalIgnoreCase,
2 => StringComparer.InvariantCulture,
3 => StringComparer.InvariantCultureIgnoreCase,
4 => StringComparer.CurrentCulture,
5 => StringComparer.CurrentCultureIgnoreCase,
_ => throw new ArgumentOutOfRangeException(nameof(mode))
});
Dictionary<string, int> dictionary = new(comparer);
Dictionary<string, int>.AlternateLookup<ReadOnlySpan<char>> lookup = dictionary.GetAlternateLookup<string, int, ReadOnlySpan<char>>();
Assert.Same(dictionary, lookup.Dictionary);
Assert.Same(lookup.Dictionary, lookup.Dictionary);
Expand Down Expand Up @@ -241,7 +229,8 @@ public void Dictionary_GetAlternateLookup_OperationsMatchUnderlyingDictionary(in

// Ensure that case-sensitivity of the comparer is respected
lookup["a".AsSpan()] = 42;
if (dictionary.Comparer.Equals(StringComparer.Ordinal) ||
if (dictionary.Comparer.Equals(EqualityComparer<string>.Default) ||
dictionary.Comparer.Equals(StringComparer.Ordinal) ||
dictionary.Comparer.Equals(StringComparer.InvariantCulture) ||
dictionary.Comparer.Equals(StringComparer.CurrentCulture))
{
Expand Down
Loading

0 comments on commit aa280ad

Please sign in to comment.