Skip to content

Commit

Permalink
Rework some logic in TlsHandler
Browse files Browse the repository at this point in the history
* Make sure TlsHandler.MediationStream works well with different style of aync calls(Still not work for Mono, see Azure#374)
* Rework some logic in Azure#366, now always close TlsHandler.MediationStream in TlsHandler.HandleFailure since it's never exported.
  • Loading branch information
yyjdelete committed Jul 8, 2018
1 parent f9b86a0 commit 2b8e544
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/DotNetty.Handlers/Tls/SniHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public sealed class SniHandler : ByteToMessageDecoder
bool readPending;

public SniHandler(ServerTlsSniSettings settings)
: this(stream => new SslStream(stream, false), settings)
: this(stream => new SslStream(stream, true), settings)
{
}

Expand Down
72 changes: 46 additions & 26 deletions src/DotNetty.Handlers/Tls/TlsHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ namespace DotNetty.Handlers.Tls
{
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.Contracts;
using System.IO;
using System.Net.Security;
Expand Down Expand Up @@ -41,7 +42,7 @@ public sealed class TlsHandler : ByteToMessageDecoder
Task<int> pendingSslStreamReadFuture;

public TlsHandler(TlsSettings settings)
: this(stream => new SslStream(stream, false), settings)
: this(stream => new SslStream(stream, true), settings)
{
}

Expand Down Expand Up @@ -344,6 +345,9 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng

outputBuffer = this.pendingSslStreamReadBuffer;
outputBufferLength = outputBuffer.WritableBytes;

this.pendingSslStreamReadFuture = null;
this.pendingSslStreamReadBuffer = null;
}
else
{
Expand All @@ -363,17 +367,23 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
if (!currentReadFuture.IsCompleted)
{
// we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input
Contract.Assert(this.mediationStream.SourceReadableBytes == 0);

continue;
}

int read = currentReadFuture.Result;

if (read == 0)
{
//Stream closed
return;
}

// Now output the result of previous read and decide whether to do an extra read on the same source or move forward
AddBufferToOutput(outputBuffer, read, output);

currentReadFuture = null;
outputBuffer = null;
if (this.mediationStream.SourceReadableBytes == 0)
{
// we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there
Expand Down Expand Up @@ -620,6 +630,7 @@ void HandleFailure(Exception cause)
// Release all resources such as internal buffers that SSLEngine
// is managing.

this.mediationStream.Dispose();
try
{
this.sslStream.Dispose();
Expand Down Expand Up @@ -701,14 +712,13 @@ public void ExpandSource(int count)

this.inputLength += count;

TaskCompletionSource<int> promise = this.readCompletionSource;
if (promise == null)
ArraySegment<byte> sslBuffer = this.sslOwnedBuffer;
if (sslBuffer.Array == null)
{
// there is no pending read operation - keep for future
return;
}

ArraySegment<byte> sslBuffer = this.sslOwnedBuffer;
this.sslOwnedBuffer = default(ArraySegment<byte>);

#if NETSTANDARD1_3
this.readByteCount = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count);
Expand All @@ -718,29 +728,35 @@ public void ExpandSource(int count)
{
var self = (MediationStream)ms;
TaskCompletionSource<int> p = self.readCompletionSource;
this.readCompletionSource = null;
self.readCompletionSource = null;
p.TrySetResult(self.readByteCount);
},
this)
.RunSynchronously(TaskScheduler.Default);
#else
int read = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count);

TaskCompletionSource<int> promise = this.readCompletionSource;
this.readCompletionSource = null;
promise.TrySetResult(read);
this.readCallback?.Invoke(promise.Task);

AsyncCallback callback = this.readCallback;
this.readCallback = null;
callback?.Invoke(promise.Task);
#endif
}

#if NETSTANDARD1_3
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
if (this.inputLength - this.inputOffset > 0)
if (this.SourceReadableBytes > 0)
{
// we have the bytes available upfront - write out synchronously
int read = this.ReadFromInput(buffer, offset, count);
return Task.FromResult(read);
}

