Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in HttpClient timeout handling and GetAsync_ContentCanBeCanceled #44169

Merged
merged 1 commit into from
Nov 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 28 additions & 22 deletions src/libraries/System.Net.Http/src/System/Net/Http/HttpClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ private async Task<string> GetStringAsyncCore(HttpRequestMessage request, Cancel
bool telemetryStarted = StartSend(request);
bool responseContentTelemetryStarted = false;

(CancellationTokenSource cts, bool disposeCts, long timeoutTime) = PrepareCancellationTokenSource(cancellationToken);
(CancellationTokenSource cts, bool disposeCts, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken);
HttpResponseMessage? response = null;
try
{
Expand Down Expand Up @@ -214,7 +214,7 @@ private async Task<string> GetStringAsyncCore(HttpRequestMessage request, Cancel
}
catch (Exception e)
{
HandleFailure(e, telemetryStarted, response, cts, cancellationToken, timeoutTime);
HandleFailure(e, telemetryStarted, response, cts, cancellationToken, pendingRequestsCts);
throw;
}
finally
Expand Down Expand Up @@ -247,7 +247,7 @@ private async Task<byte[]> GetByteArrayAsyncCore(HttpRequestMessage request, Can
bool telemetryStarted = StartSend(request);
bool responseContentTelemetryStarted = false;

(CancellationTokenSource cts, bool disposeCts, long timeoutTime) = PrepareCancellationTokenSource(cancellationToken);
(CancellationTokenSource cts, bool disposeCts, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken);
HttpResponseMessage? response = null;
try
{
Expand Down Expand Up @@ -293,7 +293,7 @@ private async Task<byte[]> GetByteArrayAsyncCore(HttpRequestMessage request, Can
}
catch (Exception e)
{
HandleFailure(e, telemetryStarted, response, cts, cancellationToken, timeoutTime);
HandleFailure(e, telemetryStarted, response, cts, cancellationToken, pendingRequestsCts);
throw;
}
finally
Expand Down Expand Up @@ -325,7 +325,7 @@ private async Task<Stream> GetStreamAsyncCore(HttpRequestMessage request, Cancel
{
bool telemetryStarted = StartSend(request);

(CancellationTokenSource cts, bool disposeCts, long timeoutTime) = PrepareCancellationTokenSource(cancellationToken);
(CancellationTokenSource cts, bool disposeCts, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken);
HttpResponseMessage? response = null;
try
{
Expand All @@ -339,7 +339,7 @@ private async Task<Stream> GetStreamAsyncCore(HttpRequestMessage request, Cancel
}
catch (Exception e)
{
HandleFailure(e, telemetryStarted, response, cts, cancellationToken, timeoutTime);
HandleFailure(e, telemetryStarted, response, cts, cancellationToken, pendingRequestsCts);
throw;
}
finally
Expand Down Expand Up @@ -458,8 +458,8 @@ public HttpResponseMessage Send(HttpRequestMessage request, HttpCompletionOption
// Called outside of async state machine to propagate certain exception even without awaiting the returned task.
CheckRequestBeforeSend(request);

(CancellationTokenSource cts, bool disposeCts, long timeoutTime) = PrepareCancellationTokenSource(cancellationToken);
ValueTask<HttpResponseMessage> sendTask = SendAsyncCore(request, completionOption, async: false, cts, disposeCts, timeoutTime, cancellationToken);
(CancellationTokenSource cts, bool disposeCts, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken);
ValueTask<HttpResponseMessage> sendTask = SendAsyncCore(request, completionOption, async: false, cts, disposeCts, pendingRequestsCts, cancellationToken);
Debug.Assert(sendTask.IsCompleted);
return sendTask.GetAwaiter().GetResult();
}
Expand All @@ -478,8 +478,8 @@ public Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, HttpCompl
// Called outside of async state machine to propagate certain exception even without awaiting the returned task.
CheckRequestBeforeSend(request);

(CancellationTokenSource cts, bool disposeCts, long timeoutTime) = PrepareCancellationTokenSource(cancellationToken);
return SendAsyncCore(request, completionOption, async: true, cts, disposeCts, timeoutTime, cancellationToken).AsTask();
(CancellationTokenSource cts, bool disposeCts, CancellationTokenSource pendingRequestsCts) = PrepareCancellationTokenSource(cancellationToken);
return SendAsyncCore(request, completionOption, async: true, cts, disposeCts, pendingRequestsCts, cancellationToken).AsTask();
}

private void CheckRequestBeforeSend(HttpRequestMessage request)
Expand All @@ -499,7 +499,8 @@ private void CheckRequestBeforeSend(HttpRequestMessage request)

private async ValueTask<HttpResponseMessage> SendAsyncCore(
HttpRequestMessage request, HttpCompletionOption completionOption,
bool async, CancellationTokenSource cts, bool disposeCts, long timeoutTime, CancellationToken originalCancellationToken)
bool async, CancellationTokenSource cts, bool disposeCts,
CancellationTokenSource pendingRequestsCts, CancellationToken originalCancellationToken)
{
bool telemetryStarted = StartSend(request);
bool responseContentTelemetryStarted = false;
Expand Down Expand Up @@ -537,7 +538,7 @@ await base.SendAsync(request, cts.Token).ConfigureAwait(false) :
}
catch (Exception e)
{
HandleFailure(e, telemetryStarted, response, cts, originalCancellationToken, timeoutTime);
HandleFailure(e, telemetryStarted, response, cts, originalCancellationToken, pendingRequestsCts);
throw;
}
finally
Expand All @@ -554,18 +555,18 @@ private static void ThrowForNullResponse([NotNull] HttpResponseMessage? response
}
}

private void HandleFailure(Exception e, bool telemetryStarted, HttpResponseMessage? response, CancellationTokenSource cts, CancellationToken cancellationToken, long timeoutTime)
private void HandleFailure(Exception e, bool telemetryStarted, HttpResponseMessage? response, CancellationTokenSource cts, CancellationToken cancellationToken, CancellationTokenSource pendingRequestsCts)
{
LogRequestFailed(telemetryStarted);

response?.Dispose();

Exception? toThrow = null;

if (e is OperationCanceledException oce && !cancellationToken.IsCancellationRequested && Environment.TickCount64 >= timeoutTime)
if (e is OperationCanceledException oce && !cancellationToken.IsCancellationRequested && !pendingRequestsCts.IsCancellationRequested)
{
// If this exception is for cancellation, but cancellation wasn't requested and instead we find that we've passed a timeout end time,
// treat this instead as a timeout.
// If this exception is for cancellation, but cancellation wasn't requested, either by the caller's token or by the pending requests source,
// the only other cause could be a timeout. Treat it as such.
e = toThrow = new TaskCanceledException(string.Format(SR.net_http_request_timedout, _timeout.TotalSeconds), new TimeoutException(e.Message, e), oce.CancellationToken);
}
else if (cts.IsCancellationRequested && e is HttpRequestException) // if cancellationToken is canceled, cts will also be canceled
Expand Down Expand Up @@ -745,28 +746,33 @@ private void PrepareRequestMessage(HttpRequestMessage request)
}
}

private (CancellationTokenSource TokenSource, bool DisposeTokenSource, long TimeoutTime) PrepareCancellationTokenSource(CancellationToken cancellationToken)
private (CancellationTokenSource TokenSource, bool DisposeTokenSource, CancellationTokenSource PendingRequestsCts) PrepareCancellationTokenSource(CancellationToken cancellationToken)
{
// We need a CancellationTokenSource to use with the request. We always have the global
// _pendingRequestsCts to use, plus we may have a token provided by the caller, and we may
// have a timeout. If we have a timeout or a caller-provided token, we need to create a new
// CTS (we can't, for example, timeout the pending requests CTS, as that could cancel other
// unrelated operations). Otherwise, we can use the pending requests CTS directly.

// Snapshot the current pending requests cancellation source. It can change concurrently due to cancellation being requested
// and it being replaced, and we need a stable view of it: if cancellation occurs and the caller's token hasn't been canceled,
// it's either due to this source or due to the timeout, and checking whether this source is the culprit is reliable whereas
// it's more approximate checking elapsed time.
CancellationTokenSource pendingRequestsCts = _pendingRequestsCts;

bool hasTimeout = _timeout != s_infiniteTimeout;
long timeoutTime = long.MaxValue;
if (hasTimeout || cancellationToken.CanBeCanceled)
{
CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _pendingRequestsCts.Token);
CancellationTokenSource cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, pendingRequestsCts.Token);
if (hasTimeout)
{
timeoutTime = Environment.TickCount64 + (_timeout.Ticks / TimeSpan.TicksPerMillisecond);
cts.CancelAfter(_timeout);
}

return (cts, DisposeTokenSource: true, timeoutTime);
return (cts, DisposeTokenSource: true, pendingRequestsCts);
}

return (_pendingRequestsCts, DisposeTokenSource: false, timeoutTime);
return (pendingRequestsCts, DisposeTokenSource: false, pendingRequestsCts);
}

private static void CheckBaseAddress(Uri? baseAddress, string parameterName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ await LoopbackServerFactory.CreateClientAndServerAsync(
},
async server =>
{
await server.AcceptConnectionAsync(async connection =>
Task serverHandling = server.AcceptConnectionAsync(async connection =>
{
await connection.ReadRequestDataAsync(readBody: false);
await connection.SendResponseAsync(HttpStatusCode.OK, headers: new HttpHeaderData[] { new HttpHeaderData("Content-Length", "5") });
Expand All @@ -497,11 +497,20 @@ await server.AcceptConnectionAsync(async connection =>
httpClient.CancelPendingRequests();
break;

// case 2: timeout fires on its own
// case 2: timeout fires on its own
}

await tcs.Task;
});

// The client may have completed before even sending a request when testing HttpClient.Timeout.
await Task.WhenAny(serverHandling, tcs.Task);
if (cancelMode != 2)
{
// If using a timeout to cancel requests, it's possible the server's processing could have gotten interrupted,
// so we want to ignore any exceptions from the server when in that mode. For anything else, let exceptions propagate.
await serverHandling;
}
});
}

Expand Down