From a258cdb5207b30ad253ae680ffa4103b3588c3dc Mon Sep 17 00:00:00 2001 From: SilverFox Date: Sun, 24 May 2020 13:57:13 +0800 Subject: [PATCH 1/6] Avoid unnessary ByteBufferUtil.PrettyHexDump in test --- test/DotNetty.Handlers.Tests/SniHandlerTest.cs | 10 ++++++++-- test/DotNetty.Handlers.Tests/TlsHandlerTest.cs | 10 ++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/test/DotNetty.Handlers.Tests/SniHandlerTest.cs b/test/DotNetty.Handlers.Tests/SniHandlerTest.cs index 0f4d09328..d73f46433 100644 --- a/test/DotNetty.Handlers.Tests/SniHandlerTest.cs +++ b/test/DotNetty.Handlers.Tests/SniHandlerTest.cs @@ -94,7 +94,10 @@ public async Task TlsRead(int[] frameLengths, bool isClient, IWriteStrategy writ await Task.WhenAll(writeTasks).WithTimeout(TimeSpan.FromSeconds(5)); IByteBuffer finalReadBuffer = Unpooled.Buffer(16 * 1024); await ReadOutboundAsync(async () => ch.ReadInbound(), expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout); - Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); + if (!ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer)) + { + Assert.True(false, $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); + } if (!isClient) { @@ -171,7 +174,10 @@ await ReadOutboundAsync( return Unpooled.WrappedBuffer(readBuffer, 0, read); }, expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout); - Assert.True(ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer), $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); + if (!ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer)) + { + Assert.True(false, $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); + } if (!isClient) { diff --git a/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs b/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs index ba7748c59..1a99f3f91 100644 --- a/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs +++ b/test/DotNetty.Handlers.Tests/TlsHandlerTest.cs @@ -102,10 +102,9 @@ public async Task TlsRead(int[] frameLengths, bool isClient, IWriteStrategy writ await Task.WhenAll(writeTasks).WithTimeout(TimeSpan.FromSeconds(5)); IByteBuffer finalReadBuffer = Unpooled.Buffer(16 * 1024); await ReadOutboundAsync(async () => ch.ReadInbound(), expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout); - bool isEqual = ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer); - if (!isEqual) + if (!ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer)) { - Assert.True(isEqual, $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); + Assert.True(false, $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); } driverStream.Dispose(); Assert.False(ch.Finish()); @@ -192,10 +191,9 @@ await ReadOutboundAsync( return Unpooled.WrappedBuffer(readBuffer, 0, read); }, expectedBuffer.ReadableBytes, finalReadBuffer, TestTimeout); - bool isEqual = ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer); - if (!isEqual) + if (!ByteBufferUtil.Equals(expectedBuffer, finalReadBuffer)) { - Assert.True(isEqual, $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); + Assert.True(false, $"---Expected:\n{ByteBufferUtil.PrettyHexDump(expectedBuffer)}\n---Actual:\n{ByteBufferUtil.PrettyHexDump(finalReadBuffer)}"); } driverStream.Dispose(); Assert.False(ch.Finish()); From 1a203bcb09c7eee71a35b6c6c45d3e905c7b2b3e Mon Sep 17 00:00:00 2001 From: SilverFox Date: Sun, 24 May 2020 15:40:08 +0800 Subject: [PATCH 2/6] Save old buffer to ownBuffer in ResetSource if there are some bytes which is not read. Maybe needed for mono(old version only?) and net5.0(without NETSTANDARD1_3 defined) --- src/DotNetty.Handlers/Tls/TlsHandler.cs | 170 +++++++++++++++++------- 1 file changed, 125 insertions(+), 45 deletions(-) diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.cs b/src/DotNetty.Handlers/Tls/TlsHandler.cs index 063aa2db9..68ec0a6a9 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.cs @@ -319,7 +319,7 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng try { ArraySegment inputIoBuffer = packet.GetIoBuffer(offset, length); - this.mediationStream.SetSource(inputIoBuffer.Array, inputIoBuffer.Offset); + this.mediationStream.SetSource(inputIoBuffer.Array, inputIoBuffer.Offset, ctx.Allocator); int packetIndex = 0; @@ -446,7 +446,7 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng } finally { - this.mediationStream.ResetSource(); + this.mediationStream.ResetSource(ctx.Allocator); if (!pending && outputBuffer != null) { if (outputBuffer.IsReadable()) @@ -668,6 +668,7 @@ void NotifyHandshakeFailure(Exception cause) sealed class MediationStream : Stream { readonly TlsHandler owner; + IByteBuffer ownBuffer; byte[] input; int inputStartOffset; int inputOffset; @@ -688,66 +689,112 @@ public MediationStream(TlsHandler owner) this.owner = owner; } + public int TotalReadableBytes => (this.ownBuffer?.ReadableBytes ?? 0) + SourceReadableBytes; + public int SourceReadableBytes => this.inputLength - this.inputOffset; - public void SetSource(byte[] source, int offset) + public void SetSource(byte[] source, int offset, IByteBufferAllocator alloc) { - this.input = source; - this.inputStartOffset = offset; - this.inputOffset = 0; - this.inputLength = 0; + lock (this) + { + ResetSource(alloc); + + this.input = source; + this.inputStartOffset = offset; + this.inputOffset = 0; + this.inputLength = 0; + } } - public void ResetSource() + public void ResetSource(IByteBufferAllocator alloc) { - this.input = null; - this.inputLength = 0; + //Mono will run BeginRead in async and it's running with ResetSource at the same time + //net5.0 can also hit this with `leftLen > 0` while `!this.EnsureAuthenticated()` + lock (this) + { + int leftLen = this.SourceReadableBytes; + IByteBuffer buf = this.ownBuffer; + + if (leftLen > 0) + { + if (buf != null) + { + buf.DiscardSomeReadBytes(); + } + else + { + buf = alloc.Buffer(leftLen); + this.ownBuffer = buf; + } + buf.WriteBytes(this.input, this.inputStartOffset + this.inputOffset, leftLen); + } + else if (buf != null) + { + if (!buf.IsReadable()) + { + buf.SafeRelease(); + this.ownBuffer = null; + } + else + { + buf.DiscardSomeReadBytes(); + } + } + + this.input = null; + this.inputStartOffset = 0; + this.inputOffset = 0; + this.inputLength = 0; + } } public void ExpandSource(int count) { Contract.Assert(this.input != null); - this.inputLength += count; - - ArraySegment sslBuffer = this.sslOwnedBuffer; - if (sslBuffer.Array == null) + lock (this) { - // there is no pending read operation - keep for future - return; - } - this.sslOwnedBuffer = default(ArraySegment); + this.inputLength += count; -#if NETSTANDARD1_3 - this.readByteCount = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count); - // hack: this tricks SslStream's continuation to run synchronously instead of dispatching to TP. Remove once Begin/EndRead are available. - new Task( - ms => + ArraySegment sslBuffer = this.sslOwnedBuffer; + if (sslBuffer.Array == null) { - var self = (MediationStream)ms; - TaskCompletionSource p = self.readCompletionSource; - self.readCompletionSource = null; - p.TrySetResult(self.readByteCount); - }, - this) - .RunSynchronously(TaskScheduler.Default); + // there is no pending read operation - keep for future + return; + } + this.sslOwnedBuffer = default(ArraySegment); + +#if NETSTANDARD1_3 + this.readByteCount = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count); + // hack: this tricks SslStream's continuation to run synchronously instead of dispatching to TP. Remove once Begin/EndRead are available. + new Task( + ms => + { + var self = (MediationStream)ms; + TaskCompletionSource p = self.readCompletionSource; + self.readCompletionSource = null; + p.TrySetResult(self.readByteCount); + }, + this) + .RunSynchronously(TaskScheduler.Default); #else - int read = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count); + int read = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count); - TaskCompletionSource promise = this.readCompletionSource; - this.readCompletionSource = null; - promise.TrySetResult(read); + TaskCompletionSource promise = this.readCompletionSource; + this.readCompletionSource = null; + promise.TrySetResult(read); - AsyncCallback callback = this.readCallback; - this.readCallback = null; - callback?.Invoke(promise.Task); + AsyncCallback callback = this.readCallback; + this.readCallback = null; + callback?.Invoke(promise.Task); #endif + } } #if NETSTANDARD1_3 public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - if (this.SourceReadableBytes > 0) + if (this.TotalReadableBytes > 0) { // we have the bytes available upfront - write out synchronously int read = this.ReadFromInput(buffer, offset, count); @@ -763,7 +810,7 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel #else public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) { - if (this.SourceReadableBytes > 0) + if (this.TotalReadableBytes > 0) { // we have the bytes available upfront - write out synchronously int read = this.ReadFromInput(buffer, offset, count); @@ -899,12 +946,45 @@ int ReadFromInput(byte[] destination, int destinationOffset, int destinationCapa { Contract.Assert(destination != null); - byte[] source = this.input; - int readableBytes = this.SourceReadableBytes; - int length = Math.Min(readableBytes, destinationCapacity); - Buffer.BlockCopy(source, this.inputStartOffset + this.inputOffset, destination, destinationOffset, length); - this.inputOffset += length; - return length; + lock (this) + { + int length = 0; + do + { + int readableBytes; + IByteBuffer buf = this.ownBuffer; + if (buf != null) + { + readableBytes = buf.ReadableBytes; + if (readableBytes > 0) + { + readableBytes = Math.Min(readableBytes, destinationCapacity); + buf.ReadBytes(destination, destinationOffset, readableBytes); + length += readableBytes; + destinationCapacity -= readableBytes; + + if (destinationCapacity == 0) + break; + } + } + + byte[] source = this.input; + if (source != null) + { + readableBytes = this.SourceReadableBytes; + if (readableBytes > 0) + { + readableBytes = Math.Min(readableBytes, destinationCapacity); + Buffer.BlockCopy(source, this.inputStartOffset + this.inputOffset, destination, destinationOffset, readableBytes); + length += readableBytes; + destinationCapacity -= readableBytes; + + this.inputOffset += readableBytes; + } + } + } while (false); + return length; + } } public override void Flush() From de886b43cf837c631c94370a691e0bd8e233fe34 Mon Sep 17 00:00:00 2001 From: SilverFox Date: Sun, 24 May 2020 15:43:59 +0800 Subject: [PATCH 3/6] Handle special flush mode for old version of mono In Mono(with btls provider) on linux, and maybe also for apple provider, Write is called in another thread, so it will run after the call to Flush. --- src/DotNetty.Handlers/Tls/TlsHandler.cs | 54 +++++++++++++++++++++---- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.cs b/src/DotNetty.Handlers/Tls/TlsHandler.cs index 68ec0a6a9..4f023dec2 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.cs @@ -38,6 +38,7 @@ public sealed class TlsHandler : ByteToMessageDecoder BatchingPendingWriteQueue pendingUnencryptedWrites; Task lastContextWriteTask; bool firedChannelRead; + volatile FlushMode flushMode = FlushMode.ForceFlush; IByteBuffer pendingSslStreamReadBuffer; Task pendingSslStreamReadFuture; @@ -136,8 +137,7 @@ static void HandleHandshakeCompleted(Task task, object state) if (oldState.Has(TlsHandlerState.FlushedBeforeHandshake)) { - self.Wrap(self.capturedContext); - self.capturedContext.Flush(); + self.WrapAndFlush(self.capturedContext); } break; } @@ -530,6 +530,12 @@ public override void Flush(IChannelHandlerContext context) return; } + this.WrapAndFlush(context); + } + + void WrapAndFlush(IChannelHandlerContext context) + { + this.flushMode = FlushMode.NoFlush; try { this.Wrap(context); @@ -537,7 +543,20 @@ public override void Flush(IChannelHandlerContext context) finally { // We may have written some parts of data before an exception was thrown so ensure we always flush. - context.Flush(); + if (this.flushMode == FlushMode.NoFlush) + { + this.flushMode = FlushMode.ForceFlush; + context.Flush(); + } + else + { + context.Executor.Execute((state) => { + var self = (TlsHandler)state; + + self.flushMode = FlushMode.ForceFlush; + self.capturedContext.Flush(); + }, this); + } } } @@ -595,6 +614,12 @@ void Wrap(IChannelHandlerContext context) void FinishWrap(byte[] buffer, int offset, int count) { + // In Mono(with btls provider) on linux, and maybe also for apple provider, Write is called in another thread, + // so it will run after the call to Flush. + if (this.flushMode == FlushMode.NoFlush && !this.capturedContext.Executor.InEventLoop) + { + this.flushMode = FlushMode.PendingFlush; + } IByteBuffer output; if (count == 0) { @@ -606,7 +631,7 @@ void FinishWrap(byte[] buffer, int offset, int count) output.WriteBytes(buffer, offset, count); } - this.lastContextWriteTask = this.capturedContext.WriteAsync(output); + this.lastContextWriteTask = (this.flushMode == FlushMode.ForceFlush) ? this.capturedContext.WriteAndFlushAsync(output) : this.capturedContext.WriteAsync(output); } Task FinishWrapNonAppDataAsync(byte[] buffer, int offset, int count) @@ -665,6 +690,22 @@ void NotifyHandshakeFailure(Exception cause) } } + enum FlushMode : byte + { + /// + /// Do nothing with Flush. + /// + NoFlush = 0, + /// + /// An Flush is or will be posted to IEventExecutor. + /// + PendingFlush = 1, + /// + /// Force FinishWrap to call Flush. + /// + ForceFlush = 2, + } + sealed class MediationStream : Stream { readonly TlsHandler owner; @@ -1018,10 +1059,7 @@ public override void SetLength(long value) throw new NotSupportedException(); } - public override int Read(byte[] buffer, int offset, int count) - { - throw new NotSupportedException(); - } + public override int Read(byte[] buffer, int offset, int count) => this.ReadAsync(buffer, offset, count).Result; public override bool CanRead => true; From 331ea1e54696d58a0d2ce548988c03a62ba21b8f Mon Sep 17 00:00:00 2001 From: SilverFox Date: Sun, 24 May 2020 15:53:08 +0800 Subject: [PATCH 4/6] Handle if the last data in async path. --- src/DotNetty.Handlers/Tls/TlsHandler.cs | 198 +++++++++++++++++------- 1 file changed, 139 insertions(+), 59 deletions(-) diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.cs b/src/DotNetty.Handlers/Tls/TlsHandler.cs index 4f023dec2..a4c9da022 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.cs @@ -27,6 +27,7 @@ public sealed class TlsHandler : ByteToMessageDecoder static readonly Exception ChannelClosedException = new IOException("Channel is closed"); static readonly Action HandshakeCompletionCallback = new Action(HandleHandshakeCompleted); + static readonly Action, object> UnwrapCompletedCallback = new Action, object>(UnwrapCompleted); readonly SslStream sslStream; readonly MediationStream mediationStream; @@ -40,6 +41,7 @@ public sealed class TlsHandler : ByteToMessageDecoder bool firedChannelRead; volatile FlushMode flushMode = FlushMode.ForceFlush; IByteBuffer pendingSslStreamReadBuffer; + int pendingSslStreamReadLength; Task pendingSslStreamReadFuture; public TlsHandler(TlsSettings settings) @@ -342,10 +344,11 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng Contract.Assert(this.pendingSslStreamReadBuffer != null); outputBuffer = this.pendingSslStreamReadBuffer; - outputBufferLength = outputBuffer.WritableBytes; + outputBufferLength = this.pendingSslStreamReadLength; this.pendingSslStreamReadFuture = null; this.pendingSslStreamReadBuffer = null; + this.pendingSslStreamReadLength = 0; } else { @@ -358,86 +361,78 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng int currentPacketLength = packetLengths[packetIndex]; this.mediationStream.ExpandSource(currentPacketLength); - if (currentReadFuture != null) + while (true) { - // there was a read pending already, so we make sure we completed that first - - if (!currentReadFuture.IsCompleted) + int totalRead = 0; + if (currentReadFuture != null) { - // we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input + // there was a read pending already, so we make sure we completed that first - continue; - } + if (!currentReadFuture.IsCompleted) + { + // we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input - int read = currentReadFuture.Result; + break; + } - if (read == 0) - { - //Stream closed - return; - } + int read = currentReadFuture.Result; + totalRead += read; - // Now output the result of previous read and decide whether to do an extra read on the same source or move forward - AddBufferToOutput(outputBuffer, read, output); + if (read == 0) + { + //Stream closed + return; + } - currentReadFuture = null; - outputBuffer = null; - if (this.mediationStream.SourceReadableBytes == 0) - { - // we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there + // Now output the result of previous read and decide whether to do an extra read on the same source or move forward + AddBufferToOutput(outputBuffer, read, output); - if (read < outputBufferLength) + currentReadFuture = null; + outputBuffer = null; + if (this.mediationStream.TotalReadableBytes == 0) { - // SslStream returned non-full buffer and there's no more input to go through -> - // typically it means SslStream is done reading current frame so we skip - continue; - } + // we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there - // we've read out `read` bytes out of current packet to fulfil previously outstanding read - outputBufferLength = currentPacketLength - read; - if (outputBufferLength <= 0) + if (read < outputBufferLength) + { + // SslStream returned non-full buffer and there's no more input to go through -> + // typically it means SslStream is done reading current frame so we skip + break; + } + + // we've read out `read` bytes out of current packet to fulfil previously outstanding read + outputBufferLength = currentPacketLength - totalRead; + if (outputBufferLength <= 0) + { + // after feeding to SslStream current frame it read out more bytes than current packet size + outputBufferLength = FallbackReadBufferSize; + } + } + else { - // after feeding to SslStream current frame it read out more bytes than current packet size - outputBufferLength = FallbackReadBufferSize; + // SslStream did not get to reading current frame so it completed previous read sync + // and the next read will likely read out the new frame + outputBufferLength = currentPacketLength; } } else { - // SslStream did not get to reading current frame so it completed previous read sync - // and the next read will likely read out the new frame + // there was no pending read before so we estimate buffer of `currentPacketLength` bytes to be sufficient outputBufferLength = currentPacketLength; } - } - else - { - // there was no pending read before so we estimate buffer of `currentPacketLength` bytes to be sufficient - outputBufferLength = currentPacketLength; - } - outputBuffer = ctx.Allocator.Buffer(outputBufferLength); - currentReadFuture = this.ReadFromSslStreamAsync(outputBuffer, outputBufferLength); + outputBuffer = ctx.Allocator.Buffer(outputBufferLength); + currentReadFuture = this.ReadFromSslStreamAsync(outputBuffer, outputBufferLength); + } } - // read out the rest of SslStream's output (if any) at risk of going async - // using FallbackReadBufferSize - buffer size we're ok to have pinned with the SslStream until it's done reading - while (true) + if (currentReadFuture != null) { - if (currentReadFuture != null) - { - if (!currentReadFuture.IsCompleted) - { - break; - } - int read = currentReadFuture.Result; - AddBufferToOutput(outputBuffer, read, output); - } - outputBuffer = ctx.Allocator.Buffer(FallbackReadBufferSize); - currentReadFuture = this.ReadFromSslStreamAsync(outputBuffer, FallbackReadBufferSize); + pending = true; + this.pendingSslStreamReadBuffer = outputBuffer; + this.pendingSslStreamReadFuture = currentReadFuture; + this.pendingSslStreamReadLength = outputBufferLength; } - - pending = true; - this.pendingSslStreamReadBuffer = outputBuffer; - this.pendingSslStreamReadFuture = currentReadFuture; } catch (Exception ex) { @@ -458,6 +453,91 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng outputBuffer.SafeRelease(); } } + + if (pending) + { + //Can't use ExecuteSynchronously here for it may change the order of output if task is already completed here. + this.pendingSslStreamReadFuture?.ContinueWith(UnwrapCompletedCallback, this, TaskContinuationOptions.None); + } + } + } + + static void UnwrapCompleted(Task task, object state) + { + // Mono(with legacy provider) finish ReadAsync in async, + // so extra check is needed to receive data in async + var self = (TlsHandler)state; + Debug.Assert(self.capturedContext.Executor.InEventLoop); + + //Ignore task completed in Unwrap + if (task == self.pendingSslStreamReadFuture) + { + IByteBuffer buf = self.pendingSslStreamReadBuffer; + int outputBufferLength = self.pendingSslStreamReadLength; + + self.pendingSslStreamReadFuture = null; + self.pendingSslStreamReadBuffer = null; + self.pendingSslStreamReadLength = 0; + + while (true) + { + switch (task.Status) + { + case TaskStatus.RanToCompletion: + { + //The logic is the same as the one in Unwrap() + var read = task.Result; + //Stream Closed + if (read == 0) + return; + self.capturedContext.FireChannelRead(buf.SetWriterIndex(buf.WriterIndex + read)); + + if (self.mediationStream.TotalReadableBytes == 0) + { + self.capturedContext.FireChannelReadComplete(); + self.mediationStream.ResetSource(self.capturedContext.Allocator); + + if (read < outputBufferLength) + { + // SslStream returned non-full buffer and there's no more input to go through -> + // typically it means SslStream is done reading current frame so we skip + return; + } + } + + outputBufferLength = self.mediationStream.TotalReadableBytes; + if (outputBufferLength <= 0) + outputBufferLength = FallbackReadBufferSize; + + buf = self.capturedContext.Allocator.Buffer(outputBufferLength); + task = self.ReadFromSslStreamAsync(buf, outputBufferLength); + if (task.IsCompleted) + { + continue; + } + + self.pendingSslStreamReadFuture = task; + self.pendingSslStreamReadBuffer = buf; + self.pendingSslStreamReadLength = outputBufferLength; + task.ContinueWith(UnwrapCompletedCallback, self, TaskContinuationOptions.ExecuteSynchronously); + return; + } + + case TaskStatus.Canceled: + case TaskStatus.Faulted: + { + buf.SafeRelease(); + self.HandleFailure(task.Exception); + return; + } + + default: + { + buf.SafeRelease(); + throw new ArgumentOutOfRangeException(nameof(task), "Unexpected task status: " + task.Status); + } + } + } } } From 0df22af5d855e7d298a0ea5e5fc392c7b485cb6f Mon Sep 17 00:00:00 2001 From: SilverFox Date: Sun, 24 May 2020 15:56:57 +0800 Subject: [PATCH 5/6] Clean up unnessary rethrow `try{_=task.Result;}catch(AggregateException ex){ExceptionDispatchInfo.Capture(ex.InnerException).Throw();/*unreachable*/throw;}` => `task.GetAwaiter().GetResult()`, the latter one will not wrap Exception as AggregateException. --- src/DotNetty.Handlers/Tls/TlsHandler.cs | 20 ++------------------ 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.cs b/src/DotNetty.Handlers/Tls/TlsHandler.cs index a4c9da022..b7e9b194a 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.cs @@ -959,15 +959,7 @@ public override int EndRead(IAsyncResult asyncResult) Debug.Assert(this.readCompletionSource == null || this.readCompletionSource.Task == asyncResult); Contract.Assert(!((Task)asyncResult).IsCanceled); - try - { - return ((Task)asyncResult).Result; - } - catch (AggregateException ex) - { - ExceptionDispatchInfo.Capture(ex.InnerException).Throw(); - throw; // unreachable - } + return ((Task)asyncResult).GetAwaiter().GetResult(); } IAsyncResult PrepareSyncReadResult(int readBytes, object state) @@ -1051,15 +1043,7 @@ public override void EndWrite(IAsyncResult asyncResult) return; } - try - { - ((Task)asyncResult).Wait(); - } - catch (AggregateException ex) - { - ExceptionDispatchInfo.Capture(ex.InnerException).Throw(); - throw; - } + ((Task)asyncResult).GetAwaiter().GetResult(); } #endif From 2312aff27c116cd188eddd86fd011e2ad97c58ff Mon Sep 17 00:00:00 2001 From: SilverFox Date: Thu, 27 May 2021 20:59:21 +0800 Subject: [PATCH 6/6] Ensure TlsHandler.HandleHandshakeCompleted is called in event loop, fix concurrent issue for Wrap --- src/DotNetty.Handlers/Tls/TlsHandler.cs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.cs b/src/DotNetty.Handlers/Tls/TlsHandler.cs index b7e9b194a..9976d1611 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.cs @@ -26,7 +26,7 @@ public sealed class TlsHandler : ByteToMessageDecoder const int UnencryptedWriteBatchSize = 14 * 1024; static readonly Exception ChannelClosedException = new IOException("Channel is closed"); - static readonly Action HandshakeCompletionCallback = new Action(HandleHandshakeCompleted); + static readonly Action HandshakeCompletionCallback = new Action(HandleHandshakeCompleted); static readonly Action, object> UnwrapCompletedCallback = new Action, object>(UnwrapCompleted); readonly SslStream sslStream; @@ -118,9 +118,16 @@ bool IgnoreException(Exception t) return false; } - static void HandleHandshakeCompleted(Task task, object state) + static void HandleHandshakeCompleted(object context, object state) { var self = (TlsHandler)state; + var capturedContext = self.capturedContext; + if (!capturedContext.Executor.InEventLoop) + { + capturedContext.Executor.Execute(HandshakeCompletionCallback, context, state); + return; + } + var task = (Task)context; switch (task.Status) { case TaskStatus.RanToCompletion: