Skip to content

Commit

Permalink
Add Task.WaitAsync methods
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub committed Mar 11, 2021
1 parent 877a8df commit da10cce
Show file tree
Hide file tree
Showing 65 changed files with 498 additions and 341 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -558,9 +558,12 @@ protected async Task WhenAllOrAnyFailed(Task task1, Task task2)
}
else
{
var cts = new CancellationTokenSource();
await Task.WhenAny(incomplete, Task.Delay(500, cts.Token)); // give second task a chance to complete
cts.Cancel();
try
{
await incomplete.WaitAsync(TimeSpan.FromMilliseconds(500)); // give second task a chance to complete
}
catch (TimeoutException) { }

await (incomplete.IsCompleted ? Task.WhenAll(completed, incomplete) : completed);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public Task CreateClientAndServerAsync(Func<Uri, Task> clientFunc, Func<GenericL
Task serverTask = serverFunc(server);
await new Task[] { clientTask, serverTask }.WhenAllOrAnyFailed().ConfigureAwait(false);
}, options: options).TimeoutAfter(millisecondsTimeout);
}, options: options).WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,7 @@ public async Task ExpectSettingsAckAsync(int timeoutMs = 5000)
Task currentTask = _ignoredSettingsAckPromise?.Task;
if (currentTask != null)
{
var timeout = TimeSpan.FromMilliseconds(timeoutMs);
await currentTask.TimeoutAfter(timeout);
await currentTask.WaitAsync(TimeSpan.FromMilliseconds(timeoutMs));
}

