Skip to content

Commit

Permalink
improve SslStream exception after disposal (#79329)
Browse files Browse the repository at this point in the history
* improve SslStream exception after disposal

* add tests

* add StreamUse

* fix cleanup

* fix condition

* avoid casting
  • Loading branch information
wfurt authored Jan 6, 2023
1 parent c3d1dd9 commit 2a27452
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ private void CloseInternal()

// Ensure a Read or Auth operation is not in progress,
// block potential future read and auth operations since SslStream is disposing.
// This leaves the _nestedRead = 1 and _nestedAuth = 1, but that's ok, since
// This leaves the _nestedRead = 2 and _nestedAuth = 2, but that's ok, since
// subsequent operations check the _exception sentinel first
if (Interlocked.Exchange(ref _nestedRead, 1) == 0 &&
Interlocked.Exchange(ref _nestedAuth, 1) == 0)
if (Interlocked.Exchange(ref _nestedRead, StreamDisposed) == StreamNotInUse &&
Interlocked.Exchange(ref _nestedAuth, StreamDisposed) == StreamNotInUse)
{
_buffer.ReturnBuffer();
}
Expand Down Expand Up @@ -162,19 +162,22 @@ private async Task ReplyOnReAuthenticationAsync<TIOAdapter>(byte[]? buffer, Canc
private async Task RenegotiateAsync<TIOAdapter>(CancellationToken cancellationToken)
where TIOAdapter : IReadWriteAdapter
{
if (Interlocked.Exchange(ref _nestedAuth, 1) == 1)
if (Interlocked.CompareExchange(ref _nestedAuth, StreamInUse, StreamNotInUse) != StreamNotInUse)
{
ObjectDisposedException.ThrowIf(_nestedAuth == StreamDisposed, this);
throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "authenticate"));
}

if (Interlocked.Exchange(ref _nestedRead, 1) == 1)
if (Interlocked.CompareExchange(ref _nestedRead, StreamInUse, StreamNotInUse) != StreamNotInUse)
{
ObjectDisposedException.ThrowIf(_nestedRead == StreamDisposed, this);
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "read"));
}

if (Interlocked.Exchange(ref _nestedWrite, 1) == 1)
// Write is different since we do not do anything special in Dispose
if (Interlocked.Exchange(ref _nestedWrite, StreamInUse) != StreamNotInUse)
{
_nestedRead = 0;
_nestedRead = StreamNotInUse;
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "write"));
}

Expand Down Expand Up @@ -231,8 +234,8 @@ private async Task RenegotiateAsync<TIOAdapter>(CancellationToken cancellationTo
_buffer.ReturnBuffer();
}

