Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Memory Overrides to Streams #47125

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions src/libraries/Common/src/System/IO/ChunkedMemoryStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using System;

namespace System.IO
{
Expand Down Expand Up @@ -35,24 +36,28 @@ public byte[] ToArray()

public override void Write(byte[] buffer, int offset, int count)
{
while (count > 0)
Write(new ReadOnlySpan<byte>(buffer, offset, count));
}

public override void Write(ReadOnlySpan<byte> buffer)
{
while (!buffer.IsEmpty)
{
if (_currentChunk != null)
{
int remaining = _currentChunk._buffer.Length - _currentChunk._freeOffset;
if (remaining > 0)
{
int toCopy = Math.Min(remaining, count);
Buffer.BlockCopy(buffer, offset, _currentChunk._buffer, _currentChunk._freeOffset, toCopy);
count -= toCopy;
offset += toCopy;
int toCopy = Math.Min(remaining, buffer.Length);
buffer.Slice(0, toCopy).CopyTo(new Span<byte>(_currentChunk._buffer, _currentChunk._freeOffset, toCopy));
buffer = buffer.Slice(toCopy);
_totalLength += toCopy;
_currentChunk._freeOffset += toCopy;
continue;
}
}

AppendChunk(count);
AppendChunk(buffer.Length);
}
}

Expand All @@ -67,6 +72,17 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati
return Task.CompletedTask;
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
if (cancellationToken.IsCancellationRequested)
{
return ValueTask.FromCanceled(cancellationToken);
}

Write(buffer.Span);
return ValueTask.CompletedTask;
}

private void AppendChunk(long count)
{
int nextChunkLength = _currentChunk != null ? _currentChunk._buffer.Length * 2 : InitialChunkDefaultSize;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -907,25 +907,38 @@ public void CopyFromSourceToDestination()
}
}

public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
// Validate inputs
Debug.Assert(buffer != _arrayPoolBuffer);
_deflateStream.EnsureNotDisposed();
if (count <= 0)
{
return;
return Task.CompletedTask;
}
else if (count > buffer.Length - offset)
{
// The buffer stream is either malicious or poorly implemented and returned a number of
// bytes larger than the buffer supplied to it.
throw new InvalidDataException(SR.GenericInvalidData);
return Task.FromException(new InvalidDataException(SR.GenericInvalidData));
}

Debug.Assert(_deflateStream._inflater != null);
// Feed the data from base stream into the decompression engine.
_deflateStream._inflater.SetInput(buffer, offset, count);
return WriteAsyncCore(buffer.AsMemory(offset, count), cancellationToken).AsTask();
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
_deflateStream.EnsureNotDisposed();

return WriteAsyncCore(buffer, cancellationToken);
}