Contract.Assert(this.sslOwnedBuffer.Array == null);
// take note of buffer - we will pass bytes there once available
this.sslOwnedBuffer = new ArraySegment<byte>(buffer, offset, count);
this.readCompletionSource = new TaskCompletionSource<int>();
Expand All @@ -749,13 +765,16 @@ public override Task<int> ReadAsync(byte[] buffer, int offset, int count, Cancel
#else
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
if (this.inputLength - this.inputOffset > 0)
if (this.SourceReadableBytes > 0)
{
// we have the bytes available upfront - write out synchronously
int read = this.ReadFromInput(buffer, offset, count);
return this.PrepareSyncReadResult(read, state);
var res = this.PrepareSyncReadResult(read, state);
callback?.Invoke(res);
return res;
}

Contract.Assert(this.sslOwnedBuffer.Array == null);
// take note of buffer - we will pass bytes there once available
this.sslOwnedBuffer = new ArraySegment<byte>(buffer, offset, count);
this.readCompletionSource = new TaskCompletionSource<int>(state);
Expand All @@ -771,6 +790,7 @@ public override int EndRead(IAsyncResult asyncResult)
return syncResult.Result;
}

Debug.Assert(this.readCompletionSource == null || this.readCompletionSource.Task == asyncResult);
Contract.Assert(!((Task<int>)asyncResult).IsCanceled);

try
Expand All @@ -782,12 +802,6 @@ public override int EndRead(IAsyncResult asyncResult)
ExceptionDispatchInfo.Capture(ex.InnerException).Throw();
throw; // unreachable
}
finally
{
this.readCompletionSource = null;
this.readCallback = null;
this.sslOwnedBuffer = default(ArraySegment<byte>);
}
}

IAsyncResult PrepareSyncReadResult(int readBytes, object state)
Expand Down Expand Up @@ -817,10 +831,11 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, As
// write+flush completed synchronously (and successfully)
var result = new SynchronousAsyncResult<int>();
result.AsyncState = state;
callback(result);
callback?.Invoke(result);
return result;
default:
this.writeCallback = callback;
Contract.Assert(this.writeCompletion == null);
var tcs = new TaskCompletionSource(state);
this.writeCompletion = tcs;
task.ContinueWith(WriteCompleteCallback, this, TaskContinuationOptions.ExecuteSynchronously);
Expand All @@ -831,34 +846,39 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, As
static void HandleChannelWriteComplete(Task writeTask, object state)
{
var self = (MediationStream)state;

AsyncCallback callback = self.writeCallback;
self.writeCallback = null;

var promise = self.writeCompletion;
self.writeCompletion = null;

switch (writeTask.Status)
{
case TaskStatus.RanToCompletion:
self.writeCompletion.TryComplete();
promise.TryComplete();
break;
case TaskStatus.Canceled:
self.writeCompletion.TrySetCanceled();
promise.TrySetCanceled();
break;
case TaskStatus.Faulted:
self.writeCompletion.TrySetException(writeTask.Exception);
promise.TrySetException(writeTask.Exception);
break;
default:
throw new ArgumentOutOfRangeException("Unexpected task status: " + writeTask.Status);
}

self.writeCallback?.Invoke(self.writeCompletion.Task);
callback?.Invoke(promise.Task);
}

public override void EndWrite(IAsyncResult asyncResult)
{
this.writeCallback = null;
this.writeCompletion = null;

if (asyncResult is SynchronousAsyncResult<int>)
{
return;
}

Debug.Assert(this.writeCompletion == null || this.writeCompletion.Task == asyncResult);
try
{
((Task<int>)asyncResult).Wait();
Expand All @@ -876,7 +896,7 @@ int ReadFromInput(byte[] destination, int destinationOffset, int destinationCapa
Contract.Assert(destination != null);

byte[] source = this.input;
int readableBytes = this.inputLength - this.inputOffset;
int readableBytes = this.SourceReadableBytes;
int length = Math.Min(readableBytes, destinationCapacity);
Buffer.BlockCopy(source, this.inputStartOffset + this.inputOffset, destination, destinationOffset, length);
this.inputOffset += length;
Expand Down

0 comments on commit 2b8e544

Please sign in to comment.