Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flow CancellationTokens to BindAsync #31377

Merged
merged 1 commit into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/Servers/Kestrel/Core/src/AnyIPListenOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
using Microsoft.Extensions.Logging;
Expand All @@ -18,22 +19,22 @@ internal AnyIPListenOptions(int port)
{
}

internal override async Task BindAsync(AddressBindContext context)
internal override async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken)
{
Debug.Assert(IPEndPoint != null);

// when address is 'http://hostname:port', 'http://*:port', or 'http://+:port'
try
{
await base.BindAsync(context).ConfigureAwait(false);
await base.BindAsync(context, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex) when (!(ex is IOException))
{
context.Logger.LogDebug(CoreStrings.FormatFallbackToIPv4Any(IPEndPoint.Port));

// for machines that do not support IPv6
EndPoint = new IPEndPoint(IPAddress.Any, IPEndPoint.Port);
await base.BindAsync(context).ConfigureAwait(false);
await base.BindAsync(context, cancellationToken).ConfigureAwait(false);
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/Servers/Kestrel/Core/src/Internal/AddressBindContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;

Expand All @@ -14,7 +15,7 @@ public AddressBindContext(
ServerAddressesFeature serverAddressesFeature,
KestrelServerOptions serverOptions,
ILogger logger,
Func<ListenOptions, Task> createBinding)
Func<ListenOptions, CancellationToken, Task> createBinding)
{
ServerAddressesFeature = serverAddressesFeature;
ServerOptions = serverOptions;
Expand All @@ -28,6 +29,6 @@ public AddressBindContext(
public KestrelServerOptions ServerOptions { get; }
public ILogger Logger { get; }

public Func<ListenOptions, Task> CreateBinding { get; }
public Func<ListenOptions, CancellationToken, Task> CreateBinding { get; }
}
}
33 changes: 17 additions & 16 deletions src/Servers/Kestrel/Core/src/Internal/AddressBinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.IO;
using System.Linq;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Connections;
Expand All @@ -20,7 +21,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal
{
internal class AddressBinder
{
public static async Task BindAsync(IEnumerable<ListenOptions> listenOptions, AddressBindContext context)
public static async Task BindAsync(IEnumerable<ListenOptions> listenOptions, AddressBindContext context, CancellationToken cancellationToken)
{
var strategy = CreateStrategy(
listenOptions.ToArray(),
Expand All @@ -32,7 +33,7 @@ public static async Task BindAsync(IEnumerable<ListenOptions> listenOptions, Add
context.ServerOptions.OptionsInUse.Clear();
context.Addresses.Clear();

await strategy.BindAsync(context).ConfigureAwait(false);
await strategy.BindAsync(context, cancellationToken).ConfigureAwait(false);
}

private static IStrategy CreateStrategy(ListenOptions[] listenOptions, string[] addresses, bool preferAddresses)
Expand Down Expand Up @@ -86,11 +87,11 @@ protected internal static bool TryCreateIPEndPoint(BindingAddress address, [NotN
return true;
}

internal static async Task BindEndpointAsync(ListenOptions endpoint, AddressBindContext context)
internal static async Task BindEndpointAsync(ListenOptions endpoint, AddressBindContext context, CancellationToken cancellationToken)
{
try
{
await context.CreateBinding(endpoint).ConfigureAwait(false);
await context.CreateBinding(endpoint, cancellationToken).ConfigureAwait(false);
}
catch (AddressInUseException ex)
{
Expand Down Expand Up @@ -144,24 +145,24 @@ internal static ListenOptions ParseAddress(string address, out bool https)

private interface IStrategy
{
Task BindAsync(AddressBindContext context);
Task BindAsync(AddressBindContext context, CancellationToken cancellationToken);
}

private class DefaultAddressStrategy : IStrategy
{
public async Task BindAsync(AddressBindContext context)
public async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken)
{
var httpDefault = ParseAddress(Constants.DefaultServerAddress, out _);
context.ServerOptions.ApplyEndpointDefaults(httpDefault);
await httpDefault.BindAsync(context).ConfigureAwait(false);
await httpDefault.BindAsync(context, cancellationToken).ConfigureAwait(false);

// Conditional https default, only if a cert is available
var httpsDefault = ParseAddress(Constants.DefaultServerHttpsAddress, out _);
context.ServerOptions.ApplyEndpointDefaults(httpsDefault);

if (httpsDefault.IsTls || httpsDefault.TryUseHttps())
{
await httpsDefault.BindAsync(context).ConfigureAwait(false);
await httpsDefault.BindAsync(context, cancellationToken).ConfigureAwait(false);
context.Logger.LogDebug(CoreStrings.BindingToDefaultAddresses,
Constants.DefaultServerAddress, Constants.DefaultServerHttpsAddress);
}
Expand All @@ -180,12 +181,12 @@ public OverrideWithAddressesStrategy(IReadOnlyCollection<string> addresses)
{
}

public override Task BindAsync(AddressBindContext context)
public override Task BindAsync(AddressBindContext context, CancellationToken cancellationToken)
{
var joined = string.Join(", ", _addresses);
context.Logger.LogInformation(CoreStrings.OverridingWithPreferHostingUrls, nameof(IServerAddressesFeature.PreferHostingUrls), joined);

return base.BindAsync(context);
return base.BindAsync(context, cancellationToken);
}
}

Expand All @@ -199,12 +200,12 @@ public OverrideWithEndpointsStrategy(IReadOnlyCollection<ListenOptions> endpoint
_originalAddresses = originalAddresses;
}

public override Task BindAsync(AddressBindContext context)
public override Task BindAsync(AddressBindContext context, CancellationToken cancellationToken)
{
var joined = string.Join(", ", _originalAddresses);
context.Logger.LogWarning(CoreStrings.OverridingWithKestrelOptions, joined);

return base.BindAsync(context);
return base.BindAsync(context, cancellationToken);
}
}

Expand All @@ -217,11 +218,11 @@ public EndpointsStrategy(IReadOnlyCollection<ListenOptions> endpoints)
_endpoints = endpoints;
}

public virtual async Task BindAsync(AddressBindContext context)
public virtual async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken)
{
foreach (var endpoint in _endpoints)
{
await endpoint.BindAsync(context).ConfigureAwait(false);
await endpoint.BindAsync(context, cancellationToken).ConfigureAwait(false);
}
}
}
Expand All @@ -235,7 +236,7 @@ public AddressesStrategy(IReadOnlyCollection<string> addresses)
_addresses = addresses;
}

public virtual async Task BindAsync(AddressBindContext context)
public virtual async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken)
{
foreach (var address in _addresses)
{
Expand All @@ -247,7 +248,7 @@ public virtual async Task BindAsync(AddressBindContext context)
options.UseHttps();
}

await options.BindAsync(context).ConfigureAwait(false);
await options.BindAsync(context, cancellationToken).ConfigureAwait(false);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,19 @@ public TransportManager(
private ConnectionManager ConnectionManager => _serviceContext.ConnectionManager;
private IKestrelTrace Trace => _serviceContext.Log;

public async Task<EndPoint> BindAsync(EndPoint endPoint, ConnectionDelegate connectionDelegate, EndpointConfig? endpointConfig)
public async Task<EndPoint> BindAsync(EndPoint endPoint, ConnectionDelegate connectionDelegate, EndpointConfig? endpointConfig, CancellationToken cancellationToken)
{
if (_transportFactory is null)
{
throw new InvalidOperationException($"Cannot bind with {nameof(ConnectionDelegate)} no {nameof(IConnectionListenerFactory)} is registered.");
}

var transport = await _transportFactory.BindAsync(endPoint).ConfigureAwait(false);
var transport = await _transportFactory.BindAsync(endPoint, cancellationToken).ConfigureAwait(false);
StartAcceptLoop(new GenericConnectionListener(transport), c => connectionDelegate(c), endpointConfig);
return transport.EndPoint;
}

public async Task<EndPoint> BindAsync(EndPoint endPoint, MultiplexedConnectionDelegate multiplexedConnectionDelegate, ListenOptions listenOptions)
public async Task<EndPoint> BindAsync(EndPoint endPoint, MultiplexedConnectionDelegate multiplexedConnectionDelegate, ListenOptions listenOptions, CancellationToken cancellationToken)
{
if (_multiplexedTransportFactory is null)
{
Expand All @@ -69,7 +69,7 @@ public async Task<EndPoint> BindAsync(EndPoint endPoint, MultiplexedConnectionDe
features.Set(sslServerAuthenticationOptions);
}

var transport = await _multiplexedTransportFactory.BindAsync(endPoint, features).ConfigureAwait(false);
var transport = await _multiplexedTransportFactory.BindAsync(endPoint, features, cancellationToken).ConfigureAwait(false);
StartAcceptLoop(new GenericMultiplexedConnectionListener(transport), c => multiplexedConnectionDelegate(c), listenOptions.EndpointConfig);
return transport.EndPoint;
}
Expand Down
11 changes: 5 additions & 6 deletions src/Servers/Kestrel/Core/src/Internal/KestrelServerImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ public async Task StartAsync<TContext>(IHttpApplication<TContext> application, C

ServiceContext.Heartbeat?.Start();

async Task OnBind(ListenOptions options)
async Task OnBind(ListenOptions options, CancellationToken onBindCancellationToken)
{
// INVESTIGATE: For some reason, MsQuic needs to bind before
// sockets for it to successfully listen. It also seems racy.
Expand All @@ -177,7 +177,7 @@ async Task OnBind(ListenOptions options)
// Add the connection limit middleware
multiplexedConnectionDelegate = EnforceConnectionLimit(multiplexedConnectionDelegate, Options.Limits.MaxConcurrentConnections, Trace);

options.EndPoint = await _transportManager.BindAsync(options.EndPoint, multiplexedConnectionDelegate, options).ConfigureAwait(false);
options.EndPoint = await _transportManager.BindAsync(options.EndPoint, multiplexedConnectionDelegate, options, onBindCancellationToken).ConfigureAwait(false);
}

// Add the HTTP middleware as the terminal connection middleware
Expand All @@ -197,7 +197,7 @@ async Task OnBind(ListenOptions options)
// Add the connection limit middleware
connectionDelegate = EnforceConnectionLimit(connectionDelegate, Options.Limits.MaxConcurrentConnections, Trace);

options.EndPoint = await _transportManager.BindAsync(options.EndPoint, connectionDelegate, options.EndpointConfig).ConfigureAwait(false);
options.EndPoint = await _transportManager.BindAsync(options.EndPoint, connectionDelegate, options.EndpointConfig, onBindCancellationToken).ConfigureAwait(false);
}
}

Expand Down Expand Up @@ -275,7 +275,7 @@ private async Task BindAsync(CancellationToken cancellationToken)

Options.ConfigurationLoader?.Load();

await AddressBinder.BindAsync(Options.ListenOptions, AddressBindContext!).ConfigureAwait(false);
await AddressBinder.BindAsync(Options.ListenOptions, AddressBindContext!, cancellationToken).ConfigureAwait(false);
_configChangedRegistration = reloadToken?.RegisterChangeCallback(TriggerRebind, this);
}
finally
Expand Down Expand Up @@ -342,8 +342,7 @@ private async Task RebindAsync()
{
try
{
// TODO: This should probably be canceled by the _stopCts too, but we don't currently support bind cancellation even in StartAsync().
await listenOption.BindAsync(AddressBindContext!).ConfigureAwait(false);
await listenOption.BindAsync(AddressBindContext!, _stopCts.Token).ConfigureAwait(false);
}
catch (Exception ex)
{
Expand Down
5 changes: 3 additions & 2 deletions src/Servers/Kestrel/Core/src/ListenOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Experimental;
Expand Down Expand Up @@ -176,9 +177,9 @@ MultiplexedConnectionDelegate IMultiplexedConnectionBuilder.Build()
return app;
}

internal virtual async Task BindAsync(AddressBindContext context)
internal virtual async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken)
{
await AddressBinder.BindEndpointAsync(this, context).ConfigureAwait(false);
await AddressBinder.BindEndpointAsync(this, context, cancellationToken).ConfigureAwait(false);
context.Addresses.Add(GetDisplayName());
}
}
Expand Down
11 changes: 6 additions & 5 deletions src/Servers/Kestrel/Core/src/LocalhostListenOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
using Microsoft.Extensions.Logging;
Expand All @@ -30,16 +31,16 @@ internal override string GetDisplayName()
return $"{Scheme}://localhost:{IPEndPoint!.Port}";
}

internal override async Task BindAsync(AddressBindContext context)
internal override async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken)
{
var exceptions = new List<Exception>();

try
{
var v4Options = Clone(IPAddress.Loopback);
await AddressBinder.BindEndpointAsync(v4Options, context).ConfigureAwait(false);
await AddressBinder.BindEndpointAsync(v4Options, context, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex) when (!(ex is IOException))
catch (Exception ex) when (!(ex is IOException or OperationCanceledException))
{
context.Logger.LogInformation(0, CoreStrings.NetworkInterfaceBindingFailed, GetDisplayName(), "IPv4 loopback", ex.Message);
exceptions.Add(ex);
Expand All @@ -48,9 +49,9 @@ internal override async Task BindAsync(AddressBindContext context)
try
{
var v6Options = Clone(IPAddress.IPv6Loopback);
await AddressBinder.BindEndpointAsync(v6Options, context).ConfigureAwait(false);
await AddressBinder.BindEndpointAsync(v6Options, context, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex) when (!(ex is IOException))
catch (Exception ex) when (!(ex is IOException or OperationCanceledException))
{
context.Logger.LogInformation(0, CoreStrings.NetworkInterfaceBindingFailed, GetDisplayName(), "IPv6 loopback", ex.Message);
exceptions.Add(ex);
Expand Down
Loading