Skip to content

Commit

Permalink
Fix reconnecting after connection timeout (#1998)
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK authored Jan 4, 2023
1 parent d9dbf45 commit f6180ae
Show file tree
Hide file tree
Showing 14 changed files with 244 additions and 75 deletions.
23 changes: 16 additions & 7 deletions src/Grpc.Net.Client/Balancer/Internal/ISubchannelTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,32 @@ internal interface ISubchannelTransport : IDisposable
#endif

#if !NETSTANDARD2_0
ValueTask<bool>
ValueTask<ConnectResult>
#else
Task<bool>
Task<ConnectResult>
#endif
TryConnectAsync(ConnectContext context);

void Disconnect();
}

internal enum ConnectResult
{
Success,
Failure,
Timeout
}

internal sealed class ConnectContext
{
private readonly CancellationTokenSource _cts;
private readonly CancellationToken _token;
private bool _disposed;

// This flag allows the transport to determine why the cancellation token was canceled.
// - Explicit cancellation, e.g. the channel was disposed.
// - Connection timeout, e.g. SocketsHttpHandler.ConnectTimeout was exceeded.
public bool IsConnectCanceled { get; private set; }
public bool Disposed { get; private set; }

public CancellationToken CancellationToken => _token;

Expand All @@ -67,18 +74,20 @@ public ConnectContext(TimeSpan connectTimeout)
public void CancelConnect()
{
// Check disposed because CTS.Cancel throws if the CTS is disposed.
if (!_disposed)
if (Disposed)
{
IsConnectCanceled = true;
_cts.Cancel();
throw new ObjectDisposedException(nameof(ConnectContext));
}

IsConnectCanceled = true;
_cts.Cancel();
}

public void Dispose()
{
// Dispose the CTS because it could be created with an internal timer.
_cts.Dispose();
_disposed = true;
Disposed = true;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ public void Disconnect()

public
#if !NETSTANDARD2_0
ValueTask<bool>
ValueTask<ConnectResult>
#else
Task<bool>
Task<ConnectResult>
#endif
TryConnectAsync(ConnectContext context)
{
Expand All @@ -69,9 +69,9 @@ public void Disconnect()
_subchannel.UpdateConnectivityState(ConnectivityState.Ready, "Passively connected.");

#if !NETSTANDARD2_0
return new ValueTask<bool>(true);
return new ValueTask<ConnectResult>(ConnectResult.Success);
#else
return Task.FromResult(true);
return Task.FromResult(ConnectResult.Success);
#endif
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ internal class SocketConnectivitySubchannelTransport : ISubchannelTransport, IDi
private readonly ILogger _logger;
private readonly Subchannel _subchannel;
private readonly TimeSpan _socketPingInterval;
private readonly Func<Socket, DnsEndPoint, CancellationToken, ValueTask> _socketConnect;
private readonly List<ActiveStream> _activeStreams;
private readonly Timer _socketConnectedTimer;

Expand All @@ -63,12 +64,18 @@ internal class SocketConnectivitySubchannelTransport : ISubchannelTransport, IDi
private bool _disposed;
private BalancerAddress? _currentAddress;

public SocketConnectivitySubchannelTransport(Subchannel subchannel, TimeSpan socketPingInterval, TimeSpan? connectTimeout, ILoggerFactory loggerFactory)
public SocketConnectivitySubchannelTransport(
Subchannel subchannel,
TimeSpan socketPingInterval,
TimeSpan? connectTimeout,
ILoggerFactory loggerFactory,
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect)
{
_logger = loggerFactory.CreateLogger<SocketConnectivitySubchannelTransport>();
_subchannel = subchannel;
_socketPingInterval = socketPingInterval;
ConnectTimeout = connectTimeout;
_socketConnect = socketConnect ?? OnConnect;
_activeStreams = new List<ActiveStream>();
_socketConnectedTimer = new Timer(OnCheckSocketConnection, state: null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
}
Expand All @@ -87,6 +94,11 @@ internal IReadOnlyList<ActiveStream> GetActiveStreams()
}
}

private static ValueTask OnConnect(Socket socket, DnsEndPoint endpoint, CancellationToken cancellationToken)
{
return socket.ConnectAsync(endpoint, cancellationToken);
}

public void Disconnect()
{
lock (Lock)
Expand Down Expand Up @@ -114,7 +126,7 @@ private void DisconnectUnsynchronized()
_currentAddress = null;
}

public async ValueTask<bool> TryConnectAsync(ConnectContext context)
public async ValueTask<ConnectResult> TryConnectAsync(ConnectContext context)
{
Debug.Assert(CurrentAddress == null);

Expand All @@ -137,7 +149,7 @@ public async ValueTask<bool> TryConnectAsync(ConnectContext context)
try
{
SocketConnectivitySubchannelTransportLog.ConnectingSocket(_logger, _subchannel.Id, currentAddress);
await socket.ConnectAsync(currentAddress.EndPoint, context.CancellationToken).ConfigureAwait(false);
await _socketConnect(socket, currentAddress.EndPoint, context.CancellationToken).ConfigureAwait(false);
SocketConnectivitySubchannelTransportLog.ConnectedSocket(_logger, _subchannel.Id, currentAddress);

lock (Lock)
Expand All @@ -150,7 +162,7 @@ public async ValueTask<bool> TryConnectAsync(ConnectContext context)
}

_subchannel.UpdateConnectivityState(ConnectivityState.Ready, "Successfully connected to socket.");
return true;
return ConnectResult.Success;
}
catch (Exception ex)
{
Expand All @@ -169,12 +181,15 @@ public async ValueTask<bool> TryConnectAsync(ConnectContext context)
}
}

var result = ConnectResult.Failure;

// Check if cancellation happened because of timeout.
if (firstConnectionError is OperationCanceledException oce &&
oce.CancellationToken == context.CancellationToken &&
!context.IsConnectCanceled)
{
firstConnectionError = new TimeoutException("A connection could not be established within the configured ConnectTimeout.", firstConnectionError);
result = ConnectResult.Timeout;
}

// All connections failed
Expand All @@ -188,7 +203,7 @@ public async ValueTask<bool> TryConnectAsync(ConnectContext context)
_socketConnectedTimer.Change(TimeSpan.Zero, TimeSpan.Zero);
}
}
return false;
return result;
}

