Skip to content

Commit

Permalink
Flow Kestrel CancellationTokens to BindAsync
Browse files Browse the repository at this point in the history
  • Loading branch information
halter73 committed Mar 30, 2021
1 parent b00ae1b commit 8caa8e1
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 43 deletions.
7 changes: 4 additions & 3 deletions src/Servers/Kestrel/Core/src/AnyIPListenOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Diagnostics;
using System.IO;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
using Microsoft.Extensions.Logging;
Expand All @@ -18,22 +19,22 @@ internal AnyIPListenOptions(int port)
{
}

internal override async Task BindAsync(AddressBindContext context)
internal override async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken)
{
Debug.Assert(IPEndPoint != null);

// when address is 'http://hostname:port', 'http://*:port', or 'http://+:port'
try
{
await base.BindAsync(context).ConfigureAwait(false);
await base.BindAsync(context, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex) when (!(ex is IOException))
{
context.Logger.LogDebug(CoreStrings.FormatFallbackToIPv4Any(IPEndPoint.Port));

// for machines that do not support IPv6
EndPoint = new IPEndPoint(IPAddress.Any, IPEndPoint.Port);
await base.BindAsync(context).ConfigureAwait(false);
await base.BindAsync(context, cancellationToken).ConfigureAwait(false);
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/Servers/Kestrel/Core/src/Internal/AddressBindContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;

Expand All @@ -14,7 +15,7 @@ public AddressBindContext(
ServerAddressesFeature serverAddressesFeature,
KestrelServerOptions serverOptions,
ILogger logger,
Func<ListenOptions, Task> createBinding)
Func<ListenOptions, CancellationToken, Task> createBinding)
{
ServerAddressesFeature = serverAddressesFeature;
ServerOptions = serverOptions;
Expand All @@ -28,6 +29,6 @@ public AddressBindContext(
public KestrelServerOptions ServerOptions { get; }
public ILogger Logger { get; }

public Func<ListenOptions, Task> CreateBinding { get; }
public Func<ListenOptions, CancellationToken, Task> CreateBinding { get; }
}
}
33 changes: 17 additions & 16 deletions src/Servers/Kestrel/Core/src/Internal/AddressBinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.IO;
using System.Linq;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Connections;
Expand All @@ -20,7 +21,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal
{
internal class AddressBinder
{
public static async Task BindAsync(IEnumerable<ListenOptions> listenOptions, AddressBindContext context)
public static async Task BindAsync(IEnumerable<ListenOptions> listenOptions, AddressBindContext context, CancellationToken cancellationToken)
{
var strategy = CreateStrategy(
listenOptions.ToArray(),
Expand All @@ -32,7 +33,7 @@ public static async Task BindAsync(IEnumerable<ListenOptions> listenOptions, Add
context.ServerOptions.OptionsInUse.Clear();
context.Addresses.Clear();

await strategy.BindAsync(context).ConfigureAwait(false);
await strategy.BindAsync(context, cancellationToken).ConfigureAwait(false);
}

private static IStrategy CreateStrategy(ListenOptions[] listenOptions, string[] addresses, bool preferAddresses)
Expand Down Expand Up @@ -86,11 +87,11 @@ protected internal static bool TryCreateIPEndPoint(BindingAddress address, [NotN
return true;
}

internal static async Task BindEndpointAsync(ListenOptions endpoint, AddressBindContext context)
internal static async Task BindEndpointAsync(ListenOptions endpoint, AddressBindContext context, CancellationToken cancellationToken)
{
try
{
await context.CreateBinding(endpoint).ConfigureAwait(false);
await context.CreateBinding(endpoint, cancellationToken).ConfigureAwait(false);
}
catch (AddressInUseException ex)
{
Expand Down Expand Up @@ -144,24 +145,24 @@ internal static ListenOptions ParseAddress(string address, out bool https)

private interface IStrategy
{
Task BindAsync(AddressBindContext context);
Task BindAsync(AddressBindContext context, CancellationToken cancellationToken);
}

private class DefaultAddressStrategy : IStrategy
{
public async Task BindAsync(AddressBindContext context)
public async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken)
{
var httpDefault = ParseAddress(Constants.DefaultServerAddress, out _);
context.ServerOptions.ApplyEndpointDefaults(httpDefault);
await httpDefault.BindAsync(context).ConfigureAwait(false);
await httpDefault.BindAsync(context, cancellationToken).ConfigureAwait(false);

// Conditional https default, only if a cert is available
var httpsDefault = ParseAddress(Constants.DefaultServerHttpsAddress, out _);
context.ServerOptions.ApplyEndpointDefaults(httpsDefault);

if (httpsDefault.IsTls || httpsDefault.TryUseHttps())
{
await httpsDefault.BindAsync(context).ConfigureAwait(false);
await httpsDefault.BindAsync(context, cancellationToken).ConfigureAwait(false);
context.Logger.LogDebug(CoreStrings.BindingToDefaultAddresses,
Constants.DefaultServerAddress, Constants.DefaultServerHttpsAddress);
}
Expand All @@ -180,12 +181,12 @@ public OverrideWithAddressesStrategy(IReadOnlyCollection<string> addresses)
{
}

public override Task BindAsync(AddressBindContext context)
public override Task BindAsync(AddressBindContext context, CancellationToken cancellationToken)
{
var joined = string.Join(", ", _addresses);
context.Logger.LogInformation(CoreStrings.OverridingWithPreferHostingUrls, nameof(IServerAddressesFeature.PreferHostingUrls), joined);

return base.BindAsync(context);
return base.BindAsync(context, cancellationToken);
}
}

Expand All @@ -199,12 +200,12 @@ public OverrideWithEndpointsStrategy(IReadOnlyCollection<ListenOptions> endpoint
_originalAddresses = originalAddresses;
}

public override Task BindAsync(AddressBindContext context)
public override Task BindAsync(AddressBindContext context, CancellationToken cancellationToken)
{
var joined = string.Join(", ", _originalAddresses);
context.Logger.LogWarning(CoreStrings.OverridingWithKestrelOptions, joined);

return base.BindAsync(context);
return base.BindAsync(context, cancellationToken);
}
}

Expand All @@ -217,11 +218,11 @@ public EndpointsStrategy(IReadOnlyCollection<ListenOptions> endpoints)
_endpoints = endpoints;
}

public virtual async Task BindAsync(AddressBindContext context)
public virtual async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken)
{
foreach (var endpoint in _endpoints)
{
await endpoint.BindAsync(context).ConfigureAwait(false);
await endpoint.BindAsync(context, cancellationToken).ConfigureAwait(false);
}
}
}
Expand All @@ -235,7 +236,7 @@ public AddressesStrategy(IReadOnlyCollection<string> addresses)
_addresses = addresses;
}

public virtual async Task BindAsync(AddressBindContext context)
public virtual async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken)
{
foreach (var address in _addresses)
{
Expand All @@ -247,7 +248,7 @@ public virtual async Task BindAsync(AddressBindContext context)
options.UseHttps();
}

await options.BindAsync(context).ConfigureAwait(false);
await options.BindAsync(context, cancellationToken).ConfigureAwait(false);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,19 @@ public TransportManager(
private ConnectionManager ConnectionManager => _serviceContext.ConnectionManager;
private IKestrelTrace Trace => _serviceContext.Log;

public async Task<EndPoint> BindAsync(EndPoint endPoint, ConnectionDelegate connectionDelegate, EndpointConfig? endpointConfig)
public async Task<EndPoint> BindAsync(EndPoint endPoint, ConnectionDelegate connectionDelegate, EndpointConfig? endpointConfig, CancellationToken cancellationToken)
{
if (_transportFactory is null)
{
throw new InvalidOperationException($"Cannot bind with {nameof(ConnectionDelegate)} no {nameof(IConnectionListenerFactory)} is registered.");
}

var transport = await _transportFactory.BindAsync(endPoint).ConfigureAwait(false);
var transport = await _transportFactory.BindAsync(endPoint, cancellationToken).ConfigureAwait(false);
StartAcceptLoop(new GenericConnectionListener(transport), c => connectionDelegate(c), endpointConfig);
return transport.EndPoint;
}

public async Task<EndPoint> BindAsync(EndPoint endPoint, MultiplexedConnectionDelegate multiplexedConnectionDelegate, ListenOptions listenOptions)
public async Task<EndPoint> BindAsync(EndPoint endPoint, MultiplexedConnectionDelegate multiplexedConnectionDelegate, ListenOptions listenOptions, CancellationToken cancellationToken)
{
if (_multiplexedTransportFactory is null)
{
Expand All @@ -69,7 +69,7 @@ public async Task<EndPoint> BindAsync(EndPoint endPoint, MultiplexedConnectionDe
features.Set(sslServerAuthenticationOptions);
}

var transport = await _multiplexedTransportFactory.BindAsync(endPoint, features).ConfigureAwait(false);
var transport = await _multiplexedTransportFactory.BindAsync(endPoint, features, cancellationToken).ConfigureAwait(false);
StartAcceptLoop(new GenericMultiplexedConnectionListener(transport), c => multiplexedConnectionDelegate(c), listenOptions.EndpointConfig);
return transport.EndPoint;
}
Expand Down
11 changes: 5 additions & 6 deletions src/Servers/Kestrel/Core/src/Internal/KestrelServerImpl.cs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ public async Task StartAsync<TContext>(IHttpApplication<TContext> application, C

ServiceContext.Heartbeat?.Start();

async Task OnBind(ListenOptions options)
async Task OnBind(ListenOptions options, CancellationToken onBindCancellationToken)
{
// INVESTIGATE: For some reason, MsQuic needs to bind before
// sockets for it to successfully listen. It also seems racy.
Expand All @@ -177,7 +177,7 @@ async Task OnBind(ListenOptions options)
// Add the connection limit middleware
multiplexedConnectionDelegate = EnforceConnectionLimit(multiplexedConnectionDelegate, Options.Limits.MaxConcurrentConnections, Trace);

options.EndPoint = await _transportManager.BindAsync(options.EndPoint, multiplexedConnectionDelegate, options).ConfigureAwait(false);
options.EndPoint = await _transportManager.BindAsync(options.EndPoint, multiplexedConnectionDelegate, options, onBindCancellationToken).ConfigureAwait(false);
}

// Add the HTTP middleware as the terminal connection middleware
Expand All @@ -197,7 +197,7 @@ async Task OnBind(ListenOptions options)
// Add the connection limit middleware
connectionDelegate = EnforceConnectionLimit(connectionDelegate, Options.Limits.MaxConcurrentConnections, Trace);

options.EndPoint = await _transportManager.BindAsync(options.EndPoint, connectionDelegate, options.EndpointConfig).ConfigureAwait(false);
options.EndPoint = await _transportManager.BindAsync(options.EndPoint, connectionDelegate, options.EndpointConfig, onBindCancellationToken).ConfigureAwait(false);
}
}

Expand Down Expand Up @@ -275,7 +275,7 @@ private async Task BindAsync(CancellationToken cancellationToken)

Options.ConfigurationLoader?.Load();

await AddressBinder.BindAsync(Options.ListenOptions, AddressBindContext!).ConfigureAwait(false);
await AddressBinder.BindAsync(Options.ListenOptions, AddressBindContext!, cancellationToken).ConfigureAwait(false);
_configChangedRegistration = reloadToken?.RegisterChangeCallback(TriggerRebind, this);
}
finally
Expand Down Expand Up @@ -342,8 +342,7 @@ private async Task RebindAsync()
{
try
{
// TODO: This should probably be canceled by the _stopCts too, but we don't currently support bind cancellation even in StartAsync().
await listenOption.BindAsync(AddressBindContext!).ConfigureAwait(false);
await listenOption.BindAsync(AddressBindContext!, _stopCts.Token).ConfigureAwait(false);
}
catch (Exception ex)
{
Expand Down
5 changes: 3 additions & 2 deletions src/Servers/Kestrel/Core/src/ListenOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Experimental;
Expand Down Expand Up @@ -176,9 +177,9 @@ MultiplexedConnectionDelegate IMultiplexedConnectionBuilder.Build()
return app;
}

internal virtual async Task BindAsync(AddressBindContext context)
internal virtual async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken)
{
await AddressBinder.BindEndpointAsync(this, context).ConfigureAwait(false);
await AddressBinder.BindEndpointAsync(this, context, cancellationToken).ConfigureAwait(false);
context.Addresses.Add(GetDisplayName());
}
}
Expand Down
11 changes: 6 additions & 5 deletions src/Servers/Kestrel/Core/src/LocalhostListenOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
using Microsoft.Extensions.Logging;
Expand All @@ -30,16 +31,16 @@ internal override string GetDisplayName()
return $"{Scheme}://localhost:{IPEndPoint!.Port}";
}

internal override async Task BindAsync(AddressBindContext context)
internal override async Task BindAsync(AddressBindContext context, CancellationToken cancellationToken)
{
var exceptions = new List<Exception>();

try
{
var v4Options = Clone(IPAddress.Loopback);
await AddressBinder.BindEndpointAsync(v4Options, context).ConfigureAwait(false);
await AddressBinder.BindEndpointAsync(v4Options, context, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex) when (!(ex is IOException))
catch (Exception ex) when (!(ex is IOException or OperationCanceledException))
{
context.Logger.LogInformation(0, CoreStrings.NetworkInterfaceBindingFailed, GetDisplayName(), "IPv4 loopback", ex.Message);
exceptions.Add(ex);
Expand All @@ -48,9 +49,9 @@ internal override async Task BindAsync(AddressBindContext context)
try
{
var v6Options = Clone(IPAddress.IPv6Loopback);
await AddressBinder.BindEndpointAsync(v6Options, context).ConfigureAwait(false);
await AddressBinder.BindEndpointAsync(v6Options, context, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex) when (!(ex is IOException))
catch (Exception ex) when (!(ex is IOException or OperationCanceledException))
{
context.Logger.LogInformation(0, CoreStrings.NetworkInterfaceBindingFailed, GetDisplayName(), "IPv6 loopback", ex.Message);
exceptions.Add(ex);
Expand Down
Loading

0 comments on commit 8caa8e1

Please sign in to comment.