Skip to content

Commit

Permalink
Fix nullable annotations on IEquatable/IComparable constraints (#73495)
Browse files Browse the repository at this point in the history
* Fix nullable annotations on IEquatable/IComparable constraints

* Fix LastIndexOfAny null dereference bug

If the first two values are non-null and don't match the input and the third value is null, it gets dereferenced and results in a null reference exception.
  • Loading branch information
stephentoub authored Aug 8, 2022
1 parent be33bc7 commit 38b9d05
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 177 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public abstract partial class EqualityComparer<T> : IEqualityComparer, IEquality
public static EqualityComparer<T> Default { [Intrinsic] get; } = (EqualityComparer<T>)ComparerHelpers.CreateDefaultEqualityComparer(typeof(T));
}

public sealed partial class GenericEqualityComparer<T> : EqualityComparer<T> where T : IEquatable<T>
public sealed partial class GenericEqualityComparer<T> : EqualityComparer<T> where T : IEquatable<T>?
{
internal override int IndexOf(T[] array, T value, int startIndex, int count)
{
Expand All @@ -30,7 +30,7 @@ internal override int IndexOf(T[] array, T value, int startIndex, int count)
{
for (int i = startIndex; i < endIndex; i++)
{
if (array[i] != null && array[i].Equals(value)) return i;
if (array[i] != null && array[i]!.Equals(value)) return i;
}
}
return -1;
Expand All @@ -50,7 +50,7 @@ internal override int LastIndexOf(T[] array, T value, int startIndex, int count)
{
for (int i = startIndex; i >= endIndex; i--)
{
if (array[i] != null && array[i].Equals(value)) return i;
if (array[i] != null && array[i]!.Equals(value)) return i;
}
}
return -1;
Expand Down
142 changes: 71 additions & 71 deletions src/libraries/System.Memory/ref/System.Memory.cs

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public static class BuffersExtensions
/// Returns position of first occurrence of item in the <see cref="ReadOnlySequence{T}"/>
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static SequencePosition? PositionOf<T>(in this ReadOnlySequence<T> source, T value) where T : IEquatable<T>
public static SequencePosition? PositionOf<T>(in this ReadOnlySequence<T> source, T value) where T : IEquatable<T>?
{
if (source.IsSingleSegment)
{
Expand All @@ -32,7 +32,7 @@ public static class BuffersExtensions
}
}

private static SequencePosition? PositionOfMultiSegment<T>(in ReadOnlySequence<T> source, T value) where T : IEquatable<T>
private static SequencePosition? PositionOfMultiSegment<T>(in ReadOnlySequence<T> source, T value) where T : IEquatable<T>?
{
SequencePosition position = source.Start;
SequencePosition result = position;
Expand Down
4 changes: 4 additions & 0 deletions src/libraries/System.Memory/tests/TestHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,11 @@ public static ReadOnlyMemory<T> DangerousCreateReadOnlyMemory<T>(object obj, int

{ new string[] { "1", "3", "2" }, new string[] { null, "1" }, 0},
{ new string[] { "1", "3", "2" }, new string[] { "1", "2", null }, 2},
{ new string[] { "1", "3", "2" }, new string[] { "4", "5", null }, -1},
{ new string[] { "1", "3", "2" }, new string[] { null, null }, -1},
{ new string[] { "1", "3", "2" }, new string[] { null, null, null }, -1},
{ new string[] { "1", "3", "2" }, new string[] { null, null, null, null }, -1},
{ new string[] { "1", "3", "2" }, new string[] { null, null, null, null, null }, -1},

{ null, new string[] { null, "1" }, -1},

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public ComparisonComparer(Comparison<T> comparison)
[Serializable]
[TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")]
// Needs to be public to support binary serialization compatibility
public sealed partial class GenericComparer<T> : Comparer<T> where T : IComparable<T>
public sealed partial class GenericComparer<T> : Comparer<T> where T : IComparable<T>?
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public override int Compare(T? x, T? y)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ internal virtual int LastIndexOf(T[] array, T value, int startIndex, int count)
[Serializable]
[TypeForwardedFrom("mscorlib, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089")]
// Needs to be public to support binary serialization compatibility
public sealed partial class GenericEqualityComparer<T> : EqualityComparer<T> where T : IEquatable<T>
public sealed partial class GenericEqualityComparer<T> : EqualityComparer<T> where T : IEquatable<T>?
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public override bool Equals(T? x, T? y)
Expand Down

Large diffs are not rendered by default.

92 changes: 46 additions & 46 deletions src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs

Large diffs are not rendered by default.

64 changes: 39 additions & 25 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ public static void Fill<T>(ref T refData, nuint numElements, T value)
}
}

public static int IndexOf<T>(ref T searchSpace, int searchSpaceLength, ref T value, int valueLength) where T : IEquatable<T>
public static int IndexOf<T>(ref T searchSpace, int searchSpaceLength, ref T value, int valueLength) where T : IEquatable<T>?
{
Debug.Assert(searchSpaceLength >= 0);
Debug.Assert(valueLength >= 0);
Expand Down Expand Up @@ -220,14 +220,16 @@ public static int IndexOf<T>(ref T searchSpace, int searchSpaceLength, ref T val
}

// Adapted from IndexOf(...)
public static unsafe bool Contains<T>(ref T searchSpace, T value, int length) where T : IEquatable<T>
public static unsafe bool Contains<T>(ref T searchSpace, T value, int length) where T : IEquatable<T>?
{
Debug.Assert(length >= 0);

nint index = 0; // Use nint for arithmetic to avoid unnecessary 64->32->64 truncations

if (default(T) != null || (object)value != null)
if (default(T) != null || (object?)value != null)
{
Debug.Assert(value is not null);

while (length >= 8)
{
length -= 8;
Expand Down Expand Up @@ -277,7 +279,7 @@ public static unsafe bool Contains<T>(ref T searchSpace, T value, int length) wh
nint len = length;
for (index = 0; index < len; index++)
{
if ((object)Unsafe.Add(ref searchSpace, index) is null)
if ((object?)Unsafe.Add(ref searchSpace, index) is null)
{
goto Found;
}
Expand Down Expand Up @@ -399,13 +401,15 @@ internal static unsafe int IndexOfValueType<T>(ref T searchSpace, T value, int l
return (int)(index + 7);
}

public static unsafe int IndexOf<T>(ref T searchSpace, T value, int length) where T : IEquatable<T>
public static unsafe int IndexOf<T>(ref T searchSpace, T value, int length) where T : IEquatable<T>?
{
Debug.Assert(length >= 0);

nint index = 0; // Use nint for arithmetic to avoid unnecessary 64->32->64 truncations
if (default(T) != null || (object)value != null)
if (default(T) != null || (object?)value != null)
{
Debug.Assert(value is not null);

while (length >= 8)
{
length -= 8;
Expand Down Expand Up @@ -460,7 +464,7 @@ public static unsafe int IndexOf<T>(ref T searchSpace, T value, int length) wher
nint len = (nint)length;
for (index = 0; index < len; index++)
{
if ((object)Unsafe.Add(ref searchSpace, index) is null)
if ((object?)Unsafe.Add(ref searchSpace, index) is null)
{
goto Found;
}
Expand All @@ -486,14 +490,16 @@ public static unsafe int IndexOf<T>(ref T searchSpace, T value, int length) wher
return (int)(index + 7);
}

public static int IndexOfAny<T>(ref T searchSpace, T value0, T value1, int length) where T : IEquatable<T>
public static int IndexOfAny<T>(ref T searchSpace, T value0, T value1, int length) where T : IEquatable<T>?
{
Debug.Assert(length >= 0);

T lookUp;
int index = 0;
if (default(T) != null || ((object)value0 != null && (object)value1 != null))
if (default(T) != null || ((object?)value0 != null && (object?)value1 != null))
{
Debug.Assert(value0 is not null && value1 is not null);

while ((length - index) >= 8)
{
lookUp = Unsafe.Add(ref searchSpace, index);
Expand Down Expand Up @@ -590,14 +596,16 @@ public static int IndexOfAny<T>(ref T searchSpace, T value0, T value1, int lengt
return index + 7;
}

public static int IndexOfAny<T>(ref T searchSpace, T value0, T value1, T value2, int length) where T : IEquatable<T>
public static int IndexOfAny<T>(ref T searchSpace, T value0, T value1, T value2, int length) where T : IEquatable<T>?
{
Debug.Assert(length >= 0);

T lookUp;
int index = 0;
if (default(T) != null || ((object)value0 != null && (object)value1 != null && (object)value2 != null))
if (default(T) != null || ((object?)value0 != null && (object?)value1 != null && (object?)value2 != null))
{
Debug.Assert(value0 is not null && value1 is not null && value2 is not null);

while ((length - index) >= 8)
{
lookUp = Unsafe.Add(ref searchSpace, index);
Expand Down Expand Up @@ -693,7 +701,7 @@ public static int IndexOfAny<T>(ref T searchSpace, T value0, T value1, T value2,
return index + 7;
}

public static int IndexOfAny<T>(ref T searchSpace, int searchSpaceLength, ref T value, int valueLength) where T : IEquatable<T>
public static int IndexOfAny<T>(ref T searchSpace, int searchSpaceLength, ref T value, int valueLength) where T : IEquatable<T>?
{
Debug.Assert(searchSpaceLength >= 0);
Debug.Assert(valueLength >= 0);
Expand Down Expand Up @@ -723,7 +731,7 @@ public static int IndexOfAny<T>(ref T searchSpace, int searchSpaceLength, ref T
T candidate = Unsafe.Add(ref searchSpace, i);
for (int j = 0; j < valueLength; j++)
{
if (Unsafe.Add(ref value, j).Equals(candidate))
if (Unsafe.Add(ref value, j)!.Equals(candidate))
{
return i;
}
Expand Down Expand Up @@ -764,7 +772,7 @@ public static int IndexOfAny<T>(ref T searchSpace, int searchSpaceLength, ref T
return -1; // not found
}

public static int LastIndexOf<T>(ref T searchSpace, int searchSpaceLength, ref T value, int valueLength) where T : IEquatable<T>
public static int LastIndexOf<T>(ref T searchSpace, int searchSpaceLength, ref T value, int valueLength) where T : IEquatable<T>?
{
Debug.Assert(searchSpaceLength >= 0);
Debug.Assert(valueLength >= 0);
Expand Down Expand Up @@ -804,12 +812,14 @@ public static int LastIndexOf<T>(ref T searchSpace, int searchSpaceLength, ref T
return -1;
}

public static int LastIndexOf<T>(ref T searchSpace, T value, int length) where T : IEquatable<T>
public static int LastIndexOf<T>(ref T searchSpace, T value, int length) where T : IEquatable<T>?
{
Debug.Assert(length >= 0);

if (default(T) != null || (object)value != null)
if (default(T) != null || (object?)value != null)
{
Debug.Assert(value is not null);

while (length >= 8)
{
length -= 8;
Expand Down Expand Up @@ -858,7 +868,7 @@ public static int LastIndexOf<T>(ref T searchSpace, T value, int length) where T
{
for (length--; length >= 0; length--)
{
if ((object)Unsafe.Add(ref searchSpace, length) is null)
if ((object?)Unsafe.Add(ref searchSpace, length) is null)
{
goto Found;
}
Expand All @@ -885,13 +895,15 @@ public static int LastIndexOf<T>(ref T searchSpace, T value, int length) where T
return length + 7;
}

public static int LastIndexOfAny<T>(ref T searchSpace, T value0, T value1, int length) where T : IEquatable<T>
public static int LastIndexOfAny<T>(ref T searchSpace, T value0, T value1, int length) where T : IEquatable<T>?
{
Debug.Assert(length >= 0);

T lookUp;
if (default(T) != null || ((object)value0 != null && (object)value1 != null))
if (default(T) != null || ((object?)value0 != null && (object?)value1 != null))
{
Debug.Assert(value0 is not null && value1 is not null);

while (length >= 8)
{
length -= 8;
Expand Down Expand Up @@ -988,13 +1000,15 @@ public static int LastIndexOfAny<T>(ref T searchSpace, T value0, T value1, int l
return length + 7;
}

public static int LastIndexOfAny<T>(ref T searchSpace, T value0, T value1, T value2, int length) where T : IEquatable<T>
public static int LastIndexOfAny<T>(ref T searchSpace, T value0, T value1, T value2, int length) where T : IEquatable<T>?
{
Debug.Assert(length >= 0);

T lookUp;
if (default(T) != null || ((object)value0 != null && (object)value1 != null))
if (default(T) != null || ((object?)value0 != null && (object?)value1 != null && (object?)value2 != null))
{
Debug.Assert(value0 is not null && value1 is not null && value2 is not null);

while (length >= 8)
{
length -= 8;
Expand Down Expand Up @@ -1091,7 +1105,7 @@ public static int LastIndexOfAny<T>(ref T searchSpace, T value0, T value1, T val
return length + 7;
}

public static int LastIndexOfAny<T>(ref T searchSpace, int searchSpaceLength, ref T value, int valueLength) where T : IEquatable<T>
public static int LastIndexOfAny<T>(ref T searchSpace, int searchSpaceLength, ref T value, int valueLength) where T : IEquatable<T>?
{
Debug.Assert(searchSpaceLength >= 0);
Debug.Assert(valueLength >= 0);
Expand All @@ -1109,7 +1123,7 @@ public static int LastIndexOfAny<T>(ref T searchSpace, int searchSpaceLength, re
T candidate = Unsafe.Add(ref searchSpace, i);
for (int j = 0; j < valueLength; j++)
{
if (Unsafe.Add(ref value, j).Equals(candidate))
if (Unsafe.Add(ref value, j)!.Equals(candidate))
{
return i;
}
Expand Down Expand Up @@ -1147,7 +1161,7 @@ public static int LastIndexOfAny<T>(ref T searchSpace, int searchSpaceLength, re
return -1; // not found
}

public static bool SequenceEqual<T>(ref T first, ref T second, int length) where T : IEquatable<T>
public static bool SequenceEqual<T>(ref T first, ref T second, int length) where T : IEquatable<T>?
{
Debug.Assert(length >= 0);

Expand Down Expand Up @@ -1239,7 +1253,7 @@ public static bool SequenceEqual<T>(ref T first, ref T second, int length) where
}

public static int SequenceCompareTo<T>(ref T first, int firstLength, ref T second, int secondLength)
where T : IComparable<T>
where T : IComparable<T>?
{
Debug.Assert(firstLength >= 0);
Debug.Assert(secondLength >= 0);
Expand Down

0 comments on commit 38b9d05

Please sign in to comment.