diff --git a/src/DotNext.Tests/Threading/Tasks/TaskQueueTests.cs b/src/DotNext.Tests/Threading/Tasks/TaskQueueTests.cs new file mode 100644 index 000000000..6ce440c3d --- /dev/null +++ b/src/DotNext.Tests/Threading/Tasks/TaskQueueTests.cs @@ -0,0 +1,105 @@ +namespace DotNext.Threading.Tasks; + +public class TaskQueueTests : Test +{ + [Fact] + public static async Task EmptyQueue() + { + var queue = new TaskQueue(10); + Null(queue.HeadTask); + False(queue.TryDequeue(out _)); + Null(await queue.TryDequeueAsync()); + } + + [Fact] + public static async Task QueueOverflow() + { + var queue = new TaskQueue(3); + True(queue.TryEnqueue(Task.CompletedTask)); + True(queue.TryEnqueue(Task.CompletedTask)); + True(queue.TryEnqueue(Task.CompletedTask)); + NotNull(queue.HeadTask); + + var enqueueTask = queue.EnqueueAsync(Task.CompletedTask).AsTask(); + False(enqueueTask.IsCompleted); + + True(queue.TryDequeue(out var task)); + True(task.IsCompleted); + + await enqueueTask.WaitAsync(DefaultTimeout); + queue.Clear(); + } + + [Fact] + public static async Task EnumerateTasks() + { + var queue = new TaskQueue(3); + True(queue.TryEnqueue(Task.CompletedTask)); + True(queue.TryEnqueue(Task.CompletedTask)); + True(queue.TryEnqueue(Task.CompletedTask)); + + var count = 0; + await foreach (var task in queue) + { + Same(Task.CompletedTask, task); + count++; + } + + Equal(3, count); + } + + [Fact] + public static async Task EnumerateTasks2() + { + var queue = new TaskQueue(3); + True(queue.TryEnqueue(Task.CompletedTask)); + True(queue.TryEnqueue(Task.CompletedTask)); + True(queue.TryEnqueue(Task.CompletedTask)); + + var count = 0; + while (await queue.TryDequeueAsync() is { } task) + { + Same(Task.CompletedTask, task); + count++; + } + + Equal(3, count); + } + + [Fact] + public static async Task DelayedDequeue() + { + var queue = new TaskQueue(3); + var enqueueTask = queue.DequeueAsync().AsTask(); + False(enqueueTask.IsCompleted); + + True(queue.TryEnqueue(Task.CompletedTask)); + + await enqueueTask.WaitAsync(DefaultTimeout); + Null(await queue.TryDequeueAsync()); + } + + [Fact] + public static async Task DequeueCancellation() + { + var source = new TaskCompletionSource(); + var queue = new TaskQueue(3); + True(queue.TryEnqueue(source.Task)); + + await ThrowsAnyAsync(queue.DequeueAsync(new(canceled: true)).AsTask); + } + + [Fact] + public static async Task FailedTask() + { + var source = new TaskCompletionSource(); + var queue = new TaskQueue(3); + True(queue.TryEnqueue(source.Task)); + + var dequeueTask = queue.DequeueAsync().AsTask(); + False(dequeueTask.IsCompleted); + + source.SetException(new Exception()); + Same(source.Task, await dequeueTask); + } +} \ No newline at end of file diff --git a/src/DotNext.Threading/Threading/Tasks/TaskQueue.cs b/src/DotNext.Threading/Threading/Tasks/TaskQueue.cs new file mode 100644 index 000000000..5d1b5a2d2 --- /dev/null +++ b/src/DotNext.Threading/Threading/Tasks/TaskQueue.cs @@ -0,0 +1,302 @@ +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +namespace DotNext.Threading.Tasks; + +/// +/// Represents a queue of scheduled tasks. +/// +/// +/// The queue returns tasks in the order as they added (FIFO) in contrast +/// to . +/// +/// The type of tasks in the queue. +public class TaskQueue : IAsyncEnumerable, IResettable + where T : Task +{ + private readonly int capacity; + private readonly T?[] array; + private int tail, head, count; + private Signal? signal; + + /// + /// Initializes a new empty queue. + /// + /// The maximum number of tasks in the queue. + public TaskQueue(int capacity) + { + ArgumentOutOfRangeException.ThrowIfNegativeOrZero(capacity); + + array = new T[capacity]; + this.capacity = capacity; + } + + /// + /// Gets a head of this queue. + /// + public T? HeadTask + { + get + { + lock (array) + { + return count > 0 ? array[head] : null; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void ChangeCount([ConstantExpected] bool increment) + { + Debug.Assert(Monitor.IsEntered(array)); + + count += increment ? +1 : -1; + if (signal?.TrySetResult() ?? false) + signal = null; + } + + private void MoveNext(ref int index) + { + Debug.Assert(Monitor.IsEntered(array)); + + var value = index + 1; + index = value == array.Length ? 0 : value; + } + + /// + /// Tries to enqueue the task. + /// + /// The task to enqueue. + /// if the task is enqueued successfully; if this queue is full. + public bool TryEnqueue(T task) + { + ArgumentNullException.ThrowIfNull(task); + + bool result; + lock (array) + { + if (result = count < capacity) + { + array[tail] = task; + MoveNext(ref tail); + ChangeCount(increment: true); + } + } + + return result; + } + + private bool TryEnqueue(T task, out Task waitTask) + { + bool result; + lock (array) + { + if (result = count < capacity) + { + array[tail] = task; + MoveNext(ref tail); + ChangeCount(increment: true); + waitTask = Task.CompletedTask; + } + else + { + signal ??= new(); + waitTask = signal.Task; + } + } + + return result; + } + + /// + /// Enqueues the task. + /// + /// + /// The caller suspends if the queue is full. + /// + /// The task to enqueue. + /// The token that can be used to cancel the operation. + /// The operation has been canceled. + public async ValueTask EnqueueAsync(T task, CancellationToken token = default) + { + ArgumentNullException.ThrowIfNull(task); + + while (!TryEnqueue(task, out var waitTask)) + { + await waitTask.WaitAsync(token).ConfigureAwait(false); + } + } + + private T? TryPeekOrDequeue(out int head, out Task enqueueTask, out bool completed) + { + T? result; + lock (array) + { + if (count > 0) + { + result = array[head = this.head]; + enqueueTask = Task.CompletedTask; + if (completed = result is { IsCompleted: true }) + { + MoveNext(ref head); + ChangeCount(increment: false); + } + } + else + { + head = default; + result = null; + completed = default; + signal ??= new(); + enqueueTask = signal.Task; + } + } + + return result; + } + + private bool TryDequeue(int expectedHead, T task) + { + bool result; + lock (array) + { + ref var element = ref array[expectedHead]; + if (result = count > 0 && head == expectedHead && ReferenceEquals(element, task)) + { + MoveNext(ref head); + element = null; + ChangeCount(increment: false); + } + } + + return result; + } + + /// + /// Tries to dequeue the completed task. + /// + /// The completed task. + /// if is completed; otherwise, . + public bool TryDequeue([NotNullWhen(true)] out T? task) + { + lock (array) + { + ref var element = ref array[head]; + task = element; + if (count > 0 && task is { IsCompleted: true }) + { + MoveNext(ref head); + element = null; + ChangeCount(increment: false); + } + else + { + task = null; + } + } + + return task is not null; + } + + /// + /// Dequeues the task asynchronously. + /// + /// The caller suspends if the queue is empty. + /// The token that can be used to cancel the operation. + /// The completed task. + /// The operation has been canceled. + public async ValueTask DequeueAsync(CancellationToken token = default) + { + for (var filter = token.CanBeCanceled ? null : Predicate.Constant(true);;) + { + if (TryPeekOrDequeue(out var expectedHead, out var enqueueTask, out var completed) is not { } task) + { + await enqueueTask.WaitAsync(token).ConfigureAwait(false); + continue; + } + + if (!completed) + { + await task.WaitAsync(token).SuspendException(filter ??= token.SuspendAllExceptCancellation).ConfigureAwait(false); + + if (!TryDequeue(expectedHead, task)) + continue; + } + + return task; + } + } + + /// + /// Tries to dequeue the completed task. + /// + /// The token that can be used to cancel the operation. + /// The completed task; or if the queue is empty. + /// The operation has been canceled. + public async ValueTask TryDequeueAsync(CancellationToken token = default) + { + for (var filter = token.CanBeCanceled ? null : Predicate.Constant(true);;) + { + T? task; + if ((task = TryPeekOrDequeue(out var expectedHead, out _, out var completed)) is not null && !completed) + { + await task.WaitAsync(token).SuspendException(filter ??= token.SuspendAllExceptCancellation).ConfigureAwait(false); + + if (!TryDequeue(expectedHead, task)) + continue; + } + + return task; + } + } + + /// + /// Gets consuming enumerator over tasks in the queue. + /// + /// + /// The enumerator stops if the queue is empty. + /// + /// The token that can be used to cancel the operation. + /// The enumerator over completed tasks. + public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken token) + { + for (var filter = token.CanBeCanceled ? null : Predicate.Constant(true); + TryPeekOrDequeue(out var expectedHead, out _, out var completed) is { } task;) + { + if (!completed) + { + await task.WaitAsync(token).SuspendException(filter ??= token.SuspendAllExceptCancellation).ConfigureAwait(false); + if (!TryDequeue(expectedHead, task)) + continue; + } + + yield return task; + } + } + + /// + /// Clears the queue. + /// + public void Clear() + { + lock (array) + { + head = tail = count = 0; + Array.Clear(array); + if (signal?.TrySetResult() ?? false) + signal = null; + } + } + + /// + void IResettable.Reset() => Clear(); + + private sealed class Signal() : TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); +} + +file static class CancellationTokenExtensions +{ + internal static bool SuspendAllExceptCancellation(this object token, Exception e) + => e is not OperationCanceledException canceledEx || !canceledEx.CancellationToken.Equals(token); +} \ No newline at end of file