Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Task.WaitAsync in SemaphoreSlim.WaitAsync #55262

Merged
merged 2 commits into from
Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -701,35 +701,22 @@ private async Task<bool> WaitUntilCountOrTimeoutAsync(TaskNode asyncWaiter, int
Debug.Assert(asyncWaiter != null, "Waiter should have been constructed");
Debug.Assert(Monitor.IsEntered(m_lockObjAndDisposed), "Requires the lock be held");

if (millisecondsTimeout != Timeout.Infinite)
await new ConfiguredNoThrowAwaiter<bool>(asyncWaiter.WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout), cancellationToken));

if (cancellationToken.IsCancellationRequested)
{
// Wait until either the task is completed, cancellation is requested, or the timeout occurs.
// We need to ensure that the Task.Delay task is appropriately cleaned up if the await
// completes due to the asyncWaiter completing, so we use our own token that we can explicitly
// cancel, and we chain the caller's supplied token into it.
using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
{
if (asyncWaiter == await Task.WhenAny(asyncWaiter, Task.Delay(millisecondsTimeout, cts.Token)).ConfigureAwait(false))
{
cts.Cancel(); // ensure that the Task.Delay task is cleaned up
return true; // successfully acquired
}
}
// If we might be running as part of a cancellation callback, force the completion to be asynchronous
// so as to maintain semantics similar to when no token is passed (neither Release nor Cancel would invoke
// continuations off of this task).
await TaskScheduler.Default;
}
else // millisecondsTimeout == Timeout.Infinite

if (asyncWaiter.IsCompleted)
{
// Wait until either the task is completed or cancellation is requested.
var cancellationTask = new Task(null, TaskCreationOptions.RunContinuationsAsynchronously, promiseStyle: true);
using (cancellationToken.UnsafeRegister(static s => ((Task)s!).TrySetResult(), cancellationTask))
{
if (asyncWaiter == await Task.WhenAny(asyncWaiter, cancellationTask).ConfigureAwait(false))
{
return true; // successfully acquired
}
}
return true; // successfully acquired
}

// If we get here, the wait has timed out or been canceled.
// The wait has timed out or been canceled.

// If the await completed synchronously, we still hold the lock. If it didn't,
// we no longer hold the lock. As such, acquire it.
Expand All @@ -750,6 +737,19 @@ private async Task<bool> WaitUntilCountOrTimeoutAsync(TaskNode asyncWaiter, int
return await asyncWaiter.ConfigureAwait(false);
}

// TODO https://github.com/dotnet/runtime/issues/22144: Replace with official nothrow await solution once available.
/// <summary>Awaiter used to await a task.ConfigureAwait(false) but without throwing any exceptions for faulted or canceled tasks.</summary>
private readonly struct ConfiguredNoThrowAwaiter<T> : ICriticalNotifyCompletion
{
private readonly Task<T> _task;
public ConfiguredNoThrowAwaiter(Task<T> task) => _task = task;
public ConfiguredNoThrowAwaiter<T> GetAwaiter() => this;
public bool IsCompleted => _task.IsCompleted;
public void GetResult() { }
public void UnsafeOnCompleted(Action continuation) => _task.ConfigureAwait(false).GetAwaiter().UnsafeOnCompleted(continuation);
public void OnCompleted(Action continuation) => _task.ConfigureAwait(false).GetAwaiter().OnCompleted(continuation);
}

/// <summary>
/// Exits the <see cref="SemaphoreSlim"/> once.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,32 @@ public SystemThreadingTasks_TaskSchedulerDebugView(TaskScheduler scheduler)
// returns the scheduler's GetScheduledTasks
public IEnumerable<Task>? ScheduledTasks => m_taskScheduler.GetScheduledTasks();
}
}


// TODO https://github.com/dotnet/runtime/issues/20025: Consider exposing publicly.
/// <summary>Gets an awaiter used to queue a continuation to this TaskScheduler.</summary>
internal TaskSchedulerAwaiter GetAwaiter() => new TaskSchedulerAwaiter(this);

/// <summary>Awaiter used to queue a continuation to a specified task scheduler.</summary>
internal readonly struct TaskSchedulerAwaiter : ICriticalNotifyCompletion
{
private readonly TaskScheduler _scheduler;
public TaskSchedulerAwaiter(TaskScheduler scheduler) => _scheduler = scheduler;
public bool IsCompleted => false;
public void GetResult() { }
public void OnCompleted(Action continuation) => Task.Factory.StartNew(continuation, CancellationToken.None, TaskCreationOptions.DenyChildAttach, _scheduler);
public void UnsafeOnCompleted(Action continuation)
{
if (ReferenceEquals(_scheduler, Default))
{
ThreadPool.UnsafeQueueUserWorkItem(s => s(), continuation, preferLocal: true);
}
else
{
OnCompleted(continuation);
}
}
}
}

/// <summary>
/// A TaskScheduler implementation that executes all tasks queued to it through a call to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ public static void CancelAfterWait()
// currently we don't expose this.. but it was verified manually
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
public static async Task Cancel_WaitAsync_ContinuationInvokedAsynchronously()
[ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
[InlineData(false)]
[InlineData(true)]
public static async Task Cancel_WaitAsync_ContinuationInvokedAsynchronously(bool withTimeout)
{
await Task.Run(async () => // escape xunit's SynchronizationContext
{
Expand All @@ -60,7 +62,10 @@ await Task.Run(async () => // escape xunit's SynchronizationContext
var sentinel = new object();

var sem = new SemaphoreSlim(0);
Task continuation = sem.WaitAsync(cts.Token).ContinueWith(prev =>
Task waitTask = withTimeout ?
sem.WaitAsync(TimeSpan.FromDays(1), cts.Token) :
sem.WaitAsync(cts.Token);
Task continuation = waitTask.ContinueWith(prev =>
{
Assert.Equal(TaskStatus.Canceled, prev.Status);
Assert.NotSame(sentinel, tl.Value);
Expand Down