Skip to content

Commit

Permalink
Support zero-byte read in gRPC client (#1985)
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK authored Jan 11, 2023
1 parent fa8d1b1 commit 62927b4
Show file tree
Hide file tree
Showing 14 changed files with 336 additions and 96 deletions.
4 changes: 2 additions & 2 deletions build/dependencies.props
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<GrpcPackageVersion>2.46.5</GrpcPackageVersion>
<GrpcToolsPackageVersion>2.51.0</GrpcToolsPackageVersion>
<MicrosoftAspNetCoreAppPackageVersion>7.0.0</MicrosoftAspNetCoreAppPackageVersion>
<MicrosoftAspNetCoreApp6PackageVersion>6.0.0</MicrosoftAspNetCoreApp6PackageVersion>
<MicrosoftAspNetCoreApp6PackageVersion>6.0.11</MicrosoftAspNetCoreApp6PackageVersion>
<MicrosoftAspNetCoreApp5PackageVersion>5.0.3</MicrosoftAspNetCoreApp5PackageVersion>
<MicrosoftAspNetCoreApp31PackageVersion>3.1.3</MicrosoftAspNetCoreApp31PackageVersion>
<MicrosoftBuildLocatorPackageVersion>1.5.5</MicrosoftBuildLocatorPackageVersion>
Expand All @@ -32,7 +32,7 @@
<SystemDiagnosticsDiagnosticSourcePackageVersion>4.5.1</SystemDiagnosticsDiagnosticSourcePackageVersion>
<SystemIOPipelinesPackageVersion>5.0.1</SystemIOPipelinesPackageVersion>
<SystemMemoryPackageVersion>4.5.3</SystemMemoryPackageVersion>
<SystemNetHttpWinHttpHandlerPackageVersion>6.0.1</SystemNetHttpWinHttpHandlerPackageVersion>
<SystemNetHttpWinHttpHandlerPackageVersion>7.0.0</SystemNetHttpWinHttpHandlerPackageVersion>
<SystemSecurityPrincipalWindowsPackageVersion>4.7.0</SystemSecurityPrincipalWindowsPackageVersion>
<SystemThreadingChannelsPackageVersion>4.6.0</SystemThreadingChannelsPackageVersion>
</PropertyGroup>
Expand Down
4 changes: 2 additions & 2 deletions examples/Interceptor/Client/Client.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
<PackageReference Include="Google.Protobuf" Version="$(GoogleProtobufPackageVersion)" />
<PackageReference Include="Grpc.Net.Client" Version="$(GrpcDotNetPackageVersion)" />
<PackageReference Include="Grpc.Tools" Version="$(GrpcToolsPackageVersion)" PrivateAssets="All" />
<PackageReference Include="Microsoft.Extensions.Logging" Version="$(MicrosoftAspNetCoreApp6PackageVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="$(MicrosoftAspNetCoreApp6PackageVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging" Version="$(MicrosoftExtensionsPackageVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="$(MicrosoftExtensionsPackageVersion)" />
</ItemGroup>
</Project>
8 changes: 8 additions & 0 deletions src/Grpc.Net.Client.Web/Internal/Base64ResponseStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ public override async ValueTask<int> ReadAsync(Memory<byte> data, CancellationTo
var data = buffer.AsMemory(offset, count);
#endif

// Handle zero byte reads.
if (data.Length == 0)
{
var read = await StreamHelpers.ReadAsync(_inner, data, cancellationToken).ConfigureAwait(false);
Debug.Assert(read == 0);
return 0;
}

// There is enough remaining data to fill passed in data
if (data.Length <= _remainder)
{
Expand Down
150 changes: 105 additions & 45 deletions src/Grpc.Net.Client.Web/Internal/GrpcWebResponseStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#endregion

using System.Buffers.Binary;
using System.Data;
using System.Diagnostics;
using System.Net.Http.Headers;
using System.Text;

Expand All @@ -29,13 +31,15 @@ namespace Grpc.Net.Client.Web.Internal;
/// </summary>
internal class GrpcWebResponseStream : Stream
{
// This uses C# compiler's ability to refer to static data directly. For more information see https://vcsjones.dev/2019/02/01/csharp-readonly-span-bytes-static
private static ReadOnlySpan<byte> BytesNewLine => new byte[] { (byte)'\r', (byte)'\n' };
private const int HeaderLength = 5;

private readonly Stream _inner;
private readonly HttpHeaders _responseTrailers;
private int _contentRemaining;
private ResponseState _state;
private byte[]? _headerBuffer;

// Internal for testing
internal ResponseState _state;
internal int _contentRemaining;

public GrpcWebResponseStream(Stream inner, HttpHeaders responseTrailers)
{
Expand All @@ -52,59 +56,113 @@ public override async ValueTask<int> ReadAsync(Memory<byte> data, CancellationTo
#if NETSTANDARD2_0
var data = buffer.AsMemory(offset, count);
#endif
var headerBuffer = Memory<byte>.Empty;

if (data.Length == 0)
{
// Handle zero byte reads.
var read = await StreamHelpers.ReadAsync(_inner, data, cancellationToken).ConfigureAwait(false);
Debug.Assert(read == 0);
return 0;
}

switch (_state)
{
case ResponseState.Ready:
// Read the header first
// - 1 byte flag for compression
// - 4 bytes for the content length
Memory<byte> headerBuffer;

if (data.Length >= 5)
{
headerBuffer = data.Slice(0, 5);
// Read the header first
// - 1 byte flag for compression
// - 4 bytes for the content length
_contentRemaining = HeaderLength;
_state = ResponseState.Header;
goto case ResponseState.Header;
}
else
case ResponseState.Header:
{
// Should never get here. Client always passes 5 to read the header.
throw new InvalidOperationException("Buffer is not large enough for header");
Debug.Assert(_contentRemaining > 0);

headerBuffer = data.Length >= _contentRemaining ? data.Slice(0, _contentRemaining) : data;
var success = await TryReadDataAsync(_inner, headerBuffer, cancellationToken).ConfigureAwait(false);
if (!success)
{
return 0;
}

// On first read of header data, check first byte to see if this is a trailer.
if (_contentRemaining == HeaderLength)
{
var compressed = headerBuffer.Span[0];
var isTrailer = IsBitSet(compressed, pos: 7);
if (isTrailer)
{
_state = ResponseState.Trailer;
goto case ResponseState.Trailer;
}
}

var read = headerBuffer.Length;

// The buffer was less than header length either because this is the first read and the passed in buffer is small,
// or it is an additonal read to finish getting header data.
if (headerBuffer.Length < HeaderLength)
{
_headerBuffer ??= new byte[HeaderLength];
headerBuffer.CopyTo(_headerBuffer.AsMemory(HeaderLength - _contentRemaining));

_contentRemaining -= headerBuffer.Length;
if (_contentRemaining > 0)
{
return read;
}

headerBuffer = _headerBuffer;
}

var length = (int)BinaryPrimitives.ReadUInt32BigEndian(headerBuffer.Span.Slice(1));

_contentRemaining = length;
// If there is no content then state is reset to ready.
_state = _contentRemaining > 0 ? ResponseState.Content : ResponseState.Ready;
return read;
}

var success = await TryReadDataAsync(_inner, headerBuffer, cancellationToken).ConfigureAwait(false);
if (!success)
case ResponseState.Content:
{
return 0;
if (data.Length >= _contentRemaining)
{
data = data.Slice(0, _contentRemaining);
}

var read = await StreamHelpers.ReadAsync(_inner, data, cancellationToken).ConfigureAwait(false);
_contentRemaining -= read;
if (_contentRemaining == 0)
{
_state = ResponseState.Ready;
}

return read;
}

var compressed = headerBuffer.Span[0];
var length = (int)BinaryPrimitives.ReadUInt32BigEndian(headerBuffer.Span.Slice(1));

var isTrailer = IsBitSet(compressed, pos: 7);
if (isTrailer)
case ResponseState.Trailer:
{
Debug.Assert(headerBuffer.Length > 0);

// The trailer needs to be completely read before returning 0 to the caller.
// Ensure buffer is large enough to contain the trailer header.
if (headerBuffer.Length < HeaderLength)
{
var newBuffer = new byte[HeaderLength];
headerBuffer.CopyTo(newBuffer);
var success = await TryReadDataAsync(_inner, newBuffer.AsMemory(headerBuffer.Length), cancellationToken).ConfigureAwait(false);
if (!success)
{
return 0;
}
headerBuffer = newBuffer;
}
var length = (int)BinaryPrimitives.ReadUInt32BigEndian(headerBuffer.Span.Slice(1));

await ReadTrailersAsync(length, data, cancellationToken).ConfigureAwait(false);
return 0;
}

_contentRemaining = length;
// If there is no content then state is still ready
_state = _contentRemaining > 0 ? ResponseState.Content : ResponseState.Ready;
return 5;
case ResponseState.Content:
if (data.Length >= _contentRemaining)
{
data = data.Slice(0, _contentRemaining);
}

var read = await StreamHelpers.ReadAsync(_inner, data, cancellationToken).ConfigureAwait(false);
_contentRemaining -= read;
if (_contentRemaining == 0)
{
_state = ResponseState.Ready;
}

return read;
default:
throw new InvalidOperationException("Unexpected state.");
}
Expand Down Expand Up @@ -163,7 +221,7 @@ private void ParseTrailers(ReadOnlySpan<byte> span)
{
ReadOnlySpan<byte> line;

var lineEndIndex = remainingContent.IndexOf(BytesNewLine);
var lineEndIndex = remainingContent.IndexOf("\r\n"u8);
if (lineEndIndex == -1)
{
line = remainingContent;
Expand Down Expand Up @@ -257,10 +315,12 @@ private static async Task<bool> TryReadDataAsync(Stream responseStream, Memory<b
throw new InvalidDataException("Unexpected end of content while reading response stream.");
}

private enum ResponseState
internal enum ResponseState
{
Ready,
Header,
Content,
Trailer,
Complete
}

Expand Down
80 changes: 40 additions & 40 deletions src/Grpc.Net.Client/GrpcChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr

var resolverFactory = GetResolverFactory(channelOptions);
ResolveCredentials(channelOptions, out _isSecure, out _callCredentials);
(HttpHandlerType, ConnectTimeout) = CalculateHandlerContext(address, _isSecure, channelOptions);
(HttpHandlerType, ConnectTimeout) = CalculateHandlerContext(address, _isSecure, channelOptions);

SubchannelTransportFactory = channelOptions.ResolveService<ISubchannelTransportFactory>(new SubChannelTransportFactory(this));

if (!IsHttpOrHttpsAddress(Address) || channelOptions.ServiceConfig?.LoadBalancingConfigs.Count > 0)
if (!IsHttpOrHttpsAddress(Address) || channelOptions.ServiceConfig?.LoadBalancingConfigs.Count > 0)
{
ValidateHttpHandlerSupportsConnectivity();
}
Expand Down Expand Up @@ -215,12 +215,12 @@ private void ResolveCredentials(GrpcChannelOptions channelOptions, out bool isSe
}
}

private static bool IsHttpOrHttpsAddress(Uri address)
private static bool IsHttpOrHttpsAddress(Uri address)
{
return address.Scheme == Uri.UriSchemeHttps || address.Scheme == Uri.UriSchemeHttp;
return address.Scheme == Uri.UriSchemeHttps || address.Scheme == Uri.UriSchemeHttp;
}

private static HttpHandlerContext CalculateHandlerContext(Uri address, bool isSecure, GrpcChannelOptions channelOptions)
private static HttpHandlerContext CalculateHandlerContext(Uri address, bool isSecure, GrpcChannelOptions channelOptions)
{
if (channelOptions.HttpHandler == null)
{
Expand Down Expand Up @@ -261,17 +261,17 @@ private static HttpHandlerContext CalculateHandlerContext(Uri address, bool isSe
}
}

// If a proxy is specified then requests could be sent via an SSL tunnel.
// A CONNECT request is made to the proxy to establish the transport stream and then
// gRPC calls are sent via stream. This feature isn't supported by load balancer.
// Proxy can be specified via:
// - SocketsHttpHandler.Proxy. Set via app code.
// - HttpClient.DefaultProxy. Set via environment variables, e.g. HTTPS_PROXY.
if (IsProxied(socketsHttpHandler, address, isSecure))
{
type = HttpHandlerType.Custom;
connectTimeout = null;
}
// If a proxy is specified then requests could be sent via an SSL tunnel.
// A CONNECT request is made to the proxy to establish the transport stream and then
// gRPC calls are sent via stream. This feature isn't supported by load balancer.
// Proxy can be specified via:
// - SocketsHttpHandler.Proxy. Set via app code.
// - HttpClient.DefaultProxy. Set via environment variables, e.g. HTTPS_PROXY.
if (IsProxied(socketsHttpHandler, address, isSecure))
{
type = HttpHandlerType.Custom;
connectTimeout = null;
}
#else
type = HttpHandlerType.SocketsHttpHandler;
connectTimeout = null;
Expand All @@ -287,30 +287,30 @@ private static HttpHandlerContext CalculateHandlerContext(Uri address, bool isSe
}

#if NET5_0_OR_GREATER
private static readonly Uri HttpLoadBalancerTemporaryUri = new Uri("http://loadbalancer.temporary.invalid");
private static readonly Uri HttpsLoadBalancerTemporaryUri = new Uri("https://loadbalancer.temporary.invalid");
private static readonly Uri HttpLoadBalancerTemporaryUri = new Uri("http://loadbalancer.temporary.invalid");
private static readonly Uri HttpsLoadBalancerTemporaryUri = new Uri("https://loadbalancer.temporary.invalid");

private static bool IsProxied(SocketsHttpHandler socketsHttpHandler, Uri address, bool isSecure)
private static bool IsProxied(SocketsHttpHandler socketsHttpHandler, Uri address, bool isSecure)
{
// Check standard address directly.
// When load balancing the channel doesn't know the final addresses yet so use temporary address.
Uri resolvedAddress;
if (IsHttpOrHttpsAddress(address))
{
// Check standard address directly.
// When load balancing the channel doesn't know the final addresses yet so use temporary address.
Uri resolvedAddress;
if (IsHttpOrHttpsAddress(address))
{
resolvedAddress = address;
}
else if (isSecure)
{
resolvedAddress = HttpsLoadBalancerTemporaryUri;
}
else
{
resolvedAddress = HttpLoadBalancerTemporaryUri;
}

var proxy = socketsHttpHandler.Proxy ?? HttpClient.DefaultProxy;
return proxy.GetProxy(resolvedAddress) != null;
resolvedAddress = address;
}
else if (isSecure)
{
resolvedAddress = HttpsLoadBalancerTemporaryUri;
}
else
{
resolvedAddress = HttpLoadBalancerTemporaryUri;
}

var proxy = socketsHttpHandler.Proxy ?? HttpClient.DefaultProxy;
return proxy.GetProxy(resolvedAddress) != null;
}
#endif

#if SUPPORT_LOAD_BALANCING
Expand All @@ -321,7 +321,7 @@ private ResolverFactory GetResolverFactory(GrpcChannelOptions options)
//
// Even with just one address we still want to use the load balancing infrastructure. This enables
// the connectivity APIs on channel like GrpcChannel.State and GrpcChannel.WaitForStateChanged.
if (IsHttpOrHttpsAddress(Address))
if (IsHttpOrHttpsAddress(Address))
{
return new StaticResolverFactory(uri => new[] { new BalancerAddress(Address.Host, Address.Port) });
}
Expand Down Expand Up @@ -370,7 +370,7 @@ private LoadBalancerFactory[] ResolveLoadBalancerFactories(GrpcChannelOptions ch
{
return serviceFactories.Union(LoadBalancerFactory.KnownLoadBalancerFactories).ToArray();
}

return LoadBalancerFactory.KnownLoadBalancerFactories;
}
#endif
Expand Down Expand Up @@ -436,7 +436,7 @@ private HttpMessageInvoker CreateInternalHttpInvoker(HttpMessageHandler? handler
{
// GetHttpHandlerType recurses through DelegatingHandlers that may wrap the HttpClientHandler.
var httpClientHandler = HttpRequestHelpers.GetHttpHandlerType<HttpClientHandler>(handler);

if (httpClientHandler != null && RuntimeHelpers.QueryRuntimeSettingSwitch("System.Net.Http.UseNativeHttpHandler", defaultValue: false))
{
throw new InvalidOperationException("The channel configuration isn't valid on Android devices. " +
Expand Down
7 changes: 7 additions & 0 deletions src/Grpc.Net.Client/Internal/StreamExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ private static Status CreateUnknownMessageEncodingMessageStatus(string unsupport
GrpcCallLog.ReadingMessage(call.Logger);
cancellationToken.ThrowIfCancellationRequested();

#if NET6_0_OR_GREATER
// Start with zero-byte read.
// A zero-byte read avoids renting buffer until the response is ready. Especially useful for long running streaming calls.
var readCount = await responseStream.ReadAsync(Memory<byte>.Empty, cancellationToken).ConfigureAwait(false);
Debug.Assert(readCount == 0);
#endif

// Buffer is used to read header, then message content.
// This size was randomly chosen to hopefully be big enough for many small messages.
// If the message is larger then the array will be replaced when the message size is known.
Expand Down
Loading

0 comments on commit 62927b4

Please sign in to comment.