Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Linq CountBy method for IEnumerable and IQueryable #91507

Merged
merged 8 commits into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ public static partial class Queryable
public static bool Contains<TSource>(this System.Linq.IQueryable<TSource> source, TSource item, System.Collections.Generic.IEqualityComparer<TSource>? comparer) { throw null; }
public static int Count<TSource>(this System.Linq.IQueryable<TSource> source) { throw null; }
public static int Count<TSource>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, bool>> predicate) { throw null; }
public static System.Linq.IQueryable<System.Collections.Generic.KeyValuePair<TKey, int>> CountBy<TSource, TKey>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, TKey>> keySelector, System.Collections.Generic.IEqualityComparer<TKey>? comparer = null) where TKey : notnull { throw null; }
public static System.Linq.IQueryable<TSource?> DefaultIfEmpty<TSource>(this System.Linq.IQueryable<TSource> source) { throw null; }
public static System.Linq.IQueryable<TSource> DefaultIfEmpty<TSource>(this System.Linq.IQueryable<TSource> source, TSource defaultValue) { throw null; }
public static System.Linq.IQueryable<TSource> DistinctBy<TSource, TKey>(this System.Linq.IQueryable<TSource> source, System.Linq.Expressions.Expression<System.Func<TSource, TKey>> keySelector) { throw null; }
Expand Down
21 changes: 21 additions & 0 deletions src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,27 @@ public static int Count<TSource>(this IQueryable<TSource> source, Expression<Fun
source.Expression, Expression.Quote(predicate)));
}

/// <summary>Returns the count of each element from a sequence according to a specified key selector function.</summary>
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
/// <typeparam name="TKey">The type of key to distinguish elements by.</typeparam>
/// <param name="source">The sequence to count elements from.</param>
/// <param name="keySelector">A function to extract the key for each element.</param>
/// <param name="comparer">An <see cref="IEqualityComparer{TKey}" /> to compare keys.</param>
/// <returns>An <see cref="IQueryable{T}" /> that contains count for each distinct elements from the source sequence as a <see cref="KeyValuePair{TKey, TValue}"/> object.</returns>
/// <exception cref="ArgumentNullException"><paramref name="source" /> is <see langword="null" />.</exception>
[DynamicDependency("CountBy`2", typeof(Enumerable))]
public static IQueryable<KeyValuePair<TKey, int>> CountBy<TSource, TKey>(this IQueryable<TSource> source, Expression<Func<TSource, TKey>> keySelector, IEqualityComparer<TKey>? comparer = null) where TKey : notnull
{
ArgumentNullException.ThrowIfNull(source);
ArgumentNullException.ThrowIfNull(keySelector);

return source.Provider.CreateQuery<KeyValuePair<TKey, int>>(
Expression.Call(
null,
new Func<IQueryable<TSource>, Expression<Func<TSource, TKey>>, IEqualityComparer<TKey>, IQueryable<KeyValuePair<TKey, int>>>(CountBy).Method,
source.Expression, Expression.Quote(keySelector), Expression.Constant(comparer, typeof(IEqualityComparer<TKey>))));
}

