Skip to content

Commit

Permalink
Improve performance of UnmanagedMemoryStream (#93766)
Browse files Browse the repository at this point in the history
* Improve performance of UnmanagedMemoryStream

UnmanagedMemoryStream used Interlocked operations to update its state to prevent tearing of 64-bit values on 32-bit platforms. This pattern is expensive in general and it was found to be prohibitively expensive on recent hardware..

This change removes the expensive Interlocked operations and addresses
the tearing issues in alternative way:
- The _length field is converted to nuint that is guaranteed to be
  updated atomically.
- Writes to _length field are volatile to guaranteed the
  unininitialized memory cannot be read.
- The _position field remains long and it has a risk of tearing. It is
  not a problem since tearing of this field cannot lead to buffer
  overruns.

Fixes #93624

* Add comment
  • Loading branch information
jkotas authored Oct 20, 2023
1 parent 8a5959d commit 33ee130
Showing 1 changed file with 63 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,9 @@ namespace System.IO
* of the UnmanagedMemoryStream.
* 3) You clean up the memory when appropriate. The UnmanagedMemoryStream
* currently will do NOTHING to free this memory.
* 4) All calls to Write and WriteByte may not be threadsafe currently.
*
* It may become necessary to add in some sort of
* DeallocationMode enum, specifying whether we unmap a section of memory,
* call free, run a user-provided delegate to free the memory, etc.
* We'll suggest user write a subclass of UnmanagedMemoryStream that uses
* a SafeHandle subclass to hold onto the memory.
*
* 4) This type is not thread safe. However, the implementation should prevent buffer
* overruns or returning uninitialized memory when Reads and Writes are called
* concurrently in thread unsafe manner.
*/

/// <summary>
Expand All @@ -40,10 +35,10 @@ public class UnmanagedMemoryStream : Stream
{
private SafeBuffer? _buffer;
private unsafe byte* _mem;
private long _length;
private long _capacity;
private long _position;
private long _offset;
private nuint _capacity;
private nuint _offset;
private nuint _length; // nuint to guarantee atomic access on 32-bit platforms
private long _position; // long to allow seeking to any location beyond the length of the stream.
private FileAccess _access;
private bool _isOpen;
private CachedCompletedInt32Task _lastReadTask; // The last successful task returned from ReadAsync
Expand Down Expand Up @@ -123,10 +118,10 @@ protected void Initialize(SafeBuffer buffer, long offset, long length, FileAcces
}
}

_offset = offset;
_offset = (nuint)offset;
_buffer = buffer;
_length = length;
_capacity = length;
_length = (nuint)length;
_capacity = (nuint)length;
_access = access;
_isOpen = true;
}
Expand Down Expand Up @@ -171,8 +166,8 @@ protected unsafe void Initialize(byte* pointer, long length, long capacity, File

_mem = pointer;
_offset = 0;
_length = length;
_capacity = capacity;
_length = (nuint)length;
_capacity = (nuint)capacity;
_access = access;
_isOpen = true;
}
Expand Down Expand Up @@ -259,7 +254,7 @@ public override long Length
get
{
EnsureNotClosed();
return Interlocked.Read(ref _length);
return (long)_length;
}
}

Expand All @@ -271,7 +266,7 @@ public long Capacity
get
{
EnsureNotClosed();
return _capacity;
return (long)_capacity;
}
}

Expand All @@ -283,14 +278,14 @@ public override long Position
get
{
if (!CanSeek) ThrowHelper.ThrowObjectDisposedException_StreamClosed(null);
return Interlocked.Read(ref _position);
return _position;
}
set
{
ArgumentOutOfRangeException.ThrowIfNegative(value);
if (!CanSeek) ThrowHelper.ThrowObjectDisposedException_StreamClosed(null);

Interlocked.Exchange(ref _position, value);
_position = value;
}
}

Expand All @@ -308,11 +303,10 @@ public unsafe byte* PositionPointer
EnsureNotClosed();

// Use a temp to avoid a race
long pos = Interlocked.Read(ref _position);
if (pos > _capacity)
long pos = _position;
if (pos > (long)_capacity)
throw new IndexOutOfRangeException(SR.IndexOutOfRange_UMSPosition);
byte* ptr = _mem + pos;
return ptr;
return _mem + pos;
}
set
{
Expand All @@ -327,7 +321,7 @@ public unsafe byte* PositionPointer
if (newPosition < 0)
throw new ArgumentOutOfRangeException(nameof(value), SR.ArgumentOutOfRange_UnmanagedMemStreamLength);

Interlocked.Exchange(ref _position, newPosition);
_position = newPosition;
}
}

Expand Down Expand Up @@ -367,8 +361,13 @@ internal int ReadCore(Span<byte> buffer)

// Use a local variable to avoid a race where another thread
// changes our position after we decide we can read some bytes.
long pos = Interlocked.Read(ref _position);
long len = Interlocked.Read(ref _length);
long pos = _position;

// Use a volatile read to prevent reading of the uninitialized memory. This volatile read
// and matching volatile write that set _length avoids reordering of NativeMemory.Clear
// operations with reading of the buffer below.
long len = (long)Volatile.Read(ref _length);

long n = Math.Min(len - pos, buffer.Length);
if (n <= 0)
{
Expand Down Expand Up @@ -407,7 +406,7 @@ internal int ReadCore(Span<byte> buffer)
}
}

Interlocked.Exchange(ref _position, pos + n);
_position = pos + n;
return nInt;
}

Expand Down Expand Up @@ -484,11 +483,16 @@ public override int ReadByte()
EnsureNotClosed();
EnsureReadable();

long pos = Interlocked.Read(ref _position); // Use a local to avoid a race condition
long len = Interlocked.Read(ref _length);
long pos = _position; // Use a local to avoid a race condition

// Use a volatile read to prevent reading of the uninitialized memory. This volatile read
// and matching volatile write that set _length avoids reordering of NativeMemory.Clear
// operations with reading of the buffer below.
long len = (long)Volatile.Read(ref _length);

if (pos >= len)
return -1;
Interlocked.Exchange(ref _position, pos + 1);
_position = pos + 1;
int result;
if (_buffer != null)
{
Expand Down Expand Up @@ -529,35 +533,33 @@ public override long Seek(long offset, SeekOrigin loc)
{
EnsureNotClosed();

long newPosition;
switch (loc)
{
case SeekOrigin.Begin:
if (offset < 0)
newPosition = offset;
if (newPosition < 0)
throw new IOException(SR.IO_SeekBeforeBegin);
Interlocked.Exchange(ref _position, offset);
break;

case SeekOrigin.Current:
long pos = Interlocked.Read(ref _position);
if (offset + pos < 0)
newPosition = _position + offset;
if (newPosition < 0)
throw new IOException(SR.IO_SeekBeforeBegin);
Interlocked.Exchange(ref _position, offset + pos);
break;

case SeekOrigin.End:
long len = Interlocked.Read(ref _length);
if (len + offset < 0)
newPosition = (long)_length + offset;
if (newPosition < 0)
throw new IOException(SR.IO_SeekBeforeBegin);
Interlocked.Exchange(ref _position, len + offset);
break;

default:
throw new ArgumentException(SR.Argument_InvalidSeekOrigin);
}

long finalPos = Interlocked.Read(ref _position);
Debug.Assert(finalPos >= 0, "_position >= 0");
return finalPos;
_position = newPosition;
return newPosition;
}

/// <summary>
Expand All @@ -573,22 +575,22 @@ public override void SetLength(long value)
EnsureNotClosed();
EnsureWriteable();

if (value > _capacity)
if (value > (long)_capacity)
throw new IOException(SR.IO_FixedCapacity);

long pos = Interlocked.Read(ref _position);
long len = Interlocked.Read(ref _length);
long len = (long)_length;
if (value > len)
{
unsafe
{
NativeMemory.Clear(_mem + len, (nuint)(value - len));
}
}
Interlocked.Exchange(ref _length, value);
if (pos > value)
Volatile.Write(ref _length, (nuint)value); // volatile to prevent reading of uninitialized memory

if (_position > value)
{
Interlocked.Exchange(ref _position, value);
_position = value;
}
}

Expand Down Expand Up @@ -625,16 +627,16 @@ internal unsafe void WriteCore(ReadOnlySpan<byte> buffer)
EnsureNotClosed();
EnsureWriteable();

long pos = Interlocked.Read(ref _position); // Use a local to avoid a race condition
long len = Interlocked.Read(ref _length);
long pos = _position; // Use a local to avoid a race condition
long len = (long)_length;
long n = pos + buffer.Length;
// Check for overflow
if (n < 0)
{
throw new IOException(SR.IO_StreamTooLong);
}

if (n > _capacity)
if (n > (long)_capacity)
{
throw new NotSupportedException(SR.IO_FixedCapacity);
}
Expand All @@ -648,16 +650,16 @@ internal unsafe void WriteCore(ReadOnlySpan<byte> buffer)
NativeMemory.Clear(_mem + len, (nuint)(pos - len));
}

// set length after zeroing memory to avoid race condition of accessing unzeroed memory
// set length after zeroing memory to avoid race condition of accessing uninitialized memory
if (n > len)
{
Interlocked.Exchange(ref _length, n);
Volatile.Write(ref _length, (nuint)n); // volatile to prevent reading of uninitialized memory
}
}

if (_buffer != null)
{
long bytesLeft = _capacity - pos;
long bytesLeft = (long)_capacity - pos;
if (bytesLeft < buffer.Length)
{
throw new ArgumentException(SR.Arg_BufferTooSmall);
Expand All @@ -682,8 +684,7 @@ internal unsafe void WriteCore(ReadOnlySpan<byte> buffer)
Buffer.Memmove(ref *(_mem + pos), ref MemoryMarshal.GetReference(buffer), (nuint)buffer.Length);
}

Interlocked.Exchange(ref _position, n);
return;
_position = n;
}

/// <summary>
Expand Down Expand Up @@ -754,16 +755,16 @@ public override void WriteByte(byte value)
EnsureNotClosed();
EnsureWriteable();

long pos = Interlocked.Read(ref _position); // Use a local to avoid a race condition
long len = Interlocked.Read(ref _length);
long pos = _position; // Use a local to avoid a race condition
long len = (long)_length;
long n = pos + 1;
if (pos >= len)
{
// Check for overflow
if (n < 0)
throw new IOException(SR.IO_StreamTooLong);

if (n > _capacity)
if (n > (long)_capacity)
throw new NotSupportedException(SR.IO_FixedCapacity);

// Check to see whether we are now expanding the stream and must
Expand All @@ -779,8 +780,7 @@ public override void WriteByte(byte value)
}
}

// set length after zeroing memory to avoid race condition of accessing unzeroed memory
Interlocked.Exchange(ref _length, n);
Volatile.Write(ref _length, (nuint)n); // volatile to prevent reading of uninitialized memory
}
}

Expand Down Expand Up @@ -810,7 +810,7 @@ public override void WriteByte(byte value)
_mem[pos] = value;
}
}
Interlocked.Exchange(ref _position, n);
_position = n;
}
}
}

0 comments on commit 33ee130

Please sign in to comment.