From 2b8e544ea3e5d67b8927f9303f8a516658c349ef Mon Sep 17 00:00:00 2001 From: SilverFox Date: Sun, 8 Jul 2018 14:38:35 +0800 Subject: [PATCH] Rework some logic in TlsHandler * Make sure TlsHandler.MediationStream works well with different style of aync calls(Still not work for Mono, see #374) * Rework some logic in #366, now always close TlsHandler.MediationStream in TlsHandler.HandleFailure since it's never exported. --- src/DotNetty.Handlers/Tls/SniHandler.cs | 2 +- src/DotNetty.Handlers/Tls/TlsHandler.cs | 72 ++++++++++++++++--------- 2 files changed, 47 insertions(+), 27 deletions(-) diff --git a/src/DotNetty.Handlers/Tls/SniHandler.cs b/src/DotNetty.Handlers/Tls/SniHandler.cs index 8a797a119..0e2bb8276 100644 --- a/src/DotNetty.Handlers/Tls/SniHandler.cs +++ b/src/DotNetty.Handlers/Tls/SniHandler.cs @@ -28,7 +28,7 @@ public sealed class SniHandler : ByteToMessageDecoder bool readPending; public SniHandler(ServerTlsSniSettings settings) - : this(stream => new SslStream(stream, false), settings) + : this(stream => new SslStream(stream, true), settings) { } diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.cs b/src/DotNetty.Handlers/Tls/TlsHandler.cs index d58684ac8..4510d11b7 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.cs @@ -5,6 +5,7 @@ namespace DotNetty.Handlers.Tls { using System; using System.Collections.Generic; + using System.Diagnostics; using System.Diagnostics.Contracts; using System.IO; using System.Net.Security; @@ -41,7 +42,7 @@ public sealed class TlsHandler : ByteToMessageDecoder Task pendingSslStreamReadFuture; public TlsHandler(TlsSettings settings) - : this(stream => new SslStream(stream, false), settings) + : this(stream => new SslStream(stream, true), settings) { } @@ -344,6 +345,9 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng outputBuffer = this.pendingSslStreamReadBuffer; outputBufferLength = outputBuffer.WritableBytes; + + this.pendingSslStreamReadFuture = null; + this.pendingSslStreamReadBuffer = null; } else { @@ -363,17 +367,23 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng if (!currentReadFuture.IsCompleted) { // we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input - Contract.Assert(this.mediationStream.SourceReadableBytes == 0); continue; } int read = currentReadFuture.Result; + if (read == 0) + { + //Stream closed + return; + } + // Now output the result of previous read and decide whether to do an extra read on the same source or move forward AddBufferToOutput(outputBuffer, read, output); currentReadFuture = null; + outputBuffer = null; if (this.mediationStream.SourceReadableBytes == 0) { // we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there @@ -620,6 +630,7 @@ void HandleFailure(Exception cause) // Release all resources such as internal buffers that SSLEngine // is managing. + this.mediationStream.Dispose(); try { this.sslStream.Dispose(); @@ -701,14 +712,13 @@ public void ExpandSource(int count) this.inputLength += count; - TaskCompletionSource promise = this.readCompletionSource; - if (promise == null) + ArraySegment sslBuffer = this.sslOwnedBuffer; + if (sslBuffer.Array == null) { // there is no pending read operation - keep for future return; } - - ArraySegment sslBuffer = this.sslOwnedBuffer; + this.sslOwnedBuffer = default(ArraySegment); #if NETSTANDARD1_3 this.readByteCount = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count); @@ -718,29 +728,35 @@ public void ExpandSource(int count) { var self = (MediationStream)ms; TaskCompletionSource p = self.readCompletionSource; - this.readCompletionSource = null; + self.readCompletionSource = null; p.TrySetResult(self.readByteCount); }, this) .RunSynchronously(TaskScheduler.Default); #else int read = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count); + + TaskCompletionSource promise = this.readCompletionSource; this.readCompletionSource = null; promise.TrySetResult(read); - this.readCallback?.Invoke(promise.Task); + + AsyncCallback callback = this.readCallback; + this.readCallback = null; + callback?.Invoke(promise.Task); #endif } #if NETSTANDARD1_3 public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - if (this.inputLength - this.inputOffset > 0) + if (this.SourceReadableBytes > 0) { // we have the bytes available upfront - write out synchronously int read = this.ReadFromInput(buffer, offset, count); return Task.FromResult(read); } + Contract.Assert(this.sslOwnedBuffer.Array == null); // take note of buffer - we will pass bytes there once available this.sslOwnedBuffer = new ArraySegment(buffer, offset, count); this.readCompletionSource = new TaskCompletionSource(); @@ -749,13 +765,16 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel #else public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) { - if (this.inputLength - this.inputOffset > 0) + if (this.SourceReadableBytes > 0) { // we have the bytes available upfront - write out synchronously int read = this.ReadFromInput(buffer, offset, count); - return this.PrepareSyncReadResult(read, state); + var res = this.PrepareSyncReadResult(read, state); + callback?.Invoke(res); + return res; } + Contract.Assert(this.sslOwnedBuffer.Array == null); // take note of buffer - we will pass bytes there once available this.sslOwnedBuffer = new ArraySegment(buffer, offset, count); this.readCompletionSource = new TaskCompletionSource(state); @@ -771,6 +790,7 @@ public override int EndRead(IAsyncResult asyncResult) return syncResult.Result; } + Debug.Assert(this.readCompletionSource == null || this.readCompletionSource.Task == asyncResult); Contract.Assert(!((Task)asyncResult).IsCanceled); try @@ -782,12 +802,6 @@ public override int EndRead(IAsyncResult asyncResult) ExceptionDispatchInfo.Capture(ex.InnerException).Throw(); throw; // unreachable } - finally - { - this.readCompletionSource = null; - this.readCallback = null; - this.sslOwnedBuffer = default(ArraySegment); - } } IAsyncResult PrepareSyncReadResult(int readBytes, object state) @@ -817,10 +831,11 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, As // write+flush completed synchronously (and successfully) var result = new SynchronousAsyncResult(); result.AsyncState = state; - callback(result); + callback?.Invoke(result); return result; default: this.writeCallback = callback; + Contract.Assert(this.writeCompletion == null); var tcs = new TaskCompletionSource(state); this.writeCompletion = tcs; task.ContinueWith(WriteCompleteCallback, this, TaskContinuationOptions.ExecuteSynchronously); @@ -831,34 +846,39 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, As static void HandleChannelWriteComplete(Task writeTask, object state) { var self = (MediationStream)state; + + AsyncCallback callback = self.writeCallback; + self.writeCallback = null; + + var promise = self.writeCompletion; + self.writeCompletion = null; + switch (writeTask.Status) { case TaskStatus.RanToCompletion: - self.writeCompletion.TryComplete(); + promise.TryComplete(); break; case TaskStatus.Canceled: - self.writeCompletion.TrySetCanceled(); + promise.TrySetCanceled(); break; case TaskStatus.Faulted: - self.writeCompletion.TrySetException(writeTask.Exception); + promise.TrySetException(writeTask.Exception); break; default: throw new ArgumentOutOfRangeException("Unexpected task status: " + writeTask.Status); } - self.writeCallback?.Invoke(self.writeCompletion.Task); + callback?.Invoke(promise.Task); } public override void EndWrite(IAsyncResult asyncResult) { - this.writeCallback = null; - this.writeCompletion = null; - if (asyncResult is SynchronousAsyncResult) { return; } + Debug.Assert(this.writeCompletion == null || this.writeCompletion.Task == asyncResult); try { ((Task)asyncResult).Wait(); @@ -876,7 +896,7 @@ int ReadFromInput(byte[] destination, int destinationOffset, int destinationCapa Contract.Assert(destination != null); byte[] source = this.input; - int readableBytes = this.inputLength - this.inputOffset; + int readableBytes = this.SourceReadableBytes; int length = Math.Min(readableBytes, destinationCapacity); Buffer.BlockCopy(source, this.inputStartOffset + this.inputOffset, destination, destinationOffset, length); this.inputOffset += length;