private async void OnCheckSocketConnection(object? state)
Expand Down
50 changes: 38 additions & 12 deletions src/Grpc.Net.Client/Balancer/Subchannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -238,24 +238,34 @@ public void RequestConnection()

private void CancelInProgressConnect()
{
var connectContext = _connectContext;
if (connectContext != null)
lock (Lock)
{
lock (Lock)
if (_connectContext != null && !_connectContext.Disposed)
{
SubchannelLog.CancelingConnect(_logger, Id);

// Cancel connect cancellation token.
connectContext.CancelConnect();
connectContext.Dispose();
_connectContext.CancelConnect();
_connectContext.Dispose();
}
}
}

private async Task ConnectTransportAsync()
private ConnectContext GetConnectContext()
{
// There shouldn't be a previous connect in progress, but cancel the CTS to ensure they're no longer running.
CancelInProgressConnect();
lock (Lock)
{
// There shouldn't be a previous connect in progress, but cancel the CTS to ensure they're no longer running.
CancelInProgressConnect();

var connectContext = _connectContext = new ConnectContext(_transport.ConnectTimeout ?? Timeout.InfiniteTimeSpan);
var connectContext = _connectContext = new ConnectContext(_transport.ConnectTimeout ?? Timeout.InfiniteTimeSpan);
return connectContext;
}
}

private async Task ConnectTransportAsync()
{
var connectContext = GetConnectContext();

var backoffPolicy = _manager.BackoffPolicyFactory.Create();

Expand All @@ -273,9 +283,17 @@ private async Task ConnectTransportAsync()
}
}

if (await _transport.TryConnectAsync(connectContext).ConfigureAwait(false))
switch (await _transport.TryConnectAsync(connectContext).ConfigureAwait(false))
{
return;
case ConnectResult.Success:
return;
case ConnectResult.Timeout:
// Reset connectivity state back to idle so that new calls try to reconnect.
UpdateConnectivityState(ConnectivityState.Idle, new Status(StatusCode.Unavailable, "Timeout connecting to subchannel."));
return;
case ConnectResult.Failure:
default:
break;
}

connectContext.CancellationToken.ThrowIfCancellationRequested();
Expand Down Expand Up @@ -448,7 +466,7 @@ internal static class SubchannelLog
LoggerMessage.Define<int>(LogLevel.Trace, new EventId(9, "ConnectCanceled"), "Subchannel id '{SubchannelId}' connect canceled.");

private static readonly Action<ILogger, int, Exception?> _connectError =
LoggerMessage.Define<int>(LogLevel.Debug, new EventId(10, "ConnectError"), "Subchannel id '{SubchannelId}' error while connecting to transport.");
LoggerMessage.Define<int>(LogLevel.Error, new EventId(10, "ConnectError"), "Subchannel id '{SubchannelId}' unexpected error while connecting to transport.");

private static readonly Action<ILogger, int, ConnectivityState, string, Exception?> _subchannelStateChanged =
LoggerMessage.Define<int, ConnectivityState, string>(LogLevel.Debug, new EventId(11, "SubchannelStateChanged"), "Subchannel id '{SubchannelId}' state changed to {State}. Detail: '{Detail}'.");
Expand All @@ -468,6 +486,9 @@ internal static class SubchannelLog
private static readonly Action<ILogger, int, BalancerAddress, Exception?> _subchannelPreserved =
LoggerMessage.Define<int, BalancerAddress>(LogLevel.Trace, new EventId(16, "SubchannelPreserved"), "Subchannel id '{SubchannelId}' matches address '{Address}' and is preserved.");

private static readonly Action<ILogger, int, Exception?> _cancelingConnect =
LoggerMessage.Define<int>(LogLevel.Debug, new EventId(17, "CancelingConnect"), "Subchannel id '{SubchannelId}' canceling connect.");

