Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure FileStream.Position is correct after a failed|cancelled WriteAsync attempt #56716

Merged
merged 7 commits into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/libraries/System.IO.FileSystem/tests/FileStream/ReadAsync.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,12 @@ public async Task ReadAsyncBufferedCompletesSynchronously()
}
}

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
public async Task ReadAsyncCanceledFile()
[ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
[InlineData(0, true)] // 0 == no buffering
[InlineData(4096, true)] // 4096 == default buffer size
[InlineData(0, false)]
[InlineData(4096, false)]
public async Task ReadAsyncCanceledFile(int bufferSize, bool isAsync)
{
string fileName = GetTestFilePath();
using (FileStream fs = new FileStream(fileName, FileMode.Create))
Expand All @@ -75,7 +79,7 @@ public async Task ReadAsyncCanceledFile()
fs.Write(TestBuffer, 0, TestBuffer.Length);
}

using (FileStream fs = new FileStream(fileName, FileMode.Open))
using (FileStream fs = new FileStream(fileName, FileMode.Open, FileAccess.Read, FileShare.None, bufferSize, isAsync))
{
byte[] buffer = new byte[fs.Length];
CancellationTokenSource cts = new CancellationTokenSource();
Expand All @@ -91,6 +95,8 @@ public async Task ReadAsyncCanceledFile()
// Ideally we'd be doing an Assert.Throws<OperationCanceledException>
// but since cancellation is a race condition we accept either outcome
Assert.Equal(cts.Token, oce.CancellationToken);

Assert.Equal(0, fs.Position); // if read was cancelled, the Position should remain unchanged
}
}
}
Expand Down
13 changes: 10 additions & 3 deletions src/libraries/System.IO.FileSystem/tests/FileStream/WriteAsync.cs
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,15 @@ public async Task SimpleWriteAsync()
}
}

