|
17 | 17 | using System.Runtime.ExceptionServices;
|
18 | 18 | using System.Runtime.InteropServices;
|
19 | 19 | using System.Runtime.Versioning;
|
| 20 | +using System.Threading.Tasks.Sources; |
20 | 21 |
|
21 | 22 | namespace System.Threading.Tasks
|
22 | 23 | {
|
@@ -6659,6 +6660,191 @@ public static Task<Task<TResult>> WhenAny<TResult>(IEnumerable<Task<TResult>> ta
|
6659 | 6660 | WhenAny<Task<TResult>>(tasks);
|
6660 | 6661 | #endregion
|
6661 | 6662 |
|
| 6663 | + #region WhenEach |
| 6664 | + /// <summary>Creates an <see cref="IAsyncEnumerable{T}"/> that will yield the supplied tasks as those tasks complete.</summary> |
| 6665 | + /// <param name="tasks">The task to iterate through when completed.</param> |
| 6666 | + /// <returns>An <see cref="IAsyncEnumerable{T}"/> for iterating through the supplied tasks.</returns> |
| 6667 | + /// <remarks> |
| 6668 | + /// The supplied tasks will become available to be output via the enumerable once they've completed. The exact order |
| 6669 | + /// in which the tasks will become available is not defined. |
| 6670 | + /// </remarks> |
| 6671 | + /// <exception cref="ArgumentNullException"><paramref name="tasks"/> is null.</exception> |
| 6672 | + /// <exception cref="ArgumentException"><paramref name="tasks"/> contains a null.</exception> |
| 6673 | + public static IAsyncEnumerable<Task> WhenEach(params Task[] tasks) |
| 6674 | + { |
| 6675 | + ArgumentNullException.ThrowIfNull(tasks); |
| 6676 | + return WhenEach((ReadOnlySpan<Task>)tasks); |
| 6677 | + } |
| 6678 | + |
| 6679 | + /// <inheritdoc cref="WhenEach(Task[])"/> |
| 6680 | + public static IAsyncEnumerable<Task> WhenEach(ReadOnlySpan<Task> tasks) => // TODO https://github.com/dotnet/runtime/issues/77873: Add params |
| 6681 | + WhenEachState.Iterate<Task>(WhenEachState.Create(tasks)); |
| 6682 | + |
| 6683 | + /// <inheritdoc cref="WhenEach(Task[])"/> |
| 6684 | + public static IAsyncEnumerable<Task> WhenEach(IEnumerable<Task> tasks) => |
| 6685 | + WhenEachState.Iterate<Task>(WhenEachState.Create(tasks)); |
| 6686 | + |
| 6687 | + /// <inheritdoc cref="WhenEach(Task[])"/> |
| 6688 | + public static IAsyncEnumerable<Task<TResult>> WhenEach<TResult>(params Task<TResult>[] tasks) |
| 6689 | + { |
| 6690 | + ArgumentNullException.ThrowIfNull(tasks); |
| 6691 | + return WhenEach((ReadOnlySpan<Task<TResult>>)tasks); |
| 6692 | + } |
| 6693 | + |
| 6694 | + /// <inheritdoc cref="WhenEach(Task[])"/> |
| 6695 | + public static IAsyncEnumerable<Task<TResult>> WhenEach<TResult>(ReadOnlySpan<Task<TResult>> tasks) => // TODO https://github.com/dotnet/runtime/issues/77873: Add params |
| 6696 | + WhenEachState.Iterate<Task<TResult>>(WhenEachState.Create(ReadOnlySpan<Task>.CastUp(tasks))); |
| 6697 | + |
| 6698 | + /// <inheritdoc cref="WhenEach(Task[])"/> |
| 6699 | + public static IAsyncEnumerable<Task<TResult>> WhenEach<TResult>(IEnumerable<Task<TResult>> tasks) => |
| 6700 | + WhenEachState.Iterate<Task<TResult>>(WhenEachState.Create(tasks)); |
| 6701 | + |
| 6702 | + /// <summary>Object used by <see cref="Iterate"/> to store its state.</summary> |
| 6703 | + private sealed class WhenEachState : Queue<Task>, IValueTaskSource, ITaskCompletionAction |
| 6704 | + { |
| 6705 | + /// <summary>Implementation backing the ValueTask used to wait for the next task to be available.</summary> |
| 6706 | + /// <remarks>This is a mutable struct. Do not make it readonly.</remarks> |
| 6707 | + private ManualResetValueTaskSourceCore<bool> _waitForNextCompletedTask = new() { RunContinuationsAsynchronously = true }; // _waitForNextCompletedTask.Set is called while holding a lock |
| 6708 | + /// <summary>0 if this has never been used in an iteration; 1 if it has.</summary> |
| 6709 | + /// <remarks>This is used to ensure we only ever iterate through the tasks once.</remarks> |
| 6710 | + private int _enumerated; |
| 6711 | + |
| 6712 | + /// <summary>Called at the beginning of the iterator to assume ownership of the state.</summary> |
| 6713 | + /// <returns>true if the caller owns the state; false if the caller should end immediately.</returns> |
| 6714 | + public bool TryStart() => Interlocked.Exchange(ref _enumerated, 1) == 0; |
| 6715 | + |
| 6716 | + /// <summary>Gets or sets the number of tasks that haven't yet been yielded.</summary> |
| 6717 | + public int Remaining { get; set; } |
| 6718 | + |
| 6719 | + void ITaskCompletionAction.Invoke(Task completingTask) |
| 6720 | + { |
| 6721 | + lock (this) |
| 6722 | + { |
| 6723 | + // Enqueue the task into the queue. If the Count is now 1, we transitioned from |
| 6724 | + // empty to non-empty, which means we need to signal the MRVTSC, as the consumer |
| 6725 | + // could be waiting on a ValueTask representing a completed task being available. |
| 6726 | + Enqueue(completingTask); |
| 6727 | + if (Count == 1) |
| 6728 | + { |
| 6729 | + Debug.Assert(_waitForNextCompletedTask.GetStatus(_waitForNextCompletedTask.Version) == ValueTaskSourceStatus.Pending); |
| 6730 | + _waitForNextCompletedTask.SetResult(default); |
| 6731 | + } |
| 6732 | + } |
| 6733 | + } |
| 6734 | + bool ITaskCompletionAction.InvokeMayRunArbitraryCode => false; |
| 6735 | + |
| 6736 | + // Delegate to _waitForNextCompletedTask for IValueTaskSource implementation. |
| 6737 | + void IValueTaskSource.GetResult(short token) => _waitForNextCompletedTask.GetResult(token); |
| 6738 | + ValueTaskSourceStatus IValueTaskSource.GetStatus(short token) => _waitForNextCompletedTask.GetStatus(token); |
| 6739 | + void IValueTaskSource.OnCompleted(Action<object?> continuation, object? state, short token, ValueTaskSourceOnCompletedFlags flags) => |
| 6740 | + _waitForNextCompletedTask.OnCompleted(continuation, state, token, flags); |
| 6741 | + |
| 6742 | + /// <summary>Creates a <see cref="WhenEachState"/> from the specified tasks.</summary> |
| 6743 | + public static WhenEachState? Create(ReadOnlySpan<Task> tasks) |
| 6744 | + { |
| 6745 | + WhenEachState? waiter = null; |
| 6746 | + |
| 6747 | + if (tasks.Length != 0) |
| 6748 | + { |
| 6749 | + waiter = new(); |
| 6750 | + foreach (Task task in tasks) |
| 6751 | + { |
| 6752 | + if (task is null) |
| 6753 | + { |
| 6754 | + ThrowHelper.ThrowArgumentException(ExceptionResource.Task_MultiTaskContinuation_NullTask, ExceptionArgument.tasks); |
| 6755 | + } |
| 6756 | + |
| 6757 | + waiter.Remaining++; |
| 6758 | + task.AddCompletionAction(waiter); |
| 6759 | + } |
| 6760 | + } |
| 6761 | + |
| 6762 | + return waiter; |
| 6763 | + } |
| 6764 | + |
| 6765 | + /// <inheritdoc cref="Create(ReadOnlySpan{Task})"/> |
| 6766 | + public static WhenEachState? Create(IEnumerable<Task> tasks) |
| 6767 | + { |
| 6768 | + ArgumentNullException.ThrowIfNull(tasks); |
| 6769 | + |
| 6770 | + WhenEachState? waiter = null; |
| 6771 | + |
| 6772 | + IEnumerator<Task> e = tasks.GetEnumerator(); |
| 6773 | + if (e.MoveNext()) |
| 6774 | + { |
| 6775 | + waiter = new(); |
| 6776 | + do |
| 6777 | + { |
| 6778 | + Task task = e.Current; |
| 6779 | + if (task is null) |
| 6780 | + { |
| 6781 | + ThrowHelper.ThrowArgumentException(ExceptionResource.Task_MultiTaskContinuation_NullTask, ExceptionArgument.tasks); |
| 6782 | + } |
| 6783 | + |
| 6784 | + waiter.Remaining++; |
| 6785 | + task.AddCompletionAction(waiter); |
| 6786 | + } |
| 6787 | + while (e.MoveNext()); |
| 6788 | + } |
| 6789 | + |
| 6790 | + return waiter; |
| 6791 | + } |
| 6792 | + |
| 6793 | + /// <summary>Iterates through the tasks represented by the provided waiter.</summary> |
| 6794 | + public static async IAsyncEnumerable<T> Iterate<T>(WhenEachState? waiter, [EnumeratorCancellation] CancellationToken cancellationToken = default) where T : Task |
| 6795 | + { |
| 6796 | + // The enumerable could have GetAsyncEnumerator called on it multiple times. As we're dealing with Tasks that |
| 6797 | + // only ever transition from non-completed to completed, re-enumeration doesn't have much benefit, so we take |
| 6798 | + // advantage of the optimizations possible by not supporting that and simply have the semantics that, no matter |
| 6799 | + // how many times the enumerable is enumerated, every task is yielded only once. The original GetAsyncEnumerator |
| 6800 | + // call will give back all the tasks, and all subsequent iterations will be empty. |
| 6801 | + if (waiter?.TryStart() is not true) |
| 6802 | + { |
| 6803 | + yield break; |
| 6804 | + } |
| 6805 | + |
| 6806 | + // Loop until we've yielded all tasks. |
| 6807 | + while (waiter.Remaining > 0) |
| 6808 | + { |
| 6809 | + // Either get the next completed task from the queue, or get a |
| 6810 | + // ValueTask with which to wait for the next task to complete. |
| 6811 | + Task? next; |
| 6812 | + ValueTask waitTask = default; |
| 6813 | + lock (waiter) |
| 6814 | + { |
| 6815 | + // Reset the MRVTSC if it was signaled, then try to dequeue a task and |
| 6816 | + // either return one we got or return a ValueTask that will be signaled |
| 6817 | + // when the next completed task is available. |
| 6818 | + waiter._waitForNextCompletedTask.Reset(); |
| 6819 | + if (!waiter.TryDequeue(out next)) |
| 6820 | + { |
| 6821 | + waitTask = new(waiter, waiter._waitForNextCompletedTask.Version); |
| 6822 | + } |
| 6823 | + } |
| 6824 | + |
| 6825 | + // If we got a completed Task, yield it. |
| 6826 | + if (next is not null) |
| 6827 | + { |
| 6828 | + cancellationToken.ThrowIfCancellationRequested(); |
| 6829 | + waiter.Remaining--; |
| 6830 | + yield return (T)next; |
| 6831 | + continue; |
| 6832 | + } |
| 6833 | + |
| 6834 | + // If we have a cancellation token and the ValueTask isn't already completed, |
| 6835 | + // get a Task from the ValueTask so we can use WaitAsync to make the wait cancelable. |
| 6836 | + // Otherwise, just await the ValueTask directly. We don't need to be concerned |
| 6837 | + // about suppressing exceptions, as the ValueTask is only ever completed successfully. |
| 6838 | + if (cancellationToken.CanBeCanceled && !waitTask.IsCompleted) |
| 6839 | + { |
| 6840 | + waitTask = new ValueTask(waitTask.AsTask().WaitAsync(cancellationToken)); |
| 6841 | + } |
| 6842 | + await waitTask.ConfigureAwait(false); |
| 6843 | + } |
| 6844 | + } |
| 6845 | + } |
| 6846 | + #endregion |
| 6847 | + |
6662 | 6848 | internal static Task<TResult> CreateUnwrapPromise<TResult>(Task outerTask, bool lookForOce)
|
6663 | 6849 | {
|
6664 | 6850 | Debug.Assert(outerTask != null);
|
|
0 commit comments