Skip to content

Commit

Permalink
[QUIC] Cleaned up TlsSecret and added test. (#93119)
Browse files Browse the repository at this point in the history
* Cleaned up TlsSecret and added test.

* To be removed: test where the test actually runs

* Update MsQuicRemoteExecutorTests.cs

* Exclude the test in release

* Feedback

---------

Co-authored-by: Natalia Kondratyeva <knatalia@microsoft.com>
  • Loading branch information
ManickaP and CarnaViire authored Oct 12, 2023
1 parent b331f23 commit dae5947
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,30 +58,35 @@ internal static unsafe T GetMsQuicParameter<T>(MsQuicSafeHandle handle, uint par
where T : unmanaged
{
T value;
uint length = (uint)sizeof(T);

GetMsQuicParameter(handle, parameter, (uint)sizeof(T), (byte*)&value);
return value;
}
internal static unsafe void GetMsQuicParameter(MsQuicSafeHandle handle, uint parameter, uint length, byte* value)
{
int status = MsQuicApi.Api.GetParam(
handle,
parameter,
&length,
(byte*)&value);
value);

if (StatusFailed(status))
{
ThrowHelper.ThrowMsQuicException(status, $"GetParam({handle}, {parameter}) failed");
}

return value;
}

internal static unsafe void SetMsQuicParameter<T>(MsQuicSafeHandle handle, uint parameter, T value)
where T : unmanaged
{
SetMsQuicParameter(handle, parameter, (uint)sizeof(T), (byte*)&value);
}
internal static unsafe void SetMsQuicParameter(MsQuicSafeHandle handle, uint parameter, uint length, byte* value)
{
int status = MsQuicApi.Api.SetParam(
handle,
parameter,
(uint)sizeof(T),
(byte*)&value);
length,
value);

if (StatusFailed(status))
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Runtime.InteropServices;
using Microsoft.Quic;

Expand Down Expand Up @@ -92,22 +93,19 @@ internal sealed class MsQuicContextSafeHandle : MsQuicSafeHandle
/// </summary>
private readonly MsQuicSafeHandle? _parent;

#if DEBUG
/// <summary>
/// Native memory to hold TLS secrets. It needs to live same cycle as the underlying connection.
/// Additional, dependent object to be disposed only after the safe handle gets released.
/// </summary>
private unsafe QUIC_TLS_SECRETS* _tlsSecrets;
private IDisposable? _disposable;

public unsafe QUIC_TLS_SECRETS* GetSecretsBuffer()
public IDisposable Disposable
{
if (_tlsSecrets == null)
set
{
_tlsSecrets = (QUIC_TLS_SECRETS*)NativeMemory.Alloc((nuint)sizeof(QUIC_TLS_SECRETS));
Debug.Assert(_disposable is null);
_disposable = value;
}

return _tlsSecrets;
}
#endif

public unsafe MsQuicContextSafeHandle(QUIC_HANDLE* handle, GCHandle context, SafeHandleType safeHandleType, MsQuicSafeHandle? parent = null)
: base(handle, safeHandleType)
Expand Down Expand Up @@ -140,13 +138,7 @@ protected override unsafe bool ReleaseHandle()
NetEventSource.Info(this, $"{this} {_parent} ref count decremented");
}
}
#if DEBUG
if (_tlsSecrets != null)
{
NativeMemory.Clear(_tlsSecrets, (nuint)sizeof(QUIC_TLS_SECRETS));
NativeMemory.Free(_tlsSecrets);
}
#endif
_disposable?.Dispose();
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

#if DEBUG
using System.Diagnostics;
using System.IO;
using System.Runtime.InteropServices;
using System.Text;
Expand All @@ -19,88 +20,93 @@ internal sealed class MsQuicTlsSecret : IDisposable

public static unsafe MsQuicTlsSecret? Create(MsQuicContextSafeHandle handle)
{
if (s_fileStream != null)
if (s_fileStream is null)
{
try
{
QUIC_TLS_SECRETS* ptr = handle.GetSecretsBuffer();
if (ptr != null)
{
int status = MsQuicApi.Api.SetParam(handle, QUIC_PARAM_CONN_TLS_SECRETS, (uint)sizeof(QUIC_TLS_SECRETS), ptr);
return null;
}

if (StatusSucceeded(status))
{
return new MsQuicTlsSecret(ptr);
}
else
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Error(handle, "Failed to set native memory for TLS secret.");
}
}
}
QUIC_TLS_SECRETS* tlsSecrets = null;
try
{
tlsSecrets = (QUIC_TLS_SECRETS*)NativeMemory.AllocZeroed((nuint)sizeof(QUIC_TLS_SECRETS));
MsQuicHelpers.SetMsQuicParameter(handle, QUIC_PARAM_CONN_TLS_SECRETS, (uint)sizeof(QUIC_TLS_SECRETS), (byte*)tlsSecrets);
MsQuicTlsSecret instance = new MsQuicTlsSecret(tlsSecrets);
handle.Disposable = instance;
return instance;
}
catch (Exception ex)
{
if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Error(handle, $"Failed to set native memory for TLS secret: {ex}");
}
if (tlsSecrets is not null)
{
NativeMemory.Free(tlsSecrets);
}
catch { };
return null;
}

