From 8acd0ca426b324853e5e847f1e92e606b6e556e7 Mon Sep 17 00:00:00 2001 From: Emmanuel ANDRE <2341261+manandre@users.noreply.github.com> Date: Thu, 24 Aug 2023 00:27:54 +0200 Subject: [PATCH 1/8] Implement Linq CountBy method for IEnumerable --- src/libraries/System.Linq/ref/System.Linq.cs | 2 + .../System.Linq/src/System.Linq.csproj | 1 + .../System.Linq/src/System/Linq/CountBy.cs | 40 ++++++ .../System.Linq/tests/CountByTests.cs | 117 ++++++++++++++++++ .../tests/System.Linq.Tests.csproj | 1 + 5 files changed, 161 insertions(+) create mode 100644 src/libraries/System.Linq/src/System/Linq/CountBy.cs create mode 100644 src/libraries/System.Linq/tests/CountByTests.cs diff --git a/src/libraries/System.Linq/ref/System.Linq.cs b/src/libraries/System.Linq/ref/System.Linq.cs index e49127be6c737..31931dface380 100644 --- a/src/libraries/System.Linq/ref/System.Linq.cs +++ b/src/libraries/System.Linq/ref/System.Linq.cs @@ -47,6 +47,8 @@ public static System.Collections.Generic.IEnumerable< public static bool Contains(this System.Collections.Generic.IEnumerable source, TSource value, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static int Count(this System.Collections.Generic.IEnumerable source) { throw null; } public static int Count(this System.Collections.Generic.IEnumerable source, System.Func predicate) { throw null; } + + public static System.Collections.Generic.IEnumerable> CountBy(this System.Collections.Generic.IEnumerable source, System.Func keySelector, System.Collections.Generic.IEqualityComparer? comparer = null) where TKey : notnull { throw null; } public static System.Collections.Generic.IEnumerable DefaultIfEmpty(this System.Collections.Generic.IEnumerable source) { throw null; } public static System.Collections.Generic.IEnumerable DefaultIfEmpty(this System.Collections.Generic.IEnumerable source, TSource defaultValue) { throw null; } public static System.Collections.Generic.IEnumerable DistinctBy(this System.Collections.Generic.IEnumerable source, System.Func keySelector) { throw null; } diff --git a/src/libraries/System.Linq/src/System.Linq.csproj b/src/libraries/System.Linq/src/System.Linq.csproj index 0babce85894de..af4632c1dc2f6 100644 --- a/src/libraries/System.Linq/src/System.Linq.csproj +++ b/src/libraries/System.Linq/src/System.Linq.csproj @@ -56,6 +56,7 @@ + diff --git a/src/libraries/System.Linq/src/System/Linq/CountBy.cs b/src/libraries/System.Linq/src/System/Linq/CountBy.cs new file mode 100644 index 0000000000000..a77a6dc07fbb4 --- /dev/null +++ b/src/libraries/System.Linq/src/System/Linq/CountBy.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Runtime.InteropServices; + +namespace System.Linq +{ + public static partial class Enumerable + { + public static IEnumerable> CountBy(this IEnumerable source, Func keySelector, IEqualityComparer? comparer = null) where TKey : notnull + { + if (source is null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); + } + if (keySelector is null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.keySelector); + } + + Dictionary countsBy = new(comparer); + + using IEnumerator e = source.GetEnumerator(); + + while (e.MoveNext()) + { + TKey currentKey = keySelector(e.Current); + + ref int currentCount = ref CollectionsMarshal.GetValueRefOrAddDefault(countsBy, currentKey, out _); + checked + { + currentCount++; + } + } + + return countsBy; + } + } +} diff --git a/src/libraries/System.Linq/tests/CountByTests.cs b/src/libraries/System.Linq/tests/CountByTests.cs new file mode 100644 index 0000000000000..c484a651b965d --- /dev/null +++ b/src/libraries/System.Linq/tests/CountByTests.cs @@ -0,0 +1,117 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using Xunit; + +namespace System.Linq.Tests +{ + public class CountByTests : EnumerableTests + { + [Fact] + public void CountBy_SourceNull_ThrowsArgumentNullException() + { + string[] first = null; + + AssertExtensions.Throws("source", () => first.CountBy(x => x)); + AssertExtensions.Throws("source", () => first.CountBy(x => x, new AnagramEqualityComparer())); + } + + [Fact] + public void CountBy_KeySelectorNull_ThrowsArgumentNullException() + { + string[] source = { "Bob", "Tim", "Robert", "Chris" }; + Func keySelector = null; + + AssertExtensions.Throws("keySelector", () => source.CountBy(keySelector)); + AssertExtensions.Throws("keySelector", () => source.CountBy(keySelector, new AnagramEqualityComparer())); + } + + [Theory] + [MemberData(nameof(CountBy_TestData))] + public static void CountBy_HasExpectedOutput(IEnumerable source, Func keySelector, IEqualityComparer? comparer, IEnumerable> expected) + { + Assert.Equal(expected, source.CountBy(keySelector, comparer)); + } + + [Theory] + [MemberData(nameof(CountBy_TestData))] + public static void CountBy_RunOnce_HasExpectedOutput(IEnumerable source, Func keySelector, IEqualityComparer? comparer, IEnumerable> expected) + { + Assert.Equal(expected, source.RunOnce().CountBy(keySelector, comparer)); + } + + public static IEnumerable CountBy_TestData() + { + yield return WrapArgs( + source: Enumerable.Empty(), + keySelector: x => x, + comparer: null, + expected: Enumerable.Empty>()); + + yield return WrapArgs( + source: Enumerable.Range(0, 10), + keySelector: x => x, + comparer: null, + expected: Enumerable.Range(0, 10).ToDictionary(x => x, x => 1)); + + yield return WrapArgs( + source: Enumerable.Range(5, 10), + keySelector: x => true, + comparer: null, + expected: Enumerable.Repeat(true, 1).ToDictionary(x => x, x => 10)); + + yield return WrapArgs( + source: Enumerable.Range(0, 20), + keySelector: x => x % 5, + comparer: null, + expected: Enumerable.Range(0, 5).ToDictionary(x => x, x => 4)); + + yield return WrapArgs( + source: Enumerable.Repeat(5, 20), + keySelector: x => x, + comparer: null, + expected: Enumerable.Repeat(5, 1).ToDictionary(x => x, x => 20)); + + yield return WrapArgs( + source: new string[] { "Bob", "bob", "tim", "Bob", "Tim" }, + keySelector: x => x, + null, + expected: new string[] { "Bob", "bob", "tim", "Tim" }.ToDictionary(x => x, x => x == "Bob" ? 2 : 1)); + + yield return WrapArgs( + source: new string[] { "Bob", "bob", "tim", "Bob", "Tim" }, + keySelector: x => x, + StringComparer.OrdinalIgnoreCase, + expected: new string[] { "Bob", "tim" }.ToDictionary(x => x, x => x == "Bob" ? 3 : 2)); + + yield return WrapArgs( + source: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40) }, + keySelector: x => x.Age, + comparer: null, + expected: new int[] { 20, 30, 40 }.ToDictionary(x => x, x => 1)); + + yield return WrapArgs( + source: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 20), ("Harry", 40) }, + keySelector: x => x.Age, + comparer: null, + expected: new int[] { 20, 40 }.ToDictionary(x => x, x => x == 20 ? 2 : 1)); + + yield return WrapArgs( + source: new (string Name, int Age)[] { ("Bob", 20), ("bob", 30), ("Harry", 40) }, + keySelector: x => x.Name, + comparer: null, + expected: new string[] { "Bob", "bob", "Harry" }.ToDictionary(x => x, x => 1)); + + yield return WrapArgs( + source: new (string Name, int Age)[] { ("Bob", 20), ("bob", 30), ("Harry", 40) }, + keySelector: x => x.Name, + comparer: StringComparer.OrdinalIgnoreCase, + expected: new string[] { "Bob", "Harry" }.ToDictionary(x => x, x => x == "Bob" ? 2 : 1)); + + object[] WrapArgs(IEnumerable source, Func keySelector, IEqualityComparer? comparer, IEnumerable> expected) + => new object[] { source, keySelector, comparer, expected }; + } + } +} diff --git a/src/libraries/System.Linq/tests/System.Linq.Tests.csproj b/src/libraries/System.Linq/tests/System.Linq.Tests.csproj index 22fb00cf9cc86..9a35c490c46a3 100644 --- a/src/libraries/System.Linq/tests/System.Linq.Tests.csproj +++ b/src/libraries/System.Linq/tests/System.Linq.Tests.csproj @@ -20,6 +20,7 @@ + From 4d54545ee7abae381da123b2c81182520d48235a Mon Sep 17 00:00:00 2001 From: Emmanuel ANDRE <2341261+manandre@users.noreply.github.com> Date: Thu, 24 Aug 2023 18:24:57 +0200 Subject: [PATCH 2/8] Implement Linq CountBy method for IQueryable --- .../ref/System.Linq.Queryable.cs | 1 + .../src/System/Linq/Queryable.cs | 21 ++++++ .../tests/CountByTests.cs | 68 +++++++++++++++++++ .../tests/System.Linq.Queryable.Tests.csproj | 1 + 4 files changed, 91 insertions(+) create mode 100644 src/libraries/System.Linq.Queryable/tests/CountByTests.cs diff --git a/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs b/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs index ea36b539667cc..3a6e52e1fb979 100644 --- a/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs +++ b/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs @@ -77,6 +77,7 @@ public static partial class Queryable public static bool Contains(this System.Linq.IQueryable source, TSource item, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static int Count(this System.Linq.IQueryable source) { throw null; } public static int Count(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } + public static System.Linq.IQueryable> CountBy(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> keySelector, System.Collections.Generic.IEqualityComparer? comparer = null) where TKey : notnull { throw null; } public static System.Linq.IQueryable DefaultIfEmpty(this System.Linq.IQueryable source) { throw null; } public static System.Linq.IQueryable DefaultIfEmpty(this System.Linq.IQueryable source, TSource defaultValue) { throw null; } public static System.Linq.IQueryable DistinctBy(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> keySelector) { throw null; } diff --git a/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs b/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs index a43ec6c60ae7b..a3dcbaa7af797 100644 --- a/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs +++ b/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs @@ -1569,6 +1569,27 @@ public static int Count(this IQueryable source, ExpressionReturns the count of each element from a sequence according to a specified key selector function. + /// The type of the elements of . + /// The type of key to distinguish elements by. + /// The sequence to count elements from. + /// A function to extract the key for each element. + /// An to compare keys. + /// An that contains count for each distinct elements from the source sequence as a object. + /// is . + [DynamicDependency("CountBy`2", typeof(Enumerable))] + public static IQueryable> CountBy(this IQueryable source, Expression> keySelector, IEqualityComparer? comparer = null) where TKey : notnull + { + ArgumentNullException.ThrowIfNull(source); + ArgumentNullException.ThrowIfNull(keySelector); + + return source.Provider.CreateQuery>( + Expression.Call( + null, + new Func, Expression>, IEqualityComparer, IQueryable>>(CountBy).Method, + source.Expression, Expression.Quote(keySelector), Expression.Constant(comparer, typeof(IEqualityComparer)))); + } + [DynamicDependency("LongCount`1", typeof(Enumerable))] public static long LongCount(this IQueryable source) { diff --git a/src/libraries/System.Linq.Queryable/tests/CountByTests.cs b/src/libraries/System.Linq.Queryable/tests/CountByTests.cs new file mode 100644 index 0000000000000..041ac1e7426f1 --- /dev/null +++ b/src/libraries/System.Linq.Queryable/tests/CountByTests.cs @@ -0,0 +1,68 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Linq.Expressions; +using Xunit; + +namespace System.Linq.Tests +{ + public class CountByTests : EnumerableBasedTests + { + [Fact] + public void NullSource_ThrowsArgumentNullException() + { + IQueryable source = null; + + AssertExtensions.Throws("source", () => source.CountBy(x => x)); + AssertExtensions.Throws("source", () => source.CountBy(x => x, EqualityComparer.Default)); + } + + [Fact] + public void NullKeySelector_ThrowsArgumentNullException() + { + IQueryable source = Enumerable.Empty().AsQueryable(); + Expression> keySelector = null; + + AssertExtensions.Throws("keySelector", () => source.CountBy(keySelector)); + AssertExtensions.Throws("keySelector", () => source.CountBy(keySelector, EqualityComparer.Default)); + } + + [Fact] + public void EmptySource() + { + int[] source = { }; + Assert.Empty(source.AsQueryable().CountBy(x => x)); + } + + [Fact] + public void CountBy() + { + string[] source = { "now", "own", "won" }; + var counts = source.AsQueryable().CountBy(x => x).ToArray(); + Assert.Equal(source.Length, counts.Length); + Assert.Equal(source, counts.Select(x => x.Key).ToArray()); + Assert.All(counts, x => Assert.Equal(1, x.Value)); + } + + [Fact] + public void CountBy_CustomKeySelector() + { + string[] source = { "now", "own", "won" }; + var counts = source.AsQueryable().CountBy(x => string.Concat(x.Order())).ToArray(); + var count = Assert.Single(counts); + Assert.Equal(source[0], count.Key); + Assert.Equal(source.Length, count.Value); + } + + [Fact] + public void CountBy_CustomComparison() + { + string[] source = { "now", "own", "won" }; + var counts = source.AsQueryable().CountBy(x => x, new AnagramEqualityComparer()).ToArray(); + var count = Assert.Single(counts); + Assert.Equal(source[0], count.Key); + Assert.Equal(source.Length, count.Value); + } + } +} diff --git a/src/libraries/System.Linq.Queryable/tests/System.Linq.Queryable.Tests.csproj b/src/libraries/System.Linq.Queryable/tests/System.Linq.Queryable.Tests.csproj index 4a5a3ef305bab..547ce2b7796a5 100644 --- a/src/libraries/System.Linq.Queryable/tests/System.Linq.Queryable.Tests.csproj +++ b/src/libraries/System.Linq.Queryable/tests/System.Linq.Queryable.Tests.csproj @@ -14,6 +14,7 @@ + From b7ba6474bb421f331deb788f352d7360ffa2d488 Mon Sep 17 00:00:00 2001 From: Emmanuel ANDRE <2341261+manandre@users.noreply.github.com> Date: Thu, 24 Aug 2023 18:26:38 +0200 Subject: [PATCH 3/8] Remove useless using statement --- src/libraries/System.Linq/tests/CountByTests.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/libraries/System.Linq/tests/CountByTests.cs b/src/libraries/System.Linq/tests/CountByTests.cs index c484a651b965d..39344291f80ed 100644 --- a/src/libraries/System.Linq/tests/CountByTests.cs +++ b/src/libraries/System.Linq/tests/CountByTests.cs @@ -1,7 +1,6 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Collections.Generic; using Xunit; From 356482d4da9598b9b1cc3922decbf4965a3a7900 Mon Sep 17 00:00:00 2001 From: Emmanuel ANDRE <2341261+manandre@users.noreply.github.com> Date: Mon, 4 Sep 2023 19:24:39 +0200 Subject: [PATCH 4/8] Lazy source enumeration --- .../System.Linq/src/System/Linq/CountBy.cs | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/src/libraries/System.Linq/src/System/Linq/CountBy.cs b/src/libraries/System.Linq/src/System/Linq/CountBy.cs index a77a6dc07fbb4..3b54595a93c33 100644 --- a/src/libraries/System.Linq/src/System/Linq/CountBy.cs +++ b/src/libraries/System.Linq/src/System/Linq/CountBy.cs @@ -19,22 +19,35 @@ public static IEnumerable> CountBy(this I ThrowHelper.ThrowArgumentNullException(ExceptionArgument.keySelector); } - Dictionary countsBy = new(comparer); + return Core(source, keySelector, comparer); - using IEnumerator e = source.GetEnumerator(); - - while (e.MoveNext()) + static IEnumerable> Core(IEnumerable source, Func keySelector, IEqualityComparer? comparer) { - TKey currentKey = keySelector(e.Current); + Dictionary countsBy = BuildCountDictionary(source, keySelector, comparer); - ref int currentCount = ref CollectionsMarshal.GetValueRefOrAddDefault(countsBy, currentKey, out _); - checked + foreach (KeyValuePair countBy in countsBy) { - currentCount++; + yield return countBy; } } - return countsBy; + static Dictionary BuildCountDictionary(IEnumerable source, Func keySelector, IEqualityComparer? comparer) + { + Dictionary countsBy = new(comparer); + + foreach (TSource element in source) + { + TKey currentKey = keySelector(element); + + ref int currentCount = ref CollectionsMarshal.GetValueRefOrAddDefault(countsBy, currentKey, out _); + checked + { + currentCount++; + } + } + + return countsBy; + } } } } From c50ad74bd9673992d1307688e6ae2bf3a0d67cbc Mon Sep 17 00:00:00 2001 From: Emmanuel ANDRE <2341261+manandre@users.noreply.github.com> Date: Tue, 12 Sep 2023 21:41:48 +0200 Subject: [PATCH 5/8] Avoid useless allocation when source is empty --- src/libraries/System.Linq/ref/System.Linq.cs | 3 +-- .../System.Linq/src/System/Linq/CountBy.cs | 27 ++++++++++++------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/libraries/System.Linq/ref/System.Linq.cs b/src/libraries/System.Linq/ref/System.Linq.cs index 31931dface380..428895f9e6cb9 100644 --- a/src/libraries/System.Linq/ref/System.Linq.cs +++ b/src/libraries/System.Linq/ref/System.Linq.cs @@ -47,8 +47,7 @@ public static System.Collections.Generic.IEnumerable< public static bool Contains(this System.Collections.Generic.IEnumerable source, TSource value, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static int Count(this System.Collections.Generic.IEnumerable source) { throw null; } public static int Count(this System.Collections.Generic.IEnumerable source, System.Func predicate) { throw null; } - - public static System.Collections.Generic.IEnumerable> CountBy(this System.Collections.Generic.IEnumerable source, System.Func keySelector, System.Collections.Generic.IEqualityComparer? comparer = null) where TKey : notnull { throw null; } + public static System.Collections.Generic.IEnumerable> CountBy(this System.Collections.Generic.IEnumerable source, System.Func keySelector, System.Collections.Generic.IEqualityComparer? keyComparer = null) where TKey : notnull { throw null; } public static System.Collections.Generic.IEnumerable DefaultIfEmpty(this System.Collections.Generic.IEnumerable source) { throw null; } public static System.Collections.Generic.IEnumerable DefaultIfEmpty(this System.Collections.Generic.IEnumerable source, TSource defaultValue) { throw null; } public static System.Collections.Generic.IEnumerable DistinctBy(this System.Collections.Generic.IEnumerable source, System.Func keySelector) { throw null; } diff --git a/src/libraries/System.Linq/src/System/Linq/CountBy.cs b/src/libraries/System.Linq/src/System/Linq/CountBy.cs index 3b54595a93c33..27bf8877ad9e1 100644 --- a/src/libraries/System.Linq/src/System/Linq/CountBy.cs +++ b/src/libraries/System.Linq/src/System/Linq/CountBy.cs @@ -8,7 +8,7 @@ namespace System.Linq { public static partial class Enumerable { - public static IEnumerable> CountBy(this IEnumerable source, Func keySelector, IEqualityComparer? comparer = null) where TKey : notnull + public static IEnumerable> CountBy(this IEnumerable source, Func keySelector, IEqualityComparer? keyComparer = null) where TKey : notnull { if (source is null) { @@ -19,32 +19,39 @@ public static IEnumerable> CountBy(this I ThrowHelper.ThrowArgumentNullException(ExceptionArgument.keySelector); } - return Core(source, keySelector, comparer); + return Core(source, keySelector, keyComparer); - static IEnumerable> Core(IEnumerable source, Func keySelector, IEqualityComparer? comparer) + static IEnumerable> Core(IEnumerable source, Func keySelector, IEqualityComparer? keyComparer) { - Dictionary countsBy = BuildCountDictionary(source, keySelector, comparer); + using IEnumerator enumerator = source.GetEnumerator(); - foreach (KeyValuePair countBy in countsBy) + if (!enumerator.MoveNext()) + { + yield break; + } + + foreach (KeyValuePair countBy in BuildCountDictionary(enumerator, keySelector, keyComparer)) { yield return countBy; } } - static Dictionary BuildCountDictionary(IEnumerable source, Func keySelector, IEqualityComparer? comparer) + static Dictionary BuildCountDictionary(IEnumerator enumerator, Func keySelector, IEqualityComparer? keyComparer) { - Dictionary countsBy = new(comparer); + Dictionary countsBy = new(keyComparer); - foreach (TSource element in source) + do { - TKey currentKey = keySelector(element); + TSource value = enumerator.Current; + TKey key = keySelector(value); - ref int currentCount = ref CollectionsMarshal.GetValueRefOrAddDefault(countsBy, currentKey, out _); + ref int currentCount = ref CollectionsMarshal.GetValueRefOrAddDefault(countsBy, key, out _); checked { currentCount++; } } + while (enumerator.MoveNext()); return countsBy; } From 6e3c6d43e93e9aaded10f473420d6a43ecaeed88 Mon Sep 17 00:00:00 2001 From: Emmanuel ANDRE <2341261+manandre@users.noreply.github.com> Date: Wed, 13 Sep 2023 21:37:30 +0200 Subject: [PATCH 6/8] Rename with CountByIterator and avoid local methods --- .../System.Linq/src/System/Linq/CountBy.cs | 52 +++++++++---------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/src/libraries/System.Linq/src/System/Linq/CountBy.cs b/src/libraries/System.Linq/src/System/Linq/CountBy.cs index 27bf8877ad9e1..dcd93c5567cd1 100644 --- a/src/libraries/System.Linq/src/System/Linq/CountBy.cs +++ b/src/libraries/System.Linq/src/System/Linq/CountBy.cs @@ -19,42 +19,42 @@ public static IEnumerable> CountBy(this I ThrowHelper.ThrowArgumentNullException(ExceptionArgument.keySelector); } - return Core(source, keySelector, keyComparer); + return CountByIterator(source, keySelector, keyComparer); + } - static IEnumerable> Core(IEnumerable source, Func keySelector, IEqualityComparer? keyComparer) - { - using IEnumerator enumerator = source.GetEnumerator(); + private static IEnumerable> CountByIterator(IEnumerable source, Func keySelector, IEqualityComparer? keyComparer) where TKey : notnull + { + using IEnumerator enumerator = source.GetEnumerator(); - if (!enumerator.MoveNext()) - { - yield break; - } + if (!enumerator.MoveNext()) + { + yield break; + } - foreach (KeyValuePair countBy in BuildCountDictionary(enumerator, keySelector, keyComparer)) - { - yield return countBy; - } + foreach (KeyValuePair countBy in BuildCountDictionary(enumerator, keySelector, keyComparer)) + { + yield return countBy; } + } - static Dictionary BuildCountDictionary(IEnumerator enumerator, Func keySelector, IEqualityComparer? keyComparer) + private static Dictionary BuildCountDictionary(IEnumerator enumerator, Func keySelector, IEqualityComparer? keyComparer) where TKey : notnull + { + Dictionary countsBy = new(keyComparer); + + do { - Dictionary countsBy = new(keyComparer); + TSource value = enumerator.Current; + TKey key = keySelector(value); - do + ref int currentCount = ref CollectionsMarshal.GetValueRefOrAddDefault(countsBy, key, out _); + checked { - TSource value = enumerator.Current; - TKey key = keySelector(value); - - ref int currentCount = ref CollectionsMarshal.GetValueRefOrAddDefault(countsBy, key, out _); - checked - { - currentCount++; - } + currentCount++; } - while (enumerator.MoveNext()); - - return countsBy; } + while (enumerator.MoveNext()); + + return countsBy; } } } From 86fef1266bc96a4cd4f6e265ff620e05b1ea5fcd Mon Sep 17 00:00:00 2001 From: Emmanuel ANDRE <2341261+manandre@users.noreply.github.com> Date: Wed, 13 Sep 2023 21:39:16 +0200 Subject: [PATCH 7/8] Add tests around lazy enumeration --- .../System.Linq/tests/CountByTests.cs | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/libraries/System.Linq/tests/CountByTests.cs b/src/libraries/System.Linq/tests/CountByTests.cs index 39344291f80ed..45c825c9c69d4 100644 --- a/src/libraries/System.Linq/tests/CountByTests.cs +++ b/src/libraries/System.Linq/tests/CountByTests.cs @@ -27,6 +27,36 @@ public void CountBy_KeySelectorNull_ThrowsArgumentNullException() AssertExtensions.Throws("keySelector", () => source.CountBy(keySelector, new AnagramEqualityComparer())); } + [Fact] + public void CountBy_SourceThrowsOnGetEnumerator() + { + IEnumerable source = new ThrowsOnGetEnumerator(); + + var enumerator = source.CountBy(x => x).GetEnumerator(); + + Assert.Throws(() => enumerator.MoveNext()); + } + + [Fact] + public void CountBy_SourceThrowsOnMoveNext() + { + IEnumerable source = new ThrowsOnMoveNext(); + + var enumerator = source.CountBy(x => x).GetEnumerator(); + + Assert.Throws(() => enumerator.MoveNext()); + } + + [Fact] + public void CountBy_SourceThrowsOnCurrent() + { + IEnumerable source = new ThrowsOnCurrentEnumerator(); + + var enumerator = source.CountBy(x => x).GetEnumerator(); + + Assert.Throws(() => enumerator.MoveNext()); + } + [Theory] [MemberData(nameof(CountBy_TestData))] public static void CountBy_HasExpectedOutput(IEnumerable source, Func keySelector, IEqualityComparer? comparer, IEnumerable> expected) From 256f96a8097a431b54c79636a962276066517805 Mon Sep 17 00:00:00 2001 From: Emmanuel ANDRE <2341261+manandre@users.noreply.github.com> Date: Wed, 13 Sep 2023 21:47:58 +0200 Subject: [PATCH 8/8] Make test results more explicit --- .../System.Linq/tests/CountByTests.cs | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/src/libraries/System.Linq/tests/CountByTests.cs b/src/libraries/System.Linq/tests/CountByTests.cs index 45c825c9c69d4..526326d60c2bc 100644 --- a/src/libraries/System.Linq/tests/CountByTests.cs +++ b/src/libraries/System.Linq/tests/CountByTests.cs @@ -107,13 +107,23 @@ public static IEnumerable CountBy_TestData() source: new string[] { "Bob", "bob", "tim", "Bob", "Tim" }, keySelector: x => x, null, - expected: new string[] { "Bob", "bob", "tim", "Tim" }.ToDictionary(x => x, x => x == "Bob" ? 2 : 1)); + expected: new Dictionary() + { + { "Bob", 2 }, + { "bob", 1 }, + { "tim", 1 }, + { "Tim", 1 } + }); yield return WrapArgs( source: new string[] { "Bob", "bob", "tim", "Bob", "Tim" }, keySelector: x => x, StringComparer.OrdinalIgnoreCase, - expected: new string[] { "Bob", "tim" }.ToDictionary(x => x, x => x == "Bob" ? 3 : 2)); + expected: new Dictionary() + { + { "Bob", 3 }, + { "tim", 2 } + }); yield return WrapArgs( source: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40) }, @@ -125,7 +135,11 @@ public static IEnumerable CountBy_TestData() source: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 20), ("Harry", 40) }, keySelector: x => x.Age, comparer: null, - expected: new int[] { 20, 40 }.ToDictionary(x => x, x => x == 20 ? 2 : 1)); + expected: new Dictionary() + { + { 20, 2 }, + { 40, 1 } + }); yield return WrapArgs( source: new (string Name, int Age)[] { ("Bob", 20), ("bob", 30), ("Harry", 40) }, @@ -137,7 +151,11 @@ public static IEnumerable CountBy_TestData() source: new (string Name, int Age)[] { ("Bob", 20), ("bob", 30), ("Harry", 40) }, keySelector: x => x.Name, comparer: StringComparer.OrdinalIgnoreCase, - expected: new string[] { "Bob", "Harry" }.ToDictionary(x => x, x => x == "Bob" ? 2 : 1)); + expected: new Dictionary() + { + { "Bob", 2 }, + { "Harry", 1 } + }); object[] WrapArgs(IEnumerable source, Func keySelector, IEqualityComparer? comparer, IEnumerable> expected) => new object[] { source, keySelector, comparer, expected };