Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interrupt existing subchannel connect attempt when reconnect is requested #2410

Merged
merged 3 commits into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ internal interface ISubchannelTransport : IDisposable
TransportStatus TransportStatus { get; }

ValueTask<Stream> GetStreamAsync(DnsEndPoint endPoint, CancellationToken cancellationToken);
ValueTask<ConnectResult> TryConnectAsync(ConnectContext context);
ValueTask<ConnectResult> TryConnectAsync(ConnectContext context, int attempt);

void Disconnect();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public void Disconnect()
_subchannel.UpdateConnectivityState(ConnectivityState.Idle, "Disconnected.");
}

public ValueTask<ConnectResult> TryConnectAsync(ConnectContext context)
public ValueTask<ConnectResult> TryConnectAsync(ConnectContext context, int attempt)
{
Debug.Assert(_subchannel._addresses.Count == 1);
Debug.Assert(CurrentEndPoint == null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ private void DisconnectUnsynchronized()
_currentEndPoint = null;
}

public async ValueTask<ConnectResult> TryConnectAsync(ConnectContext context)
public async ValueTask<ConnectResult> TryConnectAsync(ConnectContext context, int attempt)
{
Debug.Assert(CurrentEndPoint == null);

Expand Down
36 changes: 32 additions & 4 deletions src/Grpc.Net.Client/Balancer/Subchannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ public void UpdateAddresses(IReadOnlyList<BalancerAddress> addresses)
return;
}

SubchannelLog.AddressesUpdated(_logger, Id, addresses);

// Get a copy of the current address before updating addresses.
// Updating addresses to not contain this value changes the property to return null.
var currentAddress = CurrentAddress;
Expand Down Expand Up @@ -278,6 +280,8 @@ private void CancelInProgressConnect()
_connectContext.CancelConnect();
_connectContext.Dispose();
}

_delayInterruptTcs?.TrySetResult(null);
}
}

Expand Down Expand Up @@ -313,7 +317,7 @@ private async Task ConnectTransportAsync()
}
}

switch (await _transport.TryConnectAsync(connectContext).ConfigureAwait(false))
switch (await _transport.TryConnectAsync(connectContext, attempt).ConfigureAwait(false))
{
case ConnectResult.Success:
return;
Expand Down Expand Up @@ -345,17 +349,21 @@ private async Task ConnectTransportAsync()
{
// Task.Delay won. Check CTS to see if it won because of cancellation.
delayCts.Token.ThrowIfCancellationRequested();
SubchannelLog.ConnectBackoffComplete(_logger, Id);
}
else
{
SubchannelLog.ConnectBackoffInterrupted(_logger, Id);

// Delay interrupt was triggered. Reset back-off.
backoffPolicy = _manager.BackoffPolicyFactory.Create();

// Cancel the Task.Delay that's no longer needed.
// https://github.com/davidfowl/AspNetCoreDiagnosticScenarios/blob/519ef7d231c01116f02bc04354816a735f2a36b6/AsyncGuidance.md#using-a-timeout
delayCts.Cancel();

// Check to connect context token to see if the delay was interrupted because of a connect cancellation.
connectContext.CancellationToken.ThrowIfCancellationRequested();

// Delay interrupt was triggered. Reset back-off.
backoffPolicy = _manager.BackoffPolicyFactory.Create();
}
}
}
Expand Down Expand Up @@ -532,6 +540,12 @@ internal static class SubchannelLog
private static readonly Action<ILogger, string, Exception?> _cancelingConnect =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(17, "CancelingConnect"), "Subchannel id '{SubchannelId}' canceling connect.");

private static readonly Action<ILogger, string, Exception?> _connectBackoffComplete =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(18, "ConnectBackoffComplete"), "Subchannel id '{SubchannelId}' connect backoff complete.");

private static readonly Action<ILogger, string, string, Exception?> _addressesUpdated =
LoggerMessage.Define<string, string>(LogLevel.Trace, new EventId(19, "AddressesUpdated"), "Subchannel id '{SubchannelId}' updated with addresses: {Addresses}");