_ignoredSettingsAckPromise = new TaskCompletionSource<bool>();
Expand Down Expand Up @@ -909,7 +908,7 @@ public override async Task WaitForCancellationAsync(bool ignoreIncomingData = tr
Frame frame;
do
{
frame = await ReadFrameAsync(TimeSpan.FromMilliseconds(TestHelper.PassingTestTimeoutMilliseconds));
frame = await ReadFrameAsync(TestHelper.PassingTestTimeout);
Assert.NotNull(frame); // We should get Rst before closing connection.
Assert.Equal(0, (int)(frame.Flags & FrameFlags.EndStream));
if (ignoreIncomingData)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ public static async Task CreateServerAsync(Func<Http2LoopbackServer, Uri, Task>
{
using (var server = Http2LoopbackServer.CreateServer())
{
await funcAsync(server, server.Address).TimeoutAfter(millisecondsTimeout).ConfigureAwait(false);
await funcAsync(server, server.Address).WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout));
}
}

Expand Down Expand Up @@ -223,7 +223,7 @@ public override async Task CreateServerAsync(Func<GenericLoopbackServer, Uri, Ta
{
using (var server = CreateServer(options))
{
await funcAsync(server, server.Address).TimeoutAfter(millisecondsTimeout).ConfigureAwait(false);
await funcAsync(server, server.Address).WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public override GenericLoopbackServer CreateServer(GenericLoopbackOptions option
public override async Task CreateServerAsync(Func<GenericLoopbackServer, Uri, Task> funcAsync, int millisecondsTimeout = 60000, GenericLoopbackOptions options = null)
{
using GenericLoopbackServer server = CreateServer(options);
await funcAsync(server, server.Address).TimeoutAfter(millisecondsTimeout).ConfigureAwait(false);
await funcAsync(server, server.Address).WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout));
}

public override Task<GenericLoopbackConnection> CreateConnectionAsync(Socket socket, Stream stream, GenericLoopbackOptions options = null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ public static async Task CreateServerAsync(Func<HttpAgnosticLoopbackServer, Uri,
{
using (var server = HttpAgnosticLoopbackServer.CreateServer())
{
await funcAsync(server, server.Address).TimeoutAfter(millisecondsTimeout).ConfigureAwait(false);
await funcAsync(server, server.Address).WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout));
}
}

Expand Down Expand Up @@ -240,7 +240,7 @@ public override async Task CreateServerAsync(Func<GenericLoopbackServer, Uri, Ta
{
using (var server = CreateServer(options))
{
await funcAsync(server, server.Address).TimeoutAfter(millisecondsTimeout).ConfigureAwait(false);
await funcAsync(server, server.Address).WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ await LoopbackServer.CreateServerAsync(async (proxyServer, proxyUrl) =>
// Send Digest challenge.
Task<List<string>> serverTask = proxyServer.AcceptConnectionSendResponseAndCloseAsync(HttpStatusCode.ProxyAuthenticationRequired, authHeader);
if (clientTask == await Task.WhenAny(clientTask, serverTask).TimeoutAfter(TestHelper.PassingTestTimeoutMilliseconds))
if (clientTask == await Task.WhenAny(clientTask, serverTask).WaitAsync(TestHelper.PassingTestTimeout))
{
// Client task shouldn't have completed successfully; propagate failure.
Assert.NotEqual(TaskStatus.RanToCompletion, clientTask.Status);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1513,7 +1513,7 @@ await server.AcceptConnectionAsync(async connection =>
await connection.ReadRequestDataAsync(readBody: true);
}
catch { } // Eat errors from client disconnect.
await clientFinished.Task.TimeoutAfter(TimeSpan.FromMinutes(2));
await clientFinished.Task.WaitAsync(TimeSpan.FromMinutes(2));
});
});
}
Expand Down
2 changes: 2 additions & 0 deletions src/libraries/Common/tests/System/Net/Http/TestHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ namespace System.Net.Http.Functional.Tests
{
public static class TestHelper
{
public static TimeSpan PassingTestTimeout => TimeSpan.FromMilliseconds(PassingTestTimeoutMilliseconds);
public static int PassingTestTimeoutMilliseconds => 60 * 1000;

public static bool JsonMessageContainsKeyValue(string message, string key, string value)
{
// Deal with JSON encoding of '\' and '"' in value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,91 +3,50 @@

using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.CompilerServices;

/// <summary>
/// Task timeout helper based on https://devblogs.microsoft.com/pfxteam/crafting-a-task-timeoutafter-method/
/// </summary>
namespace System.Threading.Tasks
{
public static class TaskTimeoutExtensions
{
public static async Task WithCancellation(this Task task, CancellationToken cancellationToken)
{
var tcs = new TaskCompletionSource<bool>();
using (cancellationToken.Register(s => ((TaskCompletionSource<bool>)s).TrySetResult(true), tcs))
{
if (task != await Task.WhenAny(task, tcs.Task).ConfigureAwait(false))
{
throw new OperationCanceledException(cancellationToken);
}
await task; // already completed; propagate any exception
}
}

public static Task TimeoutAfter(this Task task, int millisecondsTimeout)
=> task.TimeoutAfter(TimeSpan.FromMilliseconds(millisecondsTimeout));

public static async Task TimeoutAfter(this Task task, TimeSpan timeout)
{
var cts = new CancellationTokenSource();
#region WaitAsync polyfills
// Test polyfills when targeting a platform that doesn't have these ConfigureAwait overloads on Task

if (task == await Task.WhenAny(task, Task.Delay(timeout, cts.Token)).ConfigureAwait(false))
{
cts.Cancel();
await task.ConfigureAwait(false);
}
else
{
throw new TimeoutException($"Task timed out after {timeout}");
}
}
public static Task WaitAsync(this Task task, TimeSpan timeout) =>
WaitAsync(task, timeout, default);

public static Task<TResult> TimeoutAfter<TResult>(this Task<TResult> task, int millisecondsTimeout)
=> task.TimeoutAfter(TimeSpan.FromMilliseconds(millisecondsTimeout));
public static Task WaitAsync(this Task task, CancellationToken cancellationToken) =>
WaitAsync(task, Timeout.InfiniteTimeSpan, cancellationToken);

public static async Task<TResult> TimeoutAfter<TResult>(this Task<TResult> task, TimeSpan timeout)
public async static Task WaitAsync(this Task task, TimeSpan timeout, CancellationToken cancellationToken)
{
var cts = new CancellationTokenSource();

if (task == await Task<TResult>.WhenAny(task, Task<TResult>.Delay(timeout, cts.Token)).ConfigureAwait(false))
{
cts.Cancel();
return await task.ConfigureAwait(false);
}
else
var tcs = new TaskCompletionSource<bool>();
using (new Timer(s => ((TaskCompletionSource<bool>)s).TrySetException(new TimeoutException()), tcs, timeout, Timeout.InfiniteTimeSpan))
using (cancellationToken.Register(s => ((TaskCompletionSource<bool>)s).TrySetCanceled(), tcs))
{
throw new TimeoutException($"Task timed out after {timeout}");
await(await Task.WhenAny(task, tcs.Task).ConfigureAwait(false)).ConfigureAwait(false);
}
}

#if !NETFRAMEWORK
public static Task TimeoutAfter(this ValueTask task, int millisecondsTimeout)
=> task.AsTask().TimeoutAfter(TimeSpan.FromMilliseconds(millisecondsTimeout));

public static Task TimeoutAfter(this ValueTask task, TimeSpan timeout)
=> task.AsTask().TimeoutAfter(timeout);

public static Task<TResult> TimeoutAfter<TResult>(this ValueTask<TResult> task, int millisecondsTimeout)
=> task.AsTask().TimeoutAfter(TimeSpan.FromMilliseconds(millisecondsTimeout));
public static Task<TResult> WaitAsync<TResult>(this Task<TResult> task, TimeSpan timeout) =>
WaitAsync(task, timeout, default);

public static Task<TResult> TimeoutAfter<TResult>(this ValueTask<TResult> task, TimeSpan timeout)
=> task.AsTask().TimeoutAfter(timeout);
#endif
public static Task<TResult> WaitAsync<TResult>(this Task<TResult> task, CancellationToken cancellationToken) =>
WaitAsync(task, Timeout.InfiniteTimeSpan, cancellationToken);

public static async Task WhenAllOrAnyFailed(this Task[] tasks, int millisecondsTimeout)
public static async Task<TResult> WaitAsync<TResult>(this Task<TResult> task, TimeSpan timeout, CancellationToken cancellationToken)
{
var cts = new CancellationTokenSource();
Task task = tasks.WhenAllOrAnyFailed();
if (task == await Task.WhenAny(task, Task.Delay(millisecondsTimeout, cts.Token)).ConfigureAwait(false))
var tcs = new TaskCompletionSource<TResult>();
using (new Timer(s => ((TaskCompletionSource<TResult>)s).TrySetException(new TimeoutException()), tcs, timeout, Timeout.InfiniteTimeSpan))
using (cancellationToken.Register(s => ((TaskCompletionSource<TResult>)s).TrySetCanceled(), tcs))
{
cts.Cancel();
await task.ConfigureAwait(false);
}
else
{
throw new TimeoutException($"{nameof(WhenAllOrAnyFailed)} timed out after {millisecondsTimeout}ms");
return await (await Task.WhenAny(task, tcs.Task).ConfigureAwait(false)).ConfigureAwait(false);
}
}
#endregion

public static async Task WhenAllOrAnyFailed(this Task[] tasks, int millisecondsTimeout) =>
await tasks.WhenAllOrAnyFailed().WaitAsync(TimeSpan.FromMilliseconds(millisecondsTimeout));

public static async Task WhenAllOrAnyFailed(this Task[] tasks)
{
Expand All @@ -99,12 +58,11 @@ public static async Task WhenAllOrAnyFailed(this Task[] tasks)
{
// Wait a bit to allow other tasks to complete so we can include their exceptions
// in the error we throw.
using (var cts = new CancellationTokenSource())
try
{
await Task.WhenAny(
Task.WhenAll(tasks),
Task.Delay(3_000, cts.Token)).ConfigureAwait(false); // arbitrary delay; can be dialed up or down in the future
await Task.WhenAll(tasks).WaitAsync(TimeSpan.FromSeconds(3)); // arbitrary delay; can be dialed up or down in the future
}
catch { }

var exceptions = new List<Exception>();
foreach (Task t in tasks)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1286,7 +1286,7 @@ void Fail(object state)

fileSystemWatcher.CallOnRenamed(new RenamedEventArgs(WatcherChangeTypes.Renamed, root.RootPath, newDirectoryName, oldDirectoryName));

await Task.WhenAll(oldDirectoryTcs.Task, newDirectoryTcs.Task, newSubDirectoryTcs.Task, newFileTcs.Task).TimeoutAfter(TimeSpan.FromSeconds(30));
await Task.WhenAll(oldDirectoryTcs.Task, newDirectoryTcs.Task, newSubDirectoryTcs.Task, newFileTcs.Task).WaitAsync(TimeSpan.FromSeconds(30));

Assert.False(oldSubDirectoryToken.HasChanged, "Old subdirectory token should not have changed");
Assert.False(oldFileToken.HasChanged, "Old file token should not have changed");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ protected async Task<CancellationToken> StartSelfHostAsync()
// The timeout here is large, because we don't know how long the test could need
// We cover a lot of error cases above, but I want to make sure we eventually give up and don't hang the build
// just in case we missed one -anurse
await started.Task.TimeoutAfter(TimeSpan.FromMinutes(10));
await started.Task.WaitAsync(TimeSpan.FromMinutes(10));
}

return hostExitTokenSource.Token;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ private async Task ExecuteShutdownTest(string testName, string shutdownMechanic)
}
};

await started.Task.TimeoutAfter(TimeSpan.FromSeconds(60));
await started.Task.WaitAsync(TimeSpan.FromSeconds(60));

SendShutdownSignal(deployer.HostProcess);

await completed.Task.TimeoutAfter(TimeSpan.FromSeconds(60));
await completed.Task.WaitAsync(TimeSpan.FromSeconds(60));

WaitForExitOrKill(deployer.HostProcess);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ private async Task DictionaryConcurrentAccessDetection<TKey, TValue>(Dictionary<
}, TaskCreationOptions.LongRunning);

// If Dictionary regresses, we do not want to hang here indefinitely
Assert.True((await Task.WhenAny(task, Task.Delay(TimeSpan.FromSeconds(60))) == task) && task.IsCompletedSuccessfully);
await task.WaitAsync(TimeSpan.FromSeconds(60));
}

[ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ public async Task RunWorkerAsync_NoOnWorkHandler_SetsResultToNull()

backgroundWorker.RunWorkerAsync();

await Task.WhenAny(tcs.Task, Task.Delay(TimeSpan.FromSeconds(10))); // Usually takes 100th of a sec
Assert.True(tcs.Task.IsCompleted);
await tcs.Task.WaitAsync(TimeSpan.FromSeconds(10)); // Usually takes 100th of a sec
}

#region TestCancelAsync
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@
Link="System\PasteArguments.cs" />
<Compile Include="$(CommonPath)Interop\Windows\Interop.Errors.cs"
Link="Common\Interop\Windows\Interop.Errors.cs" />
<Compile Include="$(CommonPath)System\Threading\Tasks\TaskCompletionSourceWithCancellation.cs"
Link="Common\System\Threading\Tasks\TaskCompletionSourceWithCancellation.cs" />
<Compile Include="$(CommonPath)System\Threading\Tasks\TaskTimeoutExtensions.cs"
Link="Common\System\Threading\Tasks\TaskTimeoutExtensions.cs" />
<Compile Include="$(CommonPath)System\Text\ValueStringBuilder.cs"
Link="Common\System\Text\ValueStringBuilder.cs" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,25 +251,7 @@ private bool FlushMessageQueue(bool rethrowInNewThread)
}
}

// Wait until we hit EOF. This is called from Process.WaitForExit
// We will lose some information if we don't do this.
internal void WaitUntilEOF()
{
if (_readToBufferTask is Task task)
{
task.GetAwaiter().GetResult();
}
}

internal Task WaitUntilEOFAsync(CancellationToken cancellationToken)
{
if (_readToBufferTask is Task task)
{
return task.WithCancellation(cancellationToken);
}

return Task.CompletedTask;
}
internal Task EOF => _readToBufferTask ?? Task.CompletedTask;

public void Dispose()
{
Expand Down
Loading

0 comments on commit da10cce

Please sign in to comment.