Skip to content

Commit

Permalink
Fix MinBy and MaxBy handling of empty sources (#53544)
Browse files Browse the repository at this point in the history
Addresses a bug where the empty source behaviour
is determined by the key type rather than the source type.
  • Loading branch information
eiriktsarpalis authored Jun 2, 2021
1 parent 88053fe commit dbb05eb
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 92 deletions.
90 changes: 47 additions & 43 deletions src/libraries/System.Linq/src/System/Linq/Max.cs
Original file line number Diff line number Diff line change
Expand Up @@ -527,29 +527,56 @@ public static decimal Max(this IEnumerable<decimal> source)

comparer ??= Comparer<TKey>.Default;

TKey? key = default;
TSource? value = default;
using (IEnumerator<TSource> e = source.GetEnumerator())
using IEnumerator<TSource> e = source.GetEnumerator();

if (!e.MoveNext())
{
if (key == null)
if (default(TSource) is null)
{
do
{
if (!e.MoveNext())
{
return value;
}
return default;
}
else
{
ThrowHelper.ThrowNoElementsException();
}
}

value = e.Current;
key = keySelector(value);
TSource value = e.Current;
TKey key = keySelector(value);

if (default(TKey) is null)
{
while (key == null)
{
if (!e.MoveNext())
{
return value;
}
while (key == null);

value = e.Current;
key = keySelector(value);
}

while (e.MoveNext())
{
TSource nextValue = e.Current;
TKey nextKey = keySelector(nextValue);
if (nextKey != null && comparer.Compare(nextKey, key) > 0)
{
key = nextKey;
value = nextValue;
}
}
}
else
{
if (comparer == Comparer<TKey>.Default)
{
while (e.MoveNext())
{
TSource nextValue = e.Current;
TKey nextKey = keySelector(nextValue);
if (nextKey != null && comparer.Compare(nextKey, key) > 0)
if (Comparer<TKey>.Default.Compare(nextKey, key) > 0)
{
key = nextKey;
value = nextValue;
Expand All @@ -558,37 +585,14 @@ public static decimal Max(this IEnumerable<decimal> source)
}
else
{
if (!e.MoveNext())
{
ThrowHelper.ThrowNoElementsException();
}

value = e.Current;
key = keySelector(value);
if (comparer == Comparer<TSource>.Default)
{
while (e.MoveNext())
{
TSource nextValue = e.Current;
TKey nextKey = keySelector(nextValue);
if (Comparer<TKey>.Default.Compare(nextKey, key) > 0)
{
key = nextKey;
value = nextValue;
}
}
}
else
while (e.MoveNext())
{
while (e.MoveNext())
TSource nextValue = e.Current;
TKey nextKey = keySelector(nextValue);
if (comparer.Compare(nextKey, key) > 0)
{
TSource nextValue = e.Current;
TKey nextKey = keySelector(nextValue);
if (comparer.Compare(nextKey, key) > 0)
{
key = nextKey;
value = nextValue;
}
key = nextKey;
value = nextValue;
}
}
}
Expand Down
90 changes: 47 additions & 43 deletions src/libraries/System.Linq/src/System/Linq/Min.cs
Original file line number Diff line number Diff line change
Expand Up @@ -485,29 +485,56 @@ public static decimal Min(this IEnumerable<decimal> source)

comparer ??= Comparer<TKey>.Default;

TKey? key = default;
TSource? value = default;
using (IEnumerator<TSource> e = source.GetEnumerator())
using IEnumerator<TSource> e = source.GetEnumerator();

if (!e.MoveNext())
{
if (key == null)
if (default(TSource) is null)
{
do
{
if (!e.MoveNext())
{
return value;
}
return default;
}
else
{
ThrowHelper.ThrowNoElementsException();
}
}

value = e.Current;
key = keySelector(value);
TSource value = e.Current;
TKey key = keySelector(value);

if (default(TKey) is null)
{
while (key == null)
{
if (!e.MoveNext())
{
return value;
}
while (key == null);

value = e.Current;
key = keySelector(value);
}

while (e.MoveNext())
{
TSource nextValue = e.Current;
TKey nextKey = keySelector(nextValue);
if (nextKey != null && comparer.Compare(nextKey, key) < 0)
{
key = nextKey;
value = nextValue;
}
}
}
else
{
if (comparer == Comparer<TKey>.Default)
{
while (e.MoveNext())
{
TSource nextValue = e.Current;
TKey nextKey = keySelector(nextValue);
if (nextKey != null && comparer.Compare(nextKey, key) < 0)
if (Comparer<TKey>.Default.Compare(nextKey, key) < 0)
{
key = nextKey;
value = nextValue;
Expand All @@ -516,37 +543,14 @@ public static decimal Min(this IEnumerable<decimal> source)
}
else
{
if (!e.MoveNext())
{
ThrowHelper.ThrowNoElementsException();
}

value = e.Current;
key = keySelector(value);
if (comparer == Comparer<TKey>.Default)
{
while (e.MoveNext())
{
TSource nextValue = e.Current;
TKey nextKey = keySelector(nextValue);
if (Comparer<TKey>.Default.Compare(nextKey, key) < 0)
{
key = nextKey;
value = nextValue;
}
}
}
else
while (e.MoveNext())
{
while (e.MoveNext())
TSource nextValue = e.Current;
TKey nextKey = keySelector(nextValue);
if (comparer.Compare(nextKey, key) < 0)
{
TSource nextValue = e.Current;
TKey nextKey = keySelector(nextValue);
if (comparer.Compare(nextKey, key) < 0)
{
key = nextKey;
value = nextValue;
}
key = nextKey;
value = nextValue;
}
}
}
Expand Down
52 changes: 49 additions & 3 deletions src/libraries/System.Linq/tests/MaxTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -868,9 +868,49 @@ public static void MaxBy_Generic_NullKeySelector_ThrowsArgumentNullException()
[Fact]
public static void MaxBy_Generic_EmptyStructSource_ThrowsInvalidOperationException()
{
Assert.Throws<InvalidOperationException>(() => Enumerable.Empty<int>().MaxBy(x => x));
Assert.Throws<InvalidOperationException>(() => Enumerable.Empty<int>().MaxBy(x => x, comparer: null));
Assert.Throws<InvalidOperationException>(() => Enumerable.Empty<int>().MaxBy(x => x, Comparer<int>.Create((_, _) => 0)));
Assert.Throws<InvalidOperationException>(() => Enumerable.Empty<int>().MaxBy(x => x.ToString()));
Assert.Throws<InvalidOperationException>(() => Enumerable.Empty<int>().MaxBy(x => x.ToString(), comparer: null));
Assert.Throws<InvalidOperationException>(() => Enumerable.Empty<int>().MaxBy(x => x.ToString(), Comparer<string>.Create((_, _) => 0)));
}

[Fact]
public static void MaxBy_Generic_EmptyNullableSource_ReturnsNull()
{
Assert.Null(Enumerable.Empty<int?>().MaxBy(x => x.GetHashCode()));
Assert.Null(Enumerable.Empty<int?>().MaxBy(x => x.GetHashCode(), comparer: null));
Assert.Null(Enumerable.Empty<int?>().MaxBy(x => x.GetHashCode(), Comparer<int>.Create((_, _) => 0)));
}

[Fact]
public static void MaxBy_Generic_EmptyReferenceSource_ReturnsNull()
{
Assert.Null(Enumerable.Empty<string>().MaxBy(x => x.GetHashCode()));
Assert.Null(Enumerable.Empty<string>().MaxBy(x => x.GetHashCode(), comparer: null));
Assert.Null(Enumerable.Empty<string>().MaxBy(x => x.GetHashCode(), Comparer<int>.Create((_, _) => 0)));
}

[Fact]
public static void MaxBy_Generic_StructSourceAllKeysAreNull_ReturnsLastElement()
{
Assert.Equal(4, Enumerable.Range(0, 5).MaxBy(x => default(string)));
Assert.Equal(4, Enumerable.Range(0, 5).MaxBy(x => default(string), comparer: null));
Assert.Equal(4, Enumerable.Range(0, 5).MaxBy(x => default(string), Comparer<string>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
}

[Fact]
public static void MaxBy_Generic_NullableSourceAllKeysAreNull_ReturnsLastElement()
{
Assert.Equal(4, Enumerable.Range(0, 5).Cast<int?>().MaxBy(x => default(int?)));
Assert.Equal(4, Enumerable.Range(0, 5).Cast<int?>().MaxBy(x => default(int?), comparer: null));
Assert.Equal(4, Enumerable.Range(0, 5).Cast<int?>().MaxBy(x => default(int?), Comparer<int?>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
}

[Fact]
public static void MaxBy_Generic_ReferenceSourceAllKeysAreNull_ReturnsLastElement()
{
Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MaxBy(x => default(string)));
Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MaxBy(x => default(string), comparer: null));
Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MaxBy(x => default(string), Comparer<string>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
}

[Theory]
Expand Down Expand Up @@ -955,6 +995,12 @@ public static IEnumerable<object[]> MaxBy_Generic_TestData()
comparer: Comparer<string>.Create((x, y) => -x.CompareTo(y)),
expected: (Name: "Dick", Age: 55));

yield return WrapArgs(
source: new (string Name, int Age)[] { ("Tom", 43), (null, 55), ("Harry", 20) },
keySelector: x => x.Name,
comparer: Comparer<string>.Create((x, y) => -x.CompareTo(y)),
expected: (Name: "Harry", Age: 20));

object[] WrapArgs<TSource, TKey>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, IComparer<TKey>? comparer, TSource? expected)
=> new object[] { source, keySelector, comparer, expected };
}
Expand Down
52 changes: 49 additions & 3 deletions src/libraries/System.Linq/tests/MinTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -846,9 +846,49 @@ public static void MinBy_Generic_NullKeySelector_ThrowsArgumentNullException()
[Fact]
public static void MinBy_Generic_EmptyStructSource_ThrowsInvalidOperationException()
{
Assert.Throws<InvalidOperationException>(() => Enumerable.Empty<int>().MinBy(x => x));
Assert.Throws<InvalidOperationException>(() => Enumerable.Empty<int>().MinBy(x => x, comparer: null));
Assert.Throws<InvalidOperationException>(() => Enumerable.Empty<int>().MinBy(x => x, Comparer<int>.Create((_, _) => 0)));
Assert.Throws<InvalidOperationException>(() => Enumerable.Empty<int>().MinBy(x => x.ToString()));
Assert.Throws<InvalidOperationException>(() => Enumerable.Empty<int>().MinBy(x => x.ToString(), comparer: null));
Assert.Throws<InvalidOperationException>(() => Enumerable.Empty<int>().MinBy(x => x.ToString(), Comparer<string>.Create((_, _) => 0)));
}

[Fact]
public static void MinBy_Generic_EmptyNullableSource_ReturnsNull()
{
Assert.Null(Enumerable.Empty<int?>().MinBy(x => x.GetHashCode()));
Assert.Null(Enumerable.Empty<int?>().MinBy(x => x.GetHashCode(), comparer: null));
Assert.Null(Enumerable.Empty<int?>().MinBy(x => x.GetHashCode(), Comparer<int>.Create((_, _) => 0)));
}

[Fact]
public static void MinBy_Generic_EmptyReferenceSource_ReturnsNull()
{
Assert.Null(Enumerable.Empty<string>().MinBy(x => x.GetHashCode()));
Assert.Null(Enumerable.Empty<string>().MinBy(x => x.GetHashCode(), comparer: null));
Assert.Null(Enumerable.Empty<string>().MinBy(x => x.GetHashCode(), Comparer<int>.Create((_, _) => 0)));
}

[Fact]
public static void MinBy_Generic_StructSourceAllKeysAreNull_ReturnsLastElement()
{
Assert.Equal(4, Enumerable.Range(0, 5).MinBy(x => default(string)));
Assert.Equal(4, Enumerable.Range(0, 5).MinBy(x => default(string), comparer: null));
Assert.Equal(4, Enumerable.Range(0, 5).MinBy(x => default(string), Comparer<string>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
}

[Fact]
public static void MinBy_Generic_NullableSourceAllKeysAreNull_ReturnsLastElement()
{
Assert.Equal(4, Enumerable.Range(0, 5).Cast<int?>().MinBy(x => default(int?)));
Assert.Equal(4, Enumerable.Range(0, 5).Cast<int?>().MinBy(x => default(int?), comparer: null));
Assert.Equal(4, Enumerable.Range(0, 5).Cast<int?>().MinBy(x => default(int?), Comparer<int?>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
}

[Fact]
public static void MinBy_Generic_ReferenceSourceAllKeysAreNull_ReturnsLastElement()
{
Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MinBy(x => default(string)));
Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MinBy(x => default(string), comparer: null));
Assert.Equal("4", Enumerable.Range(0, 5).Select(x => x.ToString()).MinBy(x => default(string), Comparer<string>.Create((_, _) => throw new InvalidOperationException("comparer should not be called."))));
}

[Theory]
Expand Down Expand Up @@ -933,6 +973,12 @@ public static IEnumerable<object[]> MinBy_Generic_TestData()
comparer: Comparer<string>.Create((x, y) => -x.CompareTo(y)),
expected: (Name: "Tom", Age: 43));

yield return WrapArgs(
source: new (string Name, int Age)[] { (null, 43), ("Dick", 55), ("Harry", 20) },
keySelector: x => x.Name,
comparer: Comparer<string>.Create((x, y) => -x.CompareTo(y)),
expected: (Name: "Harry", Age: 20));

object[] WrapArgs<TSource, TKey>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, IComparer<TKey>? comparer, TSource? expected)
=> new object[] { source, keySelector, comparer, expected };
}
Expand Down

0 comments on commit dbb05eb

Please sign in to comment.