Skip to content

Commit

Permalink
Fix capturing ExecutionContext by timers and background tasks (#2129)
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK authored May 17, 2023
1 parent 2dc971e commit f9a00bc
Show file tree
Hide file tree
Showing 27 changed files with 392 additions and 141 deletions.
3 changes: 2 additions & 1 deletion global.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"sdk": {
"version": "7.0.201"
"version": "7.0.201",
"rollForward": "latestFeature"
}
}
1 change: 1 addition & 0 deletions src/Grpc.AspNetCore.Server/Grpc.AspNetCore.Server.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
<Compile Include="..\Shared\Server\UnaryServerMethodInvoker.cs" Link="Model\Internal\UnaryServerMethodInvoker.cs" />
<Compile Include="..\Shared\NullableAttributes.cs" Link="Internal\NullableAttributes.cs" />
<Compile Include="..\Shared\CodeAnalysisAttributes.cs" Link="Internal\CodeAnalysisAttributes.cs" />
<Compile Include="..\Shared\NonCapturingTimer.cs" Link="Internal\NonCapturingTimer.cs" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -91,12 +91,12 @@ public ServerCallDeadlineManager(HttpContextServerCallContext serverCallContext,
// Ensures there is no weird situation where the timer triggers
// before the field is set. Shouldn't happen because only long deadlines
// will take this path but better to be safe than sorry.
_longDeadlineTimer = new Timer(DeadlineExceededLongDelegate, (this, maxTimerDueTime), Timeout.Infinite, Timeout.Infinite);
_longDeadlineTimer = NonCapturingTimer.Create(DeadlineExceededLongDelegate, (this, maxTimerDueTime), Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
_longDeadlineTimer.Change(timerMilliseconds, Timeout.Infinite);
}
else
{
_longDeadlineTimer = new Timer(DeadlineExceededDelegate, this, timerMilliseconds, Timeout.Infinite);
_longDeadlineTimer = NonCapturingTimer.Create(DeadlineExceededDelegate, this, TimeSpan.FromMilliseconds(timerMilliseconds), Timeout.InfiniteTimeSpan);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/Grpc.Net.Client/Balancer/DnsResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ protected override void OnStarted()

if (_refreshInterval != Timeout.InfiniteTimeSpan)
{
_timer = new Timer(OnTimerCallback, null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
_timer = NonCapturingTimer.Create(OnTimerCallback, state: null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
_timer.Change(_refreshInterval, _refreshInterval);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -77,7 +77,7 @@ public SocketConnectivitySubchannelTransport(
ConnectTimeout = connectTimeout;
_socketConnect = socketConnect ?? OnConnect;
_activeStreams = new List<ActiveStream>();
_socketConnectedTimer = new Timer(OnCheckSocketConnection, state: null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
_socketConnectedTimer = NonCapturingTimer.Create(OnCheckSocketConnection, state: null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
}

private object Lock => _subchannel.Lock;
Expand Down
4 changes: 2 additions & 2 deletions src/Grpc.Net.Client/Balancer/PollingResolver.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -86,7 +86,7 @@ protected PollingResolver(ILoggerFactory loggerFactory, IBackoffPolicyFactory? b
/// </para>
/// </summary>
/// <param name="listener">The callback used to receive updates on the target.</param>
public override sealed void Start(Action<ResolverResult> listener)
public sealed override void Start(Action<ResolverResult> listener)
{
if (listener == null)
{
Expand Down
18 changes: 16 additions & 2 deletions src/Grpc.Net.Client/Balancer/Subchannel.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -233,7 +233,21 @@ public void RequestConnection()
}
}

_ = ConnectTransportAsync();
// Don't capture the current ExecutionContext and its AsyncLocals onto the connect
bool restoreFlow = false;
if (!ExecutionContext.IsFlowSuppressed())
{
ExecutionContext.SuppressFlow();
restoreFlow = true;
}

_ = Task.Run(ConnectTransportAsync);

// Restore the current ExecutionContext
if (restoreFlow)
{
ExecutionContext.RestoreFlow();
}
}

private void CancelInProgressConnect()
Expand Down
3 changes: 2 additions & 1 deletion src/Grpc.Net.Client/Grpc.Net.Client.csproj
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<Description>.NET client for gRPC</Description>
Expand Down Expand Up @@ -35,6 +35,7 @@
<Compile Include="..\Shared\NullableAttributes.cs" Link="Internal\NullableAttributes.cs" />
<Compile Include="..\Shared\Http2ErrorCode.cs" Link="Internal\Http2ErrorCode.cs" />
<Compile Include="..\Shared\Http3ErrorCode.cs" Link="Internal\Http3ErrorCode.cs" />
<Compile Include="..\Shared\NonCapturingTimer.cs" Link="Internal\NonCapturingTimer.cs" />
<Compile Include="..\Shared\NonDisposableMemoryStream.cs" Link="Internal\NonDisposableMemoryStream.cs" />
</ItemGroup>

Expand Down
22 changes: 15 additions & 7 deletions src/Grpc.Net.Client/GrpcChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ public sealed class GrpcChannel : ChannelBase, IDisposable
private readonly Dictionary<MethodKey, MethodConfig>? _serviceConfigMethods;
private readonly bool _isSecure;
private readonly List<CallCredentials>? _callCredentials;
// Internal for testing
internal readonly HashSet<IDisposable> ActiveCalls;
private readonly HashSet<IDisposable> _activeCalls;

internal Uri Address { get; }
internal HttpMessageInvoker HttpInvoker { get; }
Expand Down Expand Up @@ -165,7 +164,7 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr
ThrowOperationCanceledOnCancellation = channelOptions.ThrowOperationCanceledOnCancellation;
UnsafeUseInsecureChannelCallCredentials = channelOptions.UnsafeUseInsecureChannelCallCredentials;
_createMethodInfoFunc = CreateMethodInfo;
ActiveCalls = new HashSet<IDisposable>();
_activeCalls = new HashSet<IDisposable>();
if (channelOptions.ServiceConfig is { } serviceConfig)
{
RetryThrottling = serviceConfig.RetryThrottling != null ? CreateChannelRetryThrottling(serviceConfig.RetryThrottling) : null;
Expand Down Expand Up @@ -490,15 +489,15 @@ internal void RegisterActiveCall(IDisposable grpcCall)
throw new ObjectDisposedException(nameof(GrpcChannel));
}

ActiveCalls.Add(grpcCall);
_activeCalls.Add(grpcCall);
}
}

internal void FinishActiveCall(IDisposable grpcCall)
{
lock (_lock)
{
ActiveCalls.Remove(grpcCall);
_activeCalls.Remove(grpcCall);
}
}

Expand Down Expand Up @@ -749,9 +748,9 @@ public void Dispose()
return;
}

if (ActiveCalls.Count > 0)
if (_activeCalls.Count > 0)
{
activeCallsCopy = ActiveCalls.ToArray();
activeCallsCopy = _activeCalls.ToArray();
}

Disposed = true;
Expand Down Expand Up @@ -807,6 +806,15 @@ internal int GetRandomNumber(int minValue, int maxValue)
}
}

// Internal for testing
internal IDisposable[] GetActiveCalls()
{
lock (_lock)
{
return _activeCalls.ToArray();
}
}

#if SUPPORT_LOAD_BALANCING
private sealed class SubChannelTransportFactory : ISubchannelTransportFactory
{
Expand Down
2 changes: 1 addition & 1 deletion src/Grpc.Net.Client/Internal/GrpcCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ public Exception CreateFailureStatusException(Status status)
GrpcCallLog.StartingDeadlineTimeout(Logger, timeout.Value);

var dueTime = CommonGrpcProtocolHelpers.GetTimerDueTime(timeout.Value, Channel.MaxTimerDueTime);
_deadlineTimer = new Timer(DeadlineExceededCallback, null, dueTime, Timeout.Infinite);
_deadlineTimer = NonCapturingTimer.Create(DeadlineExceededCallback, state: null, TimeSpan.FromMilliseconds(dueTime), Timeout.InfiniteTimeSpan);
}
}

Expand Down
39 changes: 39 additions & 0 deletions src/Shared/NonCapturingTimer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace Grpc.Shared;

// A convenience API for interacting with System.Threading.Timer in a way
// that doesn't capture the ExecutionContext. We should be using this (or equivalent)
// everywhere we use timers to avoid rooting any values stored in asynclocals.
internal static class NonCapturingTimer
{
public static Timer Create(TimerCallback callback, object? state, TimeSpan dueTime, TimeSpan period)
{
if (callback is null)
{
throw new ArgumentNullException(nameof(callback));
}

// Don't capture the current ExecutionContext and its AsyncLocals onto the timer
bool restoreFlow = false;
try
{
if (!ExecutionContext.IsFlowSuppressed())
{
ExecutionContext.SuppressFlow();
restoreFlow = true;
}

return new Timer(callback, state, dueTime, period);
}
finally
{
// Restore the current ExecutionContext
if (restoreFlow)
{
ExecutionContext.RestoreFlow();
}
}
}
}
70 changes: 1 addition & 69 deletions test/FunctionalTests/Balancer/BalancerHelpers.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -185,74 +185,6 @@ public static async Task<GrpcChannel> CreateChannel(
return channel;
}

public static Task WaitForChannelStateAsync(ILogger logger, GrpcChannel channel, ConnectivityState state, int channelId = 1)
{
return WaitForChannelStatesAsync(logger, channel, new[] { state }, channelId);
}

public static async Task WaitForChannelStatesAsync(ILogger logger, GrpcChannel channel, ConnectivityState[] states, int channelId = 1)
{
var statesText = string.Join(", ", states.Select(s => $"'{s}'"));
logger.LogInformation($"Channel id {channelId}: Waiting for channel states {statesText}.");

var currentState = channel.State;

while (!states.Contains(currentState))
{
logger.LogInformation($"Channel id {channelId}: Current channel state '{currentState}' doesn't match expected states {statesText}.");

await channel.WaitForStateChangedAsync(currentState).DefaultTimeout();
currentState = channel.State;
}

logger.LogInformation($"Channel id {channelId}: Current channel state '{currentState}' matches expected states {statesText}.");
}

public static async Task<Subchannel> WaitForSubchannelToBeReadyAsync(ILogger logger, GrpcChannel channel, Func<SubchannelPicker?, Subchannel[]>? getPickerSubchannels = null)
{
var subChannel = (await WaitForSubchannelsToBeReadyAsync(logger, channel, 1)).Single();
return subChannel;
}

public static async Task<Subchannel[]> WaitForSubchannelsToBeReadyAsync(ILogger logger, GrpcChannel channel, int expectedCount, Func<SubchannelPicker?, Subchannel[]>? getPickerSubchannels = null)
{
if (getPickerSubchannels == null)
{
getPickerSubchannels = (picker) =>
{
return picker switch
{
RoundRobinPicker roundRobinPicker => roundRobinPicker._subchannels.ToArray(),
PickFirstPicker pickFirstPicker => new[] { pickFirstPicker.Subchannel },
EmptyPicker emptyPicker => Array.Empty<Subchannel>(),
null => Array.Empty<Subchannel>(),
_ => throw new Exception("Unexpected picker type: " + picker.GetType().FullName)
};
};
}

logger.LogInformation($"Waiting for subchannel ready count: {expectedCount}");

Subchannel[]? subChannelsCopy = null;
await TestHelpers.AssertIsTrueRetryAsync(() =>
{
var picker = channel.ConnectionManager._picker;
subChannelsCopy = getPickerSubchannels(picker);
logger.LogInformation($"Current subchannel ready count: {subChannelsCopy.Length}");
for (var i = 0; i < subChannelsCopy.Length; i++)
{
logger.LogInformation($"Ready subchannel: {subChannelsCopy[i]}");
}
return subChannelsCopy.Length == expectedCount;
}, "Wait for all subconnections to be connected.");

logger.LogInformation($"Finished waiting for subchannel ready.");

Debug.Assert(subChannelsCopy != null);
return subChannelsCopy;
}

public static T? GetInnerLoadBalancer<T>(GrpcChannel channel) where T : LoadBalancer
{
var balancer = (ChildHandlerLoadBalancer)channel.ConnectionManager._balancer!;
Expand Down
4 changes: 2 additions & 2 deletions test/FunctionalTests/Balancer/ConnectionTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -352,7 +352,7 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)

await channel.ConnectAsync().DefaultTimeout();

await BalancerHelpers.WaitForSubchannelsToBeReadyAsync(Logger, channel, 2).DefaultTimeout();
await BalancerWaitHelpers.WaitForSubchannelsToBeReadyAsync(Logger, channel, 2).DefaultTimeout();

var client = TestClientFactory.Create(channel, endpoint1.Method);

Expand Down
4 changes: 2 additions & 2 deletions test/FunctionalTests/Balancer/LeastUsedBalancerTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -67,7 +67,7 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte

var channel = await BalancerHelpers.CreateChannel(LoggerFactory, new LoadBalancingConfig("least_used"), new[] { endpoint1.Address, endpoint2.Address }, connect: true);

await BalancerHelpers.WaitForSubchannelsToBeReadyAsync(
await BalancerWaitHelpers.WaitForSubchannelsToBeReadyAsync(
Logger,
channel,
expectedCount: 2,
Expand Down
15 changes: 7 additions & 8 deletions test/FunctionalTests/Balancer/PickFirstBalancerTests.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -238,8 +238,8 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
Logger.LogInformation("Ending " + endpoint1.Address);
endpoint1.Dispose();

await BalancerHelpers.WaitForSubchannelsToBeReadyAsync(Logger, channel, expectedCount: 1,
getPickerSubchannels: picker=>
await BalancerWaitHelpers.WaitForSubchannelsToBeReadyAsync(Logger, channel, expectedCount: 1,
getPickerSubchannels: picker =>
{
// We want a subchannel that has no current address
if (picker is PickFirstPicker pickFirstPicker)
Expand Down Expand Up @@ -293,8 +293,7 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
Assert.AreEqual(ConnectivityState.Ready, channel.State);

// Wait for pooled connection to timeout and return to idle
await channel.WaitForStateChangedAsync(channel.State).DefaultTimeout();
Assert.AreEqual(ConnectivityState.Idle, channel.State);
await BalancerWaitHelpers.WaitForChannelStateAsync(Logger, channel, ConnectivityState.Idle).DefaultTimeout();

reply = await client.UnaryCall(new HelloRequest { Name = "Balancer" }).ResponseAsync.DefaultTimeout();
Assert.AreEqual("Balancer", reply.Message);
Expand Down Expand Up @@ -355,7 +354,7 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte

Logger.LogInformation($"All gRPC calls on server");

await BalancerHelpers.WaitForChannelStateAsync(Logger, channel, ConnectivityState.Ready).DefaultTimeout();
await BalancerWaitHelpers.WaitForChannelStateAsync(Logger, channel, ConnectivityState.Ready).DefaultTimeout();

var balancer = BalancerHelpers.GetInnerLoadBalancer<PickFirstBalancer>(channel)!;
var subchannel = balancer._subchannel!;
Expand Down Expand Up @@ -468,8 +467,8 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
endpoint.Dispose();

await Task.WhenAll(
BalancerHelpers.WaitForChannelStateAsync(Logger, channel1, ConnectivityState.Idle, channelId: 1),
BalancerHelpers.WaitForChannelStateAsync(Logger, channel2, ConnectivityState.Idle, channelId: 2)).DefaultTimeout();
BalancerWaitHelpers.WaitForChannelStateAsync(Logger, channel1, ConnectivityState.Idle, channelId: 1),
BalancerWaitHelpers.WaitForChannelStateAsync(Logger, channel2, ConnectivityState.Idle, channelId: 2)).DefaultTimeout();

Logger.LogInformation("Restarting");
using var endpointNew = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(50051, UnaryMethod, nameof(UnaryMethod));
Expand Down
Loading

0 comments on commit f9a00bc

Please sign in to comment.