Skip to content

Commit

Permalink
Correctly check socket on stream creation (#2215)
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK authored Aug 2, 2023
1 parent 1d12340 commit a2d005c
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -253,53 +253,7 @@ private void OnCheckSocketConnection(object? state)
{
CompatibilityHelpers.Assert(socketAddress != null);

try
{
SocketConnectivitySubchannelTransportLog.CheckingSocket(_logger, _subchannel.Id, socketAddress);

// Poll socket to check if it can be read from. Unfortunatly this requires reading pending data.
// The server might send data, e.g. HTTP/2 SETTINGS frame, so we need to read and cache it.
//
// Available data needs to be read now because the only way to determine whether the connection is closed is to
// get the results of polling after available data is received.
bool hasReadData;
do
{
closeSocket = IsSocketInBadState(socket, socketAddress);
var available = socket.Available;
if (available > 0)
{
hasReadData = true;
var serverDataAvailable = CalculateInitialSocketDataLength(_initialSocketData) + available;
if (serverDataAvailable > MaximumInitialSocketDataSize)
{
// Data sent to the client before a connection is started shouldn't be large.
// Put a maximum limit on the buffer size to prevent an unexpected scenario from consuming too much memory.
throw new InvalidOperationException($"The server sent {serverDataAvailable} bytes to the client before a connection was established. Maximum allowed data exceeded.");
}

SocketConnectivitySubchannelTransportLog.SocketReceivingAvailable(_logger, _subchannel.Id, socketAddress, available);

// Data is already available so this won't block.
var buffer = new byte[available];
var readCount = socket.Receive(buffer);

_initialSocketData ??= new List<ReadOnlyMemory<byte>>();
_initialSocketData.Add(buffer.AsMemory(0, readCount));
}
else
{
hasReadData = false;
}
}
while (hasReadData);
}
catch (Exception ex)
{
closeSocket = true;
checkException = ex;
SocketConnectivitySubchannelTransportLog.ErrorCheckingSocket(_logger, _subchannel.Id, socketAddress, ex);
}
closeSocket = ShouldCloseSocket(socket, socketAddress, ref _initialSocketData, out checkException);
}
}

Expand Down Expand Up @@ -383,7 +337,7 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat
SocketConnectivitySubchannelTransportLog.ClosingSocketFromIdleTimeoutOnCreateStream(_logger, _subchannel.Id, address, _socketIdleTimeout);
closeSocket = true;
}
else if (IsSocketInBadState(socket, address))
else if (ShouldCloseSocket(socket, address, ref socketData, out _))
{
SocketConnectivitySubchannelTransportLog.ClosingUnusableSocketOnCreateStream(_logger, _subchannel.Id, address);
closeSocket = true;
Expand Down Expand Up @@ -419,7 +373,75 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat
return stream;
}

private bool IsSocketInBadState(Socket socket, BalancerAddress address)
/// <summary>
/// Checks whether the socket is healthy. May read available data into the passed in buffer.
/// Returns true if the socket should be closed.
/// </summary>
private bool ShouldCloseSocket(Socket socket, BalancerAddress socketAddress, ref List<ReadOnlyMemory<byte>>? socketData, out Exception? checkException)
{
checkException = null;

try
{
SocketConnectivitySubchannelTransportLog.CheckingSocket(_logger, _subchannel.Id, socketAddress);

// Poll socket to check if it can be read from. Unfortunately this requires reading pending data.
// The server might send data, e.g. HTTP/2 SETTINGS frame, so we need to read and cache it.
//
// Available data needs to be read now because the only way to determine whether the connection is
// closed is to get the results of polling after available data is received.
// For example, the server may have sent an HTTP/2 SETTINGS or GOAWAY frame.
// We need to cache whatever we read so it isn't dropped.
do
{
if (PollSocket(socket, socketAddress))
{
// Polling socket reported an unhealthy state.
return true;
}

var available = socket.Available;
if (available > 0)
{
var serverDataAvailable = CalculateInitialSocketDataLength(socketData) + available;
if (serverDataAvailable > MaximumInitialSocketDataSize)
{
// Data sent to the client before a connection is started shouldn't be large.
// Put a maximum limit on the buffer size to prevent an unexpected scenario from consuming too much memory.
throw new InvalidOperationException($"The server sent {serverDataAvailable} bytes to the client before a connection was established. Maximum allowed data exceeded.");
}

SocketConnectivitySubchannelTransportLog.SocketReceivingAvailable(_logger, _subchannel.Id, socketAddress, available);

// Data is already available so this won't block.
var buffer = new byte[available];
var readCount = socket.Receive(buffer);

socketData ??= new List<ReadOnlyMemory<byte>>();
socketData.Add(buffer.AsMemory(0, readCount));
}
else
{
// There is no more available data to read and the socket is healthy.
return false;
}
}
while (true);
}
catch (Exception ex)
{
checkException = ex;
SocketConnectivitySubchannelTransportLog.ErrorCheckingSocket(_logger, _subchannel.Id, socketAddress, ex);
return true;
}
}