public static void SubchannelCreated(ILogger logger, int subchannelId, IReadOnlyList<BalancerAddress> addresses)
{
if (logger.IsEnabled(LogLevel.Debug))
Expand Down Expand Up @@ -551,5 +572,10 @@ public static void SubchannelPreserved(ILogger logger, int subchannelId, Balance
{
_subchannelPreserved(logger, subchannelId, address, null);
}

public static void CancelingConnect(ILogger logger, int subchannelId)
{
_cancelingConnect(logger, subchannelId, null);
}
}
#endif
7 changes: 6 additions & 1 deletion src/Grpc.Net.Client/GrpcChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -802,7 +802,12 @@ public ISubchannelTransport Create(Subchannel subchannel)
{
if (_channel.HttpHandlerType == HttpHandlerType.SocketsHttpHandler)
{
return new SocketConnectivitySubchannelTransport(subchannel, TimeSpan.FromSeconds(5), _channel.ConnectTimeout, _channel.LoggerFactory);
return new SocketConnectivitySubchannelTransport(
subchannel,
TimeSpan.FromSeconds(5),
_channel.ConnectTimeout,
_channel.LoggerFactory,
socketConnect: null);
}

return new PassiveSubchannelTransport(subchannel);
Expand Down
36 changes: 30 additions & 6 deletions test/FunctionalTests/Balancer/BalancerHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Sockets;
using System.Threading.Tasks;
using FunctionalTestsWebsite;
using Google.Protobuf;
Expand Down Expand Up @@ -110,21 +111,37 @@ public static GrpcTestFixture<Startup> CreateServer(int port, HttpProtocols? pro
endpointName);
}

public static Task<GrpcChannel> CreateChannel(ILoggerFactory loggerFactory, LoadBalancingConfig? loadBalancingConfig, Uri[] endpoints, HttpMessageHandler? httpMessageHandler = null, bool? connect = null, RetryPolicy? retryPolicy = null)
public static Task<GrpcChannel> CreateChannel(
ILoggerFactory loggerFactory,
LoadBalancingConfig? loadBalancingConfig,
Uri[] endpoints,
HttpMessageHandler? httpMessageHandler = null,
bool? connect = null,
RetryPolicy? retryPolicy = null,
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect = null,
TimeSpan? connectTimeout = null)
{
var resolver = new TestResolver();
var e = endpoints.Select(i => new BalancerAddress(i.Host, i.Port)).ToList();
resolver.UpdateAddresses(e);

return CreateChannel(loggerFactory, loadBalancingConfig, resolver, httpMessageHandler, connect, retryPolicy);
return CreateChannel(loggerFactory, loadBalancingConfig, resolver, httpMessageHandler, connect, retryPolicy, socketConnect, connectTimeout);
}

public static async Task<GrpcChannel> CreateChannel(ILoggerFactory loggerFactory, LoadBalancingConfig? loadBalancingConfig, TestResolver resolver, HttpMessageHandler? httpMessageHandler = null, bool? connect = null, RetryPolicy? retryPolicy = null)
public static async Task<GrpcChannel> CreateChannel(
ILoggerFactory loggerFactory,
LoadBalancingConfig? loadBalancingConfig,
TestResolver resolver,
HttpMessageHandler? httpMessageHandler = null,
bool? connect = null,
RetryPolicy? retryPolicy = null,
Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect = null,
TimeSpan? connectTimeout = null)
{
var services = new ServiceCollection();
services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
services.AddSingleton<IRandomGenerator>(new TestRandomGenerator());
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(TimeSpan.FromSeconds(0.5), connectTimeout: null));
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(TimeSpan.FromSeconds(0.5), connectTimeout, socketConnect));
services.AddSingleton<LoadBalancerFactory>(new LeastUsedBalancerFactory());

var serviceConfig = new ServiceConfig();
Expand Down Expand Up @@ -249,17 +266,24 @@ internal class TestSubchannelTransportFactory : ISubchannelTransportFactory
{
private readonly TimeSpan _socketPingInterval;
private readonly TimeSpan? _connectTimeout;
private readonly Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? _socketConnect;

public TestSubchannelTransportFactory(TimeSpan socketPingInterval, TimeSpan? connectTimeout)
public TestSubchannelTransportFactory(TimeSpan socketPingInterval, TimeSpan? connectTimeout, Func<Socket, DnsEndPoint, CancellationToken, ValueTask>? socketConnect)
{
_socketPingInterval = socketPingInterval;
_connectTimeout = connectTimeout;
_socketConnect = socketConnect;
}

public ISubchannelTransport Create(Subchannel subchannel)
{
#if NET5_0_OR_GREATER
return new SocketConnectivitySubchannelTransport(subchannel, _socketPingInterval, _connectTimeout, subchannel._manager.LoggerFactory);
return new SocketConnectivitySubchannelTransport(
subchannel,
_socketPingInterval,
_connectTimeout,
subchannel._manager.LoggerFactory,
_socketConnect);
#else
return new PassiveSubchannelTransport(subchannel);
#endif
Expand Down
Loading

0 comments on commit f6180ae

Please sign in to comment.