diff --git a/src/Common/src/System/Net/Http/HttpHandlerDefaults.cs b/src/Common/src/System/Net/Http/HttpHandlerDefaults.cs index b32f4d7d81d9..9f0f6b72b911 100644 --- a/src/Common/src/System/Net/Http/HttpHandlerDefaults.cs +++ b/src/Common/src/System/Net/Http/HttpHandlerDefaults.cs @@ -2,7 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Diagnostics; +using System.IO; using System.Threading; +using System.Threading.Tasks; namespace System.Net.Http { diff --git a/src/System.Net.Http/ref/System.Net.Http.cs b/src/System.Net.Http/ref/System.Net.Http.cs index 8e4860a69790..378f70de73b1 100644 --- a/src/System.Net.Http/ref/System.Net.Http.cs +++ b/src/System.Net.Http/ref/System.Net.Http.cs @@ -262,6 +262,7 @@ public SocketsHttpHandler() { } public System.Net.IWebProxy Proxy { get { throw null; } set { } } public System.TimeSpan ResponseDrainTimeout { get { throw null; } set { } } public System.Net.Security.SslClientAuthenticationOptions SslOptions { get { throw null; } set { } } + public System.Func> ConnectCallback { get { throw null; } set { } } public bool UseCookies { get { throw null; } set { } } public bool UseProxy { get { throw null; } set { } } protected override void Dispose(bool disposing) { } diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs index 843882ad6e5c..3a9cce0e2272 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionPool.cs @@ -622,11 +622,11 @@ public Task SendAsync(HttpRequestMessage request, bool doRe case HttpConnectionKind.Http: case HttpConnectionKind.Https: case HttpConnectionKind.ProxyConnect: - stream = await ConnectHelper.ConnectAsync(_host, _port, cancellationToken).ConfigureAwait(false); + stream = await Settings._customConnect(_host, _port, cancellationToken).ConfigureAwait(false); break; case HttpConnectionKind.Proxy: - stream = await ConnectHelper.ConnectAsync(_proxyUri.IdnHost, _proxyUri.Port, cancellationToken).ConfigureAwait(false); + stream = await Settings._customConnect(_proxyUri.IdnHost, _proxyUri.Port, cancellationToken).ConfigureAwait(false); break; case HttpConnectionKind.ProxyTunnel: diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionSettings.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionSettings.cs index f6788731ce8c..c8497196b631 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionSettings.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/HttpConnectionSettings.cs @@ -3,7 +3,11 @@ // See the LICENSE file in the project root for more information. using System.Collections.Generic; +using System.Diagnostics; +using System.IO; using System.Net.Security; +using System.Threading; +using System.Threading.Tasks; namespace System.Net.Http { @@ -48,6 +52,8 @@ internal sealed class HttpConnectionSettings internal IDictionary _properties; + internal Func> _customConnect = ConnectHelper.ConnectAsync; + public HttpConnectionSettings() { bool allowHttp2 = AllowHttp2; @@ -88,6 +94,7 @@ public HttpConnectionSettings CloneAndNormalize() _useCookies = _useCookies, _useProxy = _useProxy, _allowUnencryptedHttp2 = _allowUnencryptedHttp2, + _customConnect = _customConnect }; } diff --git a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs index 9de8d147c7bb..d5484b15d460 100644 --- a/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs +++ b/src/System.Net.Http/src/System/Net/Http/SocketsHttpHandler/SocketsHttpHandler.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System.Collections.Generic; +using System.IO; using System.Net.Security; using System.Threading; using System.Threading.Tasks; @@ -32,6 +33,16 @@ private void CheckDisposedOrStarted() } } + public Func> ConnectCallback + { + get => _settings._customConnect; + set + { + CheckDisposedOrStarted(); + _settings._customConnect = value; + } + } + public bool UseCookies { get => _settings._useCookies; diff --git a/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.cs b/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.cs index d10679ea5be2..847299c0a071 100644 --- a/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.cs +++ b/src/System.Net.Http/tests/FunctionalTests/HttpClientHandlerTest.cs @@ -2689,5 +2689,41 @@ public async Task GetAsync_InvalidUrl_ExpectedExceptionThrown() await Assert.ThrowsAsync(() => client.GetStringAsync(invalidUri)); } } + + [Fact] + public async Task ConnectCallback_Success() + { + if (!UseSocketsHttpHandler || UseHttp2) return; + + await LoopbackServer.CreateClientAndServerAsync( + async uri => + { + bool dialerCalled = false; + + Func> dialer = (host, port, token) => + { + dialerCalled = true; + + byte[] buffer = Encoding.ASCII.GetBytes("HTTP/1.1 200 OK\r\nConnection: close\r\nContent-Length: 3\r\nContent-Type: text/plain\r\n\r\nfoo"); + + var stream = new MemoryStream(buffer); + var delegateStream = new DelegateStream(canReadFunc: () => true, canWriteFunc: () => true, readFunc: stream.Read, writeFunc: delegate { }); + return new ValueTask(delegateStream); + }; + + using var handler = new SocketsHttpHandler(); + handler.ConnectCallback = dialer; + + using HttpClient client = CreateHttpClient(handler); + + // If GetStringAsync fails with a connect timeout, the dialer is not getting called -- it's hitting a loopback socket that will never accept. + string result = await client.GetStringAsync(uri); + + Assert.Equal("foo", result); + Assert.True(dialerCalled); + }, + server => Task.CompletedTask, + new LoopbackServer.Options { ListenBacklog = 0 }); + } } }