diff --git a/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs b/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs index b664a1c860d5b..d9fd59d0d4bdd 100644 --- a/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs +++ b/sdk/storage/Azure.Storage.Blobs/src/BlobBaseClient.cs @@ -1757,7 +1757,7 @@ private async ValueTask> StartDownloadAsyn if (response.GetRawResponse().Headers.TryGetValue(Constants.StructuredMessage.CrcStructuredMessageHeader, out string _) && response.GetRawResponse().Headers.TryGetValue(Constants.HeaderNames.ContentLength, out string rawContentLength)) { - result.Content = new StructuredMessageDecodingStream(result.Content, long.Parse(rawContentLength)); + (result.Content, _) = StructuredMessageDecodingStream.WrapStream(result.Content, long.Parse(rawContentLength)); } // if not null, we expected a structured message response // but we didn't find one in the above condition diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs new file mode 100644 index 0000000000000..444fe3eb2e0a9 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingRetriableStream.cs @@ -0,0 +1,178 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Core.Pipeline; + +namespace Azure.Storage.Shared; + +internal class StructuredMessageDecodingRetriableStream : Stream +{ + private readonly Stream _innerRetriable; + private long _decodedBytesRead; + + private readonly List _decodedDatas; + + private readonly Func _decodingStreamFactory; + private readonly Func> _decodingAsyncStreamFactory; + + public StructuredMessageDecodingRetriableStream( + Stream initialDecodingStream, + StructuredMessageDecodingStream.DecodedData initialDecodedData, + Func decodingStreamFactory, + Func> decodingAsyncStreamFactory, + ResponseClassifier responseClassifier, + int maxRetries) + { + _decodingStreamFactory = decodingStreamFactory; + _decodingAsyncStreamFactory = decodingAsyncStreamFactory; + _innerRetriable = RetriableStream.Create(initialDecodingStream, StreamFactory, StreamFactoryAsync, responseClassifier, maxRetries); + _decodedDatas = new() { initialDecodedData }; + } + + private Stream StreamFactory(long _) + { + long offset = _decodedDatas.Select(d => d.SegmentCrcs?.LastOrDefault().SegmentEnd ?? 0).Sum(); + (Stream decodingStream, StructuredMessageDecodingStream.DecodedData decodedData) = _decodingStreamFactory(offset); + _decodedDatas.Add(decodedData); + FastForwardInternal(decodingStream, _decodedBytesRead - offset, false).EnsureCompleted(); + return decodingStream; + } + + private async ValueTask StreamFactoryAsync(long _) + { + long offset = _decodedDatas.Select(d => d.SegmentCrcs?.LastOrDefault().SegmentEnd ?? 0).Sum(); + (Stream decodingStream, StructuredMessageDecodingStream.DecodedData decodedData) = await _decodingAsyncStreamFactory(offset).ConfigureAwait(false); + _decodedDatas.Add(decodedData); + await FastForwardInternal(decodingStream, _decodedBytesRead - offset, true).ConfigureAwait(false); + return decodingStream; + } + + private static async ValueTask FastForwardInternal(Stream stream, long bytes, bool async) + { + using (ArrayPool.Shared.RentDisposable(4 * Constants.KB, out byte[] buffer)) + { + if (async) + { + while (bytes > 0) + { + bytes -= await stream.ReadAsync(buffer, 0, (int)Math.Min(bytes, buffer.Length)).ConfigureAwait(false); + } + } + else + { + while (bytes > 0) + { + bytes -= stream.Read(buffer, 0, (int)Math.Min(bytes, buffer.Length)); + } + } + } + } + + protected override void Dispose(bool disposing) + { + foreach (IDisposable data in _decodedDatas) + { + data.Dispose(); + } + _decodedDatas.Clear(); + _innerRetriable.Dispose(); + } + + #region Read + public override int Read(byte[] buffer, int offset, int count) + { + int read = _innerRetriable.Read(buffer, offset, count); + _decodedBytesRead += read; + return read; + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + int read = await _innerRetriable.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false); + _decodedBytesRead += read; + return read; + } + +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + public override int Read(Span buffer) + { + int read = _innerRetriable.Read(buffer); + _decodedBytesRead += read; + return read; + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + int read = await _innerRetriable.ReadAsync(buffer, cancellationToken).ConfigureAwait(false); + _decodedBytesRead += read; + return read; + } +#endif + + public override int ReadByte() + { + int val = _innerRetriable.ReadByte(); + _decodedBytesRead += 1; + return val; + } + + public override int EndRead(IAsyncResult asyncResult) + { + int read = _innerRetriable.EndRead(asyncResult); + _decodedBytesRead += read; + return read; + } + #endregion + + #region Passthru + public override bool CanRead => _innerRetriable.CanRead; + + public override bool CanSeek => _innerRetriable.CanSeek; + + public override bool CanWrite => _innerRetriable.CanWrite; + + public override bool CanTimeout => _innerRetriable.CanTimeout; + + public override long Length => _innerRetriable.Length; + + public override long Position { get => _innerRetriable.Position; set => _innerRetriable.Position = value; } + + public override void Flush() => _innerRetriable.Flush(); + + public override Task FlushAsync(CancellationToken cancellationToken) => _innerRetriable.FlushAsync(cancellationToken); + + public override long Seek(long offset, SeekOrigin origin) => _innerRetriable.Seek(offset, origin); + + public override void SetLength(long value) => _innerRetriable.SetLength(value); + + public override void Write(byte[] buffer, int offset, int count) => _innerRetriable.Write(buffer, offset, count); + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) => _innerRetriable.WriteAsync(buffer, offset, count, cancellationToken); + + public override void WriteByte(byte value) => _innerRetriable.WriteByte(value); + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) => _innerRetriable.BeginWrite(buffer, offset, count, callback, state); + + public override void EndWrite(IAsyncResult asyncResult) => _innerRetriable.EndWrite(asyncResult); + + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) => _innerRetriable.BeginRead(buffer, offset, count, callback, state); + + public override int ReadTimeout { get => _innerRetriable.ReadTimeout; set => _innerRetriable.ReadTimeout = value; } + + public override int WriteTimeout { get => _innerRetriable.WriteTimeout; set => _innerRetriable.WriteTimeout = value; } + +#if NETSTANDARD2_1_OR_GREATER || NETCOREAPP3_0_OR_GREATER + public override void Write(ReadOnlySpan buffer) => _innerRetriable.Write(buffer); + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) => _innerRetriable.WriteAsync(buffer, cancellationToken); +#endif + #endregion +} diff --git a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs index aa94b8df350d2..37b15a2245750 100644 --- a/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs +++ b/sdk/storage/Azure.Storage.Common/src/Shared/StructuredMessageDecodingStream.cs @@ -38,6 +38,57 @@ namespace Azure.Storage.Shared; /// internal class StructuredMessageDecodingStream : Stream { + internal class DecodedData : IDisposable + { + private byte[] _crcBackingArray; + + public long? InnerStreamLength { get; private set; } + public int? TotalSegments { get; private set; } + public StructuredMessage.Flags? Flags { get; private set; } + public List<(ReadOnlyMemory SegmentCrc, long SegmentEnd)> SegmentCrcs { get; private set; } + public ReadOnlyMemory TotalCrc { get; private set; } + public bool DecodeCompleted { get; private set; } + + internal void SetStreamHeaderData(int totalSegments, long innerStreamLength, StructuredMessage.Flags flags) + { + TotalSegments = totalSegments; + InnerStreamLength = innerStreamLength; + Flags = flags; + + if (flags.HasFlag(StructuredMessage.Flags.StorageCrc64)) + { + _crcBackingArray = ArrayPool.Shared.Rent((totalSegments + 1) * StructuredMessage.Crc64Length); + SegmentCrcs = new(); + } + } + + internal void ReportSegmentCrc(ReadOnlySpan crc, int segmentNum, long segmentEnd) + { + int offset = (segmentNum - 1) * StructuredMessage.Crc64Length; + crc.CopyTo(new Span(_crcBackingArray, offset, StructuredMessage.Crc64Length)); + SegmentCrcs.Add((new ReadOnlyMemory(_crcBackingArray, offset, StructuredMessage.Crc64Length), segmentEnd)); + } + + internal void ReportTotalCrc(ReadOnlySpan crc) + { + int offset = (TotalSegments.Value) * StructuredMessage.Crc64Length; + crc.CopyTo(new Span(_crcBackingArray, offset, StructuredMessage.Crc64Length)); + TotalCrc = new ReadOnlyMemory(_crcBackingArray, offset, StructuredMessage.Crc64Length); + } + internal void MarkComplete() + { + DecodeCompleted = true; + } + + public void Dispose() + { + if (_crcBackingArray is not null) + { + ArrayPool.Shared.Return(_crcBackingArray); + } + } + } + private enum SMRegion { StreamHeader, @@ -58,16 +109,16 @@ private enum SMRegion private int _segmentHeaderLength; private int _segmentFooterLength; - private int _totalSegments; - private long _innerStreamLength; + private long? _expectedInnerStreamLength; - private StructuredMessage.Flags _flags; - private bool _processedFooter = false; private bool _disposed; + private readonly DecodedData _decodedData; private StorageCrc64HashAlgorithm _totalContentCrc; private StorageCrc64HashAlgorithm _segmentCrc; + private readonly bool _validateChecksums; + public override bool CanRead => true; public override bool CanWrite => false; @@ -88,18 +139,31 @@ public override long Position set => throw new NotSupportedException(); } - public StructuredMessageDecodingStream( + public static (Stream DecodedStream, DecodedData DecodedData) WrapStream( + Stream innerStream, + long? expextedStreamLength = default) + { + DecodedData data = new(); + return (new StructuredMessageDecodingStream(innerStream, data, expextedStreamLength), data); + } + + private StructuredMessageDecodingStream( Stream innerStream, - long? expectedStreamLength = default) + DecodedData decodedData, + long? expectedStreamLength) { Argument.AssertNotNull(innerStream, nameof(innerStream)); + Argument.AssertNotNull(decodedData, nameof(decodedData)); - _innerStreamLength = expectedStreamLength ?? -1; + _expectedInnerStreamLength = expectedStreamLength; _innerBufferedStream = new BufferedStream(innerStream); + _decodedData = decodedData; // Assumes stream will be structured message 1.0. Will validate this when consuming stream. _streamHeaderLength = StructuredMessage.V1_0.StreamHeaderLength; _segmentHeaderLength = StructuredMessage.V1_0.SegmentHeaderLength; + + _validateChecksums = true; } #region Write @@ -191,14 +255,15 @@ public override async ValueTask ReadAsync(Memory buf, CancellationTok private void AssertDecodeFinished() { - if (_streamFooterLength > 0 && !_processedFooter) + if (_streamFooterLength > 0 && !_decodedData.DecodeCompleted) { throw Errors.InvalidStructuredMessage("Premature end of stream."); } - _processedFooter = true; + _decodedData.MarkComplete(); } private long _innerStreamConsumed = 0; + private long _decodedContentConsumed = 0; private SMRegion _currentRegion = SMRegion.StreamHeader; private int _currentSegmentNum = 0; private long _currentSegmentContentLength; @@ -243,6 +308,7 @@ private int Decode(Span buffer) _totalContentCrc?.Append(buffer.Slice(bufferConsumed, read)); _segmentCrc?.Append(buffer.Slice(bufferConsumed, read)); bufferConsumed += read; + _decodedContentConsumed += read; _currentSegmentContentRemaining -= read; if (_currentSegmentContentRemaining == 0) { @@ -370,24 +436,25 @@ private int ProcessStreamHeader(ReadOnlySpan span) StructuredMessage.V1_0.ReadStreamHeader( span.Slice(0, _streamHeaderLength), out long streamLength, - out _flags, - out _totalSegments); + out StructuredMessage.Flags flags, + out int totalSegments); + + _decodedData.SetStreamHeaderData(totalSegments, streamLength, flags); - if (_innerStreamLength > 0 && streamLength != _innerStreamLength) + if (_expectedInnerStreamLength.HasValue && _expectedInnerStreamLength.Value != streamLength) { throw Errors.InvalidStructuredMessage("Unexpected message size."); } - else - { - _innerStreamLength = streamLength; - } - if (_flags.HasFlag(StructuredMessage.Flags.StorageCrc64)) + if (_decodedData.Flags.Value.HasFlag(StructuredMessage.Flags.StorageCrc64)) { - _segmentFooterLength = _flags.HasFlag(StructuredMessage.Flags.StorageCrc64) ? StructuredMessage.Crc64Length : 0; - _streamFooterLength = _flags.HasFlag(StructuredMessage.Flags.StorageCrc64) ? StructuredMessage.Crc64Length : 0; - _segmentCrc = StorageCrc64HashAlgorithm.Create(); - _totalContentCrc = StorageCrc64HashAlgorithm.Create(); + _segmentFooterLength = StructuredMessage.Crc64Length; + _streamFooterLength = StructuredMessage.Crc64Length; + if (_validateChecksums) + { + _segmentCrc = StorageCrc64HashAlgorithm.Create(); + _totalContentCrc = StorageCrc64HashAlgorithm.Create(); + } } _currentRegion = SMRegion.SegmentHeader; return _streamHeaderLength; @@ -396,30 +463,34 @@ private int ProcessStreamHeader(ReadOnlySpan span) private int ProcessStreamFooter(ReadOnlySpan span) { int totalProcessed = 0; - if (_flags.HasFlag(StructuredMessage.Flags.StorageCrc64)) + if (_decodedData.Flags.Value.HasFlag(StructuredMessage.Flags.StorageCrc64)) { totalProcessed += StructuredMessage.Crc64Length; - using (ArrayPool.Shared.RentAsSpanDisposable(StructuredMessage.Crc64Length, out Span calculated)) + ReadOnlySpan expected = span.Slice(0, StructuredMessage.Crc64Length); + _decodedData.ReportTotalCrc(expected); + if (_validateChecksums) { - _totalContentCrc.GetCurrentHash(calculated); - ReadOnlySpan expected = span.Slice(0, StructuredMessage.Crc64Length); - if (!calculated.SequenceEqual(expected)) + using (ArrayPool.Shared.RentAsSpanDisposable(StructuredMessage.Crc64Length, out Span calculated)) { - throw Errors.ChecksumMismatch(calculated, expected); + _totalContentCrc.GetCurrentHash(calculated); + if (!calculated.SequenceEqual(expected)) + { + throw Errors.ChecksumMismatch(calculated, expected); + } } } } - if (_innerStreamConsumed != _innerStreamLength) + if (_innerStreamConsumed != _decodedData.InnerStreamLength) { throw Errors.InvalidStructuredMessage("Unexpected message size."); } - if (_currentSegmentNum != _totalSegments) + if (_currentSegmentNum != _decodedData.TotalSegments) { throw Errors.InvalidStructuredMessage("Missing expected message segments."); } - _processedFooter = true; + _decodedData.MarkComplete(); return totalProcessed; } @@ -442,21 +513,25 @@ private int ProcessSegmentHeader(ReadOnlySpan span) private int ProcessSegmentFooter(ReadOnlySpan span) { int totalProcessed = 0; - if (_flags.HasFlag(StructuredMessage.Flags.StorageCrc64)) + if (_decodedData.Flags.Value.HasFlag(StructuredMessage.Flags.StorageCrc64)) { totalProcessed += StructuredMessage.Crc64Length; - using (ArrayPool.Shared.RentAsSpanDisposable(StructuredMessage.Crc64Length, out Span calculated)) + ReadOnlySpan expected = span.Slice(0, StructuredMessage.Crc64Length); + if (_validateChecksums) { - _segmentCrc.GetCurrentHash(calculated); - _segmentCrc = StorageCrc64HashAlgorithm.Create(); - ReadOnlySpan expected = span.Slice(0, StructuredMessage.Crc64Length); - if (!calculated.SequenceEqual(expected)) + using (ArrayPool.Shared.RentAsSpanDisposable(StructuredMessage.Crc64Length, out Span calculated)) { - throw Errors.ChecksumMismatch(calculated, expected); + _segmentCrc.GetCurrentHash(calculated); + _segmentCrc = StorageCrc64HashAlgorithm.Create(); + if (!calculated.SequenceEqual(expected)) + { + throw Errors.ChecksumMismatch(calculated, expected); + } } } + _decodedData.ReportSegmentCrc(expected, _currentSegmentNum, _decodedContentConsumed); } - _currentRegion = _currentSegmentNum == _totalSegments ? SMRegion.StreamFooter : SMRegion.SegmentHeader; + _currentRegion = _currentSegmentNum == _decodedData.TotalSegments ? SMRegion.StreamFooter : SMRegion.SegmentHeader; return totalProcessed; } #endregion diff --git a/sdk/storage/Azure.Storage.Common/tests/Azure.Storage.Common.Tests.csproj b/sdk/storage/Azure.Storage.Common/tests/Azure.Storage.Common.Tests.csproj index 0c3807d9b74ff..8bf802d14e766 100644 --- a/sdk/storage/Azure.Storage.Common/tests/Azure.Storage.Common.Tests.csproj +++ b/sdk/storage/Azure.Storage.Common/tests/Azure.Storage.Common.Tests.csproj @@ -13,6 +13,8 @@ + + @@ -46,6 +48,7 @@ + diff --git a/sdk/storage/Azure.Storage.Common/tests/Shared/FaultyStream.cs b/sdk/storage/Azure.Storage.Common/tests/Shared/FaultyStream.cs index 7411eb1499312..f4e4b92ed73c4 100644 --- a/sdk/storage/Azure.Storage.Common/tests/Shared/FaultyStream.cs +++ b/sdk/storage/Azure.Storage.Common/tests/Shared/FaultyStream.cs @@ -15,6 +15,7 @@ internal class FaultyStream : Stream private readonly Exception _exceptionToRaise; private int _remainingExceptions; private Action _onFault; + private long _position = 0; public FaultyStream( Stream innerStream, @@ -40,7 +41,7 @@ public FaultyStream( public override long Position { - get => _innerStream.Position; + get => CanSeek ? _innerStream.Position : _position; set => _innerStream.Position = value; } @@ -53,7 +54,9 @@ public override int Read(byte[] buffer, int offset, int count) { if (_remainingExceptions == 0 || Position + count <= _raiseExceptionAt || _raiseExceptionAt >= _innerStream.Length) { - return _innerStream.Read(buffer, offset, count); + int read = _innerStream.Read(buffer, offset, count); + _position += read; + return read; } else { @@ -61,11 +64,13 @@ public override int Read(byte[] buffer, int offset, int count) } } - public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { if (_remainingExceptions == 0 || Position + count <= _raiseExceptionAt || _raiseExceptionAt >= _innerStream.Length) { - return _innerStream.ReadAsync(buffer, offset, count, cancellationToken); + int read = await _innerStream.ReadAsync(buffer, offset, count, cancellationToken); + _position += read; + return read; } else { diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs new file mode 100644 index 0000000000000..666933e546189 --- /dev/null +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingRetriableStreamTests.cs @@ -0,0 +1,226 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; +using Azure.Core; +using Azure.Storage.Shared; +using Azure.Storage.Test.Shared; +using Moq; +using NUnit.Framework; + +namespace Azure.Storage.Tests; + +[TestFixture(true)] +[TestFixture(false)] +public class StructuredMessageDecodingRetriableStreamTests +{ + public bool Async { get; } + + public StructuredMessageDecodingRetriableStreamTests(bool async) + { + Async = async; + } + + private Mock AllExceptionsRetry() + { + Mock mock = new(MockBehavior.Strict); + mock.Setup(rc => rc.IsRetriableException(It.IsAny())).Returns(true); + return mock; + } + + [Test] + public async ValueTask UninterruptedStream() + { + byte[] data = new Random().NextBytesInline(4 * Constants.KB).ToArray(); + byte[] dest = new byte[data.Length]; + + // mock with a simple MemoryStream rather than an actual StructuredMessageDecodingStream + using (Stream src = new MemoryStream(data)) + using (Stream retriableSrc = new StructuredMessageDecodingRetriableStream(src, new(), default, default, default, 1)) + using (Stream dst = new MemoryStream(dest)) + { + await retriableSrc.CopyToInternal(dst, Async, default); + } + + Assert.AreEqual(data, dest); + } + + [Test] + public async Task Interrupt_DataIntact([Values(true, false)] bool multipleInterrupts) + { + const int segments = 4; + const int segmentLen = Constants.KB; + const int readLen = 128; + const int interruptPos = segmentLen + (3 * readLen) + 10; + + Random r = new(); + byte[] data = r.NextBytesInline(segments * Constants.KB).ToArray(); + byte[] dest = new byte[data.Length]; + + // Mock a decoded data for the mocked StructuredMessageDecodingStream + StructuredMessageDecodingStream.DecodedData initialDecodedData = new(); + initialDecodedData.SetStreamHeaderData(segments, data.Length, StructuredMessage.Flags.StorageCrc64); + // for test purposes, initialize a DecodedData, since we are not actively decoding in this test + initialDecodedData.ReportSegmentCrc(r.NextBytesInline(StructuredMessage.Crc64Length), 1, segmentLen); + + (Stream DecodingStream, StructuredMessageDecodingStream.DecodedData DecodedData) Factory(long offset, bool faulty) + { + Stream stream = new MemoryStream(data, (int)offset, data.Length - (int)offset); + if (faulty) + { + stream = new FaultyStream(stream, interruptPos, 1, new Exception(), () => { }); + } + // Mock a decoded data for the mocked StructuredMessageDecodingStream + StructuredMessageDecodingStream.DecodedData decodedData = new(); + decodedData.SetStreamHeaderData(segments, data.Length, StructuredMessage.Flags.StorageCrc64); + // for test purposes, initialize a DecodedData, since we are not actively decoding in this test + decodedData.ReportSegmentCrc(r.NextBytesInline(StructuredMessage.Crc64Length), 1, segmentLen); + return (stream, decodedData); + } + + // mock with a simple MemoryStream rather than an actual StructuredMessageDecodingStream + using (Stream src = new MemoryStream(data)) + using (Stream faultySrc = new FaultyStream(src, interruptPos, 1, new Exception(), () => { })) + using (Stream retriableSrc = new StructuredMessageDecodingRetriableStream( + faultySrc, + initialDecodedData, + offset => Factory(offset, multipleInterrupts), + offset => new ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.DecodedData DecodedData)>(Factory(offset, multipleInterrupts)), + AllExceptionsRetry().Object, + int.MaxValue)) + using (Stream dst = new MemoryStream(dest)) + { + await retriableSrc.CopyToInternal(dst, readLen, Async, default); + } + + Assert.AreEqual(data, dest); + } + + [Test] + public async Task Interrupt_AppropriateRewind() + { + const int segments = 2; + const int segmentLen = Constants.KB; + const int dataLen = segments * segmentLen; + const int readLen = segmentLen / 4; + const int interruptOffset = 10; + const int interruptPos = segmentLen + (2 * readLen) + interruptOffset; + Random r = new(); + + // Mock a decoded data for the mocked StructuredMessageDecodingStream + StructuredMessageDecodingStream.DecodedData initialDecodedData = new(); + initialDecodedData.SetStreamHeaderData(segments, segments * segmentLen, StructuredMessage.Flags.StorageCrc64); + // By the time of interrupt, there will be one segment reported + initialDecodedData.ReportSegmentCrc(r.NextBytesInline(StructuredMessage.Crc64Length), 1, segmentLen); + + Mock mock = new(MockBehavior.Strict); + mock.SetupGet(s => s.CanRead).Returns(true); + mock.SetupGet(s => s.CanSeek).Returns(false); + if (Async) + { + mock.SetupSequence(s => s.ReadAsync(It.IsAny(), It.IsAny(), It.IsAny(), default)) + .Returns(Task.FromResult(readLen)) // start first segment + .Returns(Task.FromResult(readLen)) + .Returns(Task.FromResult(readLen)) + .Returns(Task.FromResult(readLen)) // finish first segment + .Returns(Task.FromResult(readLen)) // start second segment + .Returns(Task.FromResult(readLen)) + // faulty stream interrupt + .Returns(Task.FromResult(readLen * 2)) // restart second segment. fast-forward uses an internal 4KB buffer, so it will leap the 512 byte catchup all at once + .Returns(Task.FromResult(readLen)) + .Returns(Task.FromResult(readLen)) // end second segment + .Returns(Task.FromResult(0)) // signal end of stream + .Returns(Task.FromResult(0)) // second signal needed for stream wrapping reasons + ; + } + else + { + mock.SetupSequence(s => s.Read(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(readLen) // start first segment + .Returns(readLen) + .Returns(readLen) + .Returns(readLen) // finish first segment + .Returns(readLen) // start second segment + .Returns(readLen) + // faulty stream interrupt + .Returns(readLen * 2) // restart second segment. fast-forward uses an internal 4KB buffer, so it will leap the 512 byte catchup all at once + .Returns(readLen) + .Returns(readLen) // end second segment + .Returns(0) // signal end of stream + .Returns(0) // second signal needed for stream wrapping reasons + ; + } + Stream faultySrc = new FaultyStream(mock.Object, interruptPos, 1, new Exception(), default); + Stream retriableSrc = new StructuredMessageDecodingRetriableStream( + faultySrc, + initialDecodedData, + offset => (mock.Object, new()), + offset => new(Task.FromResult((mock.Object, new StructuredMessageDecodingStream.DecodedData()))), + AllExceptionsRetry().Object, + 1); + + int totalRead = 0; + int read = 0; + byte[] buf = new byte[readLen]; + if (Async) + { + while ((read = await retriableSrc.ReadAsync(buf, 0, buf.Length)) > 0) + { + totalRead += read; + } + } + else + { + while ((read = retriableSrc.Read(buf, 0, buf.Length)) > 0) + { + totalRead += read; + } + } + await retriableSrc.CopyToInternal(Stream.Null, readLen, Async, default); + + // Asserts we read exactly the data length, excluding the fastforward of the inner stream + Assert.That(totalRead, Is.EqualTo(dataLen)); + } + + [Test] + public async Task Interrupt_ProperDecode([Values(true, false)] bool multipleInterrupts) + { + // decoding stream inserts a buffered layer of 4 KB. use larger sizes to avoid interference from it. + const int segments = 4; + const int segmentLen = 128 * Constants.KB; + const int readLen = 8 * Constants.KB; + const int interruptPos = segmentLen + (3 * readLen) + 10; + + Random r = new(); + byte[] data = r.NextBytesInline(segments * Constants.KB).ToArray(); + byte[] dest = new byte[data.Length]; + + (Stream DecodingStream, StructuredMessageDecodingStream.DecodedData DecodedData) Factory(long offset, bool faulty) + { + Stream stream = new MemoryStream(data, (int)offset, data.Length - (int)offset); + stream = new StructuredMessageEncodingStream(stream, segmentLen, StructuredMessage.Flags.StorageCrc64); + if (faulty) + { + stream = new FaultyStream(stream, interruptPos, 1, new Exception(), () => { }); + } + return StructuredMessageDecodingStream.WrapStream(stream); + } + + (Stream decodingStream, StructuredMessageDecodingStream.DecodedData decodedData) = Factory(0, true); + using Stream retriableSrc = new StructuredMessageDecodingRetriableStream( + decodingStream, + decodedData, + offset => Factory(offset, multipleInterrupts), + offset => new ValueTask<(Stream DecodingStream, StructuredMessageDecodingStream.DecodedData DecodedData)>(Factory(offset, multipleInterrupts)), + AllExceptionsRetry().Object, + int.MaxValue); + using Stream dst = new MemoryStream(dest); + + await retriableSrc.CopyToInternal(dst, readLen, Async, default); + + Assert.AreEqual(data, dest); + } +} diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs index f881a70c8e78f..2789672df4976 100644 --- a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageDecodingStreamTests.cs @@ -116,7 +116,7 @@ public async Task DecodesData( new Random().NextBytes(originalData); byte[] encodedData = StructuredMessageHelper.MakeEncodedData(originalData, segmentContentLength, flags); - Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData)); + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); byte[] decodedData; using (MemoryStream dest = new()) { @@ -136,7 +136,7 @@ public void BadStreamBadVersion() encodedData[0] = byte.MaxValue; - Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData)); + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); } @@ -154,7 +154,7 @@ public async Task BadSegmentCrcThrows() encodedData[badBytePos] = (byte)~encodedData[badBytePos]; MemoryStream encodedDataStream = new(encodedData); - Stream decodingStream = new StructuredMessageDecodingStream(encodedDataStream); + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(encodedDataStream); // manual try/catch to validate the proccess failed mid-stream rather than the end const int copyBufferSize = 4; @@ -183,7 +183,7 @@ public void BadStreamCrcThrows() encodedData[originalData.Length - 1] = (byte)~encodedData[originalData.Length - 1]; - Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData)); + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); } @@ -196,7 +196,7 @@ public void BadStreamWrongContentLength() BinaryPrimitives.WriteInt64LittleEndian(new Span(encodedData, V1_0.StreamHeaderMessageLengthOffset, 8), 123456789L); - Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData)); + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); } @@ -216,7 +216,7 @@ public void BadStreamWrongSegmentCount(int difference) BinaryPrimitives.WriteInt16LittleEndian( new Span(encodedData, V1_0.StreamHeaderSegmentCountOffset, 2), (short)(numSegments + difference)); - Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData)); + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); } @@ -230,7 +230,7 @@ public void BadStreamWrongSegmentNum() BinaryPrimitives.WriteInt16LittleEndian( new Span(encodedData, V1_0.StreamHeaderLength + V1_0.SegmentHeaderNumOffset, 2), 123); - Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(encodedData)); + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(encodedData)); Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); } @@ -248,7 +248,7 @@ public async Task BadStreamWrongContentLength( new Span(encodedData, V1_0.StreamHeaderMessageLengthOffset, 8), encodedData.Length + difference); - Stream decodingStream = new StructuredMessageDecodingStream( + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream( new MemoryStream(encodedData), lengthProvided ? (long?)encodedData.Length : default); @@ -284,14 +284,14 @@ public void BadStreamMissingExpectedStreamFooter() byte[] brokenData = new byte[encodedData.Length - Crc64Length]; new Span(encodedData, 0, encodedData.Length - Crc64Length).CopyTo(brokenData); - Stream decodingStream = new StructuredMessageDecodingStream(new MemoryStream(brokenData)); + (Stream decodingStream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream(brokenData)); Assert.That(async () => await CopyStream(decodingStream, Stream.Null), Throws.InnerException.TypeOf()); } [Test] public void NoSeek() { - StructuredMessageDecodingStream stream = new(new MemoryStream()); + (Stream stream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream()); Assert.That(stream.CanSeek, Is.False); Assert.That(() => stream.Length, Throws.TypeOf()); @@ -303,7 +303,7 @@ public void NoSeek() [Test] public void NoWrite() { - StructuredMessageDecodingStream stream = new(new MemoryStream()); + (Stream stream, _) = StructuredMessageDecodingStream.WrapStream(new MemoryStream()); byte[] data = new byte[1024]; new Random().NextBytes(data); diff --git a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageStreamRoundtripTests.cs b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageStreamRoundtripTests.cs index 633233db2e73c..61583aa1ebe4e 100644 --- a/sdk/storage/Azure.Storage.Common/tests/StructuredMessageStreamRoundtripTests.cs +++ b/sdk/storage/Azure.Storage.Common/tests/StructuredMessageStreamRoundtripTests.cs @@ -113,8 +113,8 @@ public async Task RoundTrip( byte[] roundtripData; using (MemoryStream source = new(originalData)) - using (StructuredMessageEncodingStream encode = new(source, segmentLength, flags)) - using (StructuredMessageDecodingStream decode = new(encode)) + using (Stream encode = new StructuredMessageEncodingStream(source, segmentLength, flags)) + using (Stream decode = StructuredMessageDecodingStream.WrapStream(encode).DecodedStream) using (MemoryStream dest = new()) { await CopyStream(source, dest, readLen);