Skip to content

Commit

Permalink
Use Task.WaitAsync in SemaphoreSlim.WaitAsync (#55262)
Browse files Browse the repository at this point in the history
* Use Task.WaitAsync in SemaphoreSlim

* Fix failing test

The Cancel_WaitAsync_ContinuationInvokedAsynchronously test was failing,
highlighting that we were no longer invoking the continuation
asynchronously from the Cancel call.  But in fact we were incompletely
doing so in the past, such that we'd only force that asynchrony if no
timeout was provided... if both a timeout and a token were provided,
then we wouldn't.  I've enhanced the test to validate both cases, and
made sure we now pass.
  • Loading branch information
stephentoub authored Jul 8, 2021
1 parent 7f88911 commit e30c200
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -701,35 +701,22 @@ private async Task<bool> WaitUntilCountOrTimeoutAsync(TaskNode asyncWaiter, int
Debug.Assert(asyncWaiter != null, "Waiter should have been constructed");
Debug.Assert(Monitor.IsEntered(m_lockObjAndDisposed), "Requires the lock be held");

if (millisecondsTimeout != Timeout.Infinite)
await new ConfiguredNoThrowAwaiter<bool>(asyncWaiter.WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout), cancellationToken));

if (cancellationToken.IsCancellationRequested)
{
// Wait until either the task is completed, cancellation is requested, or the timeout occurs.
// We need to ensure that the Task.Delay task is appropriately cleaned up if the await
// completes due to the asyncWaiter completing, so we use our own token that we can explicitly
// cancel, and we chain the caller's supplied token into it.
using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
{
if (asyncWaiter == await Task.WhenAny(asyncWaiter, Task.Delay(millisecondsTimeout, cts.Token)).ConfigureAwait(false))
{
cts.Cancel(); // ensure that the Task.Delay task is cleaned up
return true; // successfully acquired
}
}
// If we might be running as part of a cancellation callback, force the completion to be asynchronous
// so as to maintain semantics similar to when no token is passed (neither Release nor Cancel would invoke
// continuations off of this task).
await TaskScheduler.Default;
}
else // millisecondsTimeout == Timeout.Infinite

if (asyncWaiter.IsCompleted)
{
// Wait until either the task is completed or cancellation is requested.
var cancellationTask = new Task(null, TaskCreationOptions.RunContinuationsAsynchronously, promiseStyle: true);
using (cancellationToken.UnsafeRegister(static s => ((Task)s!).TrySetResult(), cancellationTask))
{
if (asyncWaiter == await Task.WhenAny(asyncWaiter, cancellationTask).ConfigureAwait(false))
{
return true; // successfully acquired
}
}
return true; // successfully acquired
}

// If we get here, the wait has timed out or been canceled.
// The wait has timed out or been canceled.

// If the await completed synchronously, we still hold the lock. If it didn't,
// we no longer hold the lock. As such, acquire it.
Expand All @@ -750,6 +737,19 @@ private async Task<bool> WaitUntilCountOrTimeoutAsync(TaskNode asyncWaiter, int
return await asyncWaiter.ConfigureAwait(false);
}

// TODO https://github.com/dotnet/runtime/issues/22144: Replace with official nothrow await solution once available.
/// <summary>Awaiter used to await a task.ConfigureAwait(false) but without throwing any exceptions for faulted or canceled tasks.</summary>
private readonly struct ConfiguredNoThrowAwaiter<T> : ICriticalNotifyCompletion
{
private readonly Task<T> _task;
public ConfiguredNoThrowAwaiter(Task<T> task) => _task = task;
public ConfiguredNoThrowAwaiter<T> GetAwaiter() => this;
public bool IsCompleted => _task.IsCompleted;
public void GetResult() { }
public void UnsafeOnCompleted(Action continuation) => _task.ConfigureAwait(false).GetAwaiter().UnsafeOnCompleted(continuation);
public void OnCompleted(Action continuation) => _task.ConfigureAwait(false).GetAwaiter().OnCompleted(continuation);
}

/// <summary>
/// Exits the <see cref="SemaphoreSlim"/> once.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,32 @@ public SystemThreadingTasks_TaskSchedulerDebugView(TaskScheduler scheduler)
// returns the scheduler's GetScheduledTasks
public IEnumerable<Task>? ScheduledTasks => m_taskScheduler.GetScheduledTasks();
}
}


// TODO https://github.com/dotnet/runtime/issues/20025: Consider exposing publicly.
/// <summary>Gets an awaiter used to queue a continuation to this TaskScheduler.</summary>
internal TaskSchedulerAwaiter GetAwaiter() => new TaskSchedulerAwaiter(this);

/// <summary>Awaiter used to queue a continuation to a specified task scheduler.</summary>
internal readonly struct TaskSchedulerAwaiter : ICriticalNotifyCompletion
{
private readonly TaskScheduler _scheduler;
public TaskSchedulerAwaiter(TaskScheduler scheduler) => _scheduler = scheduler;
public bool IsCompleted => false;
public void GetResult() { }
public void OnCompleted(Action continuation) => Task.Factory.StartNew(continuation, CancellationToken.None, TaskCreationOptions.DenyChildAttach, _scheduler);
public void UnsafeOnCompleted(Action continuation)
{
if (ReferenceEquals(_scheduler, Default))
{
ThreadPool.UnsafeQueueUserWorkItem(s => s(), continuation, preferLocal: true);
}
else
{
OnCompleted(continuation);
}
}
}
}

/// <summary>
/// A TaskScheduler implementation that executes all tasks queued to it through a call to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ public static void CancelAfterWait()
// currently we don't expose this.. but it was verified manually
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
public static async Task Cancel_WaitAsync_ContinuationInvokedAsynchronously()
[ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
[InlineData(false)]
[InlineData(true)]
public static async Task Cancel_WaitAsync_ContinuationInvokedAsynchronously(bool withTimeout)
{
await Task.Run(async () => // escape xunit's SynchronizationContext
{
Expand All @@ -60,7 +62,10 @@ await Task.Run(async () => // escape xunit's SynchronizationContext
var sentinel = new object();
var sem = new SemaphoreSlim(0);
Task continuation = sem.WaitAsync(cts.Token).ContinueWith(prev =>
Task waitTask = withTimeout ?
sem.WaitAsync(TimeSpan.FromDays(1), cts.Token) :
sem.WaitAsync(cts.Token);
Task continuation = waitTask.ContinueWith(prev =>
{
Assert.Equal(TaskStatus.Canceled, prev.Status);
Assert.NotSame(sentinel, tl.Value);
Expand Down

0 comments on commit e30c200

Please sign in to comment.