/// <summary>
/// Poll the socket to check for health and available data.
/// Shouldn't be used by itself as data needs to be consumed to accurately report the socket health.
/// <see cref="ShouldCloseSocket"/> handles consuming data and getting the socket health.
/// </summary>
private bool PollSocket(Socket socket, BalancerAddress address)
{
// From https://github.com/dotnet/runtime/blob/3195fbbd82fdb7f132d6698591ba6489ad6dd8cf/src/libraries/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnection.cs#L158-L168
try
Expand Down
25 changes: 18 additions & 7 deletions test/FunctionalTests/Balancer/BalancerHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ public static EndpointContext<TRequest, TResponse> CreateGrpcEndpoint<TRequest,
HttpProtocols? protocols = null,
bool? isHttps = null,
X509Certificate2? certificate = null,
ILoggerFactory? loggerFactory = null)
ILoggerFactory? loggerFactory = null,
Action<KestrelServerOptions>? configureServer = null)
where TRequest : class, IMessage, new()
where TResponse : class, IMessage, new()
{
var server = CreateServer(port, protocols, isHttps, certificate, loggerFactory);
var server = CreateServer(port, protocols, isHttps, certificate, loggerFactory, configureServer);
var method = server.DynamicGrpc.AddUnaryMethod(callHandler, methodName);
var url = server.GetUrl(isHttps.GetValueOrDefault(false) ? TestServerEndpointName.Http2WithTls : TestServerEndpointName.Http2);

Expand Down Expand Up @@ -88,7 +89,13 @@ public void Dispose()
}
}

public static GrpcTestFixture<Startup> CreateServer(int port, HttpProtocols? protocols = null, bool? isHttps = null, X509Certificate2? certificate = null, ILoggerFactory? loggerFactory = null)
public static GrpcTestFixture<Startup> CreateServer(
int port,
HttpProtocols? protocols = null,
bool? isHttps = null,
X509Certificate2? certificate = null,
ILoggerFactory? loggerFactory = null,
Action<KestrelServerOptions>? configureServer = null)
{
var endpointName = isHttps.GetValueOrDefault(false) ? TestServerEndpointName.Http2WithTls : TestServerEndpointName.Http2;

Expand All @@ -102,6 +109,8 @@ public static GrpcTestFixture<Startup> CreateServer(int port, HttpProtocols? pro
},
(options, urls) =>
{
configureServer?.Invoke(options);

urls[endpointName] = isHttps.GetValueOrDefault(false)
? $"https://127.0.0.1:{port}"
: $"http://127.0.0.1:{port}";
Expand Down Expand Up @@ -136,13 +145,14 @@ public static Task<GrpcChannel> CreateChannel(
RetryPolicy? retryPolicy = null,
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect = null,
TimeSpan? connectTimeout = null,
TimeSpan? connectionIdleTimeout = null)
TimeSpan? connectionIdleTimeout = null,
TimeSpan? socketPingInterval = null)
{
var resolver = new TestResolver();
var e = endpoints.Select(i => new BalancerAddress(i.Host, i.Port)).ToList();
resolver.UpdateAddresses(e);

return CreateChannel(loggerFactory, loadBalancingConfig, resolver, httpMessageHandler, connect, retryPolicy, socketConnect, connectTimeout, connectionIdleTimeout);
return CreateChannel(loggerFactory, loadBalancingConfig, resolver, httpMessageHandler, connect, retryPolicy, socketConnect, connectTimeout, connectionIdleTimeout, socketPingInterval);
}

public static async Task<GrpcChannel> CreateChannel(
Expand All @@ -154,12 +164,13 @@ public static async Task<GrpcChannel> CreateChannel(
RetryPolicy? retryPolicy = null,
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect = null,
TimeSpan? connectTimeout = null,
TimeSpan? connectionIdleTimeout = null)
TimeSpan? connectionIdleTimeout = null,
TimeSpan? socketPingInterval = null)
{
var services = new ServiceCollection();
services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
services.AddSingleton<IRandomGenerator>(new TestRandomGenerator());
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(TimeSpan.FromSeconds(0.5), connectTimeout, connectionIdleTimeout ?? TimeSpan.FromMinutes(1), socketConnect));
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(socketPingInterval ?? TimeSpan.FromSeconds(0.5), connectTimeout, connectionIdleTimeout ?? TimeSpan.FromMinutes(1), socketConnect));
services.AddSingleton<LoadBalancerFactory>(new LeastUsedBalancerFactory());

