Skip to content

Commit

Permalink
Use configure await to yield to threadpool
Browse files Browse the repository at this point in the history
  • Loading branch information
rzikm committed Feb 26, 2024
1 parent 6b9142f commit 9fdb790
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public SslConnectionOptions(QuicConnection connection, bool isClient,
_certificateChainPolicy = certificateChainPolicy;
}

internal unsafe void StartAsyncCertificateValidation(void* certificatePtr, void* chainPtr)
internal async void StartAsyncCertificateValidation(IntPtr certificatePtr, IntPtr chainPtr)
{
//
// The provided data pointers are valid only while still inside this function, so they need to be
Expand All @@ -80,96 +80,90 @@ internal unsafe void StartAsyncCertificateValidation(void* certificatePtr, void*
byte[]? chainDataRented = null;
Memory<byte> chainData = default;

if (certificatePtr != null)
if (certificatePtr != IntPtr.Zero)
{
if (MsQuicApi.UsesSChannelBackend)
{
certificate = new X509Certificate2((IntPtr)certificatePtr);
// provided data is a pointer to a CERT_CONTEXT
certificate = new X509Certificate2(certificatePtr);
// TODO: what about chainPtr?
}
else
{
// On non-SChannel backends we specify USE_PORTABLE_CERTIFICATES and the content is buffers
// with DER encoded cert and chain.
QUIC_BUFFER* certificateBuffer = (QUIC_BUFFER*)certificatePtr;
QUIC_BUFFER* chainBuffer = (QUIC_BUFFER*)chainPtr;

if (certificateBuffer->Length > 0)
unsafe
{
certDataRented = ArrayPool<byte>.Shared.Rent((int)certificateBuffer->Length);
certData = certDataRented.AsMemory(0, (int)certificateBuffer->Length);
certificateBuffer->Span.CopyTo(certData.Span);
}
// On non-SChannel backends we specify USE_PORTABLE_CERTIFICATES and the contents are buffers
// with DER encoded cert and chain.
QUIC_BUFFER* certificateBuffer = (QUIC_BUFFER*)certificatePtr;
QUIC_BUFFER* chainBuffer = (QUIC_BUFFER*)chainPtr;

if (chainBuffer->Length > 0)
{
chainDataRented = ArrayPool<byte>.Shared.Rent((int)chainBuffer->Length);
chainData = chainDataRented.AsMemory(0, (int)chainBuffer->Length);
chainBuffer->Span.CopyTo(chainData.Span);
if (certificateBuffer->Length > 0)
{
certDataRented = ArrayPool<byte>.Shared.Rent((int)certificateBuffer->Length);
certData = certDataRented.AsMemory(0, (int)certificateBuffer->Length);
certificateBuffer->Span.CopyTo(certData.Span);
}

if (chainBuffer->Length > 0)
{
chainDataRented = ArrayPool<byte>.Shared.Rent((int)chainBuffer->Length);
chainData = chainDataRented.AsMemory(0, (int)chainBuffer->Length);
chainBuffer->Span.CopyTo(chainData.Span);
}
}
}
}

QuicConnection connection = _connection;
// We wan't to do the certificate validation asynchronously, but due to a bug in MsQuic, we need to call the callback synchronously on some versions
if (MsQuicApi.SupportsAsyncCertValidation)
{
// hand-off rest of the work to the thread pool, certificatePtr and chainPtr are invalid beyond this point
_ = Task.Run(() =>
{
StartAsyncCertificateValidationCore(connection, certificate, certData, chainData, certDataRented, chainDataRented);
});
}
else
{
// due to a bug in MsQuic, we need to call the callback synchronously to close the connection properly when
// we reject the certificate
StartAsyncCertificateValidationCore(connection, certificate, certData, chainData, certDataRented, chainDataRented);
// force yield to the thread pool to free up MsQuic worker thread.
await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding);
}

static void StartAsyncCertificateValidationCore(QuicConnection thisConnection, X509Certificate2? certificate, Memory<byte> certData, Memory<byte> chainData, byte[]? certDataRented, byte[]? chainDataRented)
// certificatePtr and chainPtr are invalid beyond this point

QUIC_TLS_ALERT_CODES result;
try
{
QUIC_TLS_ALERT_CODES result;
try
if (certData.Length > 0)
{
if (certData.Length > 0)
{
Debug.Assert(certificate == null);
certificate = new X509Certificate2(certData.Span);
}

result = thisConnection._sslConnectionOptions.ValidateCertificate(certificate, certData.Span, chainData.Span);
thisConnection._remoteCertificate = certificate;
Debug.Assert(certificate == null);
certificate = new X509Certificate2(certData.Span);
}
catch (Exception ex)

result = _connection._sslConnectionOptions.ValidateCertificate(certificate, certData.Span, chainData.Span);
_connection._remoteCertificate = certificate;
}
catch (Exception ex)
{
certificate?.Dispose();
_connection._connectedTcs.TrySetException(ex);
result = QUIC_TLS_ALERT_CODES.USER_CANCELED;
}
finally
{
if (certDataRented != null)
{
certificate?.Dispose();
thisConnection._connectedTcs.TrySetException(ex);
result = QUIC_TLS_ALERT_CODES.USER_CANCELED;
ArrayPool<byte>.Shared.Return(certDataRented);
}
finally
{
if (certDataRented != null)
{
ArrayPool<byte>.Shared.Return(certDataRented);
}

if (chainDataRented != null)
{
ArrayPool<byte>.Shared.Return(chainDataRented);
}
if (chainDataRented != null)
{
ArrayPool<byte>.Shared.Return(chainDataRented);
}
}

int status = MsQuicApi.Api.ConnectionCertificateValidationComplete(
thisConnection._handle,
result == QUIC_TLS_ALERT_CODES.SUCCESS ? (byte)1 : (byte)0,
result);
int status = MsQuicApi.Api.ConnectionCertificateValidationComplete(
_connection._handle,
result == QUIC_TLS_ALERT_CODES.SUCCESS ? (byte)1 : (byte)0,
result);

if (MsQuic.StatusFailed(status))
if (MsQuic.StatusFailed(status))
{
if (NetEventSource.Log.IsEnabled())
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Error(thisConnection, $"{thisConnection} ConnectionCertificateValidationComplete failed with {ThrowHelper.GetErrorMessageForStatus(status)}");
}
NetEventSource.Error(_connection, $"{_connection} ConnectionCertificateValidationComplete failed with {ThrowHelper.GetErrorMessageForStatus(status)}");
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ private unsafe int HandleEventPeerCertificateReceived(ref PEER_CERTIFICATE_RECEI
// worker threads.
//

_sslConnectionOptions.StartAsyncCertificateValidation(data.Certificate, data.Chain);
_sslConnectionOptions.StartAsyncCertificateValidation((IntPtr)data.Certificate, (IntPtr)data.Chain);
return QUIC_STATUS_PENDING;
}

Expand Down

0 comments on commit 9fdb790

Please sign in to comment.