diff --git a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs index 0848ca5ed21bc4..8ec4cb8521e94a 100644 --- a/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs +++ b/src/libraries/Common/src/Interop/Unix/System.Security.Cryptography.Native/Interop.Ssl.cs @@ -117,7 +117,14 @@ internal static unsafe ReadOnlySpan SslGetAlpnSelected(SafeSslHandle ssl) internal static partial IntPtr SslGetPeerCertificate(SafeSslHandle ssl); [LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetPeerCertChain")] - internal static partial SafeSharedX509StackHandle SslGetPeerCertChain(SafeSslHandle ssl); + private static partial SafeSharedX509StackHandle SslGetPeerCertChain_private(SafeSslHandle ssl); + + internal static SafeSharedX509StackHandle SslGetPeerCertChain(SafeSslHandle ssl) + { + return SafeInteriorHandle.OpenInteriorHandle( + SslGetPeerCertChain_private, + ssl); + } [LibraryImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetPeerFinished")] internal static partial int SslGetPeerFinished(SafeSslHandle ssl, IntPtr buf, int count); diff --git a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Protocol.cs b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Protocol.cs index 38274fd85acd8d..bc7fae0b4b0ee2 100644 --- a/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Protocol.cs +++ b/src/libraries/System.Net.Security/src/System/Net/Security/SslStream.Protocol.cs @@ -1029,8 +1029,9 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot return true; } - _remoteCertificate = certificate; - if (_remoteCertificate == null) + // don't assign to _remoteCertificate yet, this prevents weird exceptions if SslStream is disposed in parallel with X509Chain building + + if (certificate == null) { if (NetEventSource.Log.IsEnabled() && RemoteCertRequired) NetEventSource.Error(this, $"Remote certificate required, but no remote certificate received"); sslPolicyErrors |= SslPolicyErrors.RemoteCertificateNotAvailable; @@ -1072,15 +1073,17 @@ internal bool VerifyRemoteCertificate(RemoteCertificateValidationCallback? remot sslPolicyErrors |= CertificateValidationPal.VerifyCertificateProperties( _securityContext!, chain, - _remoteCertificate, + certificate, _sslAuthenticationOptions.CheckCertName, _sslAuthenticationOptions.IsServer, TargetHostNameHelper.NormalizeHostName(_sslAuthenticationOptions.TargetHost)); } + _remoteCertificate = certificate; + if (remoteCertValidationCallback != null) { - success = remoteCertValidationCallback(this, _remoteCertificate, chain, sslPolicyErrors); + success = remoteCertValidationCallback(this, certificate, chain, sslPolicyErrors); } else { diff --git a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamDisposeTest.cs b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamDisposeTest.cs index de7aa502933b02..9864029c7a0b34 100644 --- a/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamDisposeTest.cs +++ b/src/libraries/System.Net.Security/tests/FunctionalTests/SslStreamDisposeTest.cs @@ -4,6 +4,7 @@ using System.IO; using System.Security.Cryptography.X509Certificates; using System.Threading; +using System.Security.Authentication; using System.Threading.Tasks; using Xunit; @@ -59,6 +60,7 @@ public async Task Dispose_PendingReadAsync_ThrowsODE(bool bufferedRead) using CancellationTokenSource cts = new CancellationTokenSource(); cts.CancelAfter(TestConfiguration.PassingTestTimeout); + (SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams(leaveInnerStreamOpen: true); using (client) using (server) @@ -102,5 +104,65 @@ await TestConfiguration.WhenAllOrAnyFailedWithTimeout( await Assert.ThrowsAnyAsync(() => client.ReadAsync(readBuffer, cts.Token).AsTask()); } } + + [Fact] + [OuterLoop("Computationally expensive")] + public async Task Dispose_ParallelWithHandshake_ThrowsODE() + { + using CancellationTokenSource cts = new CancellationTokenSource(); + cts.CancelAfter(TestConfiguration.PassingTestTimeout); + + await Parallel.ForEachAsync(System.Linq.Enumerable.Range(0, 10000), cts.Token, async (i, token) => + { + (Stream clientStream, Stream serverStream) = TestHelper.GetConnectedStreams(); + + using SslStream client = new SslStream(clientStream); + using SslStream server = new SslStream(serverStream); + using X509Certificate2 serverCertificate = Configuration.Certificates.GetServerCertificate(); + using X509Certificate2 clientCertificate = Configuration.Certificates.GetClientCertificate(); + + SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions() + { + TargetHost = Guid.NewGuid().ToString("N"), + RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true, + }; + + SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions() + { + ServerCertificate = serverCertificate, + }; + + var clientTask = Task.Run(() => client.AuthenticateAsClientAsync(clientOptions, cts.Token)); + var serverTask = Task.Run(() => server.AuthenticateAsServerAsync(serverOptions, cts.Token)); + + // Dispose the instances while the handshake is in progress. + client.Dispose(); + server.Dispose(); + + await ValidateExceptionAsync(clientTask); + await ValidateExceptionAsync(serverTask); + }); + + static async Task ValidateExceptionAsync(Task task) + { + try + { + await task; + } + catch (InvalidOperationException ex) when (ex.StackTrace?.Contains("System.IO.StreamBuffer.WriteAsync") ?? true) + { + // Writing to a disposed ConnectedStream (test only, does not happen with NetworkStream) + return; + } + catch (Exception ex) when (ex + is ObjectDisposedException // disposed locally + or IOException // disposed remotely (received unexpected EOF) + or AuthenticationException) // disposed wrapped in AuthenticationException or error from platform library + { + // expected + return; + } + } + } } }