From 62927b4bee240b3c6c5fc13288a986dde5f0f403 Mon Sep 17 00:00:00 2001 From: James Newton-King Date: Wed, 11 Jan 2023 10:35:00 +0800 Subject: [PATCH] Support zero-byte read in gRPC client (#1985) --- build/dependencies.props | 4 +- examples/Interceptor/Client/Client.csproj | 4 +- .../Internal/Base64ResponseStream.cs | 8 + .../Internal/GrpcWebResponseStream.cs | 150 ++++++++++++------ src/Grpc.Net.Client/GrpcChannel.cs | 80 +++++----- .../Internal/StreamExtensions.cs | 7 + .../PipeExtensionsTestsBase.cs | 4 +- .../AsyncClientStreamingCallTests.cs | 2 +- .../AsyncDuplexStreamingCallTests.cs | 2 +- .../AsyncServerStreamingCallTests.cs | 2 +- .../Grpc.Net.Client.Tests/GetTrailersTests.cs | 2 +- .../Grpc.Net.Client.Tests/Retry/RetryTests.cs | 2 +- .../GrpcWebResponseStreamTests.cs | 142 +++++++++++++++++ test/Shared/SyncPointMemoryStream.cs | 23 +++ 14 files changed, 336 insertions(+), 96 deletions(-) diff --git a/build/dependencies.props b/build/dependencies.props index 4036d132b..5b96e3e94 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -7,7 +7,7 @@ 2.46.5 2.51.0 7.0.0 - 6.0.0 + 6.0.11 5.0.3 3.1.3 1.5.5 @@ -32,7 +32,7 @@ 4.5.1 5.0.1 4.5.3 - 6.0.1 + 7.0.0 4.7.0 4.6.0 diff --git a/examples/Interceptor/Client/Client.csproj b/examples/Interceptor/Client/Client.csproj index 87dd6e0ed..7e8916529 100644 --- a/examples/Interceptor/Client/Client.csproj +++ b/examples/Interceptor/Client/Client.csproj @@ -11,7 +11,7 @@ - - + + diff --git a/src/Grpc.Net.Client.Web/Internal/Base64ResponseStream.cs b/src/Grpc.Net.Client.Web/Internal/Base64ResponseStream.cs index 0a679306a..ec37ae87a 100644 --- a/src/Grpc.Net.Client.Web/Internal/Base64ResponseStream.cs +++ b/src/Grpc.Net.Client.Web/Internal/Base64ResponseStream.cs @@ -46,6 +46,14 @@ public override async ValueTask ReadAsync(Memory data, CancellationTo var data = buffer.AsMemory(offset, count); #endif + // Handle zero byte reads. + if (data.Length == 0) + { + var read = await StreamHelpers.ReadAsync(_inner, data, cancellationToken).ConfigureAwait(false); + Debug.Assert(read == 0); + return 0; + } + // There is enough remaining data to fill passed in data if (data.Length <= _remainder) { diff --git a/src/Grpc.Net.Client.Web/Internal/GrpcWebResponseStream.cs b/src/Grpc.Net.Client.Web/Internal/GrpcWebResponseStream.cs index 81bc65d2f..3f29028d5 100644 --- a/src/Grpc.Net.Client.Web/Internal/GrpcWebResponseStream.cs +++ b/src/Grpc.Net.Client.Web/Internal/GrpcWebResponseStream.cs @@ -17,6 +17,8 @@ #endregion using System.Buffers.Binary; +using System.Data; +using System.Diagnostics; using System.Net.Http.Headers; using System.Text; @@ -29,13 +31,15 @@ namespace Grpc.Net.Client.Web.Internal; /// internal class GrpcWebResponseStream : Stream { - // This uses C# compiler's ability to refer to static data directly. For more information see https://vcsjones.dev/2019/02/01/csharp-readonly-span-bytes-static - private static ReadOnlySpan BytesNewLine => new byte[] { (byte)'\r', (byte)'\n' }; + private const int HeaderLength = 5; private readonly Stream _inner; private readonly HttpHeaders _responseTrailers; - private int _contentRemaining; - private ResponseState _state; + private byte[]? _headerBuffer; + + // Internal for testing + internal ResponseState _state; + internal int _contentRemaining; public GrpcWebResponseStream(Stream inner, HttpHeaders responseTrailers) { @@ -52,59 +56,113 @@ public override async ValueTask ReadAsync(Memory data, CancellationTo #if NETSTANDARD2_0 var data = buffer.AsMemory(offset, count); #endif + var headerBuffer = Memory.Empty; + + if (data.Length == 0) + { + // Handle zero byte reads. + var read = await StreamHelpers.ReadAsync(_inner, data, cancellationToken).ConfigureAwait(false); + Debug.Assert(read == 0); + return 0; + } switch (_state) { case ResponseState.Ready: - // Read the header first - // - 1 byte flag for compression - // - 4 bytes for the content length - Memory headerBuffer; - - if (data.Length >= 5) { - headerBuffer = data.Slice(0, 5); + // Read the header first + // - 1 byte flag for compression + // - 4 bytes for the content length + _contentRemaining = HeaderLength; + _state = ResponseState.Header; + goto case ResponseState.Header; } - else + case ResponseState.Header: { - // Should never get here. Client always passes 5 to read the header. - throw new InvalidOperationException("Buffer is not large enough for header"); + Debug.Assert(_contentRemaining > 0); + + headerBuffer = data.Length >= _contentRemaining ? data.Slice(0, _contentRemaining) : data; + var success = await TryReadDataAsync(_inner, headerBuffer, cancellationToken).ConfigureAwait(false); + if (!success) + { + return 0; + } + + // On first read of header data, check first byte to see if this is a trailer. + if (_contentRemaining == HeaderLength) + { + var compressed = headerBuffer.Span[0]; + var isTrailer = IsBitSet(compressed, pos: 7); + if (isTrailer) + { + _state = ResponseState.Trailer; + goto case ResponseState.Trailer; + } + } + + var read = headerBuffer.Length; + + // The buffer was less than header length either because this is the first read and the passed in buffer is small, + // or it is an additonal read to finish getting header data. + if (headerBuffer.Length < HeaderLength) + { + _headerBuffer ??= new byte[HeaderLength]; + headerBuffer.CopyTo(_headerBuffer.AsMemory(HeaderLength - _contentRemaining)); + + _contentRemaining -= headerBuffer.Length; + if (_contentRemaining > 0) + { + return read; + } + + headerBuffer = _headerBuffer; + } + + var length = (int)BinaryPrimitives.ReadUInt32BigEndian(headerBuffer.Span.Slice(1)); + + _contentRemaining = length; + // If there is no content then state is reset to ready. + _state = _contentRemaining > 0 ? ResponseState.Content : ResponseState.Ready; + return read; } - - var success = await TryReadDataAsync(_inner, headerBuffer, cancellationToken).ConfigureAwait(false); - if (!success) + case ResponseState.Content: { - return 0; + if (data.Length >= _contentRemaining) + { + data = data.Slice(0, _contentRemaining); + } + + var read = await StreamHelpers.ReadAsync(_inner, data, cancellationToken).ConfigureAwait(false); + _contentRemaining -= read; + if (_contentRemaining == 0) + { + _state = ResponseState.Ready; + } + + return read; } - - var compressed = headerBuffer.Span[0]; - var length = (int)BinaryPrimitives.ReadUInt32BigEndian(headerBuffer.Span.Slice(1)); - - var isTrailer = IsBitSet(compressed, pos: 7); - if (isTrailer) + case ResponseState.Trailer: { + Debug.Assert(headerBuffer.Length > 0); + + // The trailer needs to be completely read before returning 0 to the caller. + // Ensure buffer is large enough to contain the trailer header. + if (headerBuffer.Length < HeaderLength) + { + var newBuffer = new byte[HeaderLength]; + headerBuffer.CopyTo(newBuffer); + var success = await TryReadDataAsync(_inner, newBuffer.AsMemory(headerBuffer.Length), cancellationToken).ConfigureAwait(false); + if (!success) + { + return 0; + } + headerBuffer = newBuffer; + } + var length = (int)BinaryPrimitives.ReadUInt32BigEndian(headerBuffer.Span.Slice(1)); + await ReadTrailersAsync(length, data, cancellationToken).ConfigureAwait(false); return 0; } - - _contentRemaining = length; - // If there is no content then state is still ready - _state = _contentRemaining > 0 ? ResponseState.Content : ResponseState.Ready; - return 5; - case ResponseState.Content: - if (data.Length >= _contentRemaining) - { - data = data.Slice(0, _contentRemaining); - } - - var read = await StreamHelpers.ReadAsync(_inner, data, cancellationToken).ConfigureAwait(false); - _contentRemaining -= read; - if (_contentRemaining == 0) - { - _state = ResponseState.Ready; - } - - return read; default: throw new InvalidOperationException("Unexpected state."); } @@ -163,7 +221,7 @@ private void ParseTrailers(ReadOnlySpan span) { ReadOnlySpan line; - var lineEndIndex = remainingContent.IndexOf(BytesNewLine); + var lineEndIndex = remainingContent.IndexOf("\r\n"u8); if (lineEndIndex == -1) { line = remainingContent; @@ -257,10 +315,12 @@ private static async Task TryReadDataAsync(Stream responseStream, Memory(new SubChannelTransportFactory(this)); - if (!IsHttpOrHttpsAddress(Address) || channelOptions.ServiceConfig?.LoadBalancingConfigs.Count > 0) + if (!IsHttpOrHttpsAddress(Address) || channelOptions.ServiceConfig?.LoadBalancingConfigs.Count > 0) { ValidateHttpHandlerSupportsConnectivity(); } @@ -215,12 +215,12 @@ private void ResolveCredentials(GrpcChannelOptions channelOptions, out bool isSe } } - private static bool IsHttpOrHttpsAddress(Uri address) + private static bool IsHttpOrHttpsAddress(Uri address) { - return address.Scheme == Uri.UriSchemeHttps || address.Scheme == Uri.UriSchemeHttp; + return address.Scheme == Uri.UriSchemeHttps || address.Scheme == Uri.UriSchemeHttp; } - private static HttpHandlerContext CalculateHandlerContext(Uri address, bool isSecure, GrpcChannelOptions channelOptions) + private static HttpHandlerContext CalculateHandlerContext(Uri address, bool isSecure, GrpcChannelOptions channelOptions) { if (channelOptions.HttpHandler == null) { @@ -261,17 +261,17 @@ private static HttpHandlerContext CalculateHandlerContext(Uri address, bool isSe } } - // If a proxy is specified then requests could be sent via an SSL tunnel. - // A CONNECT request is made to the proxy to establish the transport stream and then - // gRPC calls are sent via stream. This feature isn't supported by load balancer. - // Proxy can be specified via: - // - SocketsHttpHandler.Proxy. Set via app code. - // - HttpClient.DefaultProxy. Set via environment variables, e.g. HTTPS_PROXY. - if (IsProxied(socketsHttpHandler, address, isSecure)) - { - type = HttpHandlerType.Custom; - connectTimeout = null; - } + // If a proxy is specified then requests could be sent via an SSL tunnel. + // A CONNECT request is made to the proxy to establish the transport stream and then + // gRPC calls are sent via stream. This feature isn't supported by load balancer. + // Proxy can be specified via: + // - SocketsHttpHandler.Proxy. Set via app code. + // - HttpClient.DefaultProxy. Set via environment variables, e.g. HTTPS_PROXY. + if (IsProxied(socketsHttpHandler, address, isSecure)) + { + type = HttpHandlerType.Custom; + connectTimeout = null; + } #else type = HttpHandlerType.SocketsHttpHandler; connectTimeout = null; @@ -287,30 +287,30 @@ private static HttpHandlerContext CalculateHandlerContext(Uri address, bool isSe } #if NET5_0_OR_GREATER - private static readonly Uri HttpLoadBalancerTemporaryUri = new Uri("http://loadbalancer.temporary.invalid"); - private static readonly Uri HttpsLoadBalancerTemporaryUri = new Uri("https://loadbalancer.temporary.invalid"); + private static readonly Uri HttpLoadBalancerTemporaryUri = new Uri("http://loadbalancer.temporary.invalid"); + private static readonly Uri HttpsLoadBalancerTemporaryUri = new Uri("https://loadbalancer.temporary.invalid"); - private static bool IsProxied(SocketsHttpHandler socketsHttpHandler, Uri address, bool isSecure) + private static bool IsProxied(SocketsHttpHandler socketsHttpHandler, Uri address, bool isSecure) + { + // Check standard address directly. + // When load balancing the channel doesn't know the final addresses yet so use temporary address. + Uri resolvedAddress; + if (IsHttpOrHttpsAddress(address)) { - // Check standard address directly. - // When load balancing the channel doesn't know the final addresses yet so use temporary address. - Uri resolvedAddress; - if (IsHttpOrHttpsAddress(address)) - { - resolvedAddress = address; - } - else if (isSecure) - { - resolvedAddress = HttpsLoadBalancerTemporaryUri; - } - else - { - resolvedAddress = HttpLoadBalancerTemporaryUri; - } - - var proxy = socketsHttpHandler.Proxy ?? HttpClient.DefaultProxy; - return proxy.GetProxy(resolvedAddress) != null; + resolvedAddress = address; + } + else if (isSecure) + { + resolvedAddress = HttpsLoadBalancerTemporaryUri; + } + else + { + resolvedAddress = HttpLoadBalancerTemporaryUri; } + + var proxy = socketsHttpHandler.Proxy ?? HttpClient.DefaultProxy; + return proxy.GetProxy(resolvedAddress) != null; + } #endif #if SUPPORT_LOAD_BALANCING @@ -321,7 +321,7 @@ private ResolverFactory GetResolverFactory(GrpcChannelOptions options) // // Even with just one address we still want to use the load balancing infrastructure. This enables // the connectivity APIs on channel like GrpcChannel.State and GrpcChannel.WaitForStateChanged. - if (IsHttpOrHttpsAddress(Address)) + if (IsHttpOrHttpsAddress(Address)) { return new StaticResolverFactory(uri => new[] { new BalancerAddress(Address.Host, Address.Port) }); } @@ -370,7 +370,7 @@ private LoadBalancerFactory[] ResolveLoadBalancerFactories(GrpcChannelOptions ch { return serviceFactories.Union(LoadBalancerFactory.KnownLoadBalancerFactories).ToArray(); } - + return LoadBalancerFactory.KnownLoadBalancerFactories; } #endif @@ -436,7 +436,7 @@ private HttpMessageInvoker CreateInternalHttpInvoker(HttpMessageHandler? handler { // GetHttpHandlerType recurses through DelegatingHandlers that may wrap the HttpClientHandler. var httpClientHandler = HttpRequestHelpers.GetHttpHandlerType(handler); - + if (httpClientHandler != null && RuntimeHelpers.QueryRuntimeSettingSwitch("System.Net.Http.UseNativeHttpHandler", defaultValue: false)) { throw new InvalidOperationException("The channel configuration isn't valid on Android devices. " + diff --git a/src/Grpc.Net.Client/Internal/StreamExtensions.cs b/src/Grpc.Net.Client/Internal/StreamExtensions.cs index 60f22f75b..399492bed 100644 --- a/src/Grpc.Net.Client/Internal/StreamExtensions.cs +++ b/src/Grpc.Net.Client/Internal/StreamExtensions.cs @@ -62,6 +62,13 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport GrpcCallLog.ReadingMessage(call.Logger); cancellationToken.ThrowIfCancellationRequested(); +#if NET6_0_OR_GREATER + // Start with zero-byte read. + // A zero-byte read avoids renting buffer until the response is ready. Especially useful for long running streaming calls. + var readCount = await responseStream.ReadAsync(Memory.Empty, cancellationToken).ConfigureAwait(false); + Debug.Assert(readCount == 0); +#endif + // Buffer is used to read header, then message content. // This size was randomly chosen to hopefully be big enough for many small messages. // If the message is larger then the array will be replaced when the message size is known. diff --git a/test/Grpc.AspNetCore.Server.Tests/PipeExtensionsTestsBase.cs b/test/Grpc.AspNetCore.Server.Tests/PipeExtensionsTestsBase.cs index fc64f463e..d12ba5f3e 100644 --- a/test/Grpc.AspNetCore.Server.Tests/PipeExtensionsTestsBase.cs +++ b/test/Grpc.AspNetCore.Server.Tests/PipeExtensionsTestsBase.cs @@ -275,7 +275,7 @@ public async Task ReadStreamMessageAsync_MessageSplitAcrossReadsWithAdditionalDa // Act 3 var messageData3Task = pipeReader.ReadStreamMessageAsync(testServerCallContext, TestDataMarshaller.ContextualDeserializer).AsTask(); - await requestStream.AddDataAndWait(Array.Empty()).DefaultTimeout(); + await requestStream.EndStreamAndWait().DefaultTimeout(); // Assert 3 var ex = await ExceptionAssert.ThrowsAsync(() => messageData3Task).DefaultTimeout(); @@ -436,7 +436,7 @@ public async Task ReadSingleMessageAsync_MessageInMultiplePipeReads_ReadMessageD } } - await requestStream.AddDataAndWait(Array.Empty()).DefaultTimeout(); + await requestStream.EndStreamAndWait().DefaultTimeout(); var readMessageData = await readTask.DefaultTimeout(); diff --git a/test/Grpc.Net.Client.Tests/AsyncClientStreamingCallTests.cs b/test/Grpc.Net.Client.Tests/AsyncClientStreamingCallTests.cs index f9adc1400..daac9fbac 100644 --- a/test/Grpc.Net.Client.Tests/AsyncClientStreamingCallTests.cs +++ b/test/Grpc.Net.Client.Tests/AsyncClientStreamingCallTests.cs @@ -209,7 +209,7 @@ await streamContent.AddDataAndWait(await ClientTestHelpers.GetResponseDataAsync( { Message = "Hello world 1" }).DefaultTimeout()).DefaultTimeout(); - await streamContent.AddDataAndWait(Array.Empty()); + await streamContent.EndStreamAndWait(); var result = await resultTask.DefaultTimeout(); Assert.AreEqual("Hello world 1", result.Message); diff --git a/test/Grpc.Net.Client.Tests/AsyncDuplexStreamingCallTests.cs b/test/Grpc.Net.Client.Tests/AsyncDuplexStreamingCallTests.cs index 5dd8f95f1..53aae9f99 100644 --- a/test/Grpc.Net.Client.Tests/AsyncDuplexStreamingCallTests.cs +++ b/test/Grpc.Net.Client.Tests/AsyncDuplexStreamingCallTests.cs @@ -169,7 +169,7 @@ await streamContent.AddDataAndWait(await ClientTestHelpers.GetResponseDataAsync( var moveNextTask3 = responseStream.MoveNext(CancellationToken.None); Assert.IsFalse(moveNextTask3.IsCompleted); - await streamContent.AddDataAndWait(Array.Empty()).DefaultTimeout(); + await streamContent.EndStreamAndWait().DefaultTimeout(); Assert.IsFalse(await moveNextTask3.DefaultTimeout()); diff --git a/test/Grpc.Net.Client.Tests/AsyncServerStreamingCallTests.cs b/test/Grpc.Net.Client.Tests/AsyncServerStreamingCallTests.cs index 89d9d7dd7..5b0d696d6 100644 --- a/test/Grpc.Net.Client.Tests/AsyncServerStreamingCallTests.cs +++ b/test/Grpc.Net.Client.Tests/AsyncServerStreamingCallTests.cs @@ -137,7 +137,7 @@ await streamContent.AddDataAndWait(await ClientTestHelpers.GetResponseDataAsync( var moveNextTask3 = responseStream.MoveNext(CancellationToken.None); Assert.IsFalse(moveNextTask3.IsCompleted); - await streamContent.AddDataAndWait(Array.Empty()).DefaultTimeout(); + await streamContent.EndStreamAndWait().DefaultTimeout(); Assert.IsFalse(await moveNextTask3.DefaultTimeout()); diff --git a/test/Grpc.Net.Client.Tests/GetTrailersTests.cs b/test/Grpc.Net.Client.Tests/GetTrailersTests.cs index 8f77a14bb..ee022613a 100644 --- a/test/Grpc.Net.Client.Tests/GetTrailersTests.cs +++ b/test/Grpc.Net.Client.Tests/GetTrailersTests.cs @@ -359,7 +359,7 @@ public async Task AsyncClientStreamingCall_CompleteWriter_ReturnsTrailers() var messageData = await ClientTestHelpers.GetResponseDataAsync(new HelloReply { Message = "Hello world" }).DefaultTimeout(); await stream.AddDataAndWait(messageData).DefaultTimeout(); - await stream.AddDataAndWait(Array.Empty()).DefaultTimeout(); + await stream.EndStreamAndWait().DefaultTimeout(); response.TrailingHeaders().Add("custom-header", "value"); trailingHeadersWrittenTcs.SetResult(true); diff --git a/test/Grpc.Net.Client.Tests/Retry/RetryTests.cs b/test/Grpc.Net.Client.Tests/Retry/RetryTests.cs index f3db8143a..f65a82d61 100644 --- a/test/Grpc.Net.Client.Tests/Retry/RetryTests.cs +++ b/test/Grpc.Net.Client.Tests/Retry/RetryTests.cs @@ -611,7 +611,7 @@ await streamContent.AddDataAndWait(await ClientTestHelpers.GetResponseDataAsync( { Message = "Hello world 1" }).DefaultTimeout()).DefaultTimeout(); - await streamContent.AddDataAndWait(Array.Empty()); + await streamContent.EndStreamAndWait(); var result = await resultTask.DefaultTimeout(); Assert.AreEqual("Hello world 1", result.Message); diff --git a/test/Grpc.Net.Client.Web.Tests/GrpcWebResponseStreamTests.cs b/test/Grpc.Net.Client.Web.Tests/GrpcWebResponseStreamTests.cs index f89b5275b..a33261af7 100644 --- a/test/Grpc.Net.Client.Web.Tests/GrpcWebResponseStreamTests.cs +++ b/test/Grpc.Net.Client.Web.Tests/GrpcWebResponseStreamTests.cs @@ -59,6 +59,110 @@ public async Task ReadAsync_EmptyMessage_ParseMessageAndTrailers() Assert.AreEqual("0", trailingHeaders.GetValues("grpc-status").Single()); } + [Test] + public async Task ReadAsync_HasMessage_OneByteBuffer_ParseMessageAndTrailers() + { + // Arrange + var header = new byte[] { 0, 0, 0, 1, 2 }; + var content = new byte[258]; + for (var i = 0; i < content.Length; i++) + { + content[i] = (byte)(i % byte.MaxValue); + } + var trailer = new byte[] { 128, 0, 0, 0, 16, 13, 10, 103, 114, 112, 99, 45, 115, 116, 97, 116, 117, 115, 58, 32, 48 }; + var data = header.Concat(content).Concat(trailer).ToArray(); + + var trailingHeaders = new TestHttpHeaders(); + var ms = new MemoryStream(data); + var responseStream = new GrpcWebResponseStream(ms, trailingHeaders); + + // Act & Assert header + var contentHeaderData = new byte[1]; + + Assert.AreEqual(1, await ReadAsync(responseStream, contentHeaderData)); + Assert.AreEqual(1, await ReadAsync(responseStream, contentHeaderData)); + Assert.AreEqual(1, await ReadAsync(responseStream, contentHeaderData)); + Assert.AreEqual(1, await ReadAsync(responseStream, contentHeaderData)); + Assert.AreEqual(1, await ReadAsync(responseStream, contentHeaderData)); + Assert.AreEqual(258, responseStream._contentRemaining); + Assert.AreEqual(GrpcWebResponseStream.ResponseState.Content, responseStream._state); + + // Act & Assert content + var readContent = new List(); + while (responseStream._contentRemaining > 0) + { + Assert.AreEqual(1, await ReadAsync(responseStream, contentHeaderData)); + readContent.Add(contentHeaderData[0]); + } + + CollectionAssert.AreEqual(content, readContent); + + // Act trailer + var read2 = await ReadAsync(responseStream, contentHeaderData); + + // Assert trailer + Assert.AreEqual(0, read2); + Assert.AreEqual(1, trailingHeaders.Count()); + Assert.AreEqual("0", trailingHeaders.GetValues("grpc-status").Single()); + } + + [Test] + public async Task ReadAsync_HasMessage_ZeroAndOneByteBuffer_ParseMessageAndTrailers() + { + // Arrange + var header = new byte[] { 0, 0, 0, 1, 2 }; + var content = new byte[258]; + for (var i = 0; i < content.Length; i++) + { + content[i] = (byte)(i % byte.MaxValue); + } + var trailer = new byte[] { 128, 0, 0, 0, 16, 13, 10, 103, 114, 112, 99, 45, 115, 116, 97, 116, 117, 115, 58, 32, 48 }; + var data = header.Concat(content).Concat(trailer).ToArray(); + + var trailingHeaders = new TestHttpHeaders(); + var ms = new MemoryStream(data); + var responseStream = new GrpcWebResponseStream(ms, trailingHeaders); + + // Act & Assert header + var contentHeaderData = new byte[1]; + + Assert.AreEqual(1, await ZeroAndContentReadAsync(responseStream, contentHeaderData)); + Assert.AreEqual(1, await ZeroAndContentReadAsync(responseStream, contentHeaderData)); + Assert.AreEqual(1, await ZeroAndContentReadAsync(responseStream, contentHeaderData)); + Assert.AreEqual(1, await ZeroAndContentReadAsync(responseStream, contentHeaderData)); + Assert.AreEqual(1, await ZeroAndContentReadAsync(responseStream, contentHeaderData)); + Assert.AreEqual(258, responseStream._contentRemaining); + Assert.AreEqual(GrpcWebResponseStream.ResponseState.Content, responseStream._state); + + // Act & Assert content + var readContent = new List(); + while (responseStream._contentRemaining > 0) + { + Assert.AreEqual(1, await ZeroAndContentReadAsync(responseStream, contentHeaderData)); + readContent.Add(contentHeaderData[0]); + } + + CollectionAssert.AreEqual(content, readContent); + + // Act trailer + var read2 = await ZeroAndContentReadAsync(responseStream, contentHeaderData); + + // Assert trailer + Assert.AreEqual(0, read2); + Assert.AreEqual(1, trailingHeaders.Count()); + Assert.AreEqual("0", trailingHeaders.GetValues("grpc-status").Single()); + + static async Task ZeroAndContentReadAsync(Stream stream, Memory data, CancellationToken cancellationToken = default) + { + // Zero byte read to ensure this works in the current stream state. + var zeroRead = await ReadAsync(stream, Memory.Empty, cancellationToken); + Assert.AreEqual(0, zeroRead); + + // Actual read. + return await ReadAsync(stream, data, cancellationToken); + } + } + [Test] public async Task ReadAsync_EmptyMessageAndTrailers_ParseMessageAndTrailers() { @@ -88,6 +192,40 @@ public async Task ReadAsync_EmptyMessageAndTrailers_ParseMessageAndTrailers() Assert.AreEqual(0, trailingHeaders.Count()); } + [Test] + public async Task ReadAsync_EmptyMessageAndTrailers_OneByteBuffer_ParseMessageAndTrailers() + { + // Arrange + var data = new byte[] { 0, 0, 0, 0, 0, 128, 0, 0, 0, 0 }; + var trailingHeaders = new TestHttpHeaders(); + var ms = new MemoryStream(data); + var responseStream = new GrpcWebResponseStream(ms, trailingHeaders); + var contentHeaderData = new byte[1]; + + await ReadByteAsync(responseStream, contentHeaderData); + await ReadByteAsync(responseStream, contentHeaderData); + await ReadByteAsync(responseStream, contentHeaderData); + await ReadByteAsync(responseStream, contentHeaderData); + await ReadByteAsync(responseStream, contentHeaderData); + + // Act 2 + var read2 = await ReadAsync(responseStream, contentHeaderData); + + // Assert 2 + Assert.AreEqual(0, read2); + Assert.AreEqual(0, trailingHeaders.Count()); + + async Task ReadByteAsync(GrpcWebResponseStream responseStream, byte[] buffer) + { + // Act + var read = await ReadAsync(responseStream, buffer); + + // Assert + Assert.AreEqual(1, read); + Assert.AreEqual(0, buffer[0]); + } + } + [Test] public async Task ReadAsync_ReadContentWithLargeBuffer_ParseMessageAndTrailers() { @@ -107,12 +245,16 @@ public async Task ReadAsync_ReadContentWithLargeBuffer_ParseMessageAndTrailers() Assert.AreEqual(0, contentHeaderData[2]); Assert.AreEqual(0, contentHeaderData[3]); Assert.AreEqual(1, contentHeaderData[4]); + Assert.AreEqual(1, responseStream._contentRemaining); + Assert.AreEqual(GrpcWebResponseStream.ResponseState.Content, responseStream._state); // Act 2 var read2 = await ReadAsync(responseStream, contentHeaderData); // Assert 2 Assert.AreEqual(1, read2); + Assert.AreEqual(99, contentHeaderData[0]); + Assert.AreEqual(GrpcWebResponseStream.ResponseState.Ready, responseStream._state); // Act 2 var read3 = await ReadAsync(responseStream, contentHeaderData); diff --git a/test/Shared/SyncPointMemoryStream.cs b/test/Shared/SyncPointMemoryStream.cs index fec583189..fdda6e66b 100644 --- a/test/Shared/SyncPointMemoryStream.cs +++ b/test/Shared/SyncPointMemoryStream.cs @@ -28,6 +28,8 @@ public class SyncPointMemoryStream : Stream private SyncPoint _syncPoint; private Func _awaiter; private byte[] _currentData; + private bool _streamEnded; + private bool _streamEndedObserved; private Exception? _exception; public SyncPointMemoryStream(bool runContinuationsAsynchronously = true) @@ -37,6 +39,17 @@ public SyncPointMemoryStream(bool runContinuationsAsynchronously = true) _awaiter = SyncPoint.Create(out _syncPoint, _runContinuationsAsynchronously); } + /// + /// End stream and wait for at least one more read that returns zero bytes. + /// Note that because of zero-byte reads, the stream may be read multiple times. + /// + public Task EndStreamAndWait() + { + AddDataCore(Array.Empty()); + _streamEnded = true; + return _awaiter(); + } + /// /// Give the stream more data and wait until it is all read. /// @@ -73,6 +86,12 @@ private void AddDataCore(byte[] data) public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { + if (_streamEndedObserved) + { + // Stream has ended and ReadAsync has been called again. Just return immediately. + return 0; + } + // Still have leftover data? if (_currentData.Length > 0) { @@ -111,6 +130,10 @@ private int ReadInternalBuffer(byte[] buffer, int offset, int count) if (_currentData.Length == 0) { + if (_streamEnded) + { + _streamEndedObserved = true; + } ResetSyncPointAndContinuePrevious(); }