Skip to content

Commit

Permalink
fix read abort handling and revert CanRead/CanWrite to previous behav…
Browse files Browse the repository at this point in the history
…ior (#55341)

Co-authored-by: Geoffrey Kizer <geoffrek@windows.microsoft.com>
  • Loading branch information
geoffkizer and Geoffrey Kizer authored Jul 8, 2021
1 parent 31c2bed commit d9f1ade
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ internal override async ValueTask<int> ReadAsync(Memory<byte> buffer, Cancellati
int bytesRead = await streamBuffer.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
if (bytesRead == 0)
{
long errorCode = _isInitiator ? _streamState._inboundErrorCode : _streamState._outboundErrorCode;
long errorCode = _isInitiator ? _streamState._inboundReadErrorCode : _streamState._outboundReadErrorCode;
if (errorCode != 0)
{
throw new QuicStreamAbortedException(errorCode);
Expand Down Expand Up @@ -121,6 +121,12 @@ internal override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, bool e
throw new NotSupportedException();
}

long errorCode = _isInitiator ? _streamState._inboundWriteErrorCode : _streamState._outboundWriteErrorCode;
if (errorCode != 0)
{
throw new QuicStreamAbortedException(errorCode);
}

using var registration = cancellationToken.UnsafeRegister(static s =>
{
var stream = (MockStream)s!;
Expand Down Expand Up @@ -171,18 +177,27 @@ internal override Task FlushAsync(CancellationToken cancellationToken)

internal override void AbortRead(long errorCode)
{
throw new NotImplementedException();
if (_isInitiator)
{
_streamState._outboundWriteErrorCode = errorCode;
}
else
{
_streamState._inboundWriteErrorCode = errorCode;
}

ReadStreamBuffer?.AbortRead();
}

internal override void AbortWrite(long errorCode)
{
if (_isInitiator)
{
_streamState._outboundErrorCode = errorCode;
_streamState._outboundReadErrorCode = errorCode;
}
else
{
_streamState._inboundErrorCode = errorCode;
_streamState._inboundReadErrorCode = errorCode;
}

WriteStreamBuffer?.EndWrite();
Expand Down Expand Up @@ -255,8 +270,10 @@ internal sealed class StreamState
public readonly long _streamId;
public StreamBuffer _outboundStreamBuffer;
public StreamBuffer? _inboundStreamBuffer;
public long _outboundErrorCode;
public long _inboundErrorCode;
public long _outboundReadErrorCode;
public long _inboundReadErrorCode;
public long _outboundWriteErrorCode;
public long _inboundWriteErrorCode;

private const int InitialBufferSize =
#if DEBUG
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ internal sealed class MsQuicStream : QuicStreamProvider

private readonly State _state = new State();

private readonly bool _canRead;
private readonly bool _canWrite;

// Backing for StreamId
private long _streamId = -1;

Expand Down Expand Up @@ -80,8 +83,10 @@ public void Cleanup()
internal MsQuicStream(MsQuicConnection.State connectionState, SafeMsQuicStreamHandle streamHandle, QUIC_STREAM_OPEN_FLAGS flags)
{
_state.Handle = streamHandle;
_canRead = true;
_canWrite = !flags.HasFlag(QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL);
_started = true;
if (flags.HasFlag(QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL))
if (!_canWrite)
{
_state.SendState = SendState.Closed;
}
Expand Down Expand Up @@ -122,8 +127,11 @@ internal MsQuicStream(MsQuicConnection.State connectionState, QUIC_STREAM_OPEN_F
{
Debug.Assert(connectionState.Handle != null);

_canRead = !flags.HasFlag(QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL);
_canWrite = true;

_state.StateGCHandle = GCHandle.Alloc(_state);
if (flags.HasFlag(QUIC_STREAM_OPEN_FLAGS.UNIDIRECTIONAL))
if (!_canRead)
{
_state.ReadState = ReadState.Closed;
}
Expand Down Expand Up @@ -167,9 +175,9 @@ internal MsQuicStream(MsQuicConnection.State connectionState, QUIC_STREAM_OPEN_F
}
}

internal override bool CanRead => _disposed == 0 && _state.ReadState < ReadState.Aborted;
internal override bool CanRead => _disposed == 0 && _canRead;

internal override bool CanWrite => _disposed == 0 && _state.SendState < SendState.Aborted;
internal override bool CanWrite => _disposed == 0 && _canWrite;

internal override long StreamId
{
Expand Down Expand Up @@ -242,6 +250,11 @@ private async ValueTask<CancellationTokenRegistration> HandleWriteStartState(Can
}
else if ( _state.SendState == SendState.Aborted)
{
if (_state.SendErrorCode != -1)
{
throw new QuicStreamAbortedException(_state.SendErrorCode);
}

throw new OperationCanceledException(cancellationToken);
}

Expand Down Expand Up @@ -292,6 +305,12 @@ private async ValueTask<CancellationTokenRegistration> HandleWriteStartState(Can
if (_state.SendState == SendState.Aborted)
{
cancellationToken.ThrowIfCancellationRequested();

if (_state.SendErrorCode != -1)
{
throw new QuicStreamAbortedException(_state.SendErrorCode);
}

throw new OperationCanceledException(SR.net_quic_sending_aborted);
}
else if (_state.SendState == SendState.ConnectionClosed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,17 +437,15 @@ await Task.Run(async () =>
}

[Fact]
public async Task StreamAbortedWithoutWriting_ReadThrows()
public async Task WriteAbortedWithoutWriting_ReadThrows()
{
long expectedErrorCode = 1234;
const long expectedErrorCode = 1234;

await RunClientServer(
clientFunction: async connection =>
{
await using QuicStream stream = connection.OpenUnidirectionalStream();
stream.AbortWrite(expectedErrorCode);

await stream.ShutdownCompleted();
},
serverFunction: async connection =>
{
Expand All @@ -458,15 +456,40 @@ await RunClientServer(
QuicStreamAbortedException ex = await Assert.ThrowsAsync<QuicStreamAbortedException>(() => ReadAll(stream, buffer));
Assert.Equal(expectedErrorCode, ex.ErrorCode);

await stream.ShutdownCompleted();
// We should still return true from CanRead, even though the read has been aborted.
Assert.True(stream.CanRead);
}
);
}

[Fact]
public async Task ReadAbortedWithoutReading_WriteThrows()
{
const long expectedErrorCode = 1234;

await RunClientServer(
clientFunction: async connection =>
{
await using QuicStream stream = connection.OpenBidirectionalStream();
stream.AbortRead(expectedErrorCode);
},
serverFunction: async connection =>
{
await using QuicStream stream = await connection.AcceptStreamAsync();

QuicStreamAbortedException ex = await Assert.ThrowsAsync<QuicStreamAbortedException>(() => WriteForever(stream));
Assert.Equal(expectedErrorCode, ex.ErrorCode);

// We should still return true from CanWrite, even though the write has been aborted.
Assert.True(stream.CanWrite);
}
);
}

[Fact]
public async Task WritePreCanceled_Throws()
{
long expectedErrorCode = 1234;
const long expectedErrorCode = 1234;

await RunClientServer(
clientFunction: async connection =>
Expand Down Expand Up @@ -502,7 +525,7 @@ await RunClientServer(
[Fact]
public async Task WriteCanceled_NextWriteThrows()
{
long expectedErrorCode = 1234;
const long expectedErrorCode = 1234;

await RunClientServer(
clientFunction: async connection =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ internal static async Task<int> ReadAll(QuicStream stream, byte[] buffer)
return bytesRead;
}

internal static async Task<int> WriteForever(QuicStream stream)
{
Memory<byte> buffer = new byte[] { 123 };
while (true)
{
await stream.WriteAsync(buffer);
}
}

internal static void AssertArrayEqual(byte[] expected, byte[] actual)
{
for (int i = 0; i < expected.Length; ++i)
Expand Down

0 comments on commit d9f1ade

Please sign in to comment.