private async ValueTask WriteAsyncCore(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken)
{
Debug.Assert(_deflateStream._inflater is not null);

// Feed the data from base stream into decompression engine.
_deflateStream._inflater.SetInput(buffer);

// While there's more decompressed data available, forward it to the buffer stream.
while (!_deflateStream._inflater.Finished())
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.InteropServices;
Expand All @@ -20,7 +21,7 @@ internal sealed class Inflater : IDisposable
private bool _isDisposed; // Prevents multiple disposals
private readonly int _windowBits; // The WindowBits parameter passed to Inflater construction
private ZLibNative.ZLibStreamHandle _zlibStream; // The handle to the primary underlying zlib stream
private GCHandle _inputBufferHandle; // The handle to the buffer that provides input to _zlibStream
private MemoryHandle _inputBufferHandle; // The handle to the buffer that provides input to _zlibStream
private readonly long _uncompressedSize;
private long _currentInflatedCount;

Expand Down Expand Up @@ -110,7 +111,7 @@ public unsafe int InflateVerified(byte* bufPtr, int length)
finally
{
// Before returning, make sure to release input buffer if necessary:
if (0 == _zlibStream.AvailIn && _inputBufferHandle.IsAllocated)
if (0 == _zlibStream.AvailIn && IsInputBufferHandleAllocated)
{
DeallocateInputBufferHandle();
}
Expand All @@ -121,7 +122,7 @@ private unsafe void ReadOutput(byte* bufPtr, int length, out int bytesRead)
{
if (ReadInflateOutput(bufPtr, length, ZLibNative.FlushCode.NoFlush, out bytesRead) == ZLibNative.ErrorCode.StreamEnd)
{
if (!NeedsInput() && IsGzipStream() && _inputBufferHandle.IsAllocated)
if (!NeedsInput() && IsGzipStream() && IsInputBufferHandleAllocated)
{
_finished = ResetStreamForLeftoverInput();
}
Expand All @@ -142,7 +143,7 @@ private unsafe bool ResetStreamForLeftoverInput()
{
Debug.Assert(!NeedsInput());
Debug.Assert(IsGzipStream());
Debug.Assert(_inputBufferHandle.IsAllocated);
Debug.Assert(IsInputBufferHandleAllocated);

lock (SyncLock)
{
Expand Down Expand Up @@ -180,16 +181,24 @@ public void SetInput(byte[] inputBuffer, int startIndex, int count)
Debug.Assert(NeedsInput(), "We have something left in previous input!");
Debug.Assert(inputBuffer != null);
Debug.Assert(startIndex >= 0 && count >= 0 && count + startIndex <= inputBuffer.Length);
Debug.Assert(!_inputBufferHandle.IsAllocated);
Debug.Assert(!IsInputBufferHandleAllocated);

if (0 == count)
SetInput(inputBuffer.AsMemory(startIndex, count));
}

public unsafe void SetInput(ReadOnlyMemory<byte> inputBuffer)
{
Debug.Assert(NeedsInput(), "We have something left in previous input!");
Debug.Assert(!IsInputBufferHandleAllocated);

if (inputBuffer.IsEmpty)
return;

lock (SyncLock)
{
_inputBufferHandle = GCHandle.Alloc(inputBuffer, GCHandleType.Pinned);
_zlibStream.NextIn = _inputBufferHandle.AddrOfPinnedObject() + startIndex;
_zlibStream.AvailIn = (uint)count;
_inputBufferHandle = inputBuffer.Pin();
_zlibStream.NextIn = (IntPtr)_inputBufferHandle.Pointer;
_zlibStream.AvailIn = (uint)inputBuffer.Length;
_finished = false;
}
}
Expand All @@ -201,7 +210,7 @@ private void Dispose(bool disposing)
if (disposing)
_zlibStream.Dispose();

if (_inputBufferHandle.IsAllocated)
if (IsInputBufferHandleAllocated)
DeallocateInputBufferHandle();

_isDisposed = true;
Expand Down Expand Up @@ -313,14 +322,16 @@ private ZLibNative.ErrorCode Inflate(ZLibNative.FlushCode flushCode)
/// </summary>
private void DeallocateInputBufferHandle()
{
Debug.Assert(_inputBufferHandle.IsAllocated);
Debug.Assert(IsInputBufferHandleAllocated);

lock (SyncLock)
{
_zlibStream.AvailIn = 0;
_zlibStream.NextIn = ZLibNative.ZNullPtr;
_inputBufferHandle.Free();
_inputBufferHandle.Dispose();
}
}

private unsafe bool IsInputBufferHandleAllocated => _inputBufferHandle.Pointer != default;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ public override Task<int> ReadAsync(byte[] buffer, int offset, int count, Cancel
return _networkStream.ReadAsync(buffer, offset, count, cancellationToken);
}

public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
return _networkStream.ReadAsync(buffer, cancellationToken);
}

public override IAsyncResult BeginWrite(byte[] buffer, int offset, int size, AsyncCallback? callback, object? state)
{
return _networkStream.BeginWrite(buffer, offset, size, callback, state);
Expand All @@ -204,6 +209,11 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati
return _networkStream.WriteAsync(buffer, offset, count, cancellationToken);
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
return _networkStream.WriteAsync(buffer, cancellationToken);
}

public override void Flush()
{
_networkStream.Flush();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ public override Task WriteAsync(byte[] buffer, int offset, int count, Cancellati
return _buffer.WriteAsync(buffer, offset, count, cancellationToken);
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
return _buffer.WriteAsync(buffer, cancellationToken);
}

public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback? asyncCallback, object? asyncState)
{
ValidateBufferArguments(buffer, offset, count);
Expand Down