Skip to content

Commit ab2bd48

Browse files
authored
Add Task.WhenEach to process tasks as they complete (#100316)
* Add Task.WhenEach to process tasks as they complete * Address PR feedback * Fix some async void tests to be async Task across libs * Remove extra awaiter field from state machine Also clean up an extra level of indentation.
1 parent 282891d commit ab2bd48

File tree

9 files changed

+312
-7
lines changed

9 files changed

+312
-7
lines changed

src/libraries/Microsoft.Extensions.DependencyInjection/tests/DI.Tests/ServiceProviderValidationTests.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot()
8787
}
8888

8989
[Fact]
90-
public async void GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot_IL_Replacement()
90+
public async Task GetService_Throws_WhenGetServiceForScopedServiceIsCalledOnRoot_IL_Replacement()
9191
{
9292
// Arrange
9393
var serviceCollection = new ServiceCollection();

src/libraries/Microsoft.Extensions.Hosting/tests/UnitTests/OptionsBuilderExtensionsTests.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ private async Task ValidateOnStart_AddEagerValidation_DoesValidationWhenHostStar
246246
}
247247

248248
[Fact]
249-
private async void CanValidateOptionsEagerly_AddOptionsWithValidateOnStart_IValidateOptions()
249+
private async Task CanValidateOptionsEagerly_AddOptionsWithValidateOnStart_IValidateOptions()
250250
{
251251
var hostBuilder = CreateHostBuilder(services =>
252252
services.AddOptionsWithValidateOnStart<ComplexOptions, ComplexOptionsValidator>()

src/libraries/Microsoft.Extensions.Http/tests/Microsoft.Extensions.Http.Tests/Logging/HttpClientLoggerTest.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ private void AssertCounters(TestCountingLogger testLogger, int requestCount, boo
162162
[InlineData(false, true)]
163163
[InlineData(true, false)]
164164
[InlineData(true, true)]
165-
public async void CustomLogger_LogsCorrectEvents_Sync(bool requestSuccessful, bool asyncSecondCall)
165+
public async Task CustomLogger_LogsCorrectEvents_Sync(bool requestSuccessful, bool asyncSecondCall)
166166
{
167167
var serviceCollection = new ServiceCollection();
168168
serviceCollection.AddTransient(_ =>

src/libraries/System.ComponentModel.TypeConverter/tests/TypeDescriptorTests.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1395,7 +1395,7 @@ public static IEnumerable<object[]> GetConverter_ByMultithread_ReturnsExpected_T
13951395

13961396
[Theory]
13971397
[MemberData(nameof(GetConverter_ByMultithread_ReturnsExpected_TestData))]
1398-
public async void GetConverter_ByMultithread_ReturnsExpected(Type typeForGetConverter, Type expectedConverterType)
1398+
public async Task GetConverter_ByMultithread_ReturnsExpected(Type typeForGetConverter, Type expectedConverterType)
13991399
{
14001400
TypeConverter[] actualConverters = await Task.WhenAll(
14011401
Enumerable.Range(0, 100).Select(_ =>
@@ -1415,7 +1415,7 @@ public static IEnumerable<object[]> GetConverterWithAddProvider_ByMultithread_Su
14151415

14161416
[ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsReflectionEmitSupported))] // Mock will try to JIT
14171417
[MemberData(nameof(GetConverterWithAddProvider_ByMultithread_Success_TestData))]
1418-
public async void GetConverterWithAddProvider_ByMultithread_Success(Type typeForGetConverter, Type expectedConverterType)
1418+
public async Task GetConverterWithAddProvider_ByMultithread_Success(Type typeForGetConverter, Type expectedConverterType)
14191419
{
14201420
TypeConverter[] actualConverters = await Task.WhenAll(
14211421
Enumerable.Range(0, 200).Select(_ =>

src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/WinHttpHandlerTest.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ public async Task SendAsync_SlowServerAndCancel_ThrowsTaskCanceledException()
9898

9999
[OuterLoop]
100100
[Fact]
101-
public async void SendAsync_SlowServerRespondsAfterDefaultReceiveTimeout_ThrowsHttpRequestException()
101+
public async Task SendAsync_SlowServerRespondsAfterDefaultReceiveTimeout_ThrowsHttpRequestException()
102102
{
103103
var handler = new WinHttpHandler();
104104
using (var client = new HttpClient(handler))

src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Task.cs

+186
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
using System.Runtime.ExceptionServices;
1818
using System.Runtime.InteropServices;
1919
using System.Runtime.Versioning;
20+
using System.Threading.Tasks.Sources;
2021

2122
namespace System.Threading.Tasks
2223
{
@@ -6659,6 +6660,191 @@ public static Task<Task<TResult>> WhenAny<TResult>(IEnumerable<Task<TResult>> ta
66596660
WhenAny<Task<TResult>>(tasks);
66606661
#endregion
66616662

6663+
#region WhenEach
6664+
/// <summary>Creates an <see cref="IAsyncEnumerable{T}"/> that will yield the supplied tasks as those tasks complete.</summary>
6665+
/// <param name="tasks">The task to iterate through when completed.</param>
6666+
/// <returns>An <see cref="IAsyncEnumerable{T}"/> for iterating through the supplied tasks.</returns>
6667+
/// <remarks>
6668+
/// The supplied tasks will become available to be output via the enumerable once they've completed. The exact order
6669+
/// in which the tasks will become available is not defined.
6670+
/// </remarks>
6671+
/// <exception cref="ArgumentNullException"><paramref name="tasks"/> is null.</exception>
6672+
/// <exception cref="ArgumentException"><paramref name="tasks"/> contains a null.</exception>
6673+
public static IAsyncEnumerable<Task> WhenEach(params Task[] tasks)
6674+
{
6675+
ArgumentNullException.ThrowIfNull(tasks);
6676+
return WhenEach((ReadOnlySpan<Task>)tasks);
6677+
}
6678+
6679+
/// <inheritdoc cref="WhenEach(Task[])"/>
6680+
public static IAsyncEnumerable<Task> WhenEach(ReadOnlySpan<Task> tasks) => // TODO https://github.com/dotnet/runtime/issues/77873: Add params
6681+
WhenEachState.Iterate<Task>(WhenEachState.Create(tasks));
6682+
6683+
/// <inheritdoc cref="WhenEach(Task[])"/>
6684+
public static IAsyncEnumerable<Task> WhenEach(IEnumerable<Task> tasks) =>
6685+
WhenEachState.Iterate<Task>(WhenEachState.Create(tasks));
6686+
6687+
/// <inheritdoc cref="WhenEach(Task[])"/>
6688+
public static IAsyncEnumerable<Task<TResult>> WhenEach<TResult>(params Task<TResult>[] tasks)
6689+
{
6690+
ArgumentNullException.ThrowIfNull(tasks);
6691+
return WhenEach((ReadOnlySpan<Task<TResult>>)tasks);
6692+
}
6693+
6694+
/// <inheritdoc cref="WhenEach(Task[])"/>
6695+
public static IAsyncEnumerable<Task<TResult>> WhenEach<TResult>(ReadOnlySpan<Task<TResult>> tasks) => // TODO https://github.com/dotnet/runtime/issues/77873: Add params
6696+
WhenEachState.Iterate<Task<TResult>>(WhenEachState.Create(ReadOnlySpan<Task>.CastUp(tasks)));
6697+
6698+
/// <inheritdoc cref="WhenEach(Task[])"/>
6699+
public static IAsyncEnumerable<Task<TResult>> WhenEach<TResult>(IEnumerable<Task<TResult>> tasks) =>
6700+
WhenEachState.Iterate<Task<TResult>>(WhenEachState.Create(tasks));
6701+
6702+
/// <summary>Object used by <see cref="Iterate"/> to store its state.</summary>
6703+
private sealed class WhenEachState : Queue<Task>, IValueTaskSource, ITaskCompletionAction
6704+
{
6705+
/// <summary>Implementation backing the ValueTask used to wait for the next task to be available.</summary>
6706+
/// <remarks>This is a mutable struct. Do not make it readonly.</remarks>
6707+
private ManualResetValueTaskSourceCore<bool> _waitForNextCompletedTask = new() { RunContinuationsAsynchronously = true }; // _waitForNextCompletedTask.Set is called while holding a lock
6708+
/// <summary>0 if this has never been used in an iteration; 1 if it has.</summary>
6709+
/// <remarks>This is used to ensure we only ever iterate through the tasks once.</remarks>
6710+
private int _enumerated;
6711+
6712+
/// <summary>Called at the beginning of the iterator to assume ownership of the state.</summary>
6713+
/// <returns>true if the caller owns the state; false if the caller should end immediately.</returns>
6714+
public bool TryStart() => Interlocked.Exchange(ref _enumerated, 1) == 0;
6715+
6716+
/// <summary>Gets or sets the number of tasks that haven't yet been yielded.</summary>
6717+
public int Remaining { get; set; }
6718+
6719+
void ITaskCompletionAction.Invoke(Task completingTask)
6720+
{
6721+
lock (this)
6722+
{
6723+
// Enqueue the task into the queue. If the Count is now 1, we transitioned from
6724+
// empty to non-empty, which means we need to signal the MRVTSC, as the consumer
6725+
// could be waiting on a ValueTask representing a completed task being available.
6726+
Enqueue(completingTask);
6727+
if (Count == 1)
6728+
{
6729+
Debug.Assert(_waitForNextCompletedTask.GetStatus(_waitForNextCompletedTask.Version) == ValueTaskSourceStatus.Pending);
6730+
_waitForNextCompletedTask.SetResult(default);
6731+
}
6732+
}
6733+
}
6734+
bool ITaskCompletionAction.InvokeMayRunArbitraryCode => false;
6735+
6736+
// Delegate to _waitForNextCompletedTask for IValueTaskSource implementation.
6737+
void IValueTaskSource.GetResult(short token) => _waitForNextCompletedTask.GetResult(token);
6738+
ValueTaskSourceStatus IValueTaskSource.GetStatus(short token) => _waitForNextCompletedTask.GetStatus(token);
6739+
void IValueTaskSource.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) =>
6740+
_waitForNextCompletedTask.OnCompleted(continuation, state, token, flags);
6741+
6742+
/// <summary>Creates a <see cref="WhenEachState"/> from the specified tasks.</summary>
6743+
public static WhenEachState? Create(ReadOnlySpan<Task> tasks)
6744+
{
6745+
WhenEachState? waiter = null;
6746+
6747+
if (tasks.Length != 0)
6748+
{
6749+
waiter = new();
6750+
foreach (Task task in tasks)
6751+
{
6752+
if (task is null)
6753+
{
6754+
ThrowHelper.ThrowArgumentException(ExceptionResource.Task_MultiTaskContinuation_NullTask, ExceptionArgument.tasks);
6755+
}
6756+
6757+
waiter.Remaining++;
6758+
task.AddCompletionAction(waiter);
6759+
}
6760+
}
6761+
6762+
return waiter;
6763+
}
6764+
6765+
/// <inheritdoc cref="Create(ReadOnlySpan{Task})"/>
6766+
public static WhenEachState? Create(IEnumerable<Task> tasks)
6767+
{
6768+
ArgumentNullException.ThrowIfNull(tasks);
6769+
6770+
WhenEachState? waiter = null;
6771+
6772+
IEnumerator<Task> e = tasks.GetEnumerator();
6773+
if (e.MoveNext())
6774+
{
6775+
waiter = new();
6776+
do
6777+
{
6778+
Task task = e.Current;
6779+
if (task is null)
6780+
{
6781+
ThrowHelper.ThrowArgumentException(ExceptionResource.Task_MultiTaskContinuation_NullTask, ExceptionArgument.tasks);
6782+
}
6783+
6784+
waiter.Remaining++;
6785+
task.AddCompletionAction(waiter);
6786+
}
6787+
while (e.MoveNext());
6788+
}
6789+
6790+
return waiter;
6791+
}
6792+
6793+
/// <summary>Iterates through the tasks represented by the provided waiter.</summary>
6794+
public static async IAsyncEnumerable<T> Iterate<T>(WhenEachState? waiter, [EnumeratorCancellation] CancellationToken cancellationToken = default) where T : Task
6795+
{
6796+
// The enumerable could have GetAsyncEnumerator called on it multiple times. As we're dealing with Tasks that
6797+
// only ever transition from non-completed to completed, re-enumeration doesn't have much benefit, so we take
6798+
// advantage of the optimizations possible by not supporting that and simply have the semantics that, no matter
6799+
// how many times the enumerable is enumerated, every task is yielded only once. The original GetAsyncEnumerator
6800+
// call will give back all the tasks, and all subsequent iterations will be empty.
6801+
if (waiter?.TryStart() is not true)
6802+
{
6803+
yield break;
6804+
}
6805+
6806+
// Loop until we've yielded all tasks.
6807+
while (waiter.Remaining > 0)
6808+
{
6809+
// Either get the next completed task from the queue, or get a
6810+
// ValueTask with which to wait for the next task to complete.
6811+
Task? next;
6812+
ValueTask waitTask = default;
6813+
lock (waiter)
6814+
{
6815+
// Reset the MRVTSC if it was signaled, then try to dequeue a task and
6816+
// either return one we got or return a ValueTask that will be signaled
6817+
// when the next completed task is available.
6818+
waiter._waitForNextCompletedTask.Reset();
6819+
if (!waiter.TryDequeue(out next))
6820+
{
6821+
waitTask = new(waiter, waiter._waitForNextCompletedTask.Version);
6822+
}
6823+
}
6824+
6825+
// If we got a completed Task, yield it.
6826+
if (next is not null)
6827+
{
6828+
cancellationToken.ThrowIfCancellationRequested();
6829+
waiter.Remaining--;
6830+
yield return (T)next;
6831+
continue;
6832+
}
6833+
6834+
// If we have a cancellation token and the ValueTask isn't already completed,
6835+
// get a Task from the ValueTask so we can use WaitAsync to make the wait cancelable.
6836+
// Otherwise, just await the ValueTask directly. We don't need to be concerned
6837+
// about suppressing exceptions, as the ValueTask is only ever completed successfully.
6838+
if (cancellationToken.CanBeCanceled && !waitTask.IsCompleted)
6839+
{
6840+
waitTask = new ValueTask(waitTask.AsTask().WaitAsync(cancellationToken));
6841+
}
6842+
await waitTask.ConfigureAwait(false);
6843+
}
6844+
}
6845+
}
6846+
#endregion
6847+
66626848
internal static Task<TResult> CreateUnwrapPromise<TResult>(Task outerTask, bool lookForOce)
66636849
{
66646850
Debug.Assert(outerTask != null);

src/libraries/System.Runtime/ref/System.Runtime.cs

+6
Original file line numberDiff line numberDiff line change
@@ -15335,6 +15335,12 @@ public static void WaitAll(System.Threading.Tasks.Task[] tasks, System.Threading
1533515335
public static System.Threading.Tasks.Task<System.Threading.Tasks.Task<TResult>> WhenAny<TResult>(System.Collections.Generic.IEnumerable<System.Threading.Tasks.Task<TResult>> tasks) { throw null; }
1533615336
public static System.Threading.Tasks.Task<System.Threading.Tasks.Task<TResult>> WhenAny<TResult>(System.Threading.Tasks.Task<TResult> task1, System.Threading.Tasks.Task<TResult> task2) { throw null; }
1533715337
public static System.Threading.Tasks.Task<System.Threading.Tasks.Task<TResult>> WhenAny<TResult>(params System.Threading.Tasks.Task<TResult>[] tasks) { throw null; }
15338+
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task> WhenEach(System.Collections.Generic.IEnumerable<System.Threading.Tasks.Task> tasks) { throw null; }
15339+
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task> WhenEach(params System.Threading.Tasks.Task[] tasks) { throw null; }
15340+
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task> WhenEach(System.ReadOnlySpan<System.Threading.Tasks.Task> tasks) { throw null; }
15341+
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task<TResult>> WhenEach<TResult>(System.Collections.Generic.IEnumerable<System.Threading.Tasks.Task<TResult>> tasks) { throw null; }
15342+
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task<TResult>> WhenEach<TResult>(params System.Threading.Tasks.Task<TResult>[] tasks) { throw null; }
15343+
public static System.Collections.Generic.IAsyncEnumerable<System.Threading.Tasks.Task<TResult>> WhenEach<TResult>(System.ReadOnlySpan<System.Threading.Tasks.Task<TResult>> tasks) { throw null; }
1533815344
public static System.Runtime.CompilerServices.YieldAwaitable Yield() { throw null; }
1533915345
}
1534015346
public static partial class TaskAsyncEnumerableExtensions

0 commit comments

Comments
 (0)