From 7515ab7e530be34e98b5fac975e1feed1c847dd1 Mon Sep 17 00:00:00 2001 From: SilverFox Date: Sun, 11 Mar 2018 23:00:40 +0800 Subject: [PATCH] Rework some logic to remove an Assert from TlsHandler.MediationStream.SetSource --- src/DotNetty.Codecs/ByteToMessageDecoder.cs | 2 +- src/DotNetty.Handlers/Tls/TlsHandler.cs | 110 ++++++++++++++------ 2 files changed, 80 insertions(+), 32 deletions(-) diff --git a/src/DotNetty.Codecs/ByteToMessageDecoder.cs b/src/DotNetty.Codecs/ByteToMessageDecoder.cs index 1f676e978..5dd85907a 100644 --- a/src/DotNetty.Codecs/ByteToMessageDecoder.cs +++ b/src/DotNetty.Codecs/ByteToMessageDecoder.cs @@ -249,7 +249,7 @@ protected void DiscardSomeReadBytes() // See: // - https://github.com/netty/netty/issues/2327 // - https://github.com/netty/netty/issues/1764 - this.cumulation.DiscardReadBytes(); // todo: use discardSomeReadBytes + this.cumulation.DiscardSomeReadBytes(); } } diff --git a/src/DotNetty.Handlers/Tls/TlsHandler.cs b/src/DotNetty.Handlers/Tls/TlsHandler.cs index 3a93138c5..1ca436b4c 100644 --- a/src/DotNetty.Handlers/Tls/TlsHandler.cs +++ b/src/DotNetty.Handlers/Tls/TlsHandler.cs @@ -323,7 +323,7 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng try { ArraySegment inputIoBuffer = packet.GetIoBuffer(offset, length); - this.mediationStream.SetSource(inputIoBuffer.Array, inputIoBuffer.Offset); + this.mediationStream.SetSource(inputIoBuffer.Array, inputIoBuffer.Offset, ctx.Allocator); int packetIndex = 0; @@ -391,7 +391,7 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng currentReadFuture = null; outputBuffer = null; - if (this.mediationStream.SourceReadableBytes == 0) + if (this.mediationStream.TotalReadableBytes == 0) { // we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there @@ -446,7 +446,7 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng } finally { - this.mediationStream.ResetSource(); + this.mediationStream.ResetSource(ctx.Allocator); if (!pending && outputBuffer != null) { if (outputBuffer.IsReadable()) @@ -485,15 +485,15 @@ static void UnwrapCompleted(Task task, object state) case TaskStatus.RanToCompletion: { var read = task.Result; - //Stream Closed + //Stream Closed if (read == 0) return; self.capturedContext.FireChannelRead(buf.SetWriterIndex(buf.WriterIndex + read)); - if (self.mediationStream.SourceReadableBytes == 0) + if (self.mediationStream.TotalReadableBytes == 0) { self.capturedContext.FireChannelReadComplete(); - self.mediationStream.ResetSource(); + self.mediationStream.ResetSource(self.capturedContext.Allocator); if (read < outputBufferLength) { @@ -503,7 +503,7 @@ static void UnwrapCompleted(Task task, object state) } } - outputBufferLength = self.mediationStream.SourceReadableBytes; + outputBufferLength = self.mediationStream.TotalReadableBytes; if (outputBufferLength <= 0) outputBufferLength = FallbackReadBufferSize; @@ -788,6 +788,7 @@ sealed class MediationStream : Stream { readonly TlsHandler owner; object sourceLock = new object(); + IByteBuffer ownBuffer; byte[] input; int inputStartOffset; int inputOffset; @@ -808,13 +809,16 @@ public MediationStream(TlsHandler owner) this.owner = owner; } + public int TotalReadableBytes => (this.ownBuffer?.ReadableBytes ?? 0) + SourceReadableBytes; + public int SourceReadableBytes => this.inputLength - this.inputOffset; - public void SetSource(byte[] source, int offset) + public void SetSource(byte[] source, int offset, IByteBufferAllocator alloc) { - Contract.Assert(this.SourceReadableBytes == 0); lock (sourceLock) { + ResetSource(alloc); + this.input = source; this.inputStartOffset = offset; this.inputOffset = 0; @@ -822,30 +826,44 @@ public void SetSource(byte[] source, int offset) } } - public void ResetSource() + public void ResetSource(IByteBufferAllocator alloc) { //Mono will run BeginRead in async and it's running with ResetSource at the same time lock (sourceLock) { int leftLen = this.SourceReadableBytes; + IByteBuffer buf = this.ownBuffer; + if (leftLen > 0) { - var data = new byte[leftLen]; - Buffer.BlockCopy(this.input, this.inputStartOffset + this.inputOffset, data, 0, leftLen); - this.input = data; - this.inputStartOffset = 0; - this.inputOffset = 0; - this.inputLength = leftLen; - - return; + if (buf != null) + { + buf.DiscardSomeReadBytes(); + } + else + { + buf = alloc.Buffer(leftLen); + this.ownBuffer = buf; + } + buf.WriteBytes(this.input, this.inputStartOffset + this.inputOffset, leftLen); } - else + else if (buf != null) { - this.input = null; - this.inputStartOffset = 0; - this.inputOffset = 0; - this.inputLength = 0; + if (!buf.IsReadable()) + { + buf.SafeRelease(); + this.ownBuffer = null; + } + else + { + buf.DiscardSomeReadBytes(); + } } + + this.input = null; + this.inputStartOffset = 0; + this.inputOffset = 0; + this.inputLength = 0; } } @@ -897,7 +915,7 @@ public void ExpandSource(int count) #if NETSTANDARD1_3 public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - if (this.inputLength - this.inputOffset > 0) + if (this.TotalReadableBytes > 0) { // we have the bytes available upfront - write out synchronously int read = this.ReadFromInput(buffer, offset, count); @@ -913,7 +931,7 @@ public override Task ReadAsync(byte[] buffer, int offset, int count, Cancel #else public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) { - if (this.inputLength - this.inputOffset > 0) + if (this.TotalReadableBytes > 0) { // we have the bytes available upfront - write out synchronously int read = this.ReadFromInput(buffer, offset, count); @@ -1040,15 +1058,45 @@ public override void EndWrite(IAsyncResult asyncResult) int ReadFromInput(byte[] destination, int destinationOffset, int destinationCapacity) { + Contract.Assert(destination != null); + lock (sourceLock) { - Contract.Assert(destination != null); + int length = 0; + do + { + int readableBytes; + IByteBuffer buf = this.ownBuffer; + if (buf != null) + { + readableBytes = buf.ReadableBytes; + if (readableBytes > 0) + { + readableBytes = Math.Min(buf.ReadableBytes, destinationCapacity); + buf.ReadBytes(destination, destinationOffset, readableBytes); + length += readableBytes; + destinationCapacity -= readableBytes; + + if (destinationCapacity == 0) + break; + } + } + + byte[] source = this.input; + if (source != null) + { + readableBytes = this.SourceReadableBytes; + if (readableBytes > 0) + { + readableBytes = Math.Min(readableBytes, destinationCapacity); + Buffer.BlockCopy(source, this.inputStartOffset + this.inputOffset, destination, destinationOffset, readableBytes); + length += readableBytes; + destinationCapacity -= readableBytes; - byte[] source = this.input; - int readableBytes = this.inputLength - this.inputOffset; - int length = Math.Min(readableBytes, destinationCapacity); - Buffer.BlockCopy(source, this.inputStartOffset + this.inputOffset, destination, destinationOffset, length); - this.inputOffset += length; + this.inputOffset += readableBytes; + } + } + } while (false); return length; } }