Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add basic support for client certificate to Quic #54302

Merged
merged 3 commits into from
Jul 9, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,36 @@ protected override bool ReleaseHandle()
// TODO: consider moving the static code from here to keep all the handle classes small and simple.
public static unsafe SafeMsQuicConfigurationHandle Create(QuicClientConnectionOptions options)
{
// TODO: lots of ClientAuthenticationOptions are not yet supported by MsQuic.
return Create(options, QUIC_CREDENTIAL_FLAGS.CLIENT, certificate: null, certificateContext: null, options.ClientAuthenticationOptions?.ApplicationProtocols);
X509Certificate? certificate = null;
if (options.ClientAuthenticationOptions?.ClientCertificates != null)
{
foreach (var cert in options.ClientAuthenticationOptions.ClientCertificates)
{
try
{
if (((X509Certificate2)cert).HasPrivateKey)
{
// Pick first certificate with private key.
certificate = cert;
break;
}
}
catch { };
wfurt marked this conversation as resolved.
Show resolved Hide resolved
}
}

return Create(options, QUIC_CREDENTIAL_FLAGS.CLIENT, certificate: certificate, certificateContext: null, options.ClientAuthenticationOptions?.ApplicationProtocols);
}

public static unsafe SafeMsQuicConfigurationHandle Create(QuicListenerOptions options)
{
// TODO: lots of ServerAuthenticationOptions are not yet supported by MsQuic.
return Create(options, QUIC_CREDENTIAL_FLAGS.NONE, options.ServerAuthenticationOptions?.ServerCertificate, options.ServerAuthenticationOptions?.ServerCertificateContext, options.ServerAuthenticationOptions?.ApplicationProtocols);
QUIC_CREDENTIAL_FLAGS flags = QUIC_CREDENTIAL_FLAGS.NONE;
if (options.ServerAuthenticationOptions != null && options.ServerAuthenticationOptions.ClientCertificateRequired)
{
flags |= QUIC_CREDENTIAL_FLAGS.REQUIRE_CLIENT_AUTHENTICATION | QUIC_CREDENTIAL_FLAGS.INDICATE_CERTIFICATE_RECEIVED | QUIC_CREDENTIAL_FLAGS.NO_CERTIFICATE_VALIDATION;
}

return Create(options, flags, options.ServerAuthenticationOptions?.ServerCertificate, options.ServerAuthenticationOptions?.ServerCertificateContext, options.ServerAuthenticationOptions?.ApplicationProtocols);
}

// TODO: this is called from MsQuicListener and when it fails it wreaks havoc in MsQuicListener finalizer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,23 @@ public void SetClosing()
}

// constructor for inbound connections
public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, SafeMsQuicConnectionHandle handle)
public MsQuicConnection(IPEndPoint localEndPoint, IPEndPoint remoteEndPoint, SafeMsQuicConnectionHandle handle, bool remoteCertificateRequired = false, X509RevocationMode revocationMode = X509RevocationMode.Offline, RemoteCertificateValidationCallback? remoteCertificateValidationCallback = null)
{
_state.Handle = handle;
_state.StateGCHandle = GCHandle.Alloc(_state);
_state.Connected = true;
_isServer = true;
_localEndPoint = localEndPoint;
_remoteEndPoint = remoteEndPoint;
_remoteCertificateRequired = false;
_isServer = true;
_remoteCertificateRequired = remoteCertificateRequired;
_revocationMode = revocationMode;
_remoteCertificateValidationCallback = remoteCertificateValidationCallback;

if (_remoteCertificateRequired)
{
// We need to link connection for the validation callback.
_state.Connection = this;
}

try
{
Expand Down Expand Up @@ -333,6 +341,11 @@ private static uint HandleEventPeerCertificateReceived(State state, ref Connecti
return MsQuicStatusCodes.InvalidState;
}

if (connection._isServer)
wfurt marked this conversation as resolved.
Show resolved Hide resolved
{
state.Connection = null;
}

try
{
if (connectionEvent.Data.PeerCertificateReceived.PlatformCertificateHandle != IntPtr.Zero)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Net.Quic.Implementations.MsQuic.Internal;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Channels;
Expand All @@ -31,9 +32,19 @@ private sealed class State
public readonly SafeMsQuicConfigurationHandle ConnectionConfiguration;
public readonly Channel<MsQuicConnection> AcceptConnectionQueue;

public bool RemoteCertificateRequired;
public X509RevocationMode RevocationMode = X509RevocationMode.Offline;
public RemoteCertificateValidationCallback? RemoteCertificateValidationCallback;

public State(QuicListenerOptions options)
{
ConnectionConfiguration = SafeMsQuicConfigurationHandle.Create(options);
if (options.ServerAuthenticationOptions != null)
{
RemoteCertificateRequired = options.ServerAuthenticationOptions.ClientCertificateRequired;
RevocationMode = options.ServerAuthenticationOptions.CertificateRevocationCheckMode;
RemoteCertificateValidationCallback = options.ServerAuthenticationOptions.RemoteCertificateValidationCallback;
}

AcceptConnectionQueue = Channel.CreateBounded<MsQuicConnection>(new BoundedChannelOptions(options.ListenBacklog)
{
Expand Down Expand Up @@ -182,7 +193,7 @@ private static unsafe uint NativeCallbackHandler(
uint status = MsQuicApi.Api.ConnectionSetConfigurationDelegate(connectionHandle, state.ConnectionConfiguration);
QuicExceptionHelpers.ThrowIfFailed(status, "ConnectionSetConfiguration failed.");

var msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, connectionHandle);
var msQuicConnection = new MsQuicConnection(localEndPoint, remoteEndPoint, connectionHandle, state.RemoteCertificateRequired, state.RevocationMode, state.RemoteCertificateValidationCallback);
msQuicConnection.SetNegotiatedAlpn(connectionInfo.NegotiatedAlpn, connectionInfo.NegotiatedAlpnLength);

if (!state.AcceptConnectionQueue.Writer.TryWrite(msQuicConnection))
Expand Down
42 changes: 42 additions & 0 deletions src/libraries/System.Net.Quic/tests/FunctionalTests/MsQuicTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,48 @@ public async Task ConnectWithCertificateChain()
}

[Fact]
[PlatformSpecific(TestPlatforms.Windows)]
public async Task ConnectWithClientCertificate()
{
bool clientCertificateOK = false;

var serverOptions = new QuicListenerOptions();
serverOptions.ListenEndPoint = new IPEndPoint(IPAddress.Loopback, 0);
serverOptions.ServerAuthenticationOptions = GetSslServerAuthenticationOptions();
serverOptions.ServerAuthenticationOptions.ClientCertificateRequired = true;
serverOptions.ServerAuthenticationOptions.RemoteCertificateValidationCallback = (sender, cert, chain, errors) =>
{
_output.WriteLine("client certificate {0}", cert);
Assert.NotNull(cert);
Assert.Equal(ClientCertificate.Thumbprint, ((X509Certificate2)cert).Thumbprint);

clientCertificateOK = true;
return true;
};
using QuicListener listener = new QuicListener(QuicImplementationProviders.MsQuic, serverOptions);

QuicClientConnectionOptions clientOptions = new QuicClientConnectionOptions()
{
RemoteEndPoint = listener.ListenEndPoint,
ClientAuthenticationOptions = GetSslClientAuthenticationOptions(),
};
clientOptions.ClientAuthenticationOptions.ClientCertificates = new X509CertificateCollection() { ClientCertificate };

using QuicConnection clientConnection = new QuicConnection(QuicImplementationProviders.MsQuic, clientOptions);
ValueTask clientTask = clientConnection.ConnectAsync();

using QuicConnection serverConnection = await listener.AcceptConnectionAsync();
await clientTask;
// Verify functionality of the connections.
await PingPong(clientConnection, serverConnection);
// check we completed the client certificate verification.
Assert.True(clientCertificateOK);

await serverConnection.CloseAsync(0);
}

[Fact]
[ActiveIssue("https://github.com/dotnet/runtime/issues/52048")]
public async Task WaitForAvailableUnidirectionStreamsAsyncWorks()
{
using QuicListener listener = CreateQuicListener(maxUnidirectionalStreams: 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ namespace System.Net.Quic.Tests
public abstract class QuicTestBase<T>
where T : IQuicImplProviderFactory, new()
{
private static readonly byte[] s_ping = Encoding.UTF8.GetBytes("PING");
private static readonly byte[] s_pong = Encoding.UTF8.GetBytes("PONG");
private static readonly IQuicImplProviderFactory s_factory = new T();

public static QuicImplementationProvider ImplementationProvider { get; } = s_factory.GetProvider();
Expand All @@ -23,6 +25,7 @@ public abstract class QuicTestBase<T>
public static SslApplicationProtocol ApplicationProtocol { get; } = new SslApplicationProtocol("quictest");

public X509Certificate2 ServerCertificate = System.Net.Test.Common.Configuration.Certificates.GetServerCertificate();
public X509Certificate2 ClientCertificate = System.Net.Test.Common.Configuration.Certificates.GetClientCertificate();

public bool RemoteCertificateValidationCallback(object sender, X509Certificate? certificate, X509Chain? chain, SslPolicyErrors sslPolicyErrors)
{
Expand Down Expand Up @@ -75,6 +78,36 @@ internal QuicListener CreateQuicListener(IPEndPoint endpoint)
return CreateQuicListener(options);
}

internal async Task PingPong(QuicConnection client, QuicConnection server)
{
using QuicStream clientStream = client.OpenBidirectionalStream();
ValueTask t = clientStream.WriteAsync(s_ping);
using QuicStream serverStream = await server.AcceptStreamAsync();

byte[] buffer = new byte[s_ping.Length];
int remains = s_ping.Length;
while (remains > 0)
{
int readLength = await serverStream.ReadAsync(buffer, buffer.Length - remains, remains);
Assert.True(readLength > 0);
remains -= readLength;
}
Assert.Equal(s_ping, buffer);
await t;

t = serverStream.WriteAsync(s_pong);
remains = s_pong.Length;
while (remains > 0)
{
int readLength = await clientStream.ReadAsync(buffer, buffer.Length - remains, remains);
Assert.True(readLength > 0);
remains -= readLength;
}

Assert.Equal(s_pong, buffer);
await t;
}
wfurt marked this conversation as resolved.
Show resolved Hide resolved

private QuicListener CreateQuicListener(QuicListenerOptions options) => new QuicListener(ImplementationProvider, options);

internal async Task RunClientServer(Func<QuicConnection, Task> clientFunction, Func<QuicConnection, Task> serverFunction, int iterations = 1, int millisecondsTimeout = 10_000)
Expand Down