From 3c2ce9312ee094f84dbac5b44a42e478026a8396 Mon Sep 17 00:00:00 2001 From: Thomas Levesque Date: Sat, 18 Jul 2020 11:05:22 +0200 Subject: [PATCH 1/2] Fix Min/Max(By) to correctly handle null values --- src/Linq.Extras/MinMax.cs | 155 ++++++++++++++++-- .../XEnumerableTests/MinMaxTests.cs | 94 +++++++++-- 2 files changed, 223 insertions(+), 26 deletions(-) diff --git a/src/Linq.Extras/MinMax.cs b/src/Linq.Extras/MinMax.cs index 053470a..10089a5 100644 --- a/src/Linq.Extras/MinMax.cs +++ b/src/Linq.Extras/MinMax.cs @@ -15,6 +15,12 @@ partial class XEnumerable /// The sequence to return the maximum element from. /// The comparer used to compare elements. /// The maximum element according to the specified comparer. + /// + /// If TSource is a reference type or nullable value type, null values are ignored, unless the sequence consists + /// entirely of null values (in which case the method will return null). + /// If TSource is a reference type or nullable value type, and the sequence is empty, the method will return null. + /// If TSource is a value type, and the sequence is empty, the method will throw an . + /// [Pure] public static TSource Max( [NotNull] this IEnumerable source, @@ -32,6 +38,12 @@ public static TSource Max( /// The sequence to return the minimum element from. /// The comparer used to compare elements. /// The minimum element according to the specified comparer. + /// + /// If TSource is a reference type or nullable value type, null values are ignored, unless the sequence consists + /// entirely of null values (in which case the method will return null). + /// If TSource is a reference type or nullable value type, and the sequence is empty, the method will return null. + /// If TSource is a value type, and the sequence is empty, the method will throw an . + /// [Pure] public static TSource Min( [NotNull] this IEnumerable source, @@ -47,22 +59,56 @@ private static TSource Extreme(this IEnumerable source, ICompa { comparer = comparer ?? Comparer.Default; TSource extreme = default!; - bool first = true; - foreach (var item in source) + + using var e = source.GetEnumerator(); + if (extreme is null) { - int compare = 0; - if (!first) - compare = comparer.Compare(item, extreme); + // For nullable types, return null if the sequence is empty + // or contains only null values. - if (Math.Sign(compare) == sign || first) + // First, skip until the first non-null value, if any + do { - extreme = item; + if (!e.MoveNext()) + { + return extreme; + } + + extreme = e.Current; + } while (extreme is null); + + while (e.MoveNext()) + { + if (e.Current is null) + { + continue; + } + + if (Math.Sign(comparer.Compare(e.Current, extreme)) == sign) + { + extreme = e.Current; + } } - first = false; } + else + { + // For non-nullable types, throw an exception if the sequence is empty + + if (!e.MoveNext()) + { + throw EmptySequenceException(); + } + + extreme = e.Current; - if (first) - throw EmptySequenceException(); + while (e.MoveNext()) + { + if (Math.Sign(comparer.Compare(e.Current, extreme)) == sign) + { + extreme = e.Current; + } + } + } return extreme; } @@ -81,6 +127,12 @@ private static InvalidOperationException EmptySequenceException() /// A delegate that returns the key used to compare elements. /// A comparer to compare the keys. /// The element of source that has the maximum value for the specified key. + /// + /// If TKey is a reference type or nullable value type, null keys are ignored, unless the sequence consists + /// entirely of items with null keys (in which case the method will return null). + /// If TKey is a reference type or nullable value type, and the sequence is empty, the method will return null. + /// If TKey is a value type, and the sequence is empty, the method will throw an . + /// [Pure] public static TSource MaxBy( [NotNull] this IEnumerable source, @@ -89,8 +141,7 @@ public static TSource MaxBy( { source.CheckArgumentNull(nameof(source)); keySelector.CheckArgumentNull(nameof(keySelector)); - var comparer = XComparer.By(keySelector, keyComparer); - return source.Max(comparer); + return source.ExtremeBy(keySelector, keyComparer, 1); } /// @@ -102,6 +153,12 @@ public static TSource MaxBy( /// A delegate that returns the key used to compare elements. /// A comparer to compare the keys. /// The element of source that has the minimum value for the specified key. + /// + /// If TKey is a reference type or nullable value type, null keys are ignored, unless the sequence consists + /// entirely of items with null keys (in which case the method will return null). + /// If TKey is a reference type or nullable value type, and the sequence is empty, the method will return null. + /// If TKey is a value type, and the sequence is empty, the method will throw an . + /// [Pure] public static TSource MinBy( [NotNull] this IEnumerable source, @@ -110,8 +167,78 @@ public static TSource MinBy( { source.CheckArgumentNull(nameof(source)); keySelector.CheckArgumentNull(nameof(keySelector)); - var comparer = XComparer.By(keySelector, keyComparer); - return source.Min(comparer); + return source.ExtremeBy(keySelector, keyComparer, -1); + } + + [Pure] + private static TSource ExtremeBy( + this IEnumerable source, + Func keySelector, + IComparer? keyComparer, + int sign) + { + keyComparer = keyComparer ?? Comparer.Default; + TSource extreme = default!; + TKey extremeKey = default!; + + using var e = source.GetEnumerator(); + + if (extremeKey is null) + { + // For nullable types, return null if the sequence is empty + // or contains only values with null keys. + + // First, skip until the first non-null key value, if any + do + { + if (!e.MoveNext()) + { + return extreme; + } + + extreme = e.Current; + extremeKey = keySelector(extreme); + } while (extremeKey is null); + + while (e.MoveNext()) + { + var currentKey = keySelector(e.Current); + if (currentKey is null) + { + continue; + } + + if (Math.Sign(keyComparer.Compare(currentKey, extremeKey)) == sign) + { + extreme = e.Current; + extremeKey = currentKey; + } + } + } + else + { + // For non-nullable types, throw an exception if the sequence is empty + + if (!e.MoveNext()) + { + throw EmptySequenceException(); + } + + extreme = e.Current; + extremeKey = keySelector(e.Current); + + while (e.MoveNext()) + { + var currentKey = keySelector(e.Current); + if (Math.Sign(keyComparer.Compare(currentKey, extremeKey)) == sign) + { + extreme = e.Current; + extremeKey = currentKey; + } + } + } + + return extreme; } } } diff --git a/tests/Linq.Extras.Tests/XEnumerableTests/MinMaxTests.cs b/tests/Linq.Extras.Tests/XEnumerableTests/MinMaxTests.cs index 41499e1..8370be6 100644 --- a/tests/Linq.Extras.Tests/XEnumerableTests/MinMaxTests.cs +++ b/tests/Linq.Extras.Tests/XEnumerableTests/MinMaxTests.cs @@ -18,11 +18,19 @@ public void MaxBy_Throws_If_Argument_Is_Null() } [Fact] - public void MaxBy_Throws_If_Source_Is_Empty() + public void MaxBy_Throws_If_Source_Is_Empty_And_TSource_Is_Not_Nullable() { - var foos = new Foo[] { }.ForbidMultipleEnumeration(); + var items = new DateTime[] { }.ForbidMultipleEnumeration(); // ReSharper disable once ReturnValueOfPureMethodIsNotUsed - Assert.Throws(() => foos.MaxBy(_getFooValue)); + Assert.Throws(() => items.MaxBy(d => d.Ticks)); + } + + [Fact] + public void MaxBy_Returns_Null_If_Source_Is_Empty_And_TSource_Is_Nullable() + { + var items = new Foo[] { }.ForbidMultipleEnumeration(); + var actual = items.MaxBy(f => f.Value); + actual.Should().BeNull(); } [Fact] @@ -54,11 +62,19 @@ public void MinBy_Throws_If_Argument_Is_Null() } [Fact] - public void MinBy_Throws_If_Source_Is_Empty() + public void MinBy_Throws_If_Source_Is_Empty_And_TSource_Is_Not_Nullable() { - var foos = new Foo[] { }.ForbidMultipleEnumeration(); + var items = new DateTime[] { }.ForbidMultipleEnumeration(); // ReSharper disable once ReturnValueOfPureMethodIsNotUsed - Assert.Throws(() => foos.MinBy(_getFooValue)); + Assert.Throws(() => items.MinBy(d => d.Ticks)); + } + + [Fact] + public void MinBy_Returns_Null_If_Source_Is_Empty_And_TSource_Is_Nullable() + { + var items = new Foo[] { }.ForbidMultipleEnumeration(); + var actual = items.MinBy(f => f.Value); + actual.Should().BeNull(); } [Fact] @@ -81,6 +97,34 @@ public void MinBy_Returns_Item_With_Min_Value_For_Key_Based_On_Comparer() actual.Should().Be(expected); } + [Fact] + public void MinBy_Returns_Item_With_Min_Non_Null_Value_For_Key() + { + var bars = new[] + { + new Bar("abcd"), + new Bar(null), + new Bar("efgh") + }.ForbidMultipleEnumeration(); + var fooWithMinValue = bars.MinBy(b => b.Value); + var expected = "abcd"; + var actual = fooWithMinValue.Value; + actual.Should().Be(expected); + } + + [Fact] + public void MinBy_Returns_Item_With_Null_Key_If_All_Items_Have_A_Null_Key() + { + var bars = new[] { + new Bar(null), + new Bar(null), + new Bar(null) + }.ForbidMultipleEnumeration(); + var fooWithMinValue = bars.MinBy(b => b.Value); + var actual = fooWithMinValue.Value; + actual.Should().BeNull(); + } + [Fact] public void Max_Throws_If_Argument_Is_Null() { @@ -101,11 +145,19 @@ public void Max_Return_Max_Value_Based_On_Comparer() } [Fact] - public void Max_Throws_If_Source_Is_Empty() + public void Max_Throws_If_Source_Is_Empty_And_TSource_Is_Not_Nullable() { - var foos = new Foo[] { }.ForbidMultipleEnumeration(); + var items = new DateTime[] { }.ForbidMultipleEnumeration(); // ReSharper disable once ReturnValueOfPureMethodIsNotUsed - Assert.Throws(() => foos.Max(new FooComparer())); + Assert.Throws(() => items.Max(Comparer.Default)); + } + + [Fact] + public void Max_Returns_Null_If_Source_Is_Empty_And_TSource_Is_Nullable() + { + var items = new Foo[] { }.ForbidMultipleEnumeration(); + var actual = items.Max(new FooComparer()); + actual.Should().BeNull(); } [Fact] @@ -128,11 +180,19 @@ public void Min_Return_Min_Value_Based_On_Comparer() } [Fact] - public void Min_Throws_If_Source_Is_Empty() + public void Min_Throws_If_Source_Is_Empty_And_TSource_Is_Not_Nullable() { - var foos = new Foo[] { }.ForbidMultipleEnumeration(); + var items = new DateTime[] { }.ForbidMultipleEnumeration(); // ReSharper disable once ReturnValueOfPureMethodIsNotUsed - Assert.Throws(() => foos.Min(new FooComparer())); + Assert.Throws(() => items.Min(Comparer.Default)); + } + + [Fact] + public void Min_Returns_Null_If_Source_Is_Empty_And_TSource_Is_Nullable() + { + var items = new Foo[] { }.ForbidMultipleEnumeration(); + var actual = items.Min(new FooComparer()); + actual.Should().BeNull(); } private static IEnumerable GetFoos() @@ -163,6 +223,16 @@ public Foo(string value) public string Value { get; } } + private class Bar + { + public Bar(string? value) + { + Value = value; + } + + public string? Value { get; } + } + [ExcludeFromCodeCoverage] private class FooComparer : IComparer { From a49e1b0082019e790b70ed356c5f8ed7f21c9c2a Mon Sep 17 00:00:00 2001 From: Thomas Levesque Date: Sat, 18 Jul 2020 11:07:52 +0200 Subject: [PATCH 2/2] Add [return: MaybeNull] to Min/Max(By) --- src/Linq.Extras/Internal/MaybeNullAttribute.cs | 8 ++++++++ src/Linq.Extras/MinMax.cs | 5 +++++ .../XEnumerableTests/MinMaxTests.cs | 16 ++++++++-------- 3 files changed, 21 insertions(+), 8 deletions(-) create mode 100644 src/Linq.Extras/Internal/MaybeNullAttribute.cs diff --git a/src/Linq.Extras/Internal/MaybeNullAttribute.cs b/src/Linq.Extras/Internal/MaybeNullAttribute.cs new file mode 100644 index 0000000..6534c9a --- /dev/null +++ b/src/Linq.Extras/Internal/MaybeNullAttribute.cs @@ -0,0 +1,8 @@ +namespace System.Diagnostics.CodeAnalysis +{ + /// Specifies that an output may be null even if the corresponding type disallows it. + [AttributeUsage(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.ReturnValue, Inherited = false)] + internal sealed class MaybeNullAttribute : Attribute + { + } +} diff --git a/src/Linq.Extras/MinMax.cs b/src/Linq.Extras/MinMax.cs index 10089a5..8d15fdf 100644 --- a/src/Linq.Extras/MinMax.cs +++ b/src/Linq.Extras/MinMax.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; using JetBrains.Annotations; using Linq.Extras.Internal; using Linq.Extras.Properties; @@ -22,6 +23,7 @@ partial class XEnumerable /// If TSource is a value type, and the sequence is empty, the method will throw an . /// [Pure] + [return: MaybeNull] public static TSource Max( [NotNull] this IEnumerable source, [NotNull] IComparer comparer) @@ -45,6 +47,7 @@ public static TSource Max( /// If TSource is a value type, and the sequence is empty, the method will throw an . /// [Pure] + [return: MaybeNull] public static TSource Min( [NotNull] this IEnumerable source, [NotNull] IComparer comparer) @@ -134,6 +137,7 @@ private static InvalidOperationException EmptySequenceException() /// If TKey is a value type, and the sequence is empty, the method will throw an . /// [Pure] + [return: MaybeNull] public static TSource MaxBy( [NotNull] this IEnumerable source, [NotNull] Func keySelector, @@ -160,6 +164,7 @@ public static TSource MaxBy( /// If TKey is a value type, and the sequence is empty, the method will throw an . /// [Pure] + [return: MaybeNull] public static TSource MinBy( [NotNull] this IEnumerable source, [NotNull] Func keySelector, diff --git a/tests/Linq.Extras.Tests/XEnumerableTests/MinMaxTests.cs b/tests/Linq.Extras.Tests/XEnumerableTests/MinMaxTests.cs index 8370be6..5de9267 100644 --- a/tests/Linq.Extras.Tests/XEnumerableTests/MinMaxTests.cs +++ b/tests/Linq.Extras.Tests/XEnumerableTests/MinMaxTests.cs @@ -39,7 +39,7 @@ public void MaxBy_Returns_Item_With_Max_Value_For_Key() var foos = GetFoos().ForbidMultipleEnumeration(); var fooWithMaxValue = foos.MaxBy(_getFooValue); var expected = "xyz"; - var actual = fooWithMaxValue.Value; + var actual = fooWithMaxValue?.Value; actual.Should().Be(expected); } @@ -49,7 +49,7 @@ public void MaxBy_Returns_Item_With_Max_Value_For_Key_Based_On_Comparer() var foos = GetFoos().ForbidMultipleEnumeration(); var fooWithMaxValue = foos.MaxBy(_getFooValue, Comparer.Default.Reverse()); var expected = "abcd"; - var actual = fooWithMaxValue.Value; + var actual = fooWithMaxValue?.Value; actual.Should().Be(expected); } @@ -83,7 +83,7 @@ public void MinBy_Returns_Item_With_Min_Value_For_Key() var foos = GetFoos().ForbidMultipleEnumeration(); var fooWithMinValue = foos.MinBy(_getFooValue); var expected = "abcd"; - var actual = fooWithMinValue.Value; + var actual = fooWithMinValue?.Value; actual.Should().Be(expected); } @@ -93,7 +93,7 @@ public void MinBy_Returns_Item_With_Min_Value_For_Key_Based_On_Comparer() var foos = GetFoos().ForbidMultipleEnumeration(); var fooWithMinValue = foos.MinBy(_getFooValue, Comparer.Default.Reverse()); var expected = "xyz"; - var actual = fooWithMinValue.Value; + var actual = fooWithMinValue?.Value; actual.Should().Be(expected); } @@ -108,7 +108,7 @@ public void MinBy_Returns_Item_With_Min_Non_Null_Value_For_Key() }.ForbidMultipleEnumeration(); var fooWithMinValue = bars.MinBy(b => b.Value); var expected = "abcd"; - var actual = fooWithMinValue.Value; + var actual = fooWithMinValue?.Value; actual.Should().Be(expected); } @@ -121,7 +121,7 @@ public void MinBy_Returns_Item_With_Null_Key_If_All_Items_Have_A_Null_Key() new Bar(null) }.ForbidMultipleEnumeration(); var fooWithMinValue = bars.MinBy(b => b.Value); - var actual = fooWithMinValue.Value; + var actual = fooWithMinValue?.Value; actual.Should().BeNull(); } @@ -140,7 +140,7 @@ public void Max_Return_Max_Value_Based_On_Comparer() var foos = GetFoos().ForbidMultipleEnumeration(); var fooWithMaxValue = foos.Max(new FooComparer()); var expected = "xyz"; - var actual = fooWithMaxValue.Value; + var actual = fooWithMaxValue?.Value; actual.Should().Be(expected); } @@ -175,7 +175,7 @@ public void Min_Return_Min_Value_Based_On_Comparer() var foos = GetFoos().ForbidMultipleEnumeration(); var fooWithMinValue = foos.Min(new FooComparer()); var expected = "abcd"; - var actual = fooWithMinValue.Value; + var actual = fooWithMinValue?.Value; actual.Should().Be(expected); }