Skip to content

Commit

Permalink
Fix LINQ Aggregate/CountBy tests for Native AOT (#105357)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub committed Jul 24, 2024
1 parent 5acfa2c commit 6ce4666
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 60 deletions.
48 changes: 18 additions & 30 deletions src/libraries/System.Linq/tests/AggregateByTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,65 +93,50 @@ public void AggregateBy_SourceThrowsOnCurrent()
Assert.Throws<InvalidOperationException>(() => enumerator.MoveNext());
}

[Theory]
[ActiveIssue("https://github.com/dotnet/runtime/issues/92387", typeof(PlatformDetection), nameof(PlatformDetection.IsNativeAot))]
[MemberData(nameof(AggregateBy_TestData))]
public static void AggregateBy_HasExpectedOutput<TSource, TKey, TAccumulate>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TKey, TAccumulate> seedSelector, Func<TAccumulate, TSource, TAccumulate> func, IEqualityComparer<TKey>? comparer, IEnumerable<KeyValuePair<TKey, TAccumulate>> expected)
{
Assert.Equal(expected, source.AggregateBy(keySelector, seedSelector, func, comparer));
}

[Theory]
[ActiveIssue("https://github.com/dotnet/runtime/issues/92387", typeof(PlatformDetection), nameof(PlatformDetection.IsNativeAot))]
[MemberData(nameof(AggregateBy_TestData))]
public static void AggregateBy_RunOnce_HasExpectedOutput<TSource, TKey, TAccumulate>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TKey, TAccumulate> seedSelector, Func<TAccumulate, TSource, TAccumulate> func, IEqualityComparer<TKey>? comparer, IEnumerable<KeyValuePair<TKey, TAccumulate>> expected)
{
Assert.Equal(expected, source.RunOnce().AggregateBy(keySelector, seedSelector, func, comparer));
}

public static IEnumerable<object[]> AggregateBy_TestData()
[Fact]
public void AggregateBy_HasExpectedOutput()
{
yield return WrapArgs(
Validate(
source: Enumerable.Empty<int>(),
keySelector: x => x,
seedSelector: x => 0,
func: (x, y) => x + y,
comparer: null,
expected: Enumerable.Empty<KeyValuePair<int,int>>());

yield return WrapArgs(
Validate(
source: Enumerable.Range(0, 10),
keySelector: x => x,
seedSelector: x => 0,
func: (x, y) => x + y,
comparer: null,
expected: Enumerable.Range(0, 10).Select(x => new KeyValuePair<int, int>(x, x)));

yield return WrapArgs(
Validate(
source: Enumerable.Range(5, 10),
keySelector: x => true,
seedSelector: x => 0,
func: (x, y) => x + y,
comparer: null,
expected: Enumerable.Repeat(true, 1).Select(x => new KeyValuePair<bool, int>(x, 95)));

yield return WrapArgs(
Validate(
source: Enumerable.Range(0, 20),
keySelector: x => x % 5,
seedSelector: x => 0,
func: (x, y) => x + y,
comparer: null,
expected: Enumerable.Range(0, 5).Select(x => new KeyValuePair<int, int>(x, 30 + 4 * x)));

yield return WrapArgs(
Validate(
source: Enumerable.Repeat(5, 20),
keySelector: x => x,
seedSelector: x => 0,
func: (x, y) => x + y,
comparer: null,
expected: Enumerable.Repeat(5, 1).Select(x => new KeyValuePair<int, int>(x, 100)));

yield return WrapArgs(
Validate(
source: new string[] { "Bob", "bob", "tim", "Bob", "Tim" },
keySelector: x => x,
seedSelector: x => string.Empty,
Expand All @@ -165,7 +150,7 @@ public static IEnumerable<object[]> AggregateBy_TestData()
new("Tim", "Tim"),
]);

yield return WrapArgs(
Validate(
source: new string[] { "Bob", "bob", "tim", "Bob", "Tim" },
keySelector: x => x,
seedSelector: x => string.Empty,
Expand All @@ -177,7 +162,7 @@ public static IEnumerable<object[]> AggregateBy_TestData()
new("tim", "timTim")
]);

yield return WrapArgs(
Validate(
source: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40) },
keySelector: x => x.Age,
seedSelector: x => $"I am {x} and my name is ",
Expand All @@ -190,7 +175,7 @@ public static IEnumerable<object[]> AggregateBy_TestData()
new(40, "I am 40 and my name is Harry")
]);

yield return WrapArgs(
Validate(
source: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 20), ("Harry", 40) },
keySelector: x => x.Age,
seedSelector: x => $"I am {x} and my name is",
Expand All @@ -202,15 +187,15 @@ public static IEnumerable<object[]> AggregateBy_TestData()
new(40, "I am 40 and my name is maybe Harry")
]);

yield return WrapArgs(
Validate(
source: new (string Name, int Age)[] { ("Bob", 20), ("bob", 20), ("Harry", 20) },
keySelector: x => x.Name,
seedSelector: x => 0,
func: (x, y) => x + y.Age,
comparer: null,
expected: new string[] { "Bob", "bob", "Harry" }.Select(x => new KeyValuePair<string, int>(x, 20)));

yield return WrapArgs(
Validate(
source: new (string Name, int Age)[] { ("Bob", 20), ("bob", 30), ("Harry", 40) },
keySelector: x => x.Name,
seedSelector: x => 0,
Expand All @@ -222,8 +207,11 @@ public static IEnumerable<object[]> AggregateBy_TestData()
new("Harry", 40)
]);

object[] WrapArgs<TSource, TKey, TAccumulate>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TKey, TAccumulate> seedSelector, Func<TAccumulate, TSource, TAccumulate> func, IEqualityComparer<TKey>? comparer, IEnumerable<KeyValuePair<TKey, TAccumulate>> expected)
=> new object[] { source, keySelector, seedSelector, func, comparer, expected };
static void Validate<TSource, TKey, TAccumulate>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, Func<TKey, TAccumulate> seedSelector, Func<TAccumulate, TSource, TAccumulate> func, IEqualityComparer<TKey>? comparer, IEnumerable<KeyValuePair<TKey, TAccumulate>> expected)
{
Assert.Equal(expected, source.AggregateBy(keySelector, seedSelector, func, comparer));
Assert.Equal(expected, source.RunOnce().AggregateBy(keySelector, seedSelector, func, comparer));
}
}

[Fact]
Expand Down
48 changes: 18 additions & 30 deletions src/libraries/System.Linq/tests/CountByTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,55 +57,40 @@ public void CountBy_SourceThrowsOnCurrent()
Assert.Throws<InvalidOperationException>(() => enumerator.MoveNext());
}

[Theory]
[ActiveIssue("https://github.com/dotnet/runtime/issues/92387", typeof(PlatformDetection), nameof(PlatformDetection.IsNativeAot))]
[MemberData(nameof(CountBy_TestData))]
public static void CountBy_HasExpectedOutput<TSource, TKey>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, IEqualityComparer<TKey>? comparer, IEnumerable<KeyValuePair<TKey, int>> expected)
{
Assert.Equal(expected, source.CountBy(keySelector, comparer));
}

[Theory]
[ActiveIssue("https://github.com/dotnet/runtime/issues/92387", typeof(PlatformDetection), nameof(PlatformDetection.IsNativeAot))]
[MemberData(nameof(CountBy_TestData))]
public static void CountBy_RunOnce_HasExpectedOutput<TSource, TKey>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, IEqualityComparer<TKey>? comparer, IEnumerable<KeyValuePair<TKey, int>> expected)
{
Assert.Equal(expected, source.RunOnce().CountBy(keySelector, comparer));
}

public static IEnumerable<object[]> CountBy_TestData()
[Fact]
public void CountBy_HasExpectedOutput()
{
yield return WrapArgs(
Validate(
source: Enumerable.Empty<int>(),
keySelector: x => x,
comparer: null,
expected: Enumerable.Empty<KeyValuePair<int,int>>());

yield return WrapArgs(
Validate(
source: Enumerable.Range(0, 10),
keySelector: x => x,
comparer: null,
expected: Enumerable.Range(0, 10).Select(x => new KeyValuePair<int, int>(x, 1)));

yield return WrapArgs(
Validate(
source: Enumerable.Range(5, 10),
keySelector: x => true,
comparer: null,
expected: Enumerable.Repeat(true, 1).Select(x => new KeyValuePair<bool, int>(x, 10)));

yield return WrapArgs(
Validate(
source: Enumerable.Range(0, 20),
keySelector: x => x % 5,
comparer: null,
expected: Enumerable.Range(0, 5).Select(x => new KeyValuePair<int, int>(x, 4)));

yield return WrapArgs(
Validate(
source: Enumerable.Repeat(5, 20),
keySelector: x => x,
comparer: null,
expected: Enumerable.Repeat(5, 1).Select(x => new KeyValuePair<int, int>(x, 20)));

yield return WrapArgs(
Validate(
source: new string[] { "Bob", "bob", "tim", "Bob", "Tim" },
keySelector: x => x,
null,
Expand All @@ -117,7 +102,7 @@ public static IEnumerable<object[]> CountBy_TestData()
new("Tim", 1)
]);

yield return WrapArgs(
Validate(
source: new string[] { "Bob", "bob", "tim", "Bob", "Tim" },
keySelector: x => x,
StringComparer.OrdinalIgnoreCase,
Expand All @@ -127,13 +112,13 @@ public static IEnumerable<object[]> CountBy_TestData()
new("tim", 2)
]);

yield return WrapArgs(
Validate(
source: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40) },
keySelector: x => x.Age,
comparer: null,
expected: new int[] { 20, 30, 40 }.Select(x => new KeyValuePair<int, int>(x, 1)));

yield return WrapArgs(
Validate(
source: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 20), ("Harry", 40) },
keySelector: x => x.Age,
comparer: null,
Expand All @@ -143,13 +128,13 @@ public static IEnumerable<object[]> CountBy_TestData()
new(40, 1)
]);

yield return WrapArgs(
Validate(
source: new (string Name, int Age)[] { ("Bob", 20), ("bob", 30), ("Harry", 40) },
keySelector: x => x.Name,
comparer: null,
expected: new string[] { "Bob", "bob", "Harry" }.Select(x => new KeyValuePair<string, int>(x, 1)));

yield return WrapArgs(
Validate(
source: new (string Name, int Age)[] { ("Bob", 20), ("bob", 30), ("Harry", 40) },
keySelector: x => x.Name,
comparer: StringComparer.OrdinalIgnoreCase,
Expand All @@ -159,8 +144,11 @@ public static IEnumerable<object[]> CountBy_TestData()
new("Harry", 1)
]);

object[] WrapArgs<TSource, TKey>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, IEqualityComparer<TKey>? comparer, IEnumerable<KeyValuePair<TKey, int>> expected)
=> new object[] { source, keySelector, comparer, expected };
static void Validate<TSource, TKey>(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, IEqualityComparer<TKey>? comparer, IEnumerable<KeyValuePair<TKey, int>> expected)
{
Assert.Equal(expected, source.CountBy(keySelector, comparer));
Assert.Equal(expected, source.RunOnce().CountBy(keySelector, comparer));
}
}
}
}

0 comments on commit 6ce4666

Please sign in to comment.