diff --git a/AspNetCore.slnx b/AspNetCore.slnx index 76933202a2d8..42988a869787 100644 --- a/AspNetCore.slnx +++ b/AspNetCore.slnx @@ -868,6 +868,7 @@ + diff --git a/src/Servers/HttpSys/HttpSysServer.slnf b/src/Servers/HttpSys/HttpSysServer.slnf index d950668e34a8..7c081bb272c4 100644 --- a/src/Servers/HttpSys/HttpSysServer.slnf +++ b/src/Servers/HttpSys/HttpSysServer.slnf @@ -38,6 +38,7 @@ "src\\Servers\\HttpSys\\samples\\QueueSharing\\QueueSharing.csproj", "src\\Servers\\HttpSys\\samples\\SelfHostServer\\SelfHostServer.csproj", "src\\Servers\\HttpSys\\samples\\TestClient\\TestClient.csproj", + "src\\Servers\\HttpSys\\samples\\TlsFeaturesObserve\\TlsFeaturesObserve.csproj", "src\\Servers\\HttpSys\\src\\Microsoft.AspNetCore.Server.HttpSys.csproj", "src\\Servers\\HttpSys\\test\\FunctionalTests\\Microsoft.AspNetCore.Server.HttpSys.FunctionalTests.csproj", "src\\Servers\\HttpSys\\test\\NonHelixTests\\Microsoft.AspNetCore.Server.HttpSys.NonHelixTests.csproj", diff --git a/src/Servers/HttpSys/samples/TlsFeaturesObserve/HttpSys/HttpSysConfigurator.cs b/src/Servers/HttpSys/samples/TlsFeaturesObserve/HttpSys/HttpSysConfigurator.cs new file mode 100644 index 000000000000..3865ecd59451 --- /dev/null +++ b/src/Servers/HttpSys/samples/TlsFeaturesObserve/HttpSys/HttpSysConfigurator.cs @@ -0,0 +1,123 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using System.Runtime.InteropServices; + +namespace TlsFeaturesObserve.HttpSys; + +internal static class HttpSysConfigurator +{ + const uint HTTP_INITIALIZE_CONFIG = 0x00000002; + const uint ERROR_ALREADY_EXISTS = 183; + + static readonly HTTPAPI_VERSION HttpApiVersion = new HTTPAPI_VERSION(1, 0); + + internal static void ConfigureCacheTlsClientHello() + { + // Arbitrarily chosen port, but must match the port used in the web server. Via UrlPrefixes or launchsettings. + var ipPort = new IPEndPoint(new IPAddress([0, 0, 0, 0]), 6000); + var certThumbprint = "" /* your cert thumbprint here */; + var appId = Guid.NewGuid(); + var sslCertStoreName = "My"; + + CallHttpApi(() => SetConfiguration(ipPort, certThumbprint, appId, sslCertStoreName)); + } + + static void SetConfiguration(IPEndPoint ipPort, string certThumbprint, Guid appId, string sslCertStoreName) + { + var sockAddrHandle = CreateSockaddrStructure(ipPort); + var pIpPort = sockAddrHandle.AddrOfPinnedObject(); + var httpServiceConfigSslKey = new HTTP_SERVICE_CONFIG_SSL_KEY(pIpPort); + + var hash = GetHash(certThumbprint); + var handleHash = GCHandle.Alloc(hash, GCHandleType.Pinned); + var configSslParam = new HTTP_SERVICE_CONFIG_SSL_PARAM + { + AppId = appId, + DefaultFlags = 0x00008000 /* HTTP_SERVICE_CONFIG_SSL_FLAG_ENABLE_CACHE_CLIENT_HELLO */, + DefaultRevocationFreshnessTime = 0, + DefaultRevocationUrlRetrievalTimeout = 15, + pSslCertStoreName = sslCertStoreName, + pSslHash = handleHash.AddrOfPinnedObject(), + SslHashLength = hash.Length, + pDefaultSslCtlIdentifier = null, + pDefaultSslCtlStoreName = sslCertStoreName + }; + + var configSslSet = new HTTP_SERVICE_CONFIG_SSL_SET + { + ParamDesc = configSslParam, + KeyDesc = httpServiceConfigSslKey + }; + + var pInputConfigInfo = Marshal.AllocCoTaskMem( + Marshal.SizeOf(typeof(HTTP_SERVICE_CONFIG_SSL_SET))); + Marshal.StructureToPtr(configSslSet, pInputConfigInfo, false); + + var status = HttpSetServiceConfiguration(nint.Zero, + HTTP_SERVICE_CONFIG_ID.HttpServiceConfigSSLCertInfo, + pInputConfigInfo, + Marshal.SizeOf(configSslSet), + nint.Zero); + + if (status == ERROR_ALREADY_EXISTS || status == 0) // already present or success + { + Console.WriteLine($"HttpServiceConfiguration is correct"); + } + else + { + Console.WriteLine("Failed to HttpSetServiceConfiguration: " + status); + } + } + + static byte[] GetHash(string thumbprint) + { + var length = thumbprint.Length; + var bytes = new byte[length / 2]; + for (var i = 0; i < length; i += 2) + { + bytes[i / 2] = Convert.ToByte(thumbprint.Substring(i, 2), 16); + } + + return bytes; + } + + static GCHandle CreateSockaddrStructure(IPEndPoint ipEndPoint) + { + var socketAddress = ipEndPoint.Serialize(); + + // use an array of bytes instead of the sockaddr structure + var sockAddrStructureBytes = new byte[socketAddress.Size]; + var sockAddrHandle = GCHandle.Alloc(sockAddrStructureBytes, GCHandleType.Pinned); + for (var i = 0; i < socketAddress.Size; ++i) + { + sockAddrStructureBytes[i] = socketAddress[i]; + } + return sockAddrHandle; + } + + static void CallHttpApi(Action body) + { + const uint flags = HTTP_INITIALIZE_CONFIG; + var retVal = HttpInitialize(HttpApiVersion, flags, IntPtr.Zero); + body(); + } + +// disabled warning since it is just a sample +#pragma warning disable SYSLIB1054 // Use 'LibraryImportAttribute' instead of 'DllImportAttribute' to generate P/Invoke marshalling code at compile time + [DllImport("httpapi.dll", SetLastError = true)] + private static extern uint HttpInitialize( + HTTPAPI_VERSION version, + uint flags, + IntPtr pReserved); + + [DllImport("httpapi.dll", SetLastError = true)] + public static extern uint HttpSetServiceConfiguration( + nint serviceIntPtr, + HTTP_SERVICE_CONFIG_ID configId, + nint pConfigInformation, + int configInformationLength, + nint pOverlapped); +#pragma warning restore SYSLIB1054 // Use 'LibraryImportAttribute' instead of 'DllImportAttribute' to generate P/Invoke marshalling code at compile time +} diff --git a/src/Servers/HttpSys/samples/TlsFeaturesObserve/HttpSys/Native.cs b/src/Servers/HttpSys/samples/TlsFeaturesObserve/HttpSys/Native.cs new file mode 100644 index 000000000000..b939163d2252 --- /dev/null +++ b/src/Servers/HttpSys/samples/TlsFeaturesObserve/HttpSys/Native.cs @@ -0,0 +1,97 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace TlsFeaturesObserve.HttpSys; + +// Http.Sys types from https://learn.microsoft.com/windows/win32/api/http/ + +[StructLayout(LayoutKind.Sequential, Pack = 2)] +public struct HTTPAPI_VERSION +{ + public ushort HttpApiMajorVersion; + public ushort HttpApiMinorVersion; + + public HTTPAPI_VERSION(ushort majorVersion, ushort minorVersion) + { + HttpApiMajorVersion = majorVersion; + HttpApiMinorVersion = minorVersion; + } +} + +public enum HTTP_SERVICE_CONFIG_ID +{ + HttpServiceConfigIPListenList = 0, + HttpServiceConfigSSLCertInfo, + HttpServiceConfigUrlAclInfo, + HttpServiceConfigMax +} + +[StructLayout(LayoutKind.Sequential)] +public struct HTTP_SERVICE_CONFIG_SSL_SET +{ + public HTTP_SERVICE_CONFIG_SSL_KEY KeyDesc; + public HTTP_SERVICE_CONFIG_SSL_PARAM ParamDesc; +} + +[StructLayout(LayoutKind.Sequential)] +public struct HTTP_SERVICE_CONFIG_SSL_KEY +{ + public IntPtr pIpPort; + + public HTTP_SERVICE_CONFIG_SSL_KEY(IntPtr pIpPort) + { + this.pIpPort = pIpPort; + } +} + +[StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] +public struct HTTP_SERVICE_CONFIG_SSL_PARAM +{ + public int SslHashLength; + public IntPtr pSslHash; + public Guid AppId; + [MarshalAs(UnmanagedType.LPWStr)] + public string pSslCertStoreName; + public CertCheckModes DefaultCertCheckMode; + public int DefaultRevocationFreshnessTime; + public int DefaultRevocationUrlRetrievalTimeout; + [MarshalAs(UnmanagedType.LPWStr)] + public string pDefaultSslCtlIdentifier; + [MarshalAs(UnmanagedType.LPWStr)] + public string pDefaultSslCtlStoreName; + public uint DefaultFlags; // HTTP_SERVICE_CONFIG_SSL_FLAG +} + +[Flags] +public enum CertCheckModes : uint +{ + /// + /// Enables the client certificate revocation check. + /// + None = 0, + + /// + /// Client certificate is not to be verified for revocation. + /// + DoNotVerifyCertificateRevocation = 1, + + /// + /// Only cached certificate is to be used the revocation check. + /// + VerifyRevocationWithCachedCertificateOnly = 2, + + /// + /// The RevocationFreshnessTime setting is enabled. + /// + EnableRevocationFreshnessTime = 4, + + /// + /// No usage check is to be performed. + /// + NoUsageCheck = 0x10000 +} diff --git a/src/Servers/HttpSys/samples/TlsFeaturesObserve/Program.cs b/src/Servers/HttpSys/samples/TlsFeaturesObserve/Program.cs new file mode 100644 index 000000000000..3742211a9a8d --- /dev/null +++ b/src/Servers/HttpSys/samples/TlsFeaturesObserve/Program.cs @@ -0,0 +1,60 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Reflection; +using System.Runtime.InteropServices; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.HttpSys; +using Microsoft.Extensions.Hosting; +using TlsFeatureObserve; +using TlsFeaturesObserve.HttpSys; + +HttpSysConfigurator.ConfigureCacheTlsClientHello(); +CreateHostBuilder(args).Build().Run(); + +static IHostBuilder CreateHostBuilder(string[] args) => + Host.CreateDefaultBuilder(args) + .ConfigureWebHost(webBuilder => + { + webBuilder.UseStartup() + .UseHttpSys(options => + { + // If you want to use https locally: https://stackoverflow.com/a/51841893 + options.UrlPrefixes.Add("https://*:6000"); // HTTPS + + options.Authentication.Schemes = AuthenticationSchemes.None; + options.Authentication.AllowAnonymous = true; + + options.TlsClientHelloBytesCallback = ProcessTlsClientHello; + }); + }); + +static void ProcessTlsClientHello(IFeatureCollection features, ReadOnlySpan tlsClientHelloBytes) +{ + var httpConnectionFeature = features.Get(); + + var myTlsFeature = new MyTlsFeature( + connectionId: httpConnectionFeature.ConnectionId, + tlsClientHelloLength: tlsClientHelloBytes.Length); + + features.Set(myTlsFeature); +} + +public interface IMyTlsFeature +{ + string ConnectionId { get; } + int TlsClientHelloLength { get; } +} + +public class MyTlsFeature : IMyTlsFeature +{ + public string ConnectionId { get; } + public int TlsClientHelloLength { get; } + + public MyTlsFeature(string connectionId, int tlsClientHelloLength) + { + ConnectionId = connectionId; + TlsClientHelloLength = tlsClientHelloLength; + } +} diff --git a/src/Servers/HttpSys/samples/TlsFeaturesObserve/Properties/launchSettings.json b/src/Servers/HttpSys/samples/TlsFeaturesObserve/Properties/launchSettings.json new file mode 100644 index 000000000000..c9d6b5efcb3c --- /dev/null +++ b/src/Servers/HttpSys/samples/TlsFeaturesObserve/Properties/launchSettings.json @@ -0,0 +1,10 @@ +{ + "profiles": { + "TlsFeaturesObserve": { + "commandName": "Project", + "launchBrowser": true, + "applicationUrl": "http://localhost:5000", + "nativeDebugging": true + } + } +} \ No newline at end of file diff --git a/src/Servers/HttpSys/samples/TlsFeaturesObserve/Startup.cs b/src/Servers/HttpSys/samples/TlsFeaturesObserve/Startup.cs new file mode 100644 index 000000000000..8ba6d27aef98 --- /dev/null +++ b/src/Servers/HttpSys/samples/TlsFeaturesObserve/Startup.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.HttpSys; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +namespace TlsFeatureObserve; + +public class Startup +{ + public void Configure(IApplicationBuilder app) + { + app.Run(async (HttpContext context) => + { + context.Response.ContentType = "text/plain"; + + var tlsFeature = context.Features.Get(); + await context.Response.WriteAsync("TlsClientHello data: " + $"connectionId={tlsFeature?.ConnectionId}; length={tlsFeature?.TlsClientHelloLength}"); + }); + } +} diff --git a/src/Servers/HttpSys/samples/TlsFeaturesObserve/TlsFeaturesObserve.csproj b/src/Servers/HttpSys/samples/TlsFeaturesObserve/TlsFeaturesObserve.csproj new file mode 100644 index 000000000000..f65f8a98a72a --- /dev/null +++ b/src/Servers/HttpSys/samples/TlsFeaturesObserve/TlsFeaturesObserve.csproj @@ -0,0 +1,14 @@ + + + + $(DefaultNetCoreTargetFramework) + Exe + true + + + + + + + + diff --git a/src/Servers/HttpSys/src/HttpSysListener.cs b/src/Servers/HttpSys/src/HttpSysListener.cs index 933c627cd56e..0f168e53bd2f 100644 --- a/src/Servers/HttpSys/src/HttpSysListener.cs +++ b/src/Servers/HttpSys/src/HttpSysListener.cs @@ -5,6 +5,7 @@ using System.Diagnostics; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.HttpSys.Internal; +using Microsoft.AspNetCore.Server.HttpSys.RequestProcessing; using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.Logging; using Windows.Win32; @@ -41,6 +42,7 @@ internal sealed partial class HttpSysListener : IDisposable private readonly UrlGroup _urlGroup; private readonly RequestQueue _requestQueue; private readonly DisconnectListener _disconnectListener; + private readonly TlsListener? _tlsListener; private readonly object _internalLock; @@ -71,12 +73,14 @@ public HttpSysListener(HttpSysOptions options, ILoggerFactory loggerFactory) try { _serverSession = new ServerSession(); - _requestQueue = new RequestQueue(options.RequestQueueName, options.RequestQueueMode, Logger); - _urlGroup = new UrlGroup(_serverSession, _requestQueue, Logger); _disconnectListener = new DisconnectListener(_requestQueue, Logger); + if (options.TlsClientHelloBytesCallback is not null) + { + _tlsListener = new TlsListener(Logger, options.TlsClientHelloBytesCallback); + } } catch (Exception exception) { @@ -84,6 +88,7 @@ public HttpSysListener(HttpSysOptions options, ILoggerFactory loggerFactory) _requestQueue?.Dispose(); _urlGroup?.Dispose(); _serverSession?.Dispose(); + _tlsListener?.Dispose(); Log.HttpSysListenerCtorError(Logger, exception); throw; } @@ -98,20 +103,10 @@ internal enum State internal ILogger Logger { get; private set; } - internal UrlGroup UrlGroup - { - get { return _urlGroup; } - } - - internal RequestQueue RequestQueue - { - get { return _requestQueue; } - } - - internal DisconnectListener DisconnectListener - { - get { return _disconnectListener; } - } + internal UrlGroup UrlGroup => _urlGroup; + internal RequestQueue RequestQueue => _requestQueue; + internal TlsListener? TlsListener => _tlsListener; + internal DisconnectListener DisconnectListener => _disconnectListener; public HttpSysOptions Options { get; } @@ -264,6 +259,7 @@ private void DisposeInternal() Debug.Assert(!_serverSession.Id.IsInvalid, "ServerSessionHandle is invalid in CloseV2Config"); _serverSession.Dispose(); + _tlsListener?.Dispose(); } /// diff --git a/src/Servers/HttpSys/src/HttpSysOptions.cs b/src/Servers/HttpSys/src/HttpSysOptions.cs index 44bc0bc5faa8..3e83e10212f9 100644 --- a/src/Servers/HttpSys/src/HttpSysOptions.cs +++ b/src/Servers/HttpSys/src/HttpSysOptions.cs @@ -237,10 +237,21 @@ public Http503VerbosityLevel Http503Verbosity /// Configures request headers to use encoding. /// /// - /// Defaults to `false`, in which case will be used. />. + /// Defaults to false, in which case will be used. />. /// public bool UseLatin1RequestHeaders { get; set; } + /// + /// A callback to be invoked to get the TLS client hello bytes. + /// Null by default. + /// + /// + /// Works only if HTTP_SERVICE_CONFIG_SSL_FLAG_ENABLE_CACHE_CLIENT_HELLO flag is set on http.sys service configuration. + /// See + /// and + /// + public Action>? TlsClientHelloBytesCallback { get; set; } + // Not called when attaching to an existing queue. internal void Apply(UrlGroup urlGroup, RequestQueue? requestQueue) { diff --git a/src/Servers/HttpSys/src/LoggerEventIds.cs b/src/Servers/HttpSys/src/LoggerEventIds.cs index 5bc0b6b65ed6..e6d745f506be 100644 --- a/src/Servers/HttpSys/src/LoggerEventIds.cs +++ b/src/Servers/HttpSys/src/LoggerEventIds.cs @@ -59,4 +59,5 @@ internal static class LoggerEventIds public const int AcceptCancelExpectationMismatch = 52; public const int AcceptObserveExpectationMismatch = 53; public const int RequestParsingError = 54; + public const int TlsListenerError = 55; } diff --git a/src/Servers/HttpSys/src/NativeInterop/ErrorCodes.cs b/src/Servers/HttpSys/src/NativeInterop/ErrorCodes.cs index d66397390430..b8263ef12a1a 100644 --- a/src/Servers/HttpSys/src/NativeInterop/ErrorCodes.cs +++ b/src/Servers/HttpSys/src/NativeInterop/ErrorCodes.cs @@ -13,6 +13,7 @@ internal static class ErrorCodes internal const uint ERROR_HANDLE_EOF = 38; internal const uint ERROR_NOT_SUPPORTED = 50; internal const uint ERROR_INVALID_PARAMETER = 87; + internal const uint ERROR_INSUFFICIENT_BUFFER = 122; internal const uint ERROR_INVALID_NAME = 123; internal const uint ERROR_ALREADY_EXISTS = 183; internal const uint ERROR_MORE_DATA = 234; diff --git a/src/Servers/HttpSys/src/NativeInterop/HttpApi.cs b/src/Servers/HttpSys/src/NativeInterop/HttpApi.cs index 1fec2ffea7e6..f5dfbc96a6cd 100644 --- a/src/Servers/HttpSys/src/NativeInterop/HttpApi.cs +++ b/src/Servers/HttpSys/src/NativeInterop/HttpApi.cs @@ -35,22 +35,41 @@ internal static partial class HttpApi internal static partial uint CancelIoEx(SafeHandle handle, SafeNativeOverlapped overlapped); internal unsafe delegate uint HttpGetRequestPropertyInvoker(SafeHandle requestQueueHandle, ulong requestId, HTTP_REQUEST_PROPERTY propertyId, - void* qualifier, uint qualifierSize, void* output, uint outputSize, uint* bytesReturned, IntPtr overlapped); + void* qualifier, uint qualifierSize, void* output, uint outputSize, IntPtr bytesReturned, IntPtr overlapped); - internal unsafe delegate uint HttpSetRequestPropertyInvoker(SafeHandle requestQueueHandle, ulong requestId, HTTP_REQUEST_PROPERTY propertyId, void* input, uint inputSize, IntPtr overlapped); + internal unsafe delegate uint HttpSetRequestPropertyInvoker(SafeHandle requestQueueHandle, ulong requestId, HTTP_REQUEST_PROPERTY propertyId, + void* input, uint inputSize, IntPtr overlapped); // HTTP_PROPERTY_FLAGS.Present (1) internal static HTTP_PROPERTY_FLAGS HTTP_PROPERTY_FLAGS_PRESENT { get; } = new() { _bitfield = 0x00000001 }; // This property is used by HttpListener to pass the version structure to the native layer in API calls. internal static HTTPAPI_VERSION Version { get; } = new () { HttpApiMajorVersion = 2 }; internal static SafeLibraryHandle? HttpApiModule { get; } - internal static HttpGetRequestPropertyInvoker? HttpGetRequestProperty { get; } - internal static HttpSetRequestPropertyInvoker? HttpSetRequestProperty { get; } - [MemberNotNullWhen(true, nameof(HttpSetRequestProperty))] + + private static HttpGetRequestPropertyInvoker? HttpGetRequestInvoker { get; } + private static HttpSetRequestPropertyInvoker? HttpSetRequestInvoker { get; } + + internal static bool HttpGetRequestPropertySupported => HttpGetRequestInvoker is not null; + internal static bool HttpSetRequestPropertySupported => HttpSetRequestInvoker is not null; + + internal static unsafe uint HttpGetRequestProperty(SafeHandle requestQueueHandle, ulong requestId, HTTP_REQUEST_PROPERTY propertyId, + void* qualifier, uint qualifierSize, void* output, uint outputSize, IntPtr bytesReturned, IntPtr overlapped) + { + return HttpGetRequestInvoker!(requestQueueHandle, requestId, propertyId, qualifier, qualifierSize, output, outputSize, bytesReturned, overlapped); + } + + internal static unsafe uint HttpSetRequestProperty(SafeHandle requestQueueHandle, ulong requestId, HTTP_REQUEST_PROPERTY propertyId, + void* input, uint inputSize, IntPtr overlapped) + { + return HttpSetRequestInvoker!(requestQueueHandle, requestId, propertyId, input, inputSize, overlapped); + } + + [MemberNotNullWhen(true, nameof(HttpSetRequestInvoker))] internal static bool SupportsTrailers { get; } - [MemberNotNullWhen(true, nameof(HttpSetRequestProperty))] + [MemberNotNullWhen(true, nameof(HttpSetRequestInvoker))] internal static bool SupportsReset { get; } internal static bool SupportsDelegation { get; } + internal static bool SupportsClientHello { get; } internal static bool Supported { get; } static unsafe HttpApi() @@ -61,11 +80,12 @@ static unsafe HttpApi() { Supported = true; HttpApiModule = SafeLibraryHandle.Open(HTTPAPI); - HttpGetRequestProperty = HttpApiModule.GetProcAddress("HttpQueryRequestProperty", throwIfNotFound: false); - HttpSetRequestProperty = HttpApiModule.GetProcAddress("HttpSetRequestProperty", throwIfNotFound: false); - SupportsReset = HttpSetRequestProperty != null; + HttpGetRequestInvoker = HttpApiModule.GetProcAddress("HttpQueryRequestProperty", throwIfNotFound: false); + HttpSetRequestInvoker = HttpApiModule.GetProcAddress("HttpSetRequestProperty", throwIfNotFound: false); + SupportsReset = HttpSetRequestPropertySupported; SupportsTrailers = IsFeatureSupported(HTTP_FEATURE_ID.HttpFeatureResponseTrailers); SupportsDelegation = IsFeatureSupported(HTTP_FEATURE_ID.HttpFeatureDelegateEx); + SupportsClientHello = IsFeatureSupported((HTTP_FEATURE_ID)11 /* HTTP_FEATURE_ID.HttpFeatureCacheTlsClientHello */) && HttpGetRequestPropertySupported; } } diff --git a/src/Servers/HttpSys/src/PublicAPI.Unshipped.txt b/src/Servers/HttpSys/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..e18d576e45d3 100644 --- a/src/Servers/HttpSys/src/PublicAPI.Unshipped.txt +++ b/src/Servers/HttpSys/src/PublicAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +Microsoft.AspNetCore.Server.HttpSys.HttpSysOptions.TlsClientHelloBytesCallback.get -> System.Action>? +Microsoft.AspNetCore.Server.HttpSys.HttpSysOptions.TlsClientHelloBytesCallback.set -> void diff --git a/src/Servers/HttpSys/src/RequestProcessing/Request.cs b/src/Servers/HttpSys/src/RequestProcessing/Request.cs index 8e4babf7ca21..478a8a587db6 100644 --- a/src/Servers/HttpSys/src/RequestProcessing/Request.cs +++ b/src/Servers/HttpSys/src/RequestProcessing/Request.cs @@ -9,6 +9,7 @@ using System.Security.Cryptography.X509Certificates; using System.Security.Principal; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.HttpSys.Internal; using Microsoft.AspNetCore.Shared; using Microsoft.Extensions.Logging; @@ -369,6 +370,9 @@ private void GetTlsHandshakeResults() SniHostName = sni.Hostname.ToString(); } + internal bool GetAndInvokeTlsClientHelloCallback(IFeatureCollection features, Action> tlsClientHelloBytesCallback) + => RequestContext.GetAndInvokeTlsClientHelloMessageBytesCallback(features, tlsClientHelloBytesCallback); + public X509Certificate2? ClientCertificate { get diff --git a/src/Servers/HttpSys/src/RequestProcessing/RequestContext.Log.cs b/src/Servers/HttpSys/src/RequestProcessing/RequestContext.Log.cs index 41b1cf480d5a..d7766698bc41 100644 --- a/src/Servers/HttpSys/src/RequestProcessing/RequestContext.Log.cs +++ b/src/Servers/HttpSys/src/RequestProcessing/RequestContext.Log.cs @@ -20,5 +20,8 @@ private static partial class Log [LoggerMessage(LoggerEventIds.RequestParsingError, LogLevel.Debug, "Failed to parse request.", EventName = "RequestParsingError")] public static partial void RequestParsingError(ILogger logger, Exception exception); + + [LoggerMessage(LoggerEventIds.RequestParsingError, LogLevel.Debug, "Failed to invoke QueryTlsClientHello; RequestId: {RequestId}; Win32 Error code: {Win32Error}", EventName = "TlsClientHelloRetrieveError")] + public static partial void TlsClientHelloRetrieveError(ILogger logger, ulong requestId, uint win32Error); } } diff --git a/src/Servers/HttpSys/src/RequestProcessing/RequestContext.cs b/src/Servers/HttpSys/src/RequestProcessing/RequestContext.cs index cd39566c5266..9e34d23f8584 100644 --- a/src/Servers/HttpSys/src/RequestProcessing/RequestContext.cs +++ b/src/Servers/HttpSys/src/RequestProcessing/RequestContext.cs @@ -1,9 +1,11 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Buffers; using System.Runtime.InteropServices; using System.Security.Principal; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.HttpSys.Internal; using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.Logging; @@ -219,29 +221,117 @@ internal void ForceCancelRequest() } } - internal unsafe HTTP_REQUEST_PROPERTY_SNI GetClientSni() + /// + /// Attempts to get the client hello message bytes from the http.sys. + /// If not successful, will return false. + /// + internal unsafe bool GetAndInvokeTlsClientHelloMessageBytesCallback(IFeatureCollection features, Action> tlsClientHelloBytesCallback) { - if (HttpApi.HttpGetRequestProperty != null) + if (!HttpApi.SupportsClientHello) + { + // not supported, so we just return and don't invoke the callback + return false; + } + + uint bytesReturnedValue = 0; + uint* bytesReturned = &bytesReturnedValue; + uint statusCode; + + var requestId = PinsReleased ? Request.RequestId : RequestId; + + // we will try with some "random" buffer size + var buffer = ArrayPool.Shared.Rent(512); + try { - var buffer = new byte[HttpApiTypes.SniPropertySizeInBytes]; fixed (byte* pBuffer = buffer) { - var statusCode = HttpApi.HttpGetRequestProperty( - Server.RequestQueue.Handle, - RequestId, - HTTP_REQUEST_PROPERTY.HttpRequestPropertySni, + statusCode = HttpApi.HttpGetRequestProperty( + requestQueueHandle: Server.RequestQueue.Handle, + requestId, + propertyId: (HTTP_REQUEST_PROPERTY)11 /* HTTP_REQUEST_PROPERTY.HttpRequestPropertyTlsClientHello */, qualifier: null, qualifierSize: 0, - (void*)pBuffer, - (uint)buffer.Length, - bytesReturned: null, - IntPtr.Zero); + output: pBuffer, + outputSize: (uint)buffer.Length, + bytesReturned: (IntPtr)bytesReturned, + overlapped: IntPtr.Zero); + + if (statusCode is ErrorCodes.ERROR_SUCCESS) + { + tlsClientHelloBytesCallback(features, buffer.AsSpan(0, (int)bytesReturnedValue)); + return true; + } + } + } + finally + { + ArrayPool.Shared.Return(buffer, clearArray: true); + } + + // if buffer supplied is too small, `bytesReturned` will have proper size + // so retry should succeed with the properly allocated buffer + if (statusCode is ErrorCodes.ERROR_MORE_DATA or ErrorCodes.ERROR_INSUFFICIENT_BUFFER) + { + try + { + var correctSize = (int)bytesReturnedValue; + buffer = ArrayPool.Shared.Rent(correctSize); - if (statusCode == ErrorCodes.ERROR_SUCCESS) + fixed (byte* pBuffer = buffer) { - return Marshal.PtrToStructure((IntPtr)pBuffer); + statusCode = HttpApi.HttpGetRequestProperty( + requestQueueHandle: Server.RequestQueue.Handle, + requestId, + propertyId: (HTTP_REQUEST_PROPERTY)11 /* HTTP_REQUEST_PROPERTY.HttpRequestPropertyTlsClientHello */, + qualifier: null, + qualifierSize: 0, + output: pBuffer, + outputSize: (uint)buffer.Length, + bytesReturned: (IntPtr)bytesReturned, + overlapped: IntPtr.Zero); + + if (statusCode is ErrorCodes.ERROR_SUCCESS) + { + tlsClientHelloBytesCallback(features, buffer.AsSpan(0, correctSize)); + return true; + } } } + finally + { + ArrayPool.Shared.Return(buffer, clearArray: true); + } + } + + Log.TlsClientHelloRetrieveError(Logger, requestId, statusCode); + return false; + } + + internal unsafe HTTP_REQUEST_PROPERTY_SNI GetClientSni() + { + if (!HttpApi.HttpGetRequestPropertySupported) + { + return default; + } + + var buffer = new byte[HttpApiTypes.SniPropertySizeInBytes]; + fixed (byte* pBuffer = buffer) + { + var statusCode = HttpApi.HttpGetRequestProperty( + Server.RequestQueue.Handle, + RequestId, + HTTP_REQUEST_PROPERTY.HttpRequestPropertySni, + qualifier: null, + qualifierSize: 0, + pBuffer, + (uint)buffer.Length, + bytesReturned: IntPtr.Zero, + IntPtr.Zero); + + if (statusCode == ErrorCodes.ERROR_SUCCESS) + { + return Marshal.PtrToStructure((IntPtr)pBuffer); + } } return default; diff --git a/src/Servers/HttpSys/src/RequestProcessing/RequestContextOfT.cs b/src/Servers/HttpSys/src/RequestProcessing/RequestContextOfT.cs index 2a1d06a06d26..399f1292d60d 100644 --- a/src/Servers/HttpSys/src/RequestProcessing/RequestContextOfT.cs +++ b/src/Servers/HttpSys/src/RequestProcessing/RequestContextOfT.cs @@ -48,6 +48,12 @@ public override async Task ExecuteAsync() context = application.CreateContext(Features); try { + if (Server.Options.TlsClientHelloBytesCallback is not null && Server.TlsListener is not null + && Request.IsHttps) + { + Server.TlsListener.InvokeTlsClientHelloCallback(Request.RawConnectionId, Features, Request.GetAndInvokeTlsClientHelloCallback); + } + await application.ProcessRequestAsync(context); await CompleteAsync(); } diff --git a/src/Servers/HttpSys/src/RequestProcessing/TlsListener.Log.cs b/src/Servers/HttpSys/src/RequestProcessing/TlsListener.Log.cs new file mode 100644 index 000000000000..20ffe5c74b6f --- /dev/null +++ b/src/Servers/HttpSys/src/RequestProcessing/TlsListener.Log.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.HttpSys.RequestProcessing; + +internal sealed partial class TlsListener : IDisposable +{ + private static partial class Log + { + [LoggerMessage(LoggerEventIds.TlsListenerError, LogLevel.Error, "Error during closed connection cleanup.", EventName = "TlsListenerCleanupClosedConnectionError")] + public static partial void CleanupClosedConnectionError(ILogger logger, Exception exception); + } +} diff --git a/src/Servers/HttpSys/src/RequestProcessing/TlsListener.cs b/src/Servers/HttpSys/src/RequestProcessing/TlsListener.cs new file mode 100644 index 000000000000..8e7edb9bb47d --- /dev/null +++ b/src/Servers/HttpSys/src/RequestProcessing/TlsListener.cs @@ -0,0 +1,143 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Concurrent; +using System.Collections.ObjectModel; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Server.HttpSys.RequestProcessing; + +internal sealed partial class TlsListener : IDisposable +{ + private readonly ConcurrentDictionary _connectionTimestamps = new(); + private readonly Action> _tlsClientHelloBytesCallback; + private readonly ILogger _logger; + + private readonly PeriodicTimer _cleanupTimer; + private readonly Task _cleanupTask; + private readonly TimeProvider _timeProvider; + + private readonly TimeSpan ConnectionIdleTime = TimeSpan.FromMinutes(5); + private readonly TimeSpan CleanupDelay = TimeSpan.FromSeconds(10); + internal readonly int CacheSizeLimit = 1_000_000; + + // Internal for testing purposes + internal ReadOnlyDictionary ConnectionTimeStamps => _connectionTimestamps.AsReadOnly(); + + internal TlsListener(ILogger logger, Action> tlsClientHelloBytesCallback, TimeProvider? timeProvider = null) + { + if (AppContext.GetData("Microsoft.AspNetCore.Server.HttpSys.TlsListener.CacheSizeLimit") is int limit) + { + CacheSizeLimit = limit; + } + + if (AppContext.GetData("Microsoft.AspNetCore.Server.HttpSys.TlsListener.ConnectionIdleTime") is int idleTime) + { + ConnectionIdleTime = TimeSpan.FromSeconds(idleTime); + } + + if (AppContext.GetData("Microsoft.AspNetCore.Server.HttpSys.TlsListener.CleanupDelay") is int cleanupDelay) + { + CleanupDelay = TimeSpan.FromSeconds(cleanupDelay); + } + + _logger = logger; + _tlsClientHelloBytesCallback = tlsClientHelloBytesCallback; + + _timeProvider = timeProvider ?? TimeProvider.System; + _cleanupTimer = new PeriodicTimer(CleanupDelay, _timeProvider); + _cleanupTask = CleanupLoopAsync(); + } + + // Method looks weird because we want it to be testable by not directly requiring a Request object + internal void InvokeTlsClientHelloCallback(ulong connectionId, IFeatureCollection features, + Func>, bool> invokeTlsClientHelloCallback) + { + if (!_connectionTimestamps.TryAdd(connectionId, _timeProvider.GetUtcNow())) + { + // update TTL + _connectionTimestamps[connectionId] = _timeProvider.GetUtcNow(); + return; + } + + _ = invokeTlsClientHelloCallback(features, _tlsClientHelloBytesCallback); + } + + internal async Task CleanupLoopAsync() + { + while (await _cleanupTimer.WaitForNextTickAsync()) + { + try + { + var now = _timeProvider.GetUtcNow(); + + // Remove idle connections + foreach (var kvp in _connectionTimestamps) + { + if (now - kvp.Value >= ConnectionIdleTime) + { + _connectionTimestamps.TryRemove(kvp.Key, out _); + } + } + + // Evict oldest items if above CacheSizeLimit + var currentCount = _connectionTimestamps.Count; + if (currentCount > CacheSizeLimit) + { + var excessCount = currentCount - CacheSizeLimit; + + // Find the oldest items in a single pass + var oldestTimestamps = new SortedSet>(TimeComparer.Instance); + + foreach (var kvp in _connectionTimestamps) + { + if (oldestTimestamps.Count < excessCount) + { + oldestTimestamps.Add(new KeyValuePair(kvp.Key, kvp.Value)); + } + else if (kvp.Value < oldestTimestamps.Max.Value) + { + oldestTimestamps.Remove(oldestTimestamps.Max); + oldestTimestamps.Add(new KeyValuePair(kvp.Key, kvp.Value)); + } + } + + // Remove the oldest keys + foreach (var item in oldestTimestamps) + { + _connectionTimestamps.TryRemove(item.Key, out _); + } + } + } + catch (Exception ex) + { + Log.CleanupClosedConnectionError(_logger, ex); + } + } + } + + public void Dispose() + { + _cleanupTimer.Dispose(); + _cleanupTask.Wait(); + } + + private sealed class TimeComparer : IComparer> + { + public static TimeComparer Instance { get; } = new TimeComparer(); + + public int Compare(KeyValuePair x, KeyValuePair y) + { + // Compare timestamps first + int timestampComparison = x.Value.CompareTo(y.Value); + if (timestampComparison != 0) + { + return timestampComparison; + } + + // Use the key as a tiebreaker to ensure uniqueness + return x.Key.CompareTo(y.Key); + } + } +} diff --git a/src/Servers/HttpSys/test/FunctionalTests/Microsoft.AspNetCore.Server.HttpSys.FunctionalTests.csproj b/src/Servers/HttpSys/test/FunctionalTests/Microsoft.AspNetCore.Server.HttpSys.FunctionalTests.csproj index 08276e6a23fd..56f300b89198 100644 --- a/src/Servers/HttpSys/test/FunctionalTests/Microsoft.AspNetCore.Server.HttpSys.FunctionalTests.csproj +++ b/src/Servers/HttpSys/test/FunctionalTests/Microsoft.AspNetCore.Server.HttpSys.FunctionalTests.csproj @@ -32,6 +32,7 @@ + diff --git a/src/Servers/HttpSys/test/FunctionalTests/TlsListenerTests.cs b/src/Servers/HttpSys/test/FunctionalTests/TlsListenerTests.cs new file mode 100644 index 000000000000..d0ff2731a017 --- /dev/null +++ b/src/Servers/HttpSys/test/FunctionalTests/TlsListenerTests.cs @@ -0,0 +1,140 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.HttpSys.RequestProcessing; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Time.Testing; +using Moq; + +namespace Microsoft.AspNetCore.Server.HttpSys.FunctionalTests; + +public class TlsListenerTests +{ + [Fact] + public void AddsAndUpdatesConnectionTimestamps() + { + // Arrange + var logger = Mock.Of(); + var timeProvider = new FakeTimeProvider(); + var callbackInvoked = false; + var tlsListener = new TlsListener(logger, (_, __) => { callbackInvoked = true; }, timeProvider); + + var features = Mock.Of(); + + // Act + tlsListener.InvokeTlsClientHelloCallback(connectionId: 1UL, features, + invokeTlsClientHelloCallback: (f, cb) => { cb(f, ReadOnlySpan.Empty); return true; }); + + var originalTime = timeProvider.GetUtcNow(); + + // Assert + Assert.True(callbackInvoked); + Assert.Equal(originalTime, Assert.Single(tlsListener.ConnectionTimeStamps).Value); + + timeProvider.Advance(TimeSpan.FromSeconds(1)); + callbackInvoked = false; + // Update the timestamp + tlsListener.InvokeTlsClientHelloCallback(connectionId: 1UL, features, + invokeTlsClientHelloCallback: (f, cb) => { cb(f, ReadOnlySpan.Empty); return true; }); + + // Callback should not be invoked again and the timestamp should be updated + Assert.False(callbackInvoked); + Assert.Equal(timeProvider.GetUtcNow(), Assert.Single(tlsListener.ConnectionTimeStamps).Value); + Assert.NotEqual(originalTime, timeProvider.GetUtcNow()); + } + + [Fact] + public async Task RemovesIdleConnections() + { + // Arrange + var logger = Mock.Of(); + var timeProvider = new FakeTimeProvider(); + using var tlsListener = new TlsListener(logger, (_, __) => { }, timeProvider); + + var features = Mock.Of(); + + bool InvokeCallback(IFeatureCollection f, Action> cb) + { + cb(f, ReadOnlySpan.Empty); + return true; + } + + // Act + tlsListener.InvokeTlsClientHelloCallback(connectionId: 1UL, features, InvokeCallback); + + // 1 less minute than the idle time cleanup + timeProvider.Advance(TimeSpan.FromMinutes(4)); + Assert.Single(tlsListener.ConnectionTimeStamps); + + tlsListener.InvokeTlsClientHelloCallback(connectionId: 2UL, features, InvokeCallback); + Assert.Equal(2, tlsListener.ConnectionTimeStamps.Count); + + // With the previous 4 minutes, this should be 5 minutes and remove the first connection + timeProvider.Advance(TimeSpan.FromMinutes(1)); + + var timeout = TimeSpan.FromSeconds(5); + while (timeout > TimeSpan.Zero) + { + // Wait for the cleanup loop to run + if (tlsListener.ConnectionTimeStamps.Count == 1) + { + break; + } + timeout -= TimeSpan.FromMilliseconds(100); + await Task.Delay(100); + } + + // Assert + Assert.Single(tlsListener.ConnectionTimeStamps); + Assert.Contains(2UL, tlsListener.ConnectionTimeStamps.Keys); + } + + [Fact] + public async Task EvictsOldestConnectionsWhenExceedingCacheSizeLimit() + { + // Arrange + var logger = Mock.Of(); + var timeProvider = new FakeTimeProvider(); + var tlsListener = new TlsListener(logger, (_, __) => { }, timeProvider); + var features = Mock.Of(); + + ulong i = 0; + for (; i < (ulong)tlsListener.CacheSizeLimit; i++) + { + tlsListener.InvokeTlsClientHelloCallback(i, features, (f, cb) => { cb(f, ReadOnlySpan.Empty); return true; }); + } + + timeProvider.Advance(TimeSpan.FromSeconds(5)); + + for (; i < (ulong)tlsListener.CacheSizeLimit + 3; i++) + { + tlsListener.InvokeTlsClientHelloCallback(i, features, (f, cb) => { cb(f, ReadOnlySpan.Empty); return true; }); + } + + // 'touch' first connection to update its timestamp + tlsListener.InvokeTlsClientHelloCallback(0, features, (f, cb) => { cb(f, ReadOnlySpan.Empty); return true; }); + + // Make sure the cleanup loop has run to evict items since we're above the cache size limit + timeProvider.Advance(TimeSpan.FromMinutes(1)); + + var timeout = TimeSpan.FromSeconds(5); + while (timeout > TimeSpan.Zero) + { + // Wait for the cleanup loop to run + if (tlsListener.ConnectionTimeStamps.Count == tlsListener.CacheSizeLimit) + { + break; + } + timeout -= TimeSpan.FromMilliseconds(100); + await Task.Delay(100); + } + + Assert.Equal(tlsListener.CacheSizeLimit, tlsListener.ConnectionTimeStamps.Count); + Assert.Contains(0UL, tlsListener.ConnectionTimeStamps.Keys); + // 3 newest connections should be present + Assert.Contains(i - 1, tlsListener.ConnectionTimeStamps.Keys); + Assert.Contains(i - 2, tlsListener.ConnectionTimeStamps.Keys); + Assert.Contains(i - 3, tlsListener.ConnectionTimeStamps.Keys); + } +} diff --git a/src/Shared/HttpSys/RequestProcessing/NativeRequestContext.cs b/src/Shared/HttpSys/RequestProcessing/NativeRequestContext.cs index e4a6622b5271..a9a25b8f0092 100644 --- a/src/Shared/HttpSys/RequestProcessing/NativeRequestContext.cs +++ b/src/Shared/HttpSys/RequestProcessing/NativeRequestContext.cs @@ -34,9 +34,11 @@ internal unsafe class NativeRequestContext : IDisposable private MemoryHandle _memoryHandle; private readonly int _bufferAlignment; private readonly bool _permanentlyPinned; - private bool _disposed; private IReadOnlyDictionary>? _requestInfo; + private bool _disposed; + private bool _pinsReleased; + [MemberNotNullWhen(false, nameof(_backingBuffer))] private bool PermanentlyPinned => _permanentlyPinned; @@ -171,6 +173,11 @@ internal uint Size } } + /// + /// Shows whether was already invoked on this native request context + /// + internal bool PinsReleased => _pinsReleased; + // ReleasePins() should be called exactly once. It must be called before Dispose() is called, which means it must be called // before an object (Request) which closes the RequestContext on demand is returned to the application. internal void ReleasePins() @@ -180,6 +187,7 @@ internal void ReleasePins() _memoryHandle.Dispose(); _memoryHandle = default; _nativeRequest = null; + _pinsReleased = true; } public bool TryGetTimestamp(HttpSysRequestTimingType timestampType, out long timestamp)