Skip to content

Commit

Permalink
Support idle connection timeout with pending sockets (#2213)
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK authored Jul 28, 2023
1 parent 946610f commit 3ad4bfe
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ internal class SocketConnectivitySubchannelTransport : ISubchannelTransport, IDi
private readonly ILogger _logger;
private readonly Subchannel _subchannel;
private readonly TimeSpan _socketPingInterval;
private readonly TimeSpan _socketIdleTimeout;
private readonly Func<Socket, DnsEndPoint, CancellationToken, ValueTask> _socketConnect;
private readonly List<ActiveStream> _activeStreams;
private readonly Timer _socketConnectedTimer;
Expand All @@ -64,20 +65,23 @@ internal class SocketConnectivitySubchannelTransport : ISubchannelTransport, IDi
internal Socket? _initialSocket;
private BalancerAddress? _initialSocketAddress;
private List<ReadOnlyMemory<byte>>? _initialSocketData;
private DateTime? _initialSocketCreatedTime;
private bool _disposed;
private BalancerAddress? _currentAddress;

public SocketConnectivitySubchannelTransport(
Subchannel subchannel,
TimeSpan socketPingInterval,
TimeSpan? connectTimeout,
TimeSpan socketIdleTimeout,
ILoggerFactory loggerFactory,
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect)
{
_logger = loggerFactory.CreateLogger<SocketConnectivitySubchannelTransport>();
_subchannel = subchannel;
_socketPingInterval = socketPingInterval;
ConnectTimeout = connectTimeout;
_socketIdleTimeout = socketIdleTimeout;
_socketConnect = socketConnect ?? OnConnect;
_activeStreams = new List<ActiveStream>();
_socketConnectedTimer = NonCapturingTimer.Create(OnCheckSocketConnection, state: null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
Expand Down Expand Up @@ -125,6 +129,7 @@ private void DisconnectUnsynchronized()
_initialSocket = null;
_initialSocketAddress = null;
_initialSocketData = null;
_initialSocketCreatedTime = null;
_lastEndPointIndex = 0;
_currentAddress = null;
}
Expand Down Expand Up @@ -162,6 +167,7 @@ public async ValueTask<ConnectResult> TryConnectAsync(ConnectContext context)
_initialSocket = socket;
_initialSocketAddress = currentAddress;
_initialSocketData = null;
_initialSocketCreatedTime = DateTime.UtcNow;

// Schedule ping. Don't set a periodic interval to avoid any chance of timer causing the target method to run multiple times in paralle.
// This could happen because of execution delays (e.g. hitting a debugger breakpoint).
Expand Down Expand Up @@ -338,6 +344,7 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat
Socket? socket = null;
BalancerAddress? socketAddress = null;
List<ReadOnlyMemory<byte>>? socketData = null;
DateTime? socketCreatedTime = null;
lock (Lock)
{
if (_initialSocket != null)
Expand All @@ -347,9 +354,11 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat
socket = _initialSocket;
socketAddress = _initialSocketAddress;
socketData = _initialSocketData;
socketCreatedTime = _initialSocketCreatedTime;
_initialSocket = null;
_initialSocketAddress = null;
_initialSocketData = null;
_initialSocketCreatedTime = null;

// Double check the address matches the socket address and only use socket on match.
// Not sure if this is possible in practice, but better safe than sorry.
Expand All @@ -365,10 +374,23 @@ public async ValueTask<Stream> GetStreamAsync(BalancerAddress address, Cancellat

if (socket != null)
{
if (IsSocketInBadState(socket, address))
Debug.Assert(socketCreatedTime != null);

var closeSocket = false;

if (DateTime.UtcNow > socketCreatedTime.Value.Add(_socketIdleTimeout))
{
SocketConnectivitySubchannelTransportLog.ClosingSocketFromIdleTimeoutOnCreateStream(_logger, _subchannel.Id, address, _socketIdleTimeout);
closeSocket = true;
}
else if (IsSocketInBadState(socket, address))
{
SocketConnectivitySubchannelTransportLog.ClosingUnusableSocketOnCreateStream(_logger, _subchannel.Id, address);
closeSocket = true;
}

if (closeSocket)
{
socket.Dispose();
socket = null;
socketData = null;
Expand Down Expand Up @@ -530,6 +552,9 @@ internal static class SocketConnectivitySubchannelTransportLog
private static readonly Action<ILogger, int, BalancerAddress, Exception?> _closingUnusableSocketOnCreateStream =
LoggerMessage.Define<int, BalancerAddress>(LogLevel.Debug, new EventId(16, "ClosingUnusableSocketOnCreateStream"), "Subchannel id '{SubchannelId}' socket {Address} is being closed because it can't be used. The socket either can't receive data or it has received unexpected data.");

private static readonly Action<ILogger, int, BalancerAddress, TimeSpan, Exception?> _closingSocketFromIdleTimeoutOnCreateStream =
LoggerMessage.Define<int, BalancerAddress, TimeSpan>(LogLevel.Debug, new EventId(16, "ClosingSocketFromIdleTimeoutOnCreateStream"), "Subchannel id '{SubchannelId}' socket {Address} is being closed because it exceeds the idle timeout of {SocketIdleTimeout}.");

public static void ConnectingSocket(ILogger logger, int subchannelId, BalancerAddress address)
{
_connectingSocket(logger, subchannelId, address, null);
Expand Down Expand Up @@ -609,5 +634,10 @@ public static void ClosingUnusableSocketOnCreateStream(ILogger logger, int subch
{
_closingUnusableSocketOnCreateStream(logger, subchannelId, address, null);
}

public static void ClosingSocketFromIdleTimeoutOnCreateStream(ILogger logger, int subchannelId, BalancerAddress address, TimeSpan socketIdleTimeout)
{
_closingSocketFromIdleTimeoutOnCreateStream(logger, subchannelId, address, socketIdleTimeout, null);
}
}
#endif
14 changes: 10 additions & 4 deletions src/Grpc.Net.Client/GrpcChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ public sealed class GrpcChannel : ChannelBase, IDisposable
internal Uri Address { get; }
internal HttpMessageInvoker HttpInvoker { get; }
internal TimeSpan? ConnectTimeout { get; }
internal TimeSpan? ConnectionIdleTimeout { get; }
internal HttpHandlerType HttpHandlerType { get; }
internal TimeSpan InitialReconnectBackoff { get; }
internal TimeSpan? MaxReconnectBackoff { get; }
Expand Down Expand Up @@ -125,7 +126,7 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr

var resolverFactory = GetResolverFactory(channelOptions);
ResolveCredentials(channelOptions, out _isSecure, out _callCredentials);
(HttpHandlerType, ConnectTimeout) = CalculateHandlerContext(Logger, address, _isSecure, channelOptions);
(HttpHandlerType, ConnectTimeout, ConnectionIdleTimeout) = CalculateHandlerContext(Logger, address, _isSecure, channelOptions);

SubchannelTransportFactory = channelOptions.ResolveService<ISubchannelTransportFactory>(new SubChannelTransportFactory(this));

Expand Down Expand Up @@ -154,7 +155,7 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr
throw new ArgumentException($"Address '{address.OriginalString}' doesn't have a host. Address should include a scheme, host, and optional port. For example, 'https://localhost:5001'.");
}
ResolveCredentials(channelOptions, out _isSecure, out _callCredentials);
(HttpHandlerType, ConnectTimeout) = CalculateHandlerContext(Logger, address, _isSecure, channelOptions);
(HttpHandlerType, ConnectTimeout, ConnectionIdleTimeout) = CalculateHandlerContext(Logger, address, _isSecure, channelOptions);
#endif

HttpInvoker = channelOptions.HttpClient ?? CreateInternalHttpInvoker(channelOptions.HttpHandler);
Expand Down Expand Up @@ -243,12 +244,14 @@ private static HttpHandlerContext CalculateHandlerContext(ILogger logger, Uri ad
{
HttpHandlerType type;
TimeSpan? connectTimeout;
TimeSpan? connectionIdleTimeout;

#if NET5_0_OR_GREATER
var socketsHttpHandler = HttpRequestHelpers.GetHttpHandlerType<SocketsHttpHandler>(channelOptions.HttpHandler)!;

type = HttpHandlerType.SocketsHttpHandler;
connectTimeout = socketsHttpHandler.ConnectTimeout;
connectionIdleTimeout = socketsHttpHandler.PooledConnectionIdleTimeout;

// Check if the SocketsHttpHandler is being shared by channels.
// It has already been setup by another channel (i.e. ConnectCallback is set) then
Expand All @@ -261,6 +264,7 @@ private static HttpHandlerContext CalculateHandlerContext(ILogger logger, Uri ad
{
type = HttpHandlerType.Custom;
connectTimeout = null;
connectionIdleTimeout = null;
}
}

Expand All @@ -282,8 +286,9 @@ private static HttpHandlerContext CalculateHandlerContext(ILogger logger, Uri ad
#else
type = HttpHandlerType.SocketsHttpHandler;
connectTimeout = null;
connectionIdleTimeout = null;
#endif
return new HttpHandlerContext(type, connectTimeout);
return new HttpHandlerContext(type, connectTimeout, connectionIdleTimeout);
}
if (HttpRequestHelpers.GetHttpHandlerType<HttpClientHandler>(channelOptions.HttpHandler) != null)
{
Expand Down Expand Up @@ -837,6 +842,7 @@ public ISubchannelTransport Create(Subchannel subchannel)
subchannel,
SocketConnectivitySubchannelTransport.SocketPingInterval,
_channel.ConnectTimeout,
_channel.ConnectionIdleTimeout ?? TimeSpan.FromMinutes(1),
_channel.LoggerFactory,
socketConnect: null);
}
Expand Down Expand Up @@ -895,7 +901,7 @@ public static void AddressPathUnused(ILogger logger, string address)
}
}

private readonly record struct HttpHandlerContext(HttpHandlerType HttpHandlerType, TimeSpan? ConnectTimeout = null);
private readonly record struct HttpHandlerContext(HttpHandlerType HttpHandlerType, TimeSpan? ConnectTimeout = null, TimeSpan? ConnectionIdleTimeout = null);
}

internal enum HttpHandlerType
Expand Down
15 changes: 10 additions & 5 deletions test/FunctionalTests/Balancer/BalancerHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,14 @@ public static Task<GrpcChannel> CreateChannel(
bool? connect = null,
RetryPolicy? retryPolicy = null,
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect = null,
TimeSpan? connectTimeout = null)
TimeSpan? connectTimeout = null,
TimeSpan? connectionIdleTimeout = null)
{
var resolver = new TestResolver();
var e = endpoints.Select(i => new BalancerAddress(i.Host, i.Port)).ToList();
resolver.UpdateAddresses(e);

return CreateChannel(loggerFactory, loadBalancingConfig, resolver, httpMessageHandler, connect, retryPolicy, socketConnect, connectTimeout);
return CreateChannel(loggerFactory, loadBalancingConfig, resolver, httpMessageHandler, connect, retryPolicy, socketConnect, connectTimeout, connectionIdleTimeout);
}

public static async Task<GrpcChannel> CreateChannel(
Expand All @@ -152,12 +153,13 @@ public static async Task<GrpcChannel> CreateChannel(
bool? connect = null,
RetryPolicy? retryPolicy = null,
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect = null,
TimeSpan? connectTimeout = null)
TimeSpan? connectTimeout = null,
TimeSpan? connectionIdleTimeout = null)
{
var services = new ServiceCollection();
services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
services.AddSingleton<IRandomGenerator>(new TestRandomGenerator());
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(TimeSpan.FromSeconds(0.5), connectTimeout, socketConnect));
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(TimeSpan.FromSeconds(0.5), connectTimeout, connectionIdleTimeout ?? TimeSpan.FromMinutes(1), socketConnect));
services.AddSingleton<LoadBalancerFactory>(new LeastUsedBalancerFactory());

var serviceConfig = new ServiceConfig();
Expand Down Expand Up @@ -214,12 +216,14 @@ internal class TestSubchannelTransportFactory : ISubchannelTransportFactory
{
private readonly TimeSpan _socketPingInterval;
private readonly TimeSpan? _connectTimeout;
private readonly TimeSpan _connectionIdleTimeout;
private readonly Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? _socketConnect;

public TestSubchannelTransportFactory(TimeSpan socketPingInterval, TimeSpan? connectTimeout, Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect)
public TestSubchannelTransportFactory(TimeSpan socketPingInterval, TimeSpan? connectTimeout, TimeSpan connectionIdleTimeout, Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect)
{
_socketPingInterval = socketPingInterval;
_connectTimeout = connectTimeout;
_connectionIdleTimeout = connectionIdleTimeout;
_socketConnect = socketConnect;
}

Expand All @@ -230,6 +234,7 @@ public ISubchannelTransport Create(Subchannel subchannel)
subchannel,
_socketPingInterval,
_connectTimeout,
_connectionIdleTimeout,
subchannel._manager.LoggerFactory,
_socketConnect);
#else
Expand Down
39 changes: 39 additions & 0 deletions test/FunctionalTests/Balancer/ConnectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,45 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte
await ExceptionAssert.ThrowsAsync<OperationCanceledException>(() => connectTask).DefaultTimeout();
}

[Test]
public async Task Active_UnaryCall_ConnectionIdleTimeout_SocketRecreated()
{
// Ignore errors
SetExpectedErrorsFilter(writeContext =>
{
return true;
});

Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
{
return Task.FromResult(new HelloReply { Message = request.Name });
}

// Arrange
using var endpoint = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(50051, UnaryMethod, nameof(UnaryMethod));

var connectionIdleTimeout = TimeSpan.FromSeconds(1);
var channel = await BalancerHelpers.CreateChannel(
LoggerFactory,
new PickFirstConfig(),
new[] { endpoint.Address },
connectionIdleTimeout: connectionIdleTimeout).DefaultTimeout();

Logger.LogInformation("Connecting channel.");
await channel.ConnectAsync();

await Task.Delay(connectionIdleTimeout);

var client = TestClientFactory.Create(channel, endpoint.Method);
var response = await client.UnaryCall(new HelloRequest { Name = "Test!" }).ResponseAsync.DefaultTimeout();

// Assert
Assert.AreEqual("Test!", response.Message);

AssertHasLog(LogLevel.Debug, "ClosingSocketFromIdleTimeoutOnCreateStream", "Subchannel id '1' socket 127.0.0.1:50051 is being closed because it exceeds the idle timeout of 00:00:01.");
AssertHasLog(LogLevel.Trace, "ConnectingOnCreateStream", "Subchannel id '1' doesn't have a connected socket available. Connecting new stream socket for 127.0.0.1:50051.");
}

[Test]
public async Task Active_UnaryCall_MultipleStreams_UnavailableAddress_FallbackToWorkingAddress()
{
Expand Down
48 changes: 48 additions & 0 deletions test/Grpc.Net.Client.Tests/Balancer/StreamWrapperTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,54 @@ namespace Grpc.Net.Client.Tests.Balancer;
[TestFixture]
public class StreamWrapperTests
{
[Test]
public async Task ReadAsync_ExactSize_Read()
{
// Arrange
var ms = new MemoryStream(new byte[] { 4 });
var data = new List<ReadOnlyMemory<byte>>
{
new byte[] { 1, 2, 3 }
};
var streamWrapper = new StreamWrapper(ms, s => { }, data);
var buffer = new byte[3];

// Act & Assert
Assert.AreEqual(3, await streamWrapper.ReadAsync(buffer));
Assert.AreEqual(1, buffer[0]);
Assert.AreEqual(2, buffer[1]);
Assert.AreEqual(3, buffer[2]);

Assert.AreEqual(1, await streamWrapper.ReadAsync(buffer));
Assert.AreEqual(4, buffer[0]);

Assert.AreEqual(0, await streamWrapper.ReadAsync(buffer));
}

[Test]
public async Task ReadAsync_BiggerThanNeeded_Read()
{
// Arrange
var ms = new MemoryStream(new byte[] { 4 });
var data = new List<ReadOnlyMemory<byte>>
{
new byte[] { 1, 2, 3 }
};
var streamWrapper = new StreamWrapper(ms, s => { }, data);
var buffer = new byte[4];

// Act & Assert
Assert.AreEqual(3, await streamWrapper.ReadAsync(buffer));
Assert.AreEqual(1, buffer[0]);
Assert.AreEqual(2, buffer[1]);
Assert.AreEqual(3, buffer[2]);

Assert.AreEqual(1, await streamWrapper.ReadAsync(buffer));
Assert.AreEqual(4, buffer[0]);

Assert.AreEqual(0, await streamWrapper.ReadAsync(buffer));
}

[Test]
public async Task ReadAsync_MultipleInitialData_ReadInOrder()
{
Expand Down
28 changes: 28 additions & 0 deletions test/Grpc.Net.Client.Tests/GrpcChannelTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,34 @@ public void Build_InsecureCredentialsWithHttps_ThrowsError()
Assert.AreEqual("Channel is configured with insecure channel credentials and can't use a HttpClient with a 'https' scheme.", ex.Message);
}
#if SUPPORT_LOAD_BALANCING
[Test]
public void Build_ConnectTimeout_ReadFromSocketsHttpHandler()
{
// Arrange & Act
var channel = GrpcChannel.ForAddress("https://localhost", CreateGrpcChannelOptions(o => o.HttpHandler = new SocketsHttpHandler
{
ConnectTimeout = TimeSpan.FromSeconds(1)
}));
// Assert
Assert.AreEqual(TimeSpan.FromSeconds(1), channel.ConnectTimeout);
}
[Test]
public void Build_ConnectionIdleTimeout_ReadFromSocketsHttpHandler()
{
// Arrange & Act
var channel = GrpcChannel.ForAddress("https://localhost", CreateGrpcChannelOptions(o => o.HttpHandler = new SocketsHttpHandler
{
PooledConnectionIdleTimeout = TimeSpan.FromSeconds(1)
}));
// Assert
Assert.AreEqual(TimeSpan.FromSeconds(1), channel.ConnectionIdleTimeout);
}
#endif
[Test]
public void Build_HttpClientAndHttpHandler_ThrowsError()
{
Expand Down

0 comments on commit 3ad4bfe

Please sign in to comment.