public static void SubchannelCreated(ILogger logger, string subchannelId, IReadOnlyList<BalancerAddress> addresses)
{
if (logger.IsEnabled(LogLevel.Debug))
Expand Down Expand Up @@ -620,5 +634,19 @@ public static void CancelingConnect(ILogger logger, string subchannelId)
{
_cancelingConnect(logger, subchannelId, null);
}

public static void ConnectBackoffComplete(ILogger logger, string subchannelId)
{
_connectBackoffComplete(logger, subchannelId, null);
}

public static void AddressesUpdated(ILogger logger, string subchannelId, IReadOnlyList<BalancerAddress> addresses)
{
if (logger.IsEnabled(LogLevel.Trace))
{
var addressesText = string.Join(", ", addresses.Select(a => a.EndPoint.Host + ":" + a.EndPoint.Port));
_addressesUpdated(logger, subchannelId, addressesText, null);
}
}
}
#endif
6 changes: 3 additions & 3 deletions test/Grpc.Net.Client.Tests/Balancer/ConnectionManagerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ public async Task PickAsync_ErrorConnectingToSubchannel_ThrowsError()
new BalancerAddress("localhost", 80)
});

var transportFactory = new TestSubchannelTransportFactory((s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, c) =>
{
return Task.FromException<TryConnectResult>(new Exception("Test error!"));
});
Expand Down Expand Up @@ -357,7 +357,7 @@ public async Task UpdateAddresses_ConnectIsInProgress_InProgressConnectIsCancele

var syncPoint = new SyncPoint(runContinuationsAsynchronously: true);

var transportFactory = new TestSubchannelTransportFactory(async (s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create(async (s, c) =>
{
c.Register(state => ((SyncPoint)state!).Continue(), syncPoint);

Expand Down Expand Up @@ -548,7 +548,7 @@ public async Task PickAsync_ExecutionContext_DoesNotCaptureAsyncLocalsInConnect(

var callbackAsyncLocalValues = new List<object>();

var transportFactory = new TestSubchannelTransportFactory((subchannel, cancellationToken) =>
var transportFactory = TestSubchannelTransportFactory.Create((subchannel, cancellationToken) =>
{
callbackAsyncLocalValues.Add(asyncLocal.Value);
if (callbackAsyncLocalValues.Count >= 2)
Expand Down
20 changes: 10 additions & 10 deletions test/Grpc.Net.Client.Tests/Balancer/PickFirstBalancerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public async Task ChangeAddresses_HasReadySubchannel_OldSubchannelShutdown()
services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));

var subChannelConnections = new List<Subchannel>();
var transportFactory = new TestSubchannelTransportFactory((s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, c) =>
{
lock (subChannelConnections)
{
Expand Down Expand Up @@ -176,7 +176,7 @@ public async Task ResolverError_HasFailedSubchannel_SubchannelShutdown()
new BalancerAddress("localhost", 80)
});

var transportFactory = new TestSubchannelTransportFactory((s, c) => Task.FromResult(new TryConnectResult(ConnectivityState.TransientFailure)));
var transportFactory = TestSubchannelTransportFactory.Create((s, c) => Task.FromResult(new TryConnectResult(ConnectivityState.TransientFailure)));
services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
services.AddSingleton<ISubchannelTransportFactory>(transportFactory);
var serviceProvider = services.BuildServiceProvider();
Expand Down Expand Up @@ -234,7 +234,7 @@ public async Task RequestConnection_InitialConnectionFails_ExponentialBackoff()
var connectivityState = ConnectivityState.TransientFailure;

services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(async (s, c) =>
services.AddSingleton<ISubchannelTransportFactory>(TestSubchannelTransportFactory.Create(async (s, c) =>
{
await syncPoint.WaitToContinue();
return new TryConnectResult(connectivityState);
Expand Down Expand Up @@ -290,7 +290,7 @@ public async Task RequestConnection_InitialConnectionEnds_EntersIdleState()
});

var transportConnectCount = 0;
var transportFactory = new TestSubchannelTransportFactory((s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, c) =>
{
transportConnectCount++;
return Task.FromResult(new TryConnectResult(ConnectivityState.Ready));
Expand Down Expand Up @@ -340,7 +340,7 @@ public async Task RequestConnection_IdleConnectionConnectAsync_StateToReady()
resolver.UpdateAddresses(new List<BalancerAddress> { new BalancerAddress("localhost", 80) });

var transportConnectCount = 0;
var transportFactory = new TestSubchannelTransportFactory((s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, c) =>
{
transportConnectCount++;
return Task.FromResult(new TryConnectResult(ConnectivityState.Ready));
Expand Down Expand Up @@ -385,7 +385,7 @@ public async Task RequestConnection_IdleConnectionPick_StateToReady()
resolver.UpdateAddresses(new List<BalancerAddress> { new BalancerAddress("localhost", 80) });

var transportConnectCount = 0;
var transportFactory = new TestSubchannelTransportFactory((s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, c) =>
{
transportConnectCount++;
return Task.FromResult(new TryConnectResult(ConnectivityState.Ready));
Expand Down Expand Up @@ -448,7 +448,7 @@ public async Task UnaryCall_TransportConnecting_OnePickStarted()

var tcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
var transportConnectCount = 0;
var transportFactory = new TestSubchannelTransportFactory((s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, c) =>
{
transportConnectCount++;

Expand Down Expand Up @@ -510,7 +510,7 @@ public async Task UnaryCall_TransportConnecting_ErrorAfterTransientFailure()

var tcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
var transportConnectCount = 0;
var transportFactory = new TestSubchannelTransportFactory((s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, c) =>
{
transportConnectCount++;

Expand Down Expand Up @@ -576,7 +576,7 @@ public async Task DeadlineExceeded_MultipleCalls_CallsWaitForDeadline()
var resolver = new TestResolver();
resolver.UpdateAddresses(new List<BalancerAddress> { new BalancerAddress("localhost", 80) });

var transportFactory = new TestSubchannelTransportFactory((s, c) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, c) =>
{
return Task.FromResult(new TryConnectResult(ConnectivityState.Connecting));
});
Expand Down Expand Up @@ -640,7 +640,7 @@ public async Task ConnectTimeout_MultipleCalls_AttemptReconnect()
resolver.UpdateAddresses(new List<BalancerAddress> { new BalancerAddress("localhost", 80) });

var tryConnectTcs = new TaskCompletionSource<TryConnectResult>(TaskCreationOptions.RunContinuationsAsynchronously);
var transportFactory = new TestSubchannelTransportFactory((s, ct) =>
var transportFactory = TestSubchannelTransportFactory.Create((s, ct) =>
{
return tryConnectTcs.Task;
});
Expand Down
132 changes: 131 additions & 1 deletion test/Grpc.Net.Client.Tests/Balancer/ResolverTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#if SUPPORT_LOAD_BALANCING
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Threading;
Expand Down Expand Up @@ -298,7 +299,7 @@ public async Task Resolver_ServiceConfigInResult()
var currentConnectivityState = ConnectivityState.Ready;

services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory(async (s, c) =>
services.AddSingleton<ISubchannelTransportFactory>(TestSubchannelTransportFactory.Create(async (s, i, c) =>
{
await syncPoint.WaitToContinue();
return new TryConnectResult(currentConnectivityState);
Expand Down Expand Up @@ -551,5 +552,134 @@ public async Task ResolveServiceConfig_ErrorOnSecondResolve_PickSuccess()
var balancer = (ChildHandlerLoadBalancer)channel.ConnectionManager._balancer!;
return (T?)balancer._current?.LoadBalancer;
}

internal class TestBackoffPolicyFactory : IBackoffPolicyFactory
{
private readonly TimeSpan _backoff;

public TestBackoffPolicyFactory() : this(TimeSpan.FromSeconds(20))
{
}

public TestBackoffPolicyFactory(TimeSpan backoff)
{
_backoff = backoff;
}

public IBackoffPolicy Create()
{
return new TestBackoffPolicy(_backoff);
}

private class TestBackoffPolicy : IBackoffPolicy
{
private readonly TimeSpan _backoff;

public TestBackoffPolicy(TimeSpan backoff)
{
_backoff = backoff;
}

public TimeSpan NextBackoff()
{
return _backoff;
}
}
}

[Test]
public async Task Resolver_UpdateResultsAfterPreviousConnect_InterruptConnect()
{
// Arrange
var services = new ServiceCollection();

// add logger
services.AddNUnitLogger();
var loggerFactory = services.BuildServiceProvider().GetRequiredService<ILoggerFactory>();
var logger = loggerFactory.CreateLogger<ResolverTests>();

// add resolver and balancer
var resolver = new TestResolver(loggerFactory);
var result = ResolverResult.ForResult(new List<BalancerAddress> { new BalancerAddress("localhost", 80) }, serviceConfig: null, serviceConfigStatus: null);
resolver.UpdateResult(result);

services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
services.AddSingleton<IBackoffPolicyFactory>(new TestBackoffPolicyFactory(TimeSpan.FromSeconds(0.2)));

var tryConnectData = new List<(IReadOnlyList<BalancerAddress> BalancerAddresses, int Attempt, bool IsCancellationRequested)>();

var tryConnectCount = 0;
services.AddSingleton<ISubchannelTransportFactory>(
TestSubchannelTransportFactory.Create((subchannel, attempt, cancellationToken) =>
{
var addresses = subchannel.GetAddresses();
var isCancellationRequested = cancellationToken.IsCancellationRequested;
ConnectivityState state;

var i = Interlocked.Increment(ref tryConnectCount);
if (i == 1)
{
state = ConnectivityState.Ready;
}
else
{
state = attempt >= 2 ? ConnectivityState.Ready : ConnectivityState.TransientFailure;
}

logger.LogInformation("TryConnect attempt {Attempt} to addresses {Addresses}. State: {ConnectivityState}, IsCancellationRequested: {IsCancellationRequested}", attempt, string.Join(", ", addresses), state, isCancellationRequested);

lock (tryConnectData)
{
tryConnectData.Add((addresses, attempt, isCancellationRequested));
}

return Task.FromResult(new TryConnectResult(state));
}));

var channelOptions = new GrpcChannelOptions
{
Credentials = ChannelCredentials.Insecure,
ServiceProvider = services.BuildServiceProvider(),
};

// Act
var channel = GrpcChannel.ForAddress("test:///test_addr", channelOptions);

logger.LogInformation("Client connecting.");
await channel.ConnectionManager.ConnectAsync(waitForReady: true, CancellationToken.None);

logger.LogInformation("Client updating resolver.");
result = ResolverResult.ForResult(new List<BalancerAddress> { new BalancerAddress("localhost", 81) }, serviceConfig: null, serviceConfigStatus: null);
resolver.UpdateResult(result);

logger.LogInformation("Client picking.");
await ExceptionAssert.ThrowsAsync<RpcException>(async () => await channel.ConnectionManager.PickAsync(
new PickContext(),
waitForReady: false,
CancellationToken.None));

logger.LogInformation("Client updating Resolver.");
result = ResolverResult.ForResult(new List<BalancerAddress> { new BalancerAddress("localhost", 82) }, serviceConfig: null, serviceConfigStatus: null);
resolver.UpdateResult(result);

logger.LogInformation("Client picking and waiting for ready.");
await channel.ConnectionManager.PickAsync(
new PickContext(),
waitForReady: true,
CancellationToken.None);

// Assert
logger.LogInformation("TryConnectData count: {Count}", tryConnectData.Count);
foreach (var data in tryConnectData)
{
logger.LogInformation("Attempt: {Attempt}, BalancerAddresses: {BalancerAddresses}, IsCancellationRequested: {IsCancellationRequested}", data.Attempt, string.Join(", ", data.BalancerAddresses), data.IsCancellationRequested);
}

var duplicate = tryConnectData.GroupBy(d => new { Address = d.BalancerAddresses.Single(), d.Attempt }).FirstOrDefault(g => g.Count() >= 2);
if (duplicate != null)
{
Assert.Fail($"Duplicate attempts to address. Count: {duplicate.Count()}, Address: {duplicate.Key.Address}");
}
}
}
#endif
Loading
Loading