Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix overread bug in Reader Skip #8709

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/Orleans.Serialization/Buffers/Reader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,13 @@ public ref struct Reader<TInput>
private readonly static bool IsReadOnlySequenceInput = typeof(TInput) == typeof(ReadOnlySequenceInput);
private readonly static bool IsReaderInput = typeof(ReaderInput).IsAssignableFrom(typeof(TInput));
private readonly static bool IsBufferSliceInput = typeof(TInput) == typeof(BufferSliceReaderInput);

private ReadOnlySpan<byte> _currentSpan;
private int _bufferPos;
private int _bufferSize;
private readonly long _sequenceOffset;
private TInput _input;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal Reader(TInput input, SerializerSession session, long globalOffset)
{
Expand Down Expand Up @@ -318,7 +318,7 @@ internal Reader(ReadOnlySpan<byte> input, SerializerSession session, long global
if (IsSpanInput)
{
_input = default;
_currentSpan = input;
_currentSpan = input;
_bufferPos = 0;
_bufferSize = _currentSpan.Length;
_sequenceOffset = globalOffset;
Expand Down Expand Up @@ -410,11 +410,11 @@ public void Skip(long count)
{
if (IsReadOnlySequenceInput || IsBufferSliceInput)
ReubenBond marked this conversation as resolved.
Show resolved Hide resolved
{
var previousBuffersSize = Unsafe.As<TInput, ReadOnlySequenceInput>(ref _input).PreviousBuffersSize;
var end = Position + count;
while (Position < end)
{
if (Position + _bufferSize >= end)
var previousBuffersSize = Unsafe.As<TInput, ReadOnlySequenceInput>(ref _input).PreviousBuffersSize;
if (end - previousBuffersSize <= _bufferSize)
{
_bufferPos = (int)(end - previousBuffersSize);
}
Expand All @@ -426,11 +426,11 @@ public void Skip(long count)
}
else if (IsBufferSliceInput)
{
var previousBuffersSize = Unsafe.As<TInput, BufferSliceReaderInput>(ref _input).PreviousBuffersSize;
var end = Position + count;
while (Position < end)
{
if (Position + _bufferSize >= end)
var previousBuffersSize = Unsafe.As<TInput, BufferSliceReaderInput>(ref _input).PreviousBuffersSize;
if (end - previousBuffersSize <= _bufferSize)
{
_bufferPos = (int)(end - previousBuffersSize);
}
Expand Down Expand Up @@ -517,13 +517,13 @@ public void ForkFrom(long position, out Reader<TInput> forked)
{
throw new NotSupportedException($"Type {typeof(TInput)} is not supported");
}

static void ThrowInvalidPosition(long expectedPosition, long actualPosition)
{
throw new InvalidOperationException($"Expected to arrive at position {expectedPosition} after ForkFrom, but resulting position is {actualPosition}");
}
}

/// <summary>
/// Resumes the reader from the specified position after forked readers are no longer in use.
/// </summary>
Expand Down
23 changes: 20 additions & 3 deletions test/Orleans.Serialization.UnitTests/ReaderWriterTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
using System.IO;
using Xunit;
using Xunit.Abstractions;
using Orleans.Serialization.Codecs;
using Orleans.Serialization.WireProtocol;

namespace Orleans.Serialization.UnitTests
{
Expand Down Expand Up @@ -209,6 +211,20 @@ protected override void DisposeBuffer(TestMultiSegmentBufferWriter buffer, TestM

[Fact]
protected override void ByteRoundTrip() => ByteRoundTripTest();

[Fact]
public void SkipBufferEdge()
{
byte[] b = new byte[] { 25, 84, 101, 115, 116, 32, 97, 99, 99, 111, 117, 110 };
byte[] b2 = new byte[] { 116, 64, 0, 0, 0 };

var seq = ReadOnlySequenceHelper.CreateReadOnlySequence(b, b2);
using SerializerSession session = this.GetSession();
var reader = Reader.Create(seq, session);
SkipFieldExtension.SkipField(ref reader, new Field(new Tag((byte)WireType.LengthPrefixed)));

Assert.Equal(64, reader.ReadInt32());
}
}

public abstract class ReaderWriterTestBase<TBuffer, TOutput, TInput> where TOutput : IBufferWriter<byte>
Expand All @@ -219,7 +235,7 @@ public abstract class ReaderWriterTestBase<TBuffer, TOutput, TInput> where TOutp

private delegate T ReadValue<T>(ref Reader<TInput> reader);
private delegate void WriteValue<T>(ref Writer<TOutput> writer, T value);

public ReaderWriterTestBase(ITestOutputHelper testOutputHelper)
{
var services = new ServiceCollection();
Expand All @@ -229,6 +245,7 @@ public ReaderWriterTestBase(ITestOutputHelper testOutputHelper)
_testOutputHelper = testOutputHelper;
}

protected SerializerSession GetSession() => _sessionPool.GetSession();
protected abstract TBuffer CreateBuffer();
protected abstract Reader<TInput> CreateReader(TBuffer buffer, SerializerSession session);
protected abstract Writer<TOutput> CreateWriter(TBuffer buffer, SerializerSession session);
Expand Down Expand Up @@ -307,7 +324,7 @@ protected void Int64RoundTripTest()
static void Write(ref Writer<TOutput> writer, long expected) => writer.WriteInt64(expected);

Gen.Long.Sample(CreateTestPredicate(Write, Read));

}

protected void Int32RoundTripTest()
Expand All @@ -323,7 +340,7 @@ protected void UInt64RoundTripTest()
static ulong Read(ref Reader<TInput> reader) => reader.ReadUInt64();
static void Write(ref Writer<TOutput> writer, ulong expected) => writer.WriteUInt64(expected);

Gen.ULong.Sample(CreateTestPredicate(Write, Read));
Gen.ULong.Sample(CreateTestPredicate(Write, Read));
}

protected void UInt32RoundTripTest()
Expand Down