Skip to content

Commit

Permalink
Rework some logic to remove an Assert from TlsHandler.MediationStream…
Browse files Browse the repository at this point in the history
….SetSource
  • Loading branch information
yyjdelete committed Apr 11, 2018
1 parent e85a443 commit 7515ab7
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/DotNetty.Codecs/ByteToMessageDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ protected void DiscardSomeReadBytes()
// See:
// - https://github.com/netty/netty/issues/2327
// - https://github.com/netty/netty/issues/1764
this.cumulation.DiscardReadBytes(); // todo: use discardSomeReadBytes
this.cumulation.DiscardSomeReadBytes();
}
}

Expand Down
110 changes: 79 additions & 31 deletions src/DotNetty.Handlers/Tls/TlsHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
try
{
ArraySegment<byte> inputIoBuffer = packet.GetIoBuffer(offset, length);
this.mediationStream.SetSource(inputIoBuffer.Array, inputIoBuffer.Offset);
this.mediationStream.SetSource(inputIoBuffer.Array, inputIoBuffer.Offset, ctx.Allocator);

int packetIndex = 0;

Expand Down Expand Up @@ -391,7 +391,7 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng

currentReadFuture = null;
outputBuffer = null;
if (this.mediationStream.SourceReadableBytes == 0)
if (this.mediationStream.TotalReadableBytes == 0)
{
// we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there

Expand Down Expand Up @@ -446,7 +446,7 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
}
finally
{
this.mediationStream.ResetSource();
this.mediationStream.ResetSource(ctx.Allocator);
if (!pending && outputBuffer != null)
{
if (outputBuffer.IsReadable())
Expand Down Expand Up @@ -485,15 +485,15 @@ static void UnwrapCompleted(Task<int> task, object state)
case TaskStatus.RanToCompletion:
{
var read = task.Result;
//Stream Closed
//Stream Closed
if (read == 0)
return;
self.capturedContext.FireChannelRead(buf.SetWriterIndex(buf.WriterIndex + read));

if (self.mediationStream.SourceReadableBytes == 0)
if (self.mediationStream.TotalReadableBytes == 0)
{
self.capturedContext.FireChannelReadComplete();
self.mediationStream.ResetSource();
self.mediationStream.ResetSource(self.capturedContext.Allocator);

if (read < outputBufferLength)
{
Expand All @@ -503,7 +503,7 @@ static void UnwrapCompleted(Task<int> task, object state)
}
}

outputBufferLength = self.mediationStream.SourceReadableBytes;
outputBufferLength = self.mediationStream.TotalReadableBytes;
if (outputBufferLength <= 0)
outputBufferLength = FallbackReadBufferSize;

Expand Down Expand Up @@ -788,6 +788,7 @@ sealed class MediationStream : Stream
{
readonly TlsHandler owner;
object sourceLock = new object();
IByteBuffer ownBuffer;
byte[] input;
int inputStartOffset;
int inputOffset;
Expand All @@ -808,44 +809,61 @@ public MediationStream(TlsHandler owner)
this.owner = owner;
}

public int TotalReadableBytes => (this.ownBuffer?.ReadableBytes ?? 0) + SourceReadableBytes;

public int SourceReadableBytes => this.inputLength - this.inputOffset;

public void SetSource(byte[] source, int offset)
public void SetSource(byte[] source, int offset, IByteBufferAllocator alloc)
{
Contract.Assert(this.SourceReadableBytes == 0);
lock (sourceLock)
{
ResetSource(alloc);

this.input = source;
this.inputStartOffset = offset;
this.inputOffset = 0;
this.inputLength = 0;
}
}

public void ResetSource()
public void ResetSource(IByteBufferAllocator alloc)
{
//Mono will run BeginRead in async and it's running with ResetSource at the same time
lock (sourceLock)
{
int leftLen = this.SourceReadableBytes;
IByteBuffer buf = this.ownBuffer;

if (leftLen > 0)
{
var data = new byte[leftLen];
Buffer.BlockCopy(this.input, this.inputStartOffset + this.inputOffset, data, 0, leftLen);
this.input = data;
this.inputStartOffset = 0;
this.inputOffset = 0;
this.inputLength = leftLen;

return;
if (buf != null)
{
buf.DiscardSomeReadBytes();
}
else
{
buf = alloc.Buffer(leftLen);
this.ownBuffer = buf;
}
buf.WriteBytes(this.input, this.inputStartOffset + this.inputOffset, leftLen);
}
else
else if (buf != null)
{
this.input = null;
this.inputStartOffset = 0;
this.inputOffset = 0;
this.inputLength = 0;
if (!buf.IsReadable())
{
buf.SafeRelease();
this.ownBuffer = null;
}
else
{
buf.DiscardSomeReadBytes();
}
}

this.input = null;
this.inputStartOffset = 0;
this.inputOffset = 0;
this.inputLength = 0;
}
}

Expand Down Expand Up @@ -897,7 +915,7 @@ public void ExpandSource(int count)
#if NETSTANDARD1_3
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
if (this.inputLength - this.inputOffset > 0)
if (this.TotalReadableBytes > 0)
{
// we have the bytes available upfront - write out synchronously
int read = this.ReadFromInput(buffer, offset, count);
Expand All @@ -913,7 +931,7 @@ public override Task<int> ReadAsync(byte[] buffer, int offset, int count, Cancel
#else
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
if (this.inputLength - this.inputOffset > 0)
if (this.TotalReadableBytes > 0)
{
// we have the bytes available upfront - write out synchronously
int read = this.ReadFromInput(buffer, offset, count);
Expand Down Expand Up @@ -1040,15 +1058,45 @@ public override void EndWrite(IAsyncResult asyncResult)

int ReadFromInput(byte[] destination, int destinationOffset, int destinationCapacity)
{
Contract.Assert(destination != null);

lock (sourceLock)
{
Contract.Assert(destination != null);
int length = 0;
do
{
int readableBytes;
IByteBuffer buf = this.ownBuffer;
if (buf != null)
{
readableBytes = buf.ReadableBytes;
if (readableBytes > 0)
{
readableBytes = Math.Min(buf.ReadableBytes, destinationCapacity);
buf.ReadBytes(destination, destinationOffset, readableBytes);
length += readableBytes;
destinationCapacity -= readableBytes;

if (destinationCapacity == 0)
break;
}
}

byte[] source = this.input;
if (source != null)
{
readableBytes = this.SourceReadableBytes;
if (readableBytes > 0)
{
readableBytes = Math.Min(readableBytes, destinationCapacity);
Buffer.BlockCopy(source, this.inputStartOffset + this.inputOffset, destination, destinationOffset, readableBytes);
length += readableBytes;
destinationCapacity -= readableBytes;

byte[] source = this.input;
int readableBytes = this.inputLength - this.inputOffset;
int length = Math.Min(readableBytes, destinationCapacity);
Buffer.BlockCopy(source, this.inputStartOffset + this.inputOffset, destination, destinationOffset, length);
this.inputOffset += length;
this.inputOffset += readableBytes;
}
}
} while (false);
return length;
}
}
Expand Down

0 comments on commit 7515ab7

Please sign in to comment.