Skip to content

follow-up: kestrel tls listener callback #62266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 5 additions & 33 deletions src/Servers/Kestrel/Core/test/TlsListenerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ public async Task RunTlsClientHelloCallbackTest_WithExtraShortLastingToken()
var cts = new CancellationTokenSource(TimeSpan.FromMilliseconds(3));

await writer.WriteAsync(new byte[1] { 0x16 });
await VerifyThrowsAnyAsync(
async () => await listener.OnTlsClientHelloAsync(transportConnection, cts.Token),
typeof(OperationCanceledException), typeof(TaskCanceledException));
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => listener.OnTlsClientHelloAsync(transportConnection, cts.Token));
Assert.False(tlsClientHelloCallbackInvoked);
}

Expand All @@ -95,9 +93,7 @@ public async Task RunTlsClientHelloCallbackTest_WithPreCanceledToken()
cts.Cancel();

await writer.WriteAsync(new byte[1] { 0x16 });
await VerifyThrowsAnyAsync(
async () => await listener.OnTlsClientHelloAsync(transportConnection, cts.Token),
typeof(OperationCanceledException), typeof(TaskCanceledException));
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => listener.OnTlsClientHelloAsync(transportConnection, cts.Token));
Assert.False(tlsClientHelloCallbackInvoked);
}

Expand All @@ -122,7 +118,7 @@ public async Task RunTlsClientHelloCallbackTest_WithPendingCancellation()
await writer.WriteAsync(new byte[2] { 0x03, 0x01 });
cts.Cancel();

await Assert.ThrowsAsync<OperationCanceledException>(async () => await listenerTask);
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => listenerTask);
Assert.False(tlsClientHelloCallbackInvoked);
}

Expand Down Expand Up @@ -158,8 +154,8 @@ public async Task RunTlsClientHelloCallbackTest_DeterministicallyReads()
Assert.Equal(5, readResult.Buffer.Length);

// ensuring that we have read limited number of times
Assert.True(reader.ReadAsyncCounter is >= 2 && reader.ReadAsyncCounter is <= 4,
$"Expected ReadAsync() to happen about 2-4 times. Actually happened {reader.ReadAsyncCounter} times.");
Assert.True(reader.ReadAsyncCounter is >= 2 && reader.ReadAsyncCounter is <= 5,
$"Expected ReadAsync() to happen about 2-5 times. Actually happened {reader.ReadAsyncCounter} times.");
}

private async Task RunTlsClientHelloCallbackTest_WithMultipleSegments(
Expand Down Expand Up @@ -623,28 +619,4 @@ public static IEnumerable<object[]> InvalidClientHelloData_Segmented()
_invalidTlsClientHelloHeader, _invalid3BytesMessage, _invalid9BytesMessage,
_invalidUnknownProtocolVersion1, _invalidUnknownProtocolVersion2, _invalidIncorrectHandshakeMessageType
};

static async Task VerifyThrowsAnyAsync(Func<Task> code, params Type[] exceptionTypes)
{
if (exceptionTypes == null || exceptionTypes.Length == 0)
{
throw new ArgumentException("At least one exception type must be provided.", nameof(exceptionTypes));
}

try
{
await code();
}
catch (Exception ex)
{
if (exceptionTypes.Any(type => type.IsInstanceOfType(ex)))
{
return;
}

throw ThrowsException.ForIncorrectExceptionType(exceptionTypes.First(), ex);
}

throw ThrowsException.ForNoException(exceptionTypes.First());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Net;
using System.Net.Security;
using System.Security.Authentication;
Expand All @@ -18,6 +19,8 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json.Linq;
using Xunit.Sdk;

namespace InMemory.FunctionalTests;

Expand Down Expand Up @@ -66,4 +69,40 @@ await sslStream.AuthenticateAsClientAsync(new SslClientAuthenticationOptions

Assert.True(tlsClientHelloCallbackInvoked);
}

[Fact]
public async Task TlsClientHelloBytesCallback_UsesOptionsTimeout()
{
var tlsClientHelloCallbackInvoked = false;
var testContext = new TestServiceContext(LoggerFactory);
await using (var server = new TestServer(context => Task.CompletedTask,
testContext,
listenOptions =>
{
listenOptions.UseHttps(_x509Certificate2, httpsOptions =>
{
httpsOptions.HandshakeTimeout = TimeSpan.FromMilliseconds(1);

httpsOptions.TlsClientHelloBytesCallback = (connection, clientHelloBytes) =>
{
Logger.LogDebug("[Received TlsClientHelloBytesCallback] Connection: {0}; TLS client hello buffer: {1}", connection.ConnectionId, clientHelloBytes.Length);
tlsClientHelloCallbackInvoked = true;
Assert.True(clientHelloBytes.Length > 32);
Assert.NotNull(connection);
};
});
}))
{
using (var connection = server.CreateConnection())
{
await connection.TransportConnection.Input.WriteAsync(new byte[] { 0x16 });
var readResult = await connection.TransportConnection.Output.ReadAsync();

// HttpsConnectionMiddleware catches the exception, so we can only check the effects of the timeout here
Assert.True(readResult.IsCompleted);
}
}

Assert.False(tlsClientHelloCallbackInvoked);
}
}
Loading