[DynamicDependency("LongCount`1", typeof(Enumerable))]
public static long LongCount<TSource>(this IQueryable<TSource> source)
{
Expand Down
68 changes: 68 additions & 0 deletions src/libraries/System.Linq.Queryable/tests/CountByTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Linq.Expressions;
using Xunit;

namespace System.Linq.Tests
{
public class CountByTests : EnumerableBasedTests
{
[Fact]
public void NullSource_ThrowsArgumentNullException()
{
IQueryable<int> source = null;

AssertExtensions.Throws<ArgumentNullException>("source", () => source.CountBy(x => x));
AssertExtensions.Throws<ArgumentNullException>("source", () => source.CountBy(x => x, EqualityComparer<int>.Default));
}

[Fact]
public void NullKeySelector_ThrowsArgumentNullException()
{
IQueryable<int> source = Enumerable.Empty<int>().AsQueryable();
Expression<Func<int, int>> keySelector = null;

AssertExtensions.Throws<ArgumentNullException>("keySelector", () => source.CountBy(keySelector));
AssertExtensions.Throws<ArgumentNullException>("keySelector", () => source.CountBy(keySelector, EqualityComparer<int>.Default));
}

[Fact]
public void EmptySource()
{
int[] source = { };
Assert.Empty(source.AsQueryable().CountBy(x => x));
}

[Fact]
public void CountBy()
{
string[] source = { "now", "own", "won" };
var counts = source.AsQueryable().CountBy(x => x).ToArray();
Assert.Equal(source.Length, counts.Length);
Assert.Equal(source, counts.Select(x => x.Key).ToArray());
Assert.All(counts, x => Assert.Equal(1, x.Value));
}

[Fact]
public void CountBy_CustomKeySelector()
{
string[] source = { "now", "own", "won" };
var counts = source.AsQueryable().CountBy(x => string.Concat(x.Order())).ToArray();
var count = Assert.Single(counts);
Assert.Equal(source[0], count.Key);
Assert.Equal(source.Length, count.Value);
}

[Fact]
public void CountBy_CustomComparison()
{
string[] source = { "now", "own", "won" };
var counts = source.AsQueryable().CountBy(x => x, new AnagramEqualityComparer()).ToArray();
var count = Assert.Single(counts);
Assert.Equal(source[0], count.Key);
Assert.Equal(source.Length, count.Value);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
<Compile Include="ContainsTests.cs" />
<Compile Include="CountTests.cs" />
<Compile Include="DefaultIfEmptyTests.cs" />
<Compile Include="CountByTests.cs" />
<Compile Include="DistinctTests.cs" />
<Compile Include="ElementAtOrDefaultTests.cs" />
<Compile Include="ElementAtTests.cs" />
Expand Down
1 change: 1 addition & 0 deletions src/libraries/System.Linq/ref/System.Linq.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ public static System.Collections.Generic.IEnumerable<
public static bool Contains<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, TSource value, System.Collections.Generic.IEqualityComparer<TSource>? comparer) { throw null; }
public static int Count<TSource>(this System.Collections.Generic.IEnumerable<TSource> source) { throw null; }
public static int Count<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, bool> predicate) { throw null; }
public static System.Collections.Generic.IEnumerable<System.Collections.Generic.KeyValuePair<TKey, int>> CountBy<TSource, TKey>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, TKey> keySelector, System.Collections.Generic.IEqualityComparer<TKey>? keyComparer = null) where TKey : notnull { throw null; }
public static System.Collections.Generic.IEnumerable<TSource?> DefaultIfEmpty<TSource>(this System.Collections.Generic.IEnumerable<TSource> source) { throw null; }
public static System.Collections.Generic.IEnumerable<TSource> DefaultIfEmpty<TSource>(this System.Collections.Generic.IEnumerable<TSource> source, TSource defaultValue) { throw null; }
public static System.Collections.Generic.IEnumerable<TSource> DistinctBy<TSource, TKey>(this System.Collections.Generic.IEnumerable<TSource> source, System.Func<TSource, TKey> keySelector) { throw null; }
Expand Down
1 change: 1 addition & 0 deletions src/libraries/System.Linq/src/System.Linq.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
<Compile Include="System\Linq\Chunk.cs" />
<Compile Include="System\Linq\Concat.cs" />
<Compile Include="System\Linq\Contains.cs" />
<Compile Include="System\Linq\CountBy.cs" />
<Compile Include="System\Linq\Count.cs" />
<Compile Include="System\Linq\DebugView.cs" />
<Compile Include="System\Linq\DefaultIfEmpty.cs" />
Expand Down
60 changes: 60 additions & 0 deletions src/libraries/System.Linq/src/System/Linq/CountBy.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Runtime.InteropServices;

namespace System.Linq
{
public static partial class Enumerable
{
public static IEnumerable<KeyValuePair<TKey, int>> CountBy<TSource, TKey>(this IEnumerable<TSource> source, Func<TSource, TKey> keySelector, IEqualityComparer<TKey>? keyComparer = null) where TKey : notnull
{
if (source is null)
{
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source);
}
if (keySelector is null)
{
ThrowHelper.ThrowArgumentNullException(ExceptionArgument.keySelector);
}

return CountByIterator(source, keySelector, keyComparer);
}

private static IEnumerable<KeyValuePair<TKey, int>> CountByIterator<TSource, TKey>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, IEqualityComparer<TKey>? keyComparer) where TKey : notnull
{
using IEnumerator<TSource> enumerator = source.GetEnumerator();

if (!enumerator.MoveNext())
{
yield break;
}

foreach (KeyValuePair<TKey, int> countBy in BuildCountDictionary(enumerator, keySelector, keyComparer))
{
yield return countBy;
}
}

private static Dictionary<TKey, int> BuildCountDictionary<TSource, TKey>(IEnumerator<TSource> enumerator, Func<TSource, TKey> keySelector, IEqualityComparer<TKey>? keyComparer) where TKey : notnull
{
Dictionary<TKey, int> countsBy = new(keyComparer);

do
{
TSource value = enumerator.Current;
TKey key = keySelector(value);

ref int currentCount = ref CollectionsMarshal.GetValueRefOrAddDefault(countsBy, key, out _);
checked
{
currentCount++;
}
}
while (enumerator.MoveNext());

return countsBy;
}
}
}
146 changes: 146 additions & 0 deletions src/libraries/System.Linq/tests/CountByTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using Xunit;

namespace System.Linq.Tests
{
public class CountByTests : EnumerableTests
{
[Fact]
public void CountBy_SourceNull_ThrowsArgumentNullException()
{
string[] first = null;

AssertExtensions.Throws<ArgumentNullException>("source", () => first.CountBy(x => x));
AssertExtensions.Throws<ArgumentNullException>("source", () => first.CountBy(x => x, new AnagramEqualityComparer()));
}

[Fact]
public void CountBy_KeySelectorNull_ThrowsArgumentNullException()
{
string[] source = { "Bob", "Tim", "Robert", "Chris" };
Func<string, string> keySelector = null;

AssertExtensions.Throws<ArgumentNullException>("keySelector", () => source.CountBy(keySelector));
AssertExtensions.Throws<ArgumentNullException>("keySelector", () => source.CountBy(keySelector, new AnagramEqualityComparer()));
}

[Fact]
public void CountBy_SourceThrowsOnGetEnumerator()
{
IEnumerable<int> source = new ThrowsOnGetEnumerator();

var enumerator = source.CountBy(x => x).GetEnumerator();
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved

Assert.Throws<InvalidOperationException>(() => enumerator.MoveNext());
}

[Fact]
public void CountBy_SourceThrowsOnMoveNext()
{
IEnumerable<int> source = new ThrowsOnMoveNext();

var enumerator = source.CountBy(x => x).GetEnumerator();

Assert.Throws<InvalidOperationException>(() => enumerator.MoveNext());
}

[Fact]
public void CountBy_SourceThrowsOnCurrent()
{
IEnumerable<int> source = new ThrowsOnCurrentEnumerator();

var enumerator = source.CountBy(x => x).GetEnumerator();

Assert.Throws<InvalidOperationException>(() => enumerator.MoveNext());
}

[Theory]
[MemberData(nameof(CountBy_TestData))]
public static void CountBy_HasExpectedOutput<TSource, TKey>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, IEqualityComparer<TKey>? comparer, IEnumerable<KeyValuePair<TKey, int>> expected)
{
Assert.Equal(expected, source.CountBy(keySelector, comparer));
}

[Theory]
[MemberData(nameof(CountBy_TestData))]
public static void CountBy_RunOnce_HasExpectedOutput<TSource, TKey>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, IEqualityComparer<TKey>? comparer, IEnumerable<KeyValuePair<TKey, int>> expected)
{
Assert.Equal(expected, source.RunOnce().CountBy(keySelector, comparer));
}
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved

public static IEnumerable<object[]> CountBy_TestData()
{
yield return WrapArgs(
source: Enumerable.Empty<int>(),
keySelector: x => x,
comparer: null,
expected: Enumerable.Empty<KeyValuePair<int,int>>());

yield return WrapArgs(
source: Enumerable.Range(0, 10),
keySelector: x => x,
comparer: null,
expected: Enumerable.Range(0, 10).ToDictionary(x => x, x => 1));

yield return WrapArgs(
source: Enumerable.Range(5, 10),
keySelector: x => true,
comparer: null,
expected: Enumerable.Repeat(true, 1).ToDictionary(x => x, x => 10));

yield return WrapArgs(
source: Enumerable.Range(0, 20),
keySelector: x => x % 5,
comparer: null,
expected: Enumerable.Range(0, 5).ToDictionary(x => x, x => 4));

yield return WrapArgs(
source: Enumerable.Repeat(5, 20),
keySelector: x => x,
comparer: null,
expected: Enumerable.Repeat(5, 1).ToDictionary(x => x, x => 20));

yield return WrapArgs(
source: new string[] { "Bob", "bob", "tim", "Bob", "Tim" },
keySelector: x => x,
null,
expected: new string[] { "Bob", "bob", "tim", "Tim" }.ToDictionary(x => x, x => x == "Bob" ? 2 : 1));
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved

yield return WrapArgs(
source: new string[] { "Bob", "bob", "tim", "Bob", "Tim" },
keySelector: x => x,
StringComparer.OrdinalIgnoreCase,
expected: new string[] { "Bob", "tim" }.ToDictionary(x => x, x => x == "Bob" ? 3 : 2));

yield return WrapArgs(
source: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40) },
keySelector: x => x.Age,
comparer: null,
expected: new int[] { 20, 30, 40 }.ToDictionary(x => x, x => 1));

yield return WrapArgs(
source: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 20), ("Harry", 40) },
keySelector: x => x.Age,
comparer: null,
expected: new int[] { 20, 40 }.ToDictionary(x => x, x => x == 20 ? 2 : 1));

yield return WrapArgs(
source: new (string Name, int Age)[] { ("Bob", 20), ("bob", 30), ("Harry", 40) },
keySelector: x => x.Name,
comparer: null,
expected: new string[] { "Bob", "bob", "Harry" }.ToDictionary(x => x, x => 1));

yield return WrapArgs(
source: new (string Name, int Age)[] { ("Bob", 20), ("bob", 30), ("Harry", 40) },
keySelector: x => x.Name,
comparer: StringComparer.OrdinalIgnoreCase,
expected: new string[] { "Bob", "Harry" }.ToDictionary(x => x, x => x == "Bob" ? 2 : 1));

object[] WrapArgs<TSource, TKey>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, IEqualityComparer<TKey>? comparer, IEnumerable<KeyValuePair<TKey, int>> expected)
=> new object[] { source, keySelector, comparer, expected };
}
}
}
1 change: 1 addition & 0 deletions src/libraries/System.Linq/tests/System.Linq.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
<Compile Include="ContainsTests.cs" />
<Compile Include="CountTests.cs" />
<Compile Include="DefaultIfEmptyTests.cs" />
<Compile Include="CountByTests.cs" />
<Compile Include="DistinctTests.cs" />
<Compile Include="ElementAtOrDefaultTests.cs" />
<Compile Include="ElementAtTests.cs" />
Expand Down