Skip to content

Commit

Permalink
Adding queue back to maintain FIFO for extra security.
Browse files Browse the repository at this point in the history
  • Loading branch information
cheenamalhotra committed Nov 18, 2020
1 parent f76e1c0 commit 2d69793
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ namespace Microsoft.Data.SqlClient.SNI
/// </summary>
internal class SNISslStream : SslStream
{
private readonly SemaphoreSlim _writeAsyncSemaphore;
private readonly SemaphoreSlim _readAsyncSemaphore;
private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore;
private readonly ConcurrentQueueSemaphore _readAsyncSemaphore;

public SNISslStream(Stream innerStream, bool leaveInnerStreamOpen, RemoteCertificateValidationCallback userCertificateValidationCallback)
: base(innerStream, leaveInnerStreamOpen, userCertificateValidationCallback)
{
_writeAsyncSemaphore = new SemaphoreSlim(1);
_readAsyncSemaphore = new SemaphoreSlim(1);
_writeAsyncSemaphore = new ConcurrentQueueSemaphore(1);
_readAsyncSemaphore = new ConcurrentQueueSemaphore(1);
}

// Prevent ReadAsync collisions by running the task in a Semaphore Slim
Expand All @@ -31,7 +31,7 @@ public override async Task<int> ReadAsync(byte[] buffer, int offset, int count,
await _readAsyncSemaphore.WaitAsync().ConfigureAwait(false);
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken);
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
Expand All @@ -45,7 +45,7 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc
await _writeAsyncSemaphore.WaitAsync().ConfigureAwait(false);
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken);
await base.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
finally
{
Expand All @@ -59,19 +59,19 @@ public override async Task WriteAsync(byte[] buffer, int offset, int count, Canc
/// </summary>
internal class SNINetworkStream : NetworkStream
{
private readonly SemaphoreSlim _writeAsyncSemaphore;
private readonly SemaphoreSlim _readAsyncSemaphore;
private readonly ConcurrentQueueSemaphore _writeAsyncSemaphore;
private readonly ConcurrentQueueSemaphore _readAsyncSemaphore;

public SNINetworkStream(Socket socket, bool ownsSocket) : base(socket, ownsSocket)
{
_writeAsyncSemaphore = new SemaphoreSlim(1);
_readAsyncSemaphore = new SemaphoreSlim(1);
_writeAsyncSemaphore = new ConcurrentQueueSemaphore(1);
_readAsyncSemaphore = new ConcurrentQueueSemaphore(1);
}

// Prevent the ReadAsync collisions by running the task in a Semaphore Slim
// Prevent ReadAsync collisions by running the task in a Semaphore Slim
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _readAsyncSemaphore.WaitAsync().ConfigureAwait(false);
await _readAsyncSemaphore.WaitAsync();
try
{
return await base.ReadAsync(buffer, offset, count, cancellationToken);
Expand All @@ -85,7 +85,7 @@ public override async Task<int> ReadAsync(byte[] buffer, int offset, int count,
// Prevent the WriteAsync collisions by running the task in a Semaphore Slim
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _writeAsyncSemaphore.WaitAsync().ConfigureAwait(false);
await _writeAsyncSemaphore.WaitAsync();
try
{
await base.WriteAsync(buffer, offset, count, cancellationToken);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2133,4 +2133,47 @@ public static MethodInfo GetPromotedToken
}
}

/// <summary>
/// This class implements a FIFO Queue with SemaphoreSlim for ordered execution of parallel tasks.
/// Currently used in Managed SNI (SNISslStream) to override SslStream's WriteAsync implementation.
/// </summary>
internal class ConcurrentQueueSemaphore
{
private readonly SemaphoreSlim _semaphore;
private readonly ConcurrentQueue<TaskCompletionSource<bool>> _queue =
new ConcurrentQueue<TaskCompletionSource<bool>>();

public ConcurrentQueueSemaphore(int initialCount)
{
_semaphore = new SemaphoreSlim(initialCount);
}

public ConcurrentQueueSemaphore(int initialCount, int maxCount)
{
_semaphore = new SemaphoreSlim(initialCount, maxCount);
}

public void Wait()
{
WaitAsync().Wait();
}

public Task WaitAsync()
{
var tcs = new TaskCompletionSource<bool>();
_queue.Enqueue(tcs);
_semaphore.WaitAsync().ContinueWith(t =>
{
if (_queue.TryDequeue(out TaskCompletionSource<bool> popped))
popped.SetResult(true);
});
return tcs.Task;
}

public void Release()
{
_semaphore.Release();
}
}

}//namespace

0 comments on commit 2d69793

Please sign in to comment.