[Fact]
public async Task WriteAsyncCancelledFile()
[Theory]
[InlineData(0, true)] // 0 == no buffering
[InlineData(4096, true)] // 4096 == default buffer size
[InlineData(0, false)]
[InlineData(4096, false)]
public async Task WriteAsyncCancelledFile(int bufferSize, bool isAsync)
{
const int writeSize = 1024 * 1024;
using (FileStream fs = new FileStream(GetTestFilePath(), FileMode.Create))
using (FileStream fs = new FileStream(GetTestFilePath(), FileMode.CreateNew, FileAccess.Write, FileShare.None, bufferSize, isAsync))
{
byte[] buffer = new byte[writeSize];
CancellationTokenSource cts = new CancellationTokenSource();
Expand All @@ -119,6 +123,9 @@ public async Task WriteAsyncCancelledFile()
// Ideally we'd be doing an Assert.Throws<OperationCanceledException>
// but since cancellation is a race condition we accept either outcome
Assert.Equal(cts.Token, oce.CancellationToken);

Assert.Equal(0, fs.Length); // if write was cancelled, the file should be empty
Assert.Equal(0, fs.Position); // if write was cancelled, the Position should remain unchanged
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ internal sealed unsafe class OverlappedValueTaskSource : IValueTaskSource<int>,

internal readonly PreAllocatedOverlapped _preallocatedOverlapped;
internal readonly SafeFileHandle _fileHandle;
private AsyncWindowsFileStreamStrategy? _strategy;
private OSFileStreamStrategy? _strategy;
internal MemoryHandle _memoryHandle;
private int _bufferSize;
internal ManualResetValueTaskSourceCore<int> _source; // mutable struct; do not make this readonly
Expand Down Expand Up @@ -77,8 +77,10 @@ internal static Exception GetIOError(int errorCode, string? path)
? ThrowHelper.CreateEndOfFileException()
: Win32Marshal.GetExceptionForWin32Error(errorCode, path);

internal NativeOverlapped* PrepareForOperation(ReadOnlyMemory<byte> memory, long fileOffset, AsyncWindowsFileStreamStrategy? strategy = null)
internal NativeOverlapped* PrepareForOperation(ReadOnlyMemory<byte> memory, long fileOffset, OSFileStreamStrategy? strategy = null)
{
Debug.Assert(strategy is null || strategy is AsyncWindowsFileStreamStrategy, $"Strategy was expected to be null or async, got {strategy}.");

_result = 0;
_strategy = strategy;
_bufferSize = memory.Length;
Expand Down Expand Up @@ -195,12 +197,12 @@ internal void Complete(uint errorCode, uint numBytes)
{
Debug.Assert(errorCode == Interop.Errors.ERROR_SUCCESS || numBytes == 0, $"Callback returned {errorCode} error and {numBytes} bytes");

AsyncWindowsFileStreamStrategy? strategy = _strategy;
OSFileStreamStrategy? strategy = _strategy;
ReleaseResources();

if (strategy is not null && _bufferSize != numBytes) // true only for incomplete reads
if (strategy is not null && _bufferSize != numBytes) // true only for incomplete operations
{
strategy.OnIncompleteRead(_bufferSize, (int)numBytes);
strategy.OnIncompleteOperation(_bufferSize, (int)numBytes);
}

switch (errorCode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.IO.Strategies;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -32,6 +33,7 @@ internal sealed class ThreadPoolValueTaskSource : IThreadPoolWorkItem, IValueTas
private ManualResetValueTaskSourceCore<long> _source;
private Operation _operation = Operation.None;
private ExecutionContext? _context;
private OSFileStreamStrategy? _strategy;
adamsitnik marked this conversation as resolved.
Show resolved Hide resolved

// These fields store the parameters for the operation.
// The first two are common for all kinds of operations.
Expand Down Expand Up @@ -116,8 +118,22 @@ private void ExecuteInternal()
}
finally
{
if (_strategy is not null)
{
// WriteAtOffset returns void, so we need to fix position only in case of an exception
if (exception is not null)
{
_strategy.OnIncompleteOperation(_singleSegment.Length, 0);
}
else if (_operation == Operation.Read && result != _singleSegment.Length)
{
_strategy.OnIncompleteOperation(_singleSegment.Length, (int)result);
}
}

_operation = Operation.None;
_context = null;
_strategy = null;
_cancellationToken = default;
_singleSegment = default;
_readScatterBuffers = null;
Expand Down Expand Up @@ -152,27 +168,29 @@ private void QueueToThreadPool()
ThreadPool.UnsafeQueueUserWorkItem(this, preferLocal: true);
}

public ValueTask<int> QueueRead(Memory<byte> buffer, long fileOffset, CancellationToken cancellationToken)
public ValueTask<int> QueueRead(Memory<byte> buffer, long fileOffset, CancellationToken cancellationToken, OSFileStreamStrategy? strategy)
{
ValidateInvariants();

_operation = Operation.Read;
_singleSegment = buffer;
_fileOffset = fileOffset;
_cancellationToken = cancellationToken;
_strategy = strategy;
QueueToThreadPool();

return new ValueTask<int>(this, _source.Version);
}

public ValueTask QueueWrite(ReadOnlyMemory<byte> buffer, long fileOffset, CancellationToken cancellationToken)
public ValueTask QueueWrite(ReadOnlyMemory<byte> buffer, long fileOffset, CancellationToken cancellationToken, OSFileStreamStrategy? strategy)
{
ValidateInvariants();

_operation = Operation.Write;
_singleSegment = buffer;
_fileOffset = fileOffset;
_cancellationToken = cancellationToken;
_strategy = strategy;
QueueToThreadPool();

return new ValueTask(this, _source.Version);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ internal static unsafe long ReadScatterAtOffset(SafeFileHandle handle, IReadOnly
return FileStreamHelpers.CheckFileCall(result, handle.Path);
}

internal static ValueTask<int> ReadAtOffsetAsync(SafeFileHandle handle, Memory<byte> buffer, long fileOffset, CancellationToken cancellationToken)
=> ScheduleSyncReadAtOffsetAsync(handle, buffer, fileOffset, cancellationToken);
internal static ValueTask<int> ReadAtOffsetAsync(SafeFileHandle handle, Memory<byte> buffer, long fileOffset, CancellationToken cancellationToken, OSFileStreamStrategy? strategy = null)
=> ScheduleSyncReadAtOffsetAsync(handle, buffer, fileOffset, cancellationToken, strategy);

private static ValueTask<long> ReadScatterAtOffsetAsync(SafeFileHandle handle, IReadOnlyList<Memory<byte>> buffers,
long fileOffset, CancellationToken cancellationToken)
Expand Down Expand Up @@ -202,8 +202,8 @@ internal static unsafe void WriteGatherAtOffset(SafeFileHandle handle, IReadOnly
}
}

internal static ValueTask WriteAtOffsetAsync(SafeFileHandle handle, ReadOnlyMemory<byte> buffer, long fileOffset, CancellationToken cancellationToken)
=> ScheduleSyncWriteAtOffsetAsync(handle, buffer, fileOffset, cancellationToken);
internal static ValueTask WriteAtOffsetAsync(SafeFileHandle handle, ReadOnlyMemory<byte> buffer, long fileOffset, CancellationToken cancellationToken, OSFileStreamStrategy? strategy = null)
=> ScheduleSyncWriteAtOffsetAsync(handle, buffer, fileOffset, cancellationToken, strategy);

private static ValueTask WriteGatherAtOffsetAsync(SafeFileHandle handle, IReadOnlyList<ReadOnlyMemory<byte>> buffers,
long fileOffset, CancellationToken cancellationToken)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,12 @@ private static unsafe void WriteSyncUsingAsyncHandle(SafeFileHandle handle, Read
}
}

internal static ValueTask<int> ReadAtOffsetAsync(SafeFileHandle handle, Memory<byte> buffer, long fileOffset, CancellationToken cancellationToken)
internal static ValueTask<int> ReadAtOffsetAsync(SafeFileHandle handle, Memory<byte> buffer, long fileOffset,
CancellationToken cancellationToken, OSFileStreamStrategy? strategy = null)
{
if (handle.IsAsync)
{
(SafeFileHandle.OverlappedValueTaskSource? vts, int errorCode) = QueueAsyncReadFile(handle, buffer, fileOffset, cancellationToken);
(SafeFileHandle.OverlappedValueTaskSource? vts, int errorCode) = QueueAsyncReadFile(handle, buffer, fileOffset, cancellationToken, strategy);

if (vts is not null)
{
Expand All @@ -236,19 +237,19 @@ internal static ValueTask<int> ReadAtOffsetAsync(SafeFileHandle handle, Memory<b
return ValueTask.FromResult(0);
}

return ValueTask.FromException<int>(Win32Marshal.GetExceptionForWin32Error(errorCode));
return ValueTask.FromException<int>(Win32Marshal.GetExceptionForWin32Error(errorCode, handle.Path));
}

return ScheduleSyncReadAtOffsetAsync(handle, buffer, fileOffset, cancellationToken);
return ScheduleSyncReadAtOffsetAsync(handle, buffer, fileOffset, cancellationToken, strategy);
}

internal static unsafe (SafeFileHandle.OverlappedValueTaskSource? vts, int errorCode) QueueAsyncReadFile(SafeFileHandle handle, Memory<byte> buffer, long fileOffset,
CancellationToken cancellationToken, AsyncWindowsFileStreamStrategy? strategy = null)
private static unsafe (SafeFileHandle.OverlappedValueTaskSource? vts, int errorCode) QueueAsyncReadFile(SafeFileHandle handle, Memory<byte> buffer, long fileOffset,
CancellationToken cancellationToken, OSFileStreamStrategy? strategy)
{
handle.EnsureThreadPoolBindingInitialized();

SafeFileHandle.OverlappedValueTaskSource vts = handle.GetOverlappedValueTaskSource();
int errorCode = 0;
int errorCode = Interop.Errors.ERROR_SUCCESS;
stephentoub marked this conversation as resolved.
Show resolved Hide resolved
try
{
NativeOverlapped* nativeOverlapped = vts.PrepareForOperation(buffer, fileOffset, strategy);
Expand Down Expand Up @@ -292,7 +293,7 @@ internal static unsafe (SafeFileHandle.OverlappedValueTaskSource? vts, int error
{
if (errorCode != Interop.Errors.ERROR_IO_PENDING && errorCode != Interop.Errors.ERROR_SUCCESS)
{
strategy?.OnIncompleteRead(buffer.Length, 0);
strategy?.OnIncompleteOperation(buffer.Length, 0);
}
}

Expand All @@ -301,11 +302,12 @@ internal static unsafe (SafeFileHandle.OverlappedValueTaskSource? vts, int error
return (vts, -1);
}

internal static ValueTask WriteAtOffsetAsync(SafeFileHandle handle, ReadOnlyMemory<byte> buffer, long fileOffset, CancellationToken cancellationToken)
internal static ValueTask WriteAtOffsetAsync(SafeFileHandle handle, ReadOnlyMemory<byte> buffer, long fileOffset,
CancellationToken cancellationToken, OSFileStreamStrategy? strategy = null)
{
if (handle.IsAsync)
{
(SafeFileHandle.OverlappedValueTaskSource? vts, int errorCode) = QueueAsyncWriteFile(handle, buffer, fileOffset, cancellationToken);
(SafeFileHandle.OverlappedValueTaskSource? vts, int errorCode) = QueueAsyncWriteFile(handle, buffer, fileOffset, cancellationToken, strategy);

if (vts is not null)
{
Expand All @@ -317,27 +319,29 @@ internal static ValueTask WriteAtOffsetAsync(SafeFileHandle handle, ReadOnlyMemo
return ValueTask.CompletedTask;
}

return ValueTask.FromException(Win32Marshal.GetExceptionForWin32Error(errorCode));
return ValueTask.FromException(Win32Marshal.GetExceptionForWin32Error(errorCode, handle.Path));
}

return ScheduleSyncWriteAtOffsetAsync(handle, buffer, fileOffset, cancellationToken);
return ScheduleSyncWriteAtOffsetAsync(handle, buffer, fileOffset, cancellationToken, strategy);
}

internal static unsafe (SafeFileHandle.OverlappedValueTaskSource? vts, int errorCode) QueueAsyncWriteFile(SafeFileHandle handle, ReadOnlyMemory<byte> buffer, long fileOffset, CancellationToken cancellationToken)
private static unsafe (SafeFileHandle.OverlappedValueTaskSource? vts, int errorCode) QueueAsyncWriteFile(SafeFileHandle handle, ReadOnlyMemory<byte> buffer, long fileOffset,
CancellationToken cancellationToken, OSFileStreamStrategy? strategy)
{
handle.EnsureThreadPoolBindingInitialized();

SafeFileHandle.OverlappedValueTaskSource vts = handle.GetOverlappedValueTaskSource();
int errorCode = Interop.Errors.ERROR_SUCCESS;
try
{
NativeOverlapped* nativeOverlapped = vts.PrepareForOperation(buffer, fileOffset);
NativeOverlapped* nativeOverlapped = vts.PrepareForOperation(buffer, fileOffset, strategy);
Debug.Assert(vts._memoryHandle.Pointer != null);

// Queue an async WriteFile operation.
if (Interop.Kernel32.WriteFile(handle, (byte*)vts._memoryHandle.Pointer, buffer.Length, IntPtr.Zero, nativeOverlapped) == 0)
{
// The operation failed, or it's pending.
int errorCode = FileStreamHelpers.GetLastWin32ErrorAndDisposeHandleIfInvalid(handle);
errorCode = FileStreamHelpers.GetLastWin32ErrorAndDisposeHandleIfInvalid(handle);
switch (errorCode)
{
case Interop.Errors.ERROR_IO_PENDING:
Expand All @@ -360,6 +364,13 @@ internal static unsafe (SafeFileHandle.OverlappedValueTaskSource? vts, int error
vts.Dispose();
throw;
}
finally
{
if (errorCode != Interop.Errors.ERROR_IO_PENDING && errorCode != Interop.Errors.ERROR_SUCCESS)
{
strategy?.OnIncompleteOperation(buffer.Length, 0);
}
}

// Completion handled by callback.
vts.FinishedScheduling();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.IO.Strategies;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
Expand Down Expand Up @@ -264,9 +265,9 @@ private static void ValidateBuffers<T>(IReadOnlyList<T> buffers)
}

private static ValueTask<int> ScheduleSyncReadAtOffsetAsync(SafeFileHandle handle, Memory<byte> buffer,
long fileOffset, CancellationToken cancellationToken)
long fileOffset, CancellationToken cancellationToken, OSFileStreamStrategy? strategy)
{
return handle.GetThreadPoolValueTaskSource().QueueRead(buffer, fileOffset, cancellationToken);
return handle.GetThreadPoolValueTaskSource().QueueRead(buffer, fileOffset, cancellationToken, strategy);
}

private static ValueTask<long> ScheduleSyncReadScatterAtOffsetAsync(SafeFileHandle handle, IReadOnlyList<Memory<byte>> buffers,
Expand All @@ -276,9 +277,9 @@ private static ValueTask<long> ScheduleSyncReadScatterAtOffsetAsync(SafeFileHand
}

private static ValueTask ScheduleSyncWriteAtOffsetAsync(SafeFileHandle handle, ReadOnlyMemory<byte> buffer,
long fileOffset, CancellationToken cancellationToken)
long fileOffset, CancellationToken cancellationToken, OSFileStreamStrategy? strategy)
{
return handle.GetThreadPoolValueTaskSource().QueueWrite(buffer, fileOffset, cancellationToken);
return handle.GetThreadPoolValueTaskSource().QueueWrite(buffer, fileOffset, cancellationToken, strategy);
}

private static ValueTask ScheduleSyncWriteGatherAtOffsetAsync(SafeFileHandle handle, IReadOnlyList<ReadOnlyMemory<byte>> buffers,
Expand Down
Loading