_nestedRead = 0;
_nestedWrite = 0;
_nestedRead = StreamNotInUse;
_nestedWrite = StreamNotInUse;
_isRenego = false;
// We will not release _nestedAuth at this point to prevent another renegotiation attempt.
}
Expand All @@ -248,7 +251,7 @@ private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[
if (reAuthenticationData == null)
{
// prevent nesting only when authentication functions are called explicitly. e.g. handle renegotiation transparently.
if (Interlocked.Exchange(ref _nestedAuth, 1) == 1)
if (Interlocked.Exchange(ref _nestedAuth, StreamInUse) == StreamInUse)
{
throw new InvalidOperationException(SR.Format(SR.net_io_invalidnestedcall, "authenticate"));
}
Expand Down Expand Up @@ -335,7 +338,7 @@ private async Task ForceAuthenticationAsync<TIOAdapter>(bool receiveFirst, byte[
{
if (reAuthenticationData == null)
{
_nestedAuth = 0;
_nestedAuth = StreamNotInUse;
_isRenego = false;
}
}
Expand Down Expand Up @@ -500,7 +503,7 @@ private bool CompleteHandshake(ref ProtocolToken? alertToken, out SslPolicyError
{
ProcessHandshakeSuccess();

if (_nestedAuth != 1)
if (_nestedAuth != StreamInUse)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, $"Ignoring unsolicited renegotiated certificate.");
// ignore certificates received outside of handshake or requested renegotiation.
Expand Down Expand Up @@ -769,13 +772,16 @@ private SecurityStatusPal DecryptData(int frameSize)
private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(Memory<byte> buffer, CancellationToken cancellationToken)
where TIOAdapter : IReadWriteAdapter
{
if (Interlocked.Exchange(ref _nestedRead, 1) == 1)
// Throw first if we already have exception.
// Check for disposal is not atomic so we will check again below.
ThrowIfExceptionalOrNotAuthenticated();

if (Interlocked.CompareExchange(ref _nestedRead, StreamInUse, StreamNotInUse) != StreamNotInUse)
{
ObjectDisposedException.ThrowIf(_nestedRead == StreamDisposed, this);
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "read"));
}

ThrowIfExceptionalOrNotAuthenticated();

try
{
int processedLength = 0;
Expand Down Expand Up @@ -910,7 +916,7 @@ private async ValueTask<int> ReadAsyncInternal<TIOAdapter>(Memory<byte> buffer,
finally
{
ReturnReadBufferIfEmpty();
_nestedRead = 0;
_nestedRead = StreamNotInUse;
}
}

Expand All @@ -925,7 +931,7 @@ private async ValueTask WriteAsyncInternal<TIOAdapter>(ReadOnlyMemory<byte> buff
return;
}

if (Interlocked.Exchange(ref _nestedWrite, 1) == 1)
if (Interlocked.Exchange(ref _nestedWrite, StreamInUse) == StreamInUse)
{
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "write"));
}
Expand All @@ -948,7 +954,7 @@ private async ValueTask WriteAsyncInternal<TIOAdapter>(ReadOnlyMemory<byte> buff
}
finally
{
_nestedWrite = 0;
_nestedWrite = StreamNotInUse;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ public void ReturnBuffer()
}
}

// used to track ussage in _nested* variables bellow
private const int StreamNotInUse = 0;
private const int StreamInUse = 1;
private const int StreamDisposed = 2;

private int _nestedWrite;
private int _nestedRead;

Expand Down Expand Up @@ -703,7 +708,7 @@ public override async ValueTask DisposeAsync()
public override int ReadByte()
{
ThrowIfExceptionalOrNotAuthenticated();
if (Interlocked.Exchange(ref _nestedRead, 1) == 1)
if (Interlocked.Exchange(ref _nestedRead, StreamInUse) == StreamInUse)
{
throw new NotSupportedException(SR.Format(SR.net_io_invalidnestedcall, "read"));
}
Expand All @@ -724,7 +729,7 @@ public override int ReadByte()
// Regardless of whether we were able to read a byte from the buffer,
// reset the read tracking. If we weren't able to read a byte, the
// subsequent call to Read will set the flag again.
_nestedRead = 0;
_nestedRead = StreamNotInUse;
}

// Otherwise, fall back to reading a byte via Read, the same way Stream.ReadByte does.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.IO;
using System.Net.Test.Common;
using System.Security.Cryptography.X509Certificates;
using System.Threading;
using System.Threading.Tasks;

using Xunit;
Expand All @@ -12,13 +12,13 @@ namespace System.Net.Security.Tests
{
using Configuration = System.Net.Test.Common.Configuration;

public abstract class SslStreamDisposeTest
public class SslStreamDisposeTest
{
[Fact]
public async Task DisposeAsync_NotConnected_ClosesStream()
{
bool disposed = false;
var stream = new SslStream(new DelegateStream(disposeFunc: _ => disposed = true), false, delegate { return true; });
var stream = new SslStream(new DelegateStream(disposeFunc: _ => disposed = true, canReadFunc: () => true, canWriteFunc: () => true), false, delegate { return true; });

Assert.False(disposed);
await stream.DisposeAsync();
Expand Down Expand Up @@ -50,5 +50,57 @@ await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
await serverStream.DisposeAsync();
Assert.NotEqual(0, trackingStream2.TimesCalled(nameof(Stream.DisposeAsync)));
}

[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task Dispose_PendingReadAsync_ThrowsODE(bool bufferedRead)
{
using CancellationTokenSource cts = new CancellationTokenSource();
cts.CancelAfter(TestConfiguration.PassingTestTimeout);

(SslStream client, SslStream server) = TestHelper.GetConnectedSslStreams(leaveInnerStreamOpen: true);
using (client)
using (server)
using (X509Certificate2 serverCertificate = Configuration.Certificates.GetServerCertificate())
using (X509Certificate2 clientCertificate = Configuration.Certificates.GetClientCertificate())
{
SslClientAuthenticationOptions clientOptions = new SslClientAuthenticationOptions()
{
TargetHost = Guid.NewGuid().ToString("N"),
};
clientOptions.RemoteCertificateValidationCallback = (sender, certificate, chain, sslPolicyErrors) => true;

SslServerAuthenticationOptions serverOptions = new SslServerAuthenticationOptions()
{
ServerCertificate = serverCertificate,
};

await TestConfiguration.WhenAllOrAnyFailedWithTimeout(
client.AuthenticateAsClientAsync(clientOptions, default),
server.AuthenticateAsServerAsync(serverOptions, default));

await TestHelper.PingPong(client, server, cts.Token);

await server.WriteAsync("PINGPONG"u8.ToArray(), cts.Token);
var readBuffer = new byte[1024];

Task<int>? task = null;
if (bufferedRead)
{
// This will read everything into internal buffer. Following ReadAsync will not need IO.
task = client.ReadAsync(readBuffer, 0, 4, cts.Token);
client.Dispose();
int readLength = await task.ConfigureAwait(false);
Assert.Equal(4, readLength);
}
else
{
client.Dispose();
}

await Assert.ThrowsAnyAsync<ObjectDisposedException>(() => client.ReadAsync(readBuffer, cts.Token).AsTask());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ public static bool AllowAnyServerCertificate(object sender, X509Certificate cert
return true;
}

public static (SslStream ClientStream, SslStream ServerStream) GetConnectedSslStreams()
public static (SslStream ClientStream, SslStream ServerStream) GetConnectedSslStreams(bool leaveInnerStreamOpen = false)
{
(Stream clientStream, Stream serverStream) = GetConnectedStreams();
return (new SslStream(clientStream), new SslStream(serverStream));
return (new SslStream(clientStream, leaveInnerStreamOpen), new SslStream(serverStream, leaveInnerStreamOpen));
}

public static (Stream ClientStream, Stream ServerStream) GetConnectedStreams()
Expand Down

0 comments on commit 2a27452

Please sign in to comment.