var serviceConfig = new ServiceConfig();
Expand Down
51 changes: 50 additions & 1 deletion test/FunctionalTests/Balancer/ConnectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
}

// Arrange
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(50051, UnaryMethod, nameof(UnaryMethod));
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(50051, UnaryMethod, nameof(UnaryMethod), loggerFactory: LoggerFactory);

var connectionIdleTimeout = TimeSpan.FromSeconds(1);
var channel = await BalancerHelpers.CreateChannel(
Expand All @@ -180,6 +180,55 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
AssertHasLog(LogLevel.Trace, "ConnectingOnCreateStream", "Subchannel id '1' doesn't have a connected socket available. Connecting new stream socket for 127.0.0.1:50051.");
}

[Test]
public async Task Active_UnaryCall_ServerCloseOnKeepAlive_SocketRecreatedOnRequest()
{
// Ignore errors
SetExpectedErrorsFilter(writeContext =>
{
return true;
});

Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
{
return Task.FromResult(new HelloReply { Message = request.Name });
}

// In this test the client connects to the server, and the server then closes it after keep-alive is triggered.
// The client then starts a gRPC call to the server. The client should discard the closed socket and create a new one.

// Arrange
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(
50051,
UnaryMethod,
nameof(UnaryMethod),
loggerFactory: LoggerFactory,
configureServer: o => o.Limits.KeepAliveTimeout = TimeSpan.FromSeconds(1));

// Don't timeout the socket or ping it from the client.
var channel = await BalancerHelpers.CreateChannel(
LoggerFactory,
new RoundRobinConfig(),
new[] { endpoint.Address },
connectionIdleTimeout: TimeSpan.FromMinutes(30),
socketPingInterval: TimeSpan.FromMinutes(30)).DefaultTimeout();

Logger.LogInformation("Connecting channel.");
await channel.ConnectAsync();

// Fails when this test is run with debugging. Kestrel doesn't trigger keepalive timeout if debugging is enabled.
await TestHelpers.AssertIsTrueRetryAsync(() =>
{
return Logs.Any(l => l.LoggerName.StartsWith("Microsoft.AspNetCore.Server.Kestrel") && l.EventId.Name == "ConnectionStop");
}, "Wait for server to close connection.");

var client = TestClientFactory.Create(channel, endpoint.Method);
var response = await client.UnaryCall(new HelloRequest { Name = "Test!" }).ResponseAsync.DefaultTimeout();

// Assert
Assert.AreEqual("Test!", response.Message);
}

[Test]
public async Task Active_UnaryCall_MultipleStreams_UnavailableAddress_FallbackToWorkingAddress()
{
Expand Down

0 comments on commit a2d005c

Please sign in to comment.