return null;
}

private unsafe MsQuicTlsSecret(QUIC_TLS_SECRETS* memory)
private unsafe MsQuicTlsSecret(QUIC_TLS_SECRETS* tlsSecrets)
{
_tlsSecrets = memory;
_tlsSecrets = tlsSecrets;
}

public void WriteSecret() => WriteSecret(s_fileStream);
public unsafe void WriteSecret(FileStream? stream)
public unsafe void WriteSecret()
{
if (stream != null && _tlsSecrets != null)
Debug.Assert(_tlsSecrets is not null);
Debug.Assert(s_fileStream is not null);

lock (s_fileStream)
{
lock (stream)
string clientRandom = string.Empty;
if (_tlsSecrets->IsSet.ClientRandom != 0)
{
string clientRandom = string.Empty;

if (_tlsSecrets->IsSet.ClientRandom != 0)
{
clientRandom = HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ClientRandom, 32));
}

if (_tlsSecrets->IsSet.ClientHandshakeTrafficSecret != 0)
{
stream.Write(Encoding.ASCII.GetBytes($"CLIENT_HANDSHAKE_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ClientHandshakeTrafficSecret, _tlsSecrets->SecretLength))}\n"));
}

if (_tlsSecrets->IsSet.ServerHandshakeTrafficSecret != 0)
{
stream.Write(Encoding.ASCII.GetBytes($"SERVER_HANDSHAKE_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ServerHandshakeTrafficSecret, _tlsSecrets->SecretLength))}\n"));
}

if (_tlsSecrets->IsSet.ClientTrafficSecret0 != 0)
{
stream.Write(Encoding.ASCII.GetBytes($"CLIENT_TRAFFIC_SECRET_0 {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ClientTrafficSecret0, _tlsSecrets->SecretLength))}\n"));
}

if (_tlsSecrets->IsSet.ServerTrafficSecret0 != 0)
{
stream.Write(Encoding.ASCII.GetBytes($"SERVER_TRAFFIC_SECRET_0 {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ServerTrafficSecret0, _tlsSecrets->SecretLength))}\n"));
}

if (_tlsSecrets->IsSet.ClientEarlyTrafficSecret != 0)
{
stream.Write(Encoding.ASCII.GetBytes($"CLIENT_EARLY_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ClientEarlyTrafficSecret, _tlsSecrets->SecretLength))}\n"));
}

stream.Flush();
clientRandom = HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ClientRandom, 32));
}
if (_tlsSecrets->IsSet.ClientHandshakeTrafficSecret != 0)
{
s_fileStream.Write(Encoding.ASCII.GetBytes($"CLIENT_HANDSHAKE_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ClientHandshakeTrafficSecret, _tlsSecrets->SecretLength))}\n"));
}
if (_tlsSecrets->IsSet.ServerHandshakeTrafficSecret != 0)
{
s_fileStream.Write(Encoding.ASCII.GetBytes($"SERVER_HANDSHAKE_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ServerHandshakeTrafficSecret, _tlsSecrets->SecretLength))}\n"));
}
if (_tlsSecrets->IsSet.ClientTrafficSecret0 != 0)
{
s_fileStream.Write(Encoding.ASCII.GetBytes($"CLIENT_TRAFFIC_SECRET_0 {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ClientTrafficSecret0, _tlsSecrets->SecretLength))}\n"));
}
if (_tlsSecrets->IsSet.ServerTrafficSecret0 != 0)
{
s_fileStream.Write(Encoding.ASCII.GetBytes($"SERVER_TRAFFIC_SECRET_0 {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ServerTrafficSecret0, _tlsSecrets->SecretLength))}\n"));
}
if (_tlsSecrets->IsSet.ClientEarlyTrafficSecret != 0)
{
s_fileStream.Write(Encoding.ASCII.GetBytes($"CLIENT_EARLY_TRAFFIC_SECRET {clientRandom} {HexConverter.ToString(new ReadOnlySpan<byte>(_tlsSecrets->ClientEarlyTrafficSecret, _tlsSecrets->SecretLength))}\n"));
}
s_fileStream.Flush();
}

NativeMemory.Clear(_tlsSecrets, (nuint)sizeof(QUIC_TLS_SECRETS));
}

public unsafe void Dispose()
{
if (_tlsSecrets != null)
if (_tlsSecrets is null)
{
return;
}
lock (this)
{
NativeMemory.Clear(_tlsSecrets, (nuint)sizeof(QUIC_TLS_SECRETS));
if (_tlsSecrets is null)
{
return;
}

QUIC_TLS_SECRETS* tlsSecrets = _tlsSecrets;
_tlsSecrets = null;
NativeMemory.Free(_tlsSecrets);
}
}
}
Expand Down
23 changes: 13 additions & 10 deletions src/libraries/System.Net.Quic/src/System/Net/Quic/QuicConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@ namespace System.Net.Quic;
/// </remarks>
public sealed partial class QuicConnection : IAsyncDisposable
{
#if DEBUG
/// <summary>
/// The actual secret structure wrapper passed to MsQuic.
/// </summary>
private readonly MsQuicTlsSecret? _tlsSecret;
#endif

/// <summary>
/// Returns <c>true</c> if QUIC is supported on the current machine and can be used; otherwise, <c>false</c>.
/// </summary>
Expand Down Expand Up @@ -152,6 +145,15 @@ static async ValueTask<QuicConnection> StartConnectAsync(QuicClientConnectionOpt
/// Set when CONNECTED is received.
/// </summary>
private SslApplicationProtocol _negotiatedApplicationProtocol;

#if DEBUG
/// <summary>
/// Will contain TLS secret after CONNECTED event is received and store it into SSLKEYLOGFILE.
/// MsQuic holds the underlying pointer so this object can be disposed only after connection native handle gets closed.
/// </summary>
private readonly MsQuicTlsSecret? _tlsSecret;
#endif

/// <summary>
/// The remote endpoint used for this connection.
/// </summary>
Expand Down Expand Up @@ -467,6 +469,10 @@ private unsafe int HandleEventConnected(ref CONNECTED_DATA data)
QuicAddr localAddress = MsQuicHelpers.GetMsQuicParameter<QuicAddr>(_handle, QUIC_PARAM_CONN_LOCAL_ADDRESS);
_localEndPoint = MsQuicHelpers.QuicAddrToIPEndPoint(&localAddress);

#if DEBUG
_tlsSecret?.WriteSecret();
#endif

if (NetEventSource.Log.IsEnabled())
{
NetEventSource.Info(this, $"{this} Connection connected {LocalEndPoint} -> {RemoteEndPoint} for {_negotiatedApplicationProtocol} protocol");
Expand Down Expand Up @@ -596,9 +602,6 @@ public async ValueTask DisposeAsync()
return;
}

#if DEBUG
_tlsSecret?.Dispose();
#endif
// Check if the connection has been shut down and if not, shut it down.
if (_shutdownTcs.TryInitialize(out ValueTask valueTask, this))
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net.Security;
using System.Threading.Tasks;
using Microsoft.DotNet.RemoteExecutor;
using Microsoft.DotNet.XUnitExtensions;
using Xunit;
using Xunit.Abstractions;

namespace System.Net.Quic.Tests
{
[Collection(nameof(DisableParallelization))]
[ConditionalClass(typeof(QuicTestBase), nameof(QuicTestBase.IsSupported))]
public class MsQuicRemoteExecutorTests : QuicTestBase
{
public MsQuicRemoteExecutorTests()
: base(null!) { }

[ConditionalFact(typeof(RemoteExecutor), nameof(RemoteExecutor.IsSupported))]
public void SslKeyLogFile_IsCreatedAndFilled()
{
if (PlatformDetection.IsReleaseRuntime)
{
throw new SkipTestException("Retrieving SSL secrets is not supported in Release mode.");
}

var psi = new ProcessStartInfo();
var tempFile = Path.GetTempFileName();
psi.Environment.Add("SSLKEYLOGFILE", tempFile);

RemoteExecutor.Invoke(async () =>
{
(QuicConnection clientConnection, QuicConnection serverConnection) = await CreateConnectedQuicConnection();
await clientConnection.DisposeAsync();
await serverConnection.DisposeAsync();
}, new RemoteInvokeOptions { StartInfo = psi }).Dispose();

Assert.True(File.Exists(tempFile));
Assert.True(File.ReadAllText(tempFile).Length > 0);
}
}
}

0 comments on commit dae5947

Please sign in to comment.