diff --git a/src/libraries/Common/src/Interop/Windows/WinHttp/Interop.winhttp_types.cs b/src/libraries/Common/src/Interop/Windows/WinHttp/Interop.winhttp_types.cs index e76fa9b67e5bcb..42b58b7df198d7 100644 --- a/src/libraries/Common/src/Interop/Windows/WinHttp/Interop.winhttp_types.cs +++ b/src/libraries/Common/src/Interop/Windows/WinHttp/Interop.winhttp_types.cs @@ -336,6 +336,16 @@ public struct WINHTTP_ASYNC_RESULT public uint dwError; } + [StructLayout(LayoutKind.Sequential)] + public unsafe struct WINHTTP_CONNECTION_INFO + { + // This field is actually 4 bytes, but we use nuint to avoid alignment issues for x64. + // If we want to read this field in the future, we need to change type and make sure + // alignment is correct for necessary archs. + public nuint cbSize; + public fixed byte LocalAddress[128]; + public fixed byte RemoteAddress[128]; + } [StructLayout(LayoutKind.Sequential)] public struct tcp_keepalive diff --git a/src/libraries/System.Net.Http.WinHttpHandler/src/System.Net.Http.WinHttpHandler.csproj b/src/libraries/System.Net.Http.WinHttpHandler/src/System.Net.Http.WinHttpHandler.csproj index 9e4f19ec066f65..4af49d748e9b28 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/src/System.Net.Http.WinHttpHandler.csproj +++ b/src/libraries/System.Net.Http.WinHttpHandler/src/System.Net.Http.WinHttpHandler.csproj @@ -80,6 +80,7 @@ System.Net.Http.WinHttpHandler Link="Common\System\Runtime\ExceptionServices\ExceptionStackTrace.cs" /> + @@ -117,6 +118,7 @@ System.Net.Http.WinHttpHandler + diff --git a/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/CachedCertificateValue.cs b/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/CachedCertificateValue.cs new file mode 100644 index 00000000000000..58ca950a8855a1 --- /dev/null +++ b/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/CachedCertificateValue.cs @@ -0,0 +1,45 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using System.Threading; + +namespace System.Net.Http +{ + internal sealed class CachedCertificateValue(byte[] rawCertificateData, long lastUsedTime) + { + private long _lastUsedTime = lastUsedTime; + public byte[] RawCertificateData { get; } = rawCertificateData; + public long LastUsedTime + { + get => Volatile.Read(ref _lastUsedTime); + set => Volatile.Write(ref _lastUsedTime, value); + } + } + + internal readonly struct CachedCertificateKey : IEquatable + { + public CachedCertificateKey(IPAddress address, HttpRequestMessage message) + { + Debug.Assert(message.RequestUri != null); + Address = address; + Host = message.Headers.Host ?? message.RequestUri.Host; + } + public IPAddress Address { get; } + public string Host { get; } + + public bool Equals(CachedCertificateKey other) => + Address.Equals(other.Address) && + Host == other.Host; + + public override bool Equals(object? obj) + { + throw new Exception("Unreachable"); + } + + public override int GetHashCode() => HashCode.Combine(Address, Host); + } +} diff --git a/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpHandler.cs b/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpHandler.cs index 86c893169270b2..aa797ecf4ca5ea 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpHandler.cs +++ b/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpHandler.cs @@ -1,8 +1,10 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.IO; using System.Net.Http.Headers; using System.Net.Security; @@ -41,11 +43,14 @@ public class WinHttpHandler : HttpMessageHandler internal static readonly Version HttpVersion20 = new Version(2, 0); internal static readonly Version HttpVersion30 = new Version(3, 0); internal static readonly Version HttpVersionUnknown = new Version(0, 0); + internal static bool CertificateCachingAppContextSwitchEnabled { get; } = AppContext.TryGetSwitch("System.Net.Http.UseWinHttpCertificateCaching", out bool enabled) && enabled; private static readonly TimeSpan s_maxTimeout = TimeSpan.FromMilliseconds(int.MaxValue); private static readonly StringWithQualityHeaderValue s_gzipHeaderValue = new StringWithQualityHeaderValue("gzip"); private static readonly StringWithQualityHeaderValue s_deflateHeaderValue = new StringWithQualityHeaderValue("deflate"); private static readonly Lazy s_supportsTls13 = new Lazy(CheckTls13Support); + private static readonly TimeSpan s_cleanCachedCertificateTimeout = TimeSpan.FromMilliseconds((int?)AppDomain.CurrentDomain.GetData("System.Net.Http.WinHttpCertificateCachingCleanupTimerInterval") ?? 60_000); + private static readonly long s_staleTimeout = (long)(s_cleanCachedCertificateTimeout.TotalSeconds * Stopwatch.Frequency); [ThreadStatic] private static StringBuilder? t_requestHeadersBuilder; @@ -93,9 +98,44 @@ private Func< private volatile bool _disposed; private SafeWinHttpHandle? _sessionHandle; private readonly WinHttpAuthHelper _authHelper = new WinHttpAuthHelper(); + private readonly Timer? _certificateCleanupTimer; + private bool _isTimerRunning; + private readonly ConcurrentDictionary _cachedCertificates = new(); public WinHttpHandler() { + if (CertificateCachingAppContextSwitchEnabled) + { + WeakReference thisRef = new(this); + bool restoreFlow = false; + try + { + if (!ExecutionContext.IsFlowSuppressed()) + { + ExecutionContext.SuppressFlow(); + restoreFlow = true; + } + + _certificateCleanupTimer = new Timer( + static s => + { + if (((WeakReference)s!).TryGetTarget(out WinHttpHandler? thisRef)) + { + thisRef.ClearStaleCertificates(); + } + }, + thisRef, + Timeout.Infinite, + Timeout.Infinite); + } + finally + { + if (restoreFlow) + { + ExecutionContext.RestoreFlow(); + } + } + } } #region Properties @@ -543,9 +583,12 @@ protected override void Dispose(bool disposing) { _disposed = true; - if (disposing && _sessionHandle != null) + if (disposing) { - SafeWinHttpHandle.DisposeAndClearHandle(ref _sessionHandle); + if (_sessionHandle is not null) { + SafeWinHttpHandle.DisposeAndClearHandle(ref _sessionHandle); + } + _certificateCleanupTimer?.Dispose(); } } @@ -1644,7 +1687,8 @@ private void SetStatusCallback( Interop.WinHttp.WINHTTP_CALLBACK_FLAG_ALL_COMPLETIONS | Interop.WinHttp.WINHTTP_CALLBACK_FLAG_HANDLES | Interop.WinHttp.WINHTTP_CALLBACK_FLAG_REDIRECT | - Interop.WinHttp.WINHTTP_CALLBACK_FLAG_SEND_REQUEST; + Interop.WinHttp.WINHTTP_CALLBACK_FLAG_SEND_REQUEST | + Interop.WinHttp.WINHTTP_CALLBACK_STATUS_CONNECTED_TO_SERVER; IntPtr oldCallback = Interop.WinHttp.WinHttpSetStatusCallback( requestHandle, @@ -1730,5 +1774,90 @@ private RendezvousAwaitable InternalReceiveResponseHeadersAsync(WinHttpRequ return state.LifecycleAwaitable; } + + internal bool GetCertificateFromCache(CachedCertificateKey key, [NotNullWhen(true)] out byte[]? rawCertificateBytes) + { + if (_cachedCertificates.TryGetValue(key, out CachedCertificateValue? cachedValue)) + { + cachedValue.LastUsedTime = Stopwatch.GetTimestamp(); + rawCertificateBytes = cachedValue.RawCertificateData; + return true; + } + + rawCertificateBytes = null; + return false; + } + + internal void AddCertificateToCache(CachedCertificateKey key, byte[] rawCertificateData) + { + if (_cachedCertificates.TryAdd(key, new CachedCertificateValue(rawCertificateData, Stopwatch.GetTimestamp()))) + { + EnsureCleanupTimerRunning(); + } + } + + internal bool TryRemoveCertificateFromCache(CachedCertificateKey key) + { + bool result = _cachedCertificates.TryRemove(key, out _); + if (result) + { + StopCleanupTimerIfEmpty(); + } + return result; + } + + private void ChangeCleanerTimer(TimeSpan timeout) + { + Debug.Assert(Monitor.IsEntered(_lockObject)); + Debug.Assert(_certificateCleanupTimer != null); + if (_certificateCleanupTimer!.Change(timeout, Timeout.InfiniteTimeSpan)) + { + _isTimerRunning = timeout != Timeout.InfiniteTimeSpan; + } + } + + private void ClearStaleCertificates() + { + foreach (KeyValuePair kvPair in _cachedCertificates) + { + if (IsStale(kvPair.Value.LastUsedTime)) + { + _cachedCertificates.TryRemove(kvPair.Key, out _); + } + } + + lock (_lockObject) + { + ChangeCleanerTimer(_cachedCertificates.IsEmpty ? Timeout.InfiniteTimeSpan : s_cleanCachedCertificateTimeout); + } + + static bool IsStale(long lastUsedTime) + { + long now = Stopwatch.GetTimestamp(); + return (now - lastUsedTime) > s_staleTimeout; + } + } + + private void EnsureCleanupTimerRunning() + { + lock (_lockObject) + { + if (!_cachedCertificates.IsEmpty && !_isTimerRunning) + { + ChangeCleanerTimer(s_cleanCachedCertificateTimeout); + } + } + } + + private void StopCleanupTimerIfEmpty() + { + lock (_lockObject) + { + if (_cachedCertificates.IsEmpty && _isTimerRunning) + { + ChangeCleanerTimer(Timeout.InfiniteTimeSpan); + } + } + } } } diff --git a/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpRequestCallback.cs b/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpRequestCallback.cs index c30694a20460b1..6c50ee16817cf7 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpRequestCallback.cs +++ b/src/libraries/System.Net.Http.WinHttpHandler/src/System/Net/Http/WinHttpRequestCallback.cs @@ -2,12 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Buffers.Binary; using System.Diagnostics; using System.IO; +using System.Linq; using System.Net.Security; +using System.Net.Sockets; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Security.Cryptography.X509Certificates; - +using System.Threading; using SafeWinHttpHandle = Interop.WinHttp.SafeWinHttpHandle; namespace System.Net.Http @@ -56,6 +60,14 @@ private static void RequestCallback( { switch (internetStatus) { + case Interop.WinHttp.WINHTTP_CALLBACK_STATUS_CONNECTED_TO_SERVER: + if (WinHttpHandler.CertificateCachingAppContextSwitchEnabled) + { + IPAddress connectedToIPAddress = IPAddress.Parse(Marshal.PtrToStringUni(statusInformation)!); + OnRequestConnectedToServer(state, connectedToIPAddress); + } + return; + case Interop.WinHttp.WINHTTP_CALLBACK_STATUS_HANDLE_CLOSING: OnRequestHandleClosing(state); return; @@ -121,6 +133,22 @@ private static void RequestCallback( } } + private static void OnRequestConnectedToServer(WinHttpRequestState state, IPAddress connectedIPAddress) + { + Debug.Assert(state != null); + Debug.Assert(state.Handler != null); + Debug.Assert(state.RequestMessage != null); + + if (state.Handler.TryRemoveCertificateFromCache(new CachedCertificateKey(connectedIPAddress, state.RequestMessage))) + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"Removed cached certificate for {connectedIPAddress}"); + } + else + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"No cached certificate for {connectedIPAddress} to remove"); + } + } + private static void OnRequestHandleClosing(WinHttpRequestState state) { Debug.Assert(state != null, "OnRequestSendRequestComplete: state is null"); @@ -231,6 +259,7 @@ private static void OnRequestRedirect(WinHttpRequestState state, Uri redirectUri private static void OnRequestSendingRequest(WinHttpRequestState state) { Debug.Assert(state != null, "OnRequestSendingRequest: state is null"); + Debug.Assert(state.Handler != null, "OnRequestSendingRequest: state.Handler is null"); Debug.Assert(state.RequestMessage != null, "OnRequestSendingRequest: state.RequestMessage is null"); Debug.Assert(state.RequestMessage.RequestUri != null, "OnRequestSendingRequest: state.RequestMessage.RequestUri is null"); @@ -279,6 +308,62 @@ private static void OnRequestSendingRequest(WinHttpRequestState state) var serverCertificate = new X509Certificate2(certHandle); Interop.Crypt32.CertFreeCertificateContext(certHandle); + IPAddress? ipAddress = null; + if (WinHttpHandler.CertificateCachingAppContextSwitchEnabled) + { + unsafe + { + Interop.WinHttp.WINHTTP_CONNECTION_INFO connectionInfo; + Interop.WinHttp.WINHTTP_CONNECTION_INFO* pConnectionInfo = &connectionInfo; + uint infoSize = (uint)sizeof(Interop.WinHttp.WINHTTP_CONNECTION_INFO); + if (Interop.WinHttp.WinHttpQueryOption( + state.RequestHandle, + // This option is available on Windows XP SP2 and later; Windows 2003 with SP1 and later. + Interop.WinHttp.WINHTTP_OPTION_CONNECTION_INFO, + (IntPtr)pConnectionInfo, + ref infoSize)) + { + // RemoteAddress is SOCKADDR_STORAGE structure, which is 128 bytes. + // See: https://learn.microsoft.com/en-us/windows/win32/api/winhttp/ns-winhttp-winhttp_connection_info + // SOCKADDR_STORAGE can hold either IPv4 or IPv6 address. + // For offset numbers: https://learn.microsoft.com/en-us/windows/win32/winsock/sockaddr-2 + ReadOnlySpan remoteAddressSpan = new ReadOnlySpan(connectionInfo.RemoteAddress, 128); + AddressFamily addressFamily = (AddressFamily)(remoteAddressSpan[0] + (remoteAddressSpan[1] << 8)); + ipAddress = addressFamily switch + { + AddressFamily.InterNetwork => new IPAddress(BinaryPrimitives.ReadUInt32LittleEndian(remoteAddressSpan.Slice(4))), + AddressFamily.InterNetworkV6 => new IPAddress(remoteAddressSpan.Slice(8, 16).ToArray()), + _ => null + }; + Debug.Assert(ipAddress != null, "AddressFamily is not supported"); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"ipAddress: {ipAddress}"); + + } + else + { + int lastError = Marshal.GetLastWin32Error(); + if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(state, $"Error getting WINHTTP_OPTION_CONNECTION_INFO, {lastError}"); + } + } + + if (ipAddress is not null && + state.Handler.GetCertificateFromCache(new CachedCertificateKey(ipAddress, state.RequestMessage), out byte[]? rawCertData) && +#if NETFRAMEWORK + rawCertData.AsSpan().SequenceEqual(serverCertificate.RawData)) +#else + rawCertData.AsSpan().SequenceEqual(serverCertificate.RawDataMemory.Span)) +#endif + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"Skipping certificate validation. ipAddress: {ipAddress}, Thumbprint: {serverCertificate.Thumbprint}"); + serverCertificate.Dispose(); + return; + } + else + { + if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(state, $"Certificate validation is required! IPAddress = {ipAddress}, Thumbprint: {serverCertificate.Thumbprint}"); + } + } + X509Chain? chain = null; SslPolicyErrors sslPolicyErrors; bool result = false; @@ -298,6 +383,10 @@ private static void OnRequestSendingRequest(WinHttpRequestState state) serverCertificate, chain, sslPolicyErrors); + if (WinHttpHandler.CertificateCachingAppContextSwitchEnabled && result && ipAddress is not null) + { + state.Handler.AddCertificateToCache(new CachedCertificateKey(ipAddress, state.RequestMessage), serverCertificate.RawData); + } } catch (Exception ex) { diff --git a/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/WinHttpHandlerTest.cs b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/WinHttpHandlerTest.cs index 0abe14c11887bf..08d3d560c7b4d4 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/WinHttpHandlerTest.cs +++ b/src/libraries/System.Net.Http.WinHttpHandler/tests/FunctionalTests/WinHttpHandlerTest.cs @@ -9,7 +9,7 @@ using System.Text; using System.Threading; using System.Threading.Tasks; - +using Microsoft.DotNet.RemoteExecutor; using Xunit; using Xunit.Abstractions; @@ -27,6 +27,8 @@ public class WinHttpHandlerTest private readonly ITestOutputHelper _output; + public static IEnumerable HttpVersions = [[HttpVersion.Version11, Configuration.Http.SecureRemoteEchoServer], [HttpVersion20.Value, Configuration.Http.Http2RemoteEchoServer]]; + public WinHttpHandlerTest(ITestOutputHelper output) { _output = output; @@ -46,6 +48,75 @@ public void SendAsync_SimpleGet_Success() } } + [OuterLoop] + [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + [MemberData(nameof(HttpVersions))] + public async Task SendAsync_ServerCertificateValidationCallback_CalledOnce(Version version, Uri uri) + { + await RemoteExecutor.Invoke(async (version, uri) => + { + AppContext.SetSwitch("System.Net.Http.UseWinHttpCertificateCaching", true); + int callbackCount = 0; + var handler = new WinHttpHandler() + { + ServerCertificateValidationCallback = (_, _, _, _) => + { + Interlocked.Increment(ref callbackCount); + return true; + } + }; + using (var client = new HttpClient(handler)) + { + for (int i = 0; i < 5; i++) + { + var response = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, uri) + { + Version = Version.Parse(version) + }); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + _ = await response.Content.ReadAsStringAsync(); + } + Assert.Equal(1, callbackCount); + } + }, version.ToString(), uri.ToString()).DisposeAsync(); + } + + [OuterLoop] + [ConditionalTheory(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))] + [MemberData(nameof(HttpVersions))] + public async Task SendAsync_ServerCertificateValidationCallbackCertificateTimerTriggered_CalledTwice(Version version, Uri uri) + { + await RemoteExecutor.Invoke(async (version, uri) => + { + const int certificateCacheCleanupInterval = 10; + AppContext.SetSwitch("System.Net.Http.UseWinHttpCertificateCaching", true); + AppDomain.CurrentDomain.SetData("System.Net.Http.WinHttpCertificateCachingCleanupTimerInterval", certificateCacheCleanupInterval); + int callbackCount = 0; + var handler = new WinHttpHandler() + { + ServerCertificateValidationCallback = (_, _, _, _) => + { + Interlocked.Increment(ref callbackCount); + return true; + } + }; + using (var client = new HttpClient(handler)) + { + for (int i = 0; i < 5; i++) + { + var response = await client.SendAsync(new HttpRequestMessage(HttpMethod.Get, uri) + { + Version = Version.Parse(version) + }); + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + _ = await response.Content.ReadAsStringAsync(); + await Task.Delay(TimeSpan.FromMilliseconds(certificateCacheCleanupInterval * 3)); + } + Assert.True(callbackCount > 1); + } + }, version.ToString(), uri.ToString()).DisposeAsync(); + } + [OuterLoop] [Theory] [InlineData(CookieUsePolicy.UseInternalCookieStoreOnly, "cookieName1", "cookieValue1")] diff --git a/src/libraries/System.Net.Http.WinHttpHandler/tests/UnitTests/System.Net.Http.WinHttpHandler.Unit.Tests.csproj b/src/libraries/System.Net.Http.WinHttpHandler/tests/UnitTests/System.Net.Http.WinHttpHandler.Unit.Tests.csproj index f8b72896871da4..68ca61edeb5817 100644 --- a/src/libraries/System.Net.Http.WinHttpHandler/tests/UnitTests/System.Net.Http.WinHttpHandler.Unit.Tests.csproj +++ b/src/libraries/System.Net.Http.WinHttpHandler/tests/UnitTests/System.Net.Http.WinHttpHandler.Unit.Tests.csproj @@ -56,6 +56,8 @@ Link="Common\System\Text\SimpleRegex.cs" /> +