diff --git a/src/Servers/Kestrel/Core/test/TlsListenerTests.cs b/src/Servers/Kestrel/Core/test/TlsListenerTests.cs index 4bce89de208d..6350465dbe41 100644 --- a/src/Servers/Kestrel/Core/test/TlsListenerTests.cs +++ b/src/Servers/Kestrel/Core/test/TlsListenerTests.cs @@ -70,9 +70,7 @@ public async Task RunTlsClientHelloCallbackTest_WithExtraShortLastingToken() var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(3)); await writer.WriteAsync(new byte[1] { 0x16 }); - await VerifyThrowsAnyAsync( - async () => await listener.OnTlsClientHelloAsync(transportConnection, cts.Token), - typeof(OperationCanceledException), typeof(TaskCanceledException)); + await Assert.ThrowsAnyAsync(() => listener.OnTlsClientHelloAsync(transportConnection, cts.Token)); Assert.False(tlsClientHelloCallbackInvoked); } @@ -95,9 +93,7 @@ public async Task RunTlsClientHelloCallbackTest_WithPreCanceledToken() cts.Cancel(); await writer.WriteAsync(new byte[1] { 0x16 }); - await VerifyThrowsAnyAsync( - async () => await listener.OnTlsClientHelloAsync(transportConnection, cts.Token), - typeof(OperationCanceledException), typeof(TaskCanceledException)); + await Assert.ThrowsAnyAsync(() => listener.OnTlsClientHelloAsync(transportConnection, cts.Token)); Assert.False(tlsClientHelloCallbackInvoked); } @@ -122,7 +118,7 @@ public async Task RunTlsClientHelloCallbackTest_WithPendingCancellation() await writer.WriteAsync(new byte[2] { 0x03, 0x01 }); cts.Cancel(); - await Assert.ThrowsAsync(async () => await listenerTask); + await Assert.ThrowsAnyAsync(() => listenerTask); Assert.False(tlsClientHelloCallbackInvoked); } @@ -158,8 +154,8 @@ public async Task RunTlsClientHelloCallbackTest_DeterministicallyReads() Assert.Equal(5, readResult.Buffer.Length); // ensuring that we have read limited number of times - Assert.True(reader.ReadAsyncCounter is >= 2 && reader.ReadAsyncCounter is <= 4, - $"Expected ReadAsync() to happen about 2-4 times. Actually happened {reader.ReadAsyncCounter} times."); + Assert.True(reader.ReadAsyncCounter is >= 2 && reader.ReadAsyncCounter is <= 5, + $"Expected ReadAsync() to happen about 2-5 times. Actually happened {reader.ReadAsyncCounter} times."); } private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments( @@ -623,28 +619,4 @@ public static IEnumerable InvalidClientHelloData_Segmented() _invalidTlsClientHelloHeader, _invalid3BytesMessage, _invalid9BytesMessage, _invalidUnknownProtocolVersion1, _invalidUnknownProtocolVersion2, _invalidIncorrectHandshakeMessageType }; - - static async Task VerifyThrowsAnyAsync(Func code, params Type[] exceptionTypes) - { - if (exceptionTypes == null || exceptionTypes.Length == 0) - { - throw new ArgumentException("At least one exception type must be provided.", nameof(exceptionTypes)); - } - - try - { - await code(); - } - catch (Exception ex) - { - if (exceptionTypes.Any(type => type.IsInstanceOfType(ex))) - { - return; - } - - throw ThrowsException.ForIncorrectExceptionType(exceptionTypes.First(), ex); - } - - throw ThrowsException.ForNoException(exceptionTypes.First()); - } } diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerTests.cs index f91ea27eae8f..f835e452ff21 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/TlsListenerTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.IO.Pipelines; using System.Net; using System.Net.Security; using System.Security.Authentication; @@ -18,6 +19,8 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; +using Newtonsoft.Json.Linq; +using Xunit.Sdk; namespace InMemory.FunctionalTests; @@ -66,4 +69,40 @@ await sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions Assert.True(tlsClientHelloCallbackInvoked); } + + [Fact] + public async Task TlsClientHelloBytesCallback_UsesOptionsTimeout() + { + var tlsClientHelloCallbackInvoked = false; + var testContext = new TestServiceContext(LoggerFactory); + await using (var server = new TestServer(context => Task.CompletedTask, + testContext, + listenOptions => + { + listenOptions.UseHttps(_x509Certificate2, httpsOptions => + { + httpsOptions.HandshakeTimeout = TimeSpan.FromMilliseconds(1); + + httpsOptions.TlsClientHelloBytesCallback = (connection, clientHelloBytes) => + { + Logger.LogDebug("[Received TlsClientHelloBytesCallback] Connection: {0}; TLS client hello buffer: {1}", connection.ConnectionId, clientHelloBytes.Length); + tlsClientHelloCallbackInvoked = true; + Assert.True(clientHelloBytes.Length > 32); + Assert.NotNull(connection); + }; + }); + })) + { + using (var connection = server.CreateConnection()) + { + await connection.TransportConnection.Input.WriteAsync(new byte[] { 0x16 }); + var readResult = await connection.TransportConnection.Output.ReadAsync(); + + // HttpsConnectionMiddleware catches the exception, so we can only check the effects of the timeout here + Assert.True(readResult.IsCompleted); + } + } + + Assert.False(tlsClientHelloCallbackInvoked); + } }