From 2be676ad3bb784a876c0f681d95565bd85d509dc Mon Sep 17 00:00:00 2001 From: James Newton-King Date: Thu, 22 Jun 2023 15:12:17 +0800 Subject: [PATCH] WriteAsync cancellation throws an error with the calls completed status if possible (#2170) --- .../Internal/GrpcCall.NonGeneric.cs | 43 +++++++- src/Grpc.Net.Client/Internal/GrpcCall.cs | 42 +------ .../Internal/Http/PushStreamContent.cs | 10 +- .../Internal/StreamExtensions.cs | 47 +++++--- test/FunctionalTests/Client/StreamingTests.cs | 2 +- .../AsyncClientStreamingCallTests.cs | 103 +++++++++++++++++- .../GrpcCallSerializationContextTests.cs | 3 +- .../StreamSerializationHelper.cs | 3 +- 8 files changed, 189 insertions(+), 64 deletions(-) diff --git a/src/Grpc.Net.Client/Internal/GrpcCall.NonGeneric.cs b/src/Grpc.Net.Client/Internal/GrpcCall.NonGeneric.cs index c69721850..97d8fa68c 100644 --- a/src/Grpc.Net.Client/Internal/GrpcCall.NonGeneric.cs +++ b/src/Grpc.Net.Client/Internal/GrpcCall.NonGeneric.cs @@ -1,4 +1,4 @@ -#region Copyright notice and license +#region Copyright notice and license // Copyright 2019 The gRPC Authors // @@ -53,6 +53,7 @@ public DefaultDeserializationContext DeserializationContext public string? RequestGrpcEncoding { get; internal set; } + public abstract Task CallTask { get; } public abstract CancellationToken CancellationToken { get; } public abstract Type RequestType { get; } public abstract Type ResponseType { get; } @@ -64,6 +65,29 @@ protected GrpcCall(CallOptions options, GrpcChannel channel) Logger = channel.LoggerFactory.CreateLogger(LoggerName); } + public Exception CreateCanceledStatusException(Exception? ex = null) + { + var status = (CallTask.IsCompletedSuccessfully()) ? CallTask.Result : new Status(StatusCode.Cancelled, string.Empty, ex); + return CreateRpcException(status); + } + + public CancellationToken GetCanceledToken(CancellationToken methodCancellationToken) + { + if (methodCancellationToken.IsCancellationRequested) + { + return methodCancellationToken; + } + else if (Options.CancellationToken.IsCancellationRequested) + { + return Options.CancellationToken; + } + else if (CancellationToken.IsCancellationRequested) + { + return CancellationToken; + } + return CancellationToken.None; + } + internal RpcException CreateRpcException(Status status) { // This code can be called from a background task. @@ -84,6 +108,23 @@ internal RpcException CreateRpcException(Status status) return new RpcException(status, trailers ?? Metadata.Empty); } + public Exception CreateFailureStatusException(Status status) + { + if (Channel.ThrowOperationCanceledOnCancellation && + (status.StatusCode == StatusCode.DeadlineExceeded || status.StatusCode == StatusCode.Cancelled)) + { + // Convert status response of DeadlineExceeded to OperationCanceledException when + // ThrowOperationCanceledOnCancellation is true. + // This avoids a race between the client-side timer and the server status throwing different + // errors on deadline exceeded. + return new OperationCanceledException(); + } + else + { + return CreateRpcException(status); + } + } + protected bool TryGetTrailers([NotNullWhen(true)] out Metadata? trailers) { if (Trailers == null) diff --git a/src/Grpc.Net.Client/Internal/GrpcCall.cs b/src/Grpc.Net.Client/Internal/GrpcCall.cs index 7ca2cb938..1e0b35a41 100644 --- a/src/Grpc.Net.Client/Internal/GrpcCall.cs +++ b/src/Grpc.Net.Client/Internal/GrpcCall.cs @@ -90,7 +90,7 @@ private void ValidateDeadline(DateTime? deadline) } } - public Task CallTask => _callTcs.Task; + public override Task CallTask => _callTcs.Task; public override CancellationToken CancellationToken => _callCts.Token; @@ -248,12 +248,6 @@ public void EnsureNotDisposed() } } - public Exception CreateCanceledStatusException(Exception? ex = null) - { - var status = (CallTask.IsCompletedSuccessfully()) ? CallTask.Result : new Status(StatusCode.Cancelled, string.Empty, ex); - return CreateRpcException(status); - } - private void FinishResponseAndCleanUp(Status status) { ResponseFinished = true; @@ -760,23 +754,6 @@ public Exception EnsureUserCancellationTokenReported(Exception ex, CancellationT return ex; } - public CancellationToken GetCanceledToken(CancellationToken methodCancellationToken) - { - if (methodCancellationToken.IsCancellationRequested) - { - return methodCancellationToken; - } - else if (Options.CancellationToken.IsCancellationRequested) - { - return Options.CancellationToken; - } - else if (CancellationToken.IsCancellationRequested) - { - return CancellationToken; - } - return CancellationToken.None; - } - private void SetFailedResult(Status status) { CompatibilityHelpers.Assert(_responseTcs != null); @@ -795,23 +772,6 @@ private void SetFailedResult(Status status) } } - public Exception CreateFailureStatusException(Status status) - { - if (Channel.ThrowOperationCanceledOnCancellation && - (status.StatusCode == StatusCode.DeadlineExceeded || status.StatusCode == StatusCode.Cancelled)) - { - // Convert status response of DeadlineExceeded to OperationCanceledException when - // ThrowOperationCanceledOnCancellation is true. - // This avoids a race between the client-side timer and the server status throwing different - // errors on deadline exceeded. - return new OperationCanceledException(); - } - else - { - return CreateRpcException(status); - } - } - private (bool diagnosticSourceEnabled, Activity? activity) InitializeCall(HttpRequestMessage request, TimeSpan? timeout) { GrpcCallLog.StartingCall(Logger, Method.Type, request.RequestUri!); diff --git a/src/Grpc.Net.Client/Internal/Http/PushStreamContent.cs b/src/Grpc.Net.Client/Internal/Http/PushStreamContent.cs index 163e581e2..16932242a 100644 --- a/src/Grpc.Net.Client/Internal/Http/PushStreamContent.cs +++ b/src/Grpc.Net.Client/Internal/Http/PushStreamContent.cs @@ -1,4 +1,4 @@ -#region Copyright notice and license +#region Copyright notice and license // Copyright 2019 The gRPC Authors // @@ -72,4 +72,10 @@ protected override bool TryComputeLength(out long length) // Hacky. ReadAsStreamAsync does not complete until SerializeToStreamAsync finishes. // WARNING: Will run SerializeToStreamAsync again on .NET Framework. internal Task PushComplete => ReadAsStreamAsync(); -} \ No newline at end of file + + // Internal for testing. + internal Task SerializeToStreamAsync(Stream stream) + { + return SerializeToStreamAsync(stream, context: null); + } +} diff --git a/src/Grpc.Net.Client/Internal/StreamExtensions.cs b/src/Grpc.Net.Client/Internal/StreamExtensions.cs index 0699e8713..2b8d2ac12 100644 --- a/src/Grpc.Net.Client/Internal/StreamExtensions.cs +++ b/src/Grpc.Net.Client/Internal/StreamExtensions.cs @@ -1,4 +1,4 @@ -#region Copyright notice and license +#region Copyright notice and license // Copyright 2019 The gRPC Authors // @@ -23,6 +23,7 @@ using System.Runtime.InteropServices; using Grpc.Core; using Grpc.Net.Compression; +using Grpc.Shared; using Microsoft.Extensions.Logging; #if NETSTANDARD2_0 @@ -318,10 +319,9 @@ public static async ValueTask WriteMessageAsync( { GrpcCallLog.ErrorSendingMessage(call.Logger, ex); - // Cancellation from disposing response while waiting for WriteAsync can throw ObjectDisposedException. - if (ex is ObjectDisposedException && call.CancellationToken.IsCancellationRequested) + if (TryCreateCallCompleteException(ex, call, out var statusException)) { - throw new OperationCanceledException(); + throw statusException; } throw; @@ -342,24 +342,41 @@ public static async ValueTask WriteMessageAsync( { GrpcCallLog.SendingMessage(call.Logger); - try - { - // Sending the header+content in a single WriteAsync call has significant performance benefits - // https://github.com/dotnet/runtime/issues/35184#issuecomment-626304981 - await stream.WriteAsync(data, cancellationToken).ConfigureAwait(false); - } - catch (ObjectDisposedException) when (call.CancellationToken.IsCancellationRequested) - { - // Cancellation from disposing response while waiting for WriteAsync can throw ObjectDisposedException. - throw new OperationCanceledException(); - } + // Sending the header+content in a single WriteAsync call has significant performance benefits + // https://github.com/dotnet/runtime/issues/35184#issuecomment-626304981 + await stream.WriteAsync(data, cancellationToken).ConfigureAwait(false); GrpcCallLog.MessageSent(call.Logger); } catch (Exception ex) { GrpcCallLog.ErrorSendingMessage(call.Logger, ex); + + if (TryCreateCallCompleteException(ex, call, out var statusException)) + { + throw statusException; + } + throw; } } + + private static bool TryCreateCallCompleteException(Exception originalException, GrpcCall call, [NotNullWhen(true)] out Exception? exception) + { + // The call may have been completed while WriteAsync was running and caused WriteAsync to throw. + // In this situation, report the call's completed status. + // + // Replace exception with the status error if: + // 1. The original exception is one Stream.WriteAsync throws if the call was completed during a write, and + // 2. The call has already been successfully completed. + if (originalException is OperationCanceledException or ObjectDisposedException && + call.CallTask.IsCompletedSuccessfully()) + { + exception = call.CreateFailureStatusException(call.CallTask.Result); + return true; + } + + exception = null; + return false; + } } diff --git a/test/FunctionalTests/Client/StreamingTests.cs b/test/FunctionalTests/Client/StreamingTests.cs index 291ca9878..8f6e731e4 100644 --- a/test/FunctionalTests/Client/StreamingTests.cs +++ b/test/FunctionalTests/Client/StreamingTests.cs @@ -1,4 +1,4 @@ -#region Copyright notice and license +#region Copyright notice and license // Copyright 2019 The gRPC Authors // diff --git a/test/Grpc.Net.Client.Tests/AsyncClientStreamingCallTests.cs b/test/Grpc.Net.Client.Tests/AsyncClientStreamingCallTests.cs index daac9fbac..9f41d6c21 100644 --- a/test/Grpc.Net.Client.Tests/AsyncClientStreamingCallTests.cs +++ b/test/Grpc.Net.Client.Tests/AsyncClientStreamingCallTests.cs @@ -1,4 +1,4 @@ -#region Copyright notice and license +#region Copyright notice and license // Copyright 2019 The gRPC Authors // @@ -107,7 +107,6 @@ public async Task AsyncClientStreamingCall_Success_RequestContentSent() var responseTask = call.ResponseAsync; Assert.IsFalse(responseTask.IsCompleted, "Response not returned until client stream is complete."); - await call.RequestStream.WriteAsync(new HelloRequest { Name = "1" }).DefaultTimeout(); await call.RequestStream.WriteAsync(new HelloRequest { Name = "2" }).DefaultTimeout(); @@ -268,6 +267,106 @@ public async Task ClientStreamWriter_WriteAfterResponseHasFinished_ErrorThrown() Assert.AreEqual("Hello world", result.Message); } + [Test] + public async Task AsyncClientStreamingCall_ErrorWhileWriting_StatusExceptionThrown() + { + // Arrange + PushStreamContent? content = null; + + var responseTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var httpClient = ClientTestHelpers.CreateTestClient(request => + { + content = (PushStreamContent)request.Content!; + return responseTcs.Task; + }); + + var invoker = HttpClientCallInvokerFactory.Create(httpClient); + + // Act + + // Client starts call + var call = invoker.AsyncClientStreamingCall(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions()); + // Client starts request stream write + var writeTask = call.RequestStream.WriteAsync(new HelloRequest()); + + // Simulate HttpClient starting to accept the write. Stream.WriteAsync is blocked. + var writeSyncPoint = new SyncPoint(runContinuationsAsynchronously: true); + var testStream = new TestStream(writeSyncPoint); + var serializeToStreamTask = content!.SerializeToStreamAsync(testStream); + + // Server completes response. + await writeSyncPoint.WaitForSyncPoint().DefaultTimeout(); + responseTcs.SetResult(ResponseUtils.CreateResponse(HttpStatusCode.OK, new ByteArrayContent(Array.Empty()), grpcStatusCode: StatusCode.InvalidArgument)); + + await ExceptionAssert.ThrowsAsync(() => call.ResponseAsync).DefaultTimeout(); + Assert.AreEqual(StatusCode.InvalidArgument, call.GetStatus().StatusCode); + + // Unblock Stream.WriteAsync + writeSyncPoint.Continue(); + + // Get error thrown from write task. It should have the status returned by the server. + var ex = await ExceptionAssert.ThrowsAsync(() => writeTask).DefaultTimeout(); + + // Assert + Assert.AreEqual(StatusCode.InvalidArgument, ex.StatusCode); + Assert.AreEqual(StatusCode.InvalidArgument, call.GetStatus().StatusCode); + Assert.AreEqual(string.Empty, call.GetStatus().Detail); + } + + private sealed class TestStream : Stream + { + private readonly SyncPoint _writeSyncPoint; + + public TestStream(SyncPoint writeSyncPoint) + { + _writeSyncPoint = writeSyncPoint; + } + + public override bool CanRead { get; } + public override bool CanSeek { get; } + public override bool CanWrite { get; } + public override long Length { get; } + public override long Position { get; set; } + + public override void Flush() + { + } + + public override int Read(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotImplementedException(); + } + + public override void SetLength(long value) + { + throw new NotImplementedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + +#if !NET472_OR_GREATER + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + await _writeSyncPoint.WaitToContinue(); + throw new OperationCanceledException(); + } +#else + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + await _writeSyncPoint.WaitToContinue(); + throw new OperationCanceledException(); + } +#endif + } + [Test] public async Task ClientStreamWriter_CancelledBeforeCallStarts_ThrowsError() { diff --git a/test/Grpc.Net.Client.Tests/GrpcCallSerializationContextTests.cs b/test/Grpc.Net.Client.Tests/GrpcCallSerializationContextTests.cs index 8db812370..1f803483f 100644 --- a/test/Grpc.Net.Client.Tests/GrpcCallSerializationContextTests.cs +++ b/test/Grpc.Net.Client.Tests/GrpcCallSerializationContextTests.cs @@ -1,4 +1,4 @@ -#region Copyright notice and license +#region Copyright notice and license // Copyright 2019 The gRPC Authors // @@ -318,6 +318,7 @@ public TestGrpcCall(CallOptions options, GrpcChannel channel) : base(options, ch public override Type RequestType { get; } = typeof(int); public override Type ResponseType { get; } = typeof(string); public override CancellationToken CancellationToken { get; } + public override Task CallTask => Task.FromResult(Status.DefaultCancelled); } private GrpcCallSerializationContext CreateSerializationContext(string? requestGrpcEncoding = null, int? maxSendMessageSize = null) diff --git a/test/Grpc.Net.Client.Tests/Infrastructure/StreamSerializationHelper.cs b/test/Grpc.Net.Client.Tests/Infrastructure/StreamSerializationHelper.cs index 03564627f..32ead947e 100644 --- a/test/Grpc.Net.Client.Tests/Infrastructure/StreamSerializationHelper.cs +++ b/test/Grpc.Net.Client.Tests/Infrastructure/StreamSerializationHelper.cs @@ -1,4 +1,4 @@ -#region Copyright notice and license +#region Copyright notice and license // Copyright 2019 The gRPC Authors // @@ -65,5 +65,6 @@ public TestGrpcCall(CallOptions options, GrpcChannel channel, Type type) : base( public override Type RequestType => _type; public override Type ResponseType => _type; public override CancellationToken CancellationToken { get; } + public override Task CallTask => Task.FromResult(Status.DefaultCancelled); } }