Skip to content

Commit

Permalink
WriteAsync cancellation throws an error with the calls completed stat…
Browse files Browse the repository at this point in the history
…us if possible (#2170)
  • Loading branch information
JamesNK authored Jun 22, 2023
1 parent 73c726b commit 2be676a
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 64 deletions.
43 changes: 42 additions & 1 deletion src/Grpc.Net.Client/Internal/GrpcCall.NonGeneric.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -53,6 +53,7 @@ public DefaultDeserializationContext DeserializationContext

public string? RequestGrpcEncoding { get; internal set; }

public abstract Task<Status> CallTask { get; }
public abstract CancellationToken CancellationToken { get; }
public abstract Type RequestType { get; }
public abstract Type ResponseType { get; }
Expand All @@ -64,6 +65,29 @@ protected GrpcCall(CallOptions options, GrpcChannel channel)
Logger = channel.LoggerFactory.CreateLogger(LoggerName);
}

public Exception CreateCanceledStatusException(Exception? ex = null)
{
var status = (CallTask.IsCompletedSuccessfully()) ? CallTask.Result : new Status(StatusCode.Cancelled, string.Empty, ex);
return CreateRpcException(status);
}

public CancellationToken GetCanceledToken(CancellationToken methodCancellationToken)
{
if (methodCancellationToken.IsCancellationRequested)
{
return methodCancellationToken;
}
else if (Options.CancellationToken.IsCancellationRequested)
{
return Options.CancellationToken;
}
else if (CancellationToken.IsCancellationRequested)
{
return CancellationToken;
}
return CancellationToken.None;
}

internal RpcException CreateRpcException(Status status)
{
// This code can be called from a background task.
Expand All @@ -84,6 +108,23 @@ internal RpcException CreateRpcException(Status status)
return new RpcException(status, trailers ?? Metadata.Empty);
}

public Exception CreateFailureStatusException(Status status)
{
if (Channel.ThrowOperationCanceledOnCancellation &&
(status.StatusCode == StatusCode.DeadlineExceeded || status.StatusCode == StatusCode.Cancelled))
{
// Convert status response of DeadlineExceeded to OperationCanceledException when
// ThrowOperationCanceledOnCancellation is true.
// This avoids a race between the client-side timer and the server status throwing different
// errors on deadline exceeded.
return new OperationCanceledException();
}
else
{
return CreateRpcException(status);
}
}

protected bool TryGetTrailers([NotNullWhen(true)] out Metadata? trailers)
{
if (Trailers == null)
Expand Down
42 changes: 1 addition & 41 deletions src/Grpc.Net.Client/Internal/GrpcCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ private void ValidateDeadline(DateTime? deadline)
}
}

public Task<Status> CallTask => _callTcs.Task;
public override Task<Status> CallTask => _callTcs.Task;

public override CancellationToken CancellationToken => _callCts.Token;

Expand Down Expand Up @@ -248,12 +248,6 @@ public void EnsureNotDisposed()
}
}

public Exception CreateCanceledStatusException(Exception? ex = null)
{
var status = (CallTask.IsCompletedSuccessfully()) ? CallTask.Result : new Status(StatusCode.Cancelled, string.Empty, ex);
return CreateRpcException(status);
}

private void FinishResponseAndCleanUp(Status status)
{
ResponseFinished = true;
Expand Down Expand Up @@ -760,23 +754,6 @@ public Exception EnsureUserCancellationTokenReported(Exception ex, CancellationT
return ex;
}

public CancellationToken GetCanceledToken(CancellationToken methodCancellationToken)
{
if (methodCancellationToken.IsCancellationRequested)
{
return methodCancellationToken;
}
else if (Options.CancellationToken.IsCancellationRequested)
{
return Options.CancellationToken;
}
else if (CancellationToken.IsCancellationRequested)
{
return CancellationToken;
}
return CancellationToken.None;
}

private void SetFailedResult(Status status)
{
CompatibilityHelpers.Assert(_responseTcs != null);
Expand All @@ -795,23 +772,6 @@ private void SetFailedResult(Status status)
}
}

public Exception CreateFailureStatusException(Status status)
{
if (Channel.ThrowOperationCanceledOnCancellation &&
(status.StatusCode == StatusCode.DeadlineExceeded || status.StatusCode == StatusCode.Cancelled))
{
// Convert status response of DeadlineExceeded to OperationCanceledException when
// ThrowOperationCanceledOnCancellation is true.
// This avoids a race between the client-side timer and the server status throwing different
// errors on deadline exceeded.
return new OperationCanceledException();
}
else
{
return CreateRpcException(status);
}
}

private (bool diagnosticSourceEnabled, Activity? activity) InitializeCall(HttpRequestMessage request, TimeSpan? timeout)
{
GrpcCallLog.StartingCall(Logger, Method.Type, request.RequestUri!);
Expand Down
10 changes: 8 additions & 2 deletions src/Grpc.Net.Client/Internal/Http/PushStreamContent.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -72,4 +72,10 @@ protected override bool TryComputeLength(out long length)
// Hacky. ReadAsStreamAsync does not complete until SerializeToStreamAsync finishes.
// WARNING: Will run SerializeToStreamAsync again on .NET Framework.
internal Task PushComplete => ReadAsStreamAsync();
}

// Internal for testing.
internal Task SerializeToStreamAsync(Stream stream)
{
return SerializeToStreamAsync(stream, context: null);
}
}
47 changes: 32 additions & 15 deletions src/Grpc.Net.Client/Internal/StreamExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand All @@ -23,6 +23,7 @@
using System.Runtime.InteropServices;
using Grpc.Core;
using Grpc.Net.Compression;
using Grpc.Shared;
using Microsoft.Extensions.Logging;

#if NETSTANDARD2_0
Expand Down Expand Up @@ -318,10 +319,9 @@ public static async ValueTask WriteMessageAsync<TMessage>(
{
GrpcCallLog.ErrorSendingMessage(call.Logger, ex);

// Cancellation from disposing response while waiting for WriteAsync can throw ObjectDisposedException.
if (ex is ObjectDisposedException && call.CancellationToken.IsCancellationRequested)
if (TryCreateCallCompleteException(ex, call, out var statusException))
{
throw new OperationCanceledException();
throw statusException;
}

throw;
Expand All @@ -342,24 +342,41 @@ public static async ValueTask WriteMessageAsync(
{
GrpcCallLog.SendingMessage(call.Logger);

try
{
// Sending the header+content in a single WriteAsync call has significant performance benefits
// https://github.com/dotnet/runtime/issues/35184#issuecomment-626304981
await stream.WriteAsync(data, cancellationToken).ConfigureAwait(false);
}
catch (ObjectDisposedException) when (call.CancellationToken.IsCancellationRequested)
{
// Cancellation from disposing response while waiting for WriteAsync can throw ObjectDisposedException.
throw new OperationCanceledException();
}
// Sending the header+content in a single WriteAsync call has significant performance benefits
// https://github.com/dotnet/runtime/issues/35184#issuecomment-626304981
await stream.WriteAsync(data, cancellationToken).ConfigureAwait(false);

GrpcCallLog.MessageSent(call.Logger);
}
catch (Exception ex)
{
GrpcCallLog.ErrorSendingMessage(call.Logger, ex);

if (TryCreateCallCompleteException(ex, call, out var statusException))
{
throw statusException;
}

throw;
}
}

private static bool TryCreateCallCompleteException(Exception originalException, GrpcCall call, [NotNullWhen(true)] out Exception? exception)
{
// The call may have been completed while WriteAsync was running and caused WriteAsync to throw.
// In this situation, report the call's completed status.
//
// Replace exception with the status error if:
// 1. The original exception is one Stream.WriteAsync throws if the call was completed during a write, and
// 2. The call has already been successfully completed.
if (originalException is OperationCanceledException or ObjectDisposedException &&
call.CallTask.IsCompletedSuccessfully())
{
exception = call.CreateFailureStatusException(call.CallTask.Result);
return true;
}

exception = null;
return false;
}
}
2 changes: 1 addition & 1 deletion test/FunctionalTests/Client/StreamingTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down
103 changes: 101 additions & 2 deletions test/Grpc.Net.Client.Tests/AsyncClientStreamingCallTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -107,7 +107,6 @@ public async Task AsyncClientStreamingCall_Success_RequestContentSent()
var responseTask = call.ResponseAsync;
Assert.IsFalse(responseTask.IsCompleted, "Response not returned until client stream is complete.");


await call.RequestStream.WriteAsync(new HelloRequest { Name = "1" }).DefaultTimeout();
await call.RequestStream.WriteAsync(new HelloRequest { Name = "2" }).DefaultTimeout();

Expand Down Expand Up @@ -268,6 +267,106 @@ public async Task ClientStreamWriter_WriteAfterResponseHasFinished_ErrorThrown()
Assert.AreEqual("Hello world", result.Message);
}

[Test]
public async Task AsyncClientStreamingCall_ErrorWhileWriting_StatusExceptionThrown()
{
// Arrange
PushStreamContent<HelloRequest, HelloReply>? content = null;

var responseTcs = new TaskCompletionSource<HttpResponseMessage>(TaskCreationOptions.RunContinuationsAsynchronously);
var httpClient = ClientTestHelpers.CreateTestClient(request =>
{
content = (PushStreamContent<HelloRequest, HelloReply>)request.Content!;
return responseTcs.Task;
});

var invoker = HttpClientCallInvokerFactory.Create(httpClient);

// Act

// Client starts call
var call = invoker.AsyncClientStreamingCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions());
// Client starts request stream write
var writeTask = call.RequestStream.WriteAsync(new HelloRequest());

// Simulate HttpClient starting to accept the write. Stream.WriteAsync is blocked.
var writeSyncPoint = new SyncPoint(runContinuationsAsynchronously: true);
var testStream = new TestStream(writeSyncPoint);
var serializeToStreamTask = content!.SerializeToStreamAsync(testStream);

// Server completes response.
await writeSyncPoint.WaitForSyncPoint().DefaultTimeout();
responseTcs.SetResult(ResponseUtils.CreateResponse(HttpStatusCode.OK, new ByteArrayContent(Array.Empty<byte>()), grpcStatusCode: StatusCode.InvalidArgument));

await ExceptionAssert.ThrowsAsync<RpcException>(() => call.ResponseAsync).DefaultTimeout();
Assert.AreEqual(StatusCode.InvalidArgument, call.GetStatus().StatusCode);

// Unblock Stream.WriteAsync
writeSyncPoint.Continue();

// Get error thrown from write task. It should have the status returned by the server.
var ex = await ExceptionAssert.ThrowsAsync<RpcException>(() => writeTask).DefaultTimeout();

// Assert
Assert.AreEqual(StatusCode.InvalidArgument, ex.StatusCode);
Assert.AreEqual(StatusCode.InvalidArgument, call.GetStatus().StatusCode);
Assert.AreEqual(string.Empty, call.GetStatus().Detail);
}

private sealed class TestStream : Stream
{
private readonly SyncPoint _writeSyncPoint;

public TestStream(SyncPoint writeSyncPoint)
{
_writeSyncPoint = writeSyncPoint;
}

public override bool CanRead { get; }
public override bool CanSeek { get; }
public override bool CanWrite { get; }
public override long Length { get; }
public override long Position { get; set; }

public override void Flush()
{
}

public override int Read(byte[] buffer, int offset, int count)
{
throw new NotImplementedException();
}

public override long Seek(long offset, SeekOrigin origin)
{
throw new NotImplementedException();
}

public override void SetLength(long value)
{
throw new NotImplementedException();
}

public override void Write(byte[] buffer, int offset, int count)
{
throw new NotImplementedException();
}

#if !NET472_OR_GREATER
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
await _writeSyncPoint.WaitToContinue();
throw new OperationCanceledException();
}
#else
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeSyncPoint.WaitToContinue();
throw new OperationCanceledException();
}
#endif
}

[Test]
public async Task ClientStreamWriter_CancelledBeforeCallStarts_ThrowsError()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -318,6 +318,7 @@ public TestGrpcCall(CallOptions options, GrpcChannel channel) : base(options, ch
public override Type RequestType { get; } = typeof(int);
public override Type ResponseType { get; } = typeof(string);
public override CancellationToken CancellationToken { get; }
public override Task<Status> CallTask => Task.FromResult(Status.DefaultCancelled);
}

private GrpcCallSerializationContext CreateSerializationContext(string? requestGrpcEncoding = null, int? maxSendMessageSize = null)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -65,5 +65,6 @@ public TestGrpcCall(CallOptions options, GrpcChannel channel, Type type) : base(
public override Type RequestType => _type;
public override Type ResponseType => _type;
public override CancellationToken CancellationToken { get; }
public override Task<Status> CallTask => Task.FromResult(Status.DefaultCancelled);
}
}

0 comments on commit 2be676a

Please sign in to comment.