Skip to content

Commit

Permalink
Fix Connection FC Handling on Stream Abort (#2066)
Browse files Browse the repository at this point in the history
  • Loading branch information
nibanks authored Oct 8, 2021
1 parent 2cbfc30 commit 86edb13
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 10 deletions.
36 changes: 28 additions & 8 deletions src/core/stream_recv.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ QuicStreamReceiveComplete(
_In_ uint64_t BufferLength
);

_IRQL_requires_max_(PASSIVE_LEVEL)
void
QuicStreamProcessResetFrame(
_In_ QUIC_STREAM* Stream,
_In_ uint64_t FinalSize,
_In_ QUIC_VAR_INT ErrorCode
);

_IRQL_requires_max_(PASSIVE_LEVEL)
void
QuicStreamRecvShutdown(
Expand Down Expand Up @@ -67,21 +75,20 @@ QuicStreamRecvShutdown(
Stream->Flags.ReceiveDataPending = FALSE;
Stream->Flags.ReceiveCallPending = FALSE;

Stream->RecvShutdownErrorCode = ErrorCode;
Stream->Flags.SentStopSending = TRUE;

if (Stream->RecvMaxLength != UINT64_MAX) {
//
// The peer has already gracefully closed, but we just haven't drained
// the receives to that point. Ignore this abort from the app and jump
// right to the closed state.
// the receives to that point. Just treat the shutdown as if it was
// already acknowledged by a reset frame.
//
Stream->Flags.RemoteCloseFin = TRUE;
Stream->Flags.RemoteCloseAcked = TRUE;
QuicStreamProcessResetFrame(Stream, Stream->RecvMaxLength, 0);
Silent = TRUE; // To indicate we try to shutdown complete.
goto Exit;
}

Stream->RecvShutdownErrorCode = ErrorCode;
Stream->Flags.SentStopSending = TRUE;

//
// Queue up a stop sending frame to be sent.
//
Expand Down Expand Up @@ -191,7 +198,6 @@ QuicStreamProcessResetFrame(
// have actually received. Make sure to update our flow control
// accounting so we stay in sync with the peer.
//

uint64_t FlowControlIncrease = FinalSize - TotalRecvLength;
Stream->Connection->Send.OrderedStreamBytesReceived += FlowControlIncrease;
if (Stream->Connection->Send.OrderedStreamBytesReceived < FlowControlIncrease ||
Expand All @@ -209,6 +215,20 @@ QuicStreamProcessResetFrame(
}
}

uint64_t TotalReadLength = Stream->RecvBuffer.BaseOffset;
if (TotalReadLength < FinalSize) {
//
// The final offset is indicating that more data was sent than the
// app has completely read. Make sure to give the peer more credit
// as a result.
//
uint64_t FlowControlIncrease = FinalSize - TotalReadLength;
Stream->Connection->Send.MaxData += FlowControlIncrease;
QuicSendSetSendFlag(
&Stream->Connection->Send,
QUIC_CONN_SEND_FLAG_MAX_DATA);
}

QuicTraceEvent(
StreamRecvState,
"[strm][%p] Recv State: %hhu",
Expand Down
1 change: 1 addition & 0 deletions src/inc/msquic.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ typedef enum QUIC_STREAM_SHUTDOWN_FLAGS {
QUIC_STREAM_SHUTDOWN_FLAG_ABORT = 0x0006, // Abruptly closes both send and receive paths.
QUIC_STREAM_SHUTDOWN_FLAG_IMMEDIATE = 0x0008, // Immediately sends completion events to app.
QUIC_STREAM_SHUTDOWN_FLAG_INLINE = 0x0010, // Process the shutdown immediately inline. Only for calls on callbacks.
// WARNING: Can cause reentrant callbacks!
} QUIC_STREAM_SHUTDOWN_FLAGS;

DEFINE_ENUM_FLAG_OPERATORS(QUIC_STREAM_SHUTDOWN_FLAGS)
Expand Down
1 change: 1 addition & 0 deletions src/inc/msquic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ class MsQuicSettings : public QUIC_SETTINGS {
MsQuicSettings& SetMtuDiscoverySearchCompleteTimeoutUs(uint64_t Time) { MtuDiscoverySearchCompleteTimeoutUs = Time; IsSet.MtuDiscoverySearchCompleteTimeoutUs = TRUE; return *this; }
MsQuicSettings& SetMtuDiscoveryMissingProbeCount(uint8_t Count) { MtuDiscoveryMissingProbeCount = Count; IsSet.MtuDiscoveryMissingProbeCount = TRUE; return *this; }
MsQuicSettings& SetKeepAlive(uint32_t Time) { KeepAliveIntervalMs = Time; IsSet.KeepAliveIntervalMs = TRUE; return *this; }
MsQuicSettings& SetConnFlowControlWindow(uint32_t Window) { ConnFlowControlWindow = Window; IsSet.ConnFlowControlWindow = TRUE; return *this; }

QUIC_STATUS
SetGlobal() const noexcept {
Expand Down
9 changes: 8 additions & 1 deletion src/test/MsQuicTests.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,10 @@ void
QuicTestStreamAbortRecvFinRace(
);

void
QuicTestStreamAbortConnFlowControl(
);

//
// QuicDrill tests
//
Expand Down Expand Up @@ -940,4 +944,7 @@ typedef struct {
#define IOCTL_QUIC_RUN_STREAM_ABORT_RECV_FIN_RACE \
QUIC_CTL_CODE(78, METHOD_BUFFERED, FILE_WRITE_DATA)

#define QUIC_MAX_IOCTL_FUNC_CODE 78
#define IOCTL_QUIC_RUN_STREAM_ABORT_CONN_FLOW_CONTROL \
QUIC_CTL_CODE(79, METHOD_BUFFERED, FILE_WRITE_DATA)

#define QUIC_MAX_IOCTL_FUNC_CODE 79
2 changes: 1 addition & 1 deletion src/test/bin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ target_link_libraries(msquictest inc gtest logging base_link)

# At least /W3 must be used on all windows builds to pass compliance
if(MSVC)
target_compile_options(msquictest PRIVATE /W3)
target_compile_options(msquictest PRIVATE /W3 /bigobj)
endif()

add_test(NAME msquictest
Expand Down
9 changes: 9 additions & 0 deletions src/test/bin/quic_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1496,6 +1496,15 @@ TEST(Misc, StreamAbortRecvFinRace) {
}
}

TEST(Misc, StreamAbortConnFlowControl) {
TestLogger Logger("StreamAbortConnFlowControl");
if (TestingKernelMode) {
ASSERT_TRUE(DriverClient.Run(IOCTL_QUIC_RUN_STREAM_ABORT_CONN_FLOW_CONTROL));
} else {
QuicTestStreamAbortConnFlowControl();
}
}

TEST(Drill, VarIntEncoder) {
TestLogger Logger("QuicDrillTestVarIntEncoder");
if (TestingKernelMode) {
Expand Down
5 changes: 5 additions & 0 deletions src/test/bin/winkernel/control.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ size_t QUIC_IOCTL_BUFFER_SIZES[] =
sizeof(INT32),
0,
0,
0,
};

CXPLAT_STATIC_ASSERT(
Expand Down Expand Up @@ -1120,6 +1121,10 @@ QuicTestCtlEvtIoDeviceControl(
QuicTestCtlRun(QuicTestStreamAbortRecvFinRace());
break;

case IOCTL_QUIC_RUN_STREAM_ABORT_CONN_FLOW_CONTROL:
QuicTestCtlRun(QuicTestStreamAbortConnFlowControl());
break;

default:
Status = STATUS_NOT_IMPLEMENTED;
break;
Expand Down
70 changes: 70 additions & 0 deletions src/test/lib/DataTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2544,3 +2544,73 @@ QuicTestStreamAbortRecvFinRace(

TEST_TRUE(Context.ClientStreamShutdownComplete.WaitTimeout(TestWaitTimeout));
}

struct StreamAbortConnFlowControl {
CxPlatEvent ClientStreamShutdownComplete;
uint32_t StreamCount {0};

static QUIC_STATUS ClientStreamCallback(_In_ MsQuicStream*, _In_opt_ void* Context, _Inout_ QUIC_STREAM_EVENT* Event) {
auto TestContext = (StreamAbortConnFlowControl*)Context;
if (Event->Type == QUIC_STREAM_EVENT_SHUTDOWN_COMPLETE) {
TestContext->ClientStreamShutdownComplete.Set();
}
return QUIC_STATUS_SUCCESS;
}

static QUIC_STATUS ServerStreamCallback(_In_ MsQuicStream* Stream, _In_opt_ void*, _Inout_ QUIC_STREAM_EVENT* Event) {
if (Event->Type == QUIC_STREAM_EVENT_RECEIVE) {
Event->RECEIVE.TotalBufferLength = 0;
Stream->Shutdown(0, QUIC_STREAM_SHUTDOWN_FLAG_ABORT_RECEIVE);
}
return QUIC_STATUS_SUCCESS;
}

static QUIC_STATUS ConnCallback(_In_ MsQuicConnection*, _In_opt_ void* Context, _Inout_ QUIC_CONNECTION_EVENT* Event) {
auto TestContext = (StreamAbortConnFlowControl*)Context;
if (Event->Type == QUIC_CONNECTION_EVENT_PEER_STREAM_STARTED) {
new(std::nothrow) MsQuicStream(Event->PEER_STREAM_STARTED.Stream, CleanUpAutoDelete, TestContext->StreamCount++ == 0 ? ServerStreamCallback : MsQuicStream::NoOpCallback, Context);
}
return QUIC_STATUS_SUCCESS;
}
};

void
QuicTestStreamAbortConnFlowControl(
)
{
MsQuicRegistration Registration(true);
TEST_QUIC_SUCCEEDED(Registration.GetInitStatus());

MsQuicConfiguration ServerConfiguration(Registration, "MsQuicTest", MsQuicSettings().SetPeerUnidiStreamCount(1).SetConnFlowControlWindow(100), ServerSelfSignedCredConfig);
TEST_QUIC_SUCCEEDED(ServerConfiguration.GetInitStatus());

MsQuicConfiguration ClientConfiguration(Registration, "MsQuicTest", MsQuicCredentialConfig());
TEST_QUIC_SUCCEEDED(ClientConfiguration.GetInitStatus());

StreamAbortConnFlowControl Context;
MsQuicAutoAcceptListener Listener(Registration, ServerConfiguration, StreamAbortConnFlowControl::ConnCallback, &Context);
TEST_QUIC_SUCCEEDED(Listener.GetInitStatus());
TEST_QUIC_SUCCEEDED(Listener.Start("MsQuicTest"));
QuicAddr ServerLocalAddr;
TEST_QUIC_SUCCEEDED(Listener.GetLocalAddr(ServerLocalAddr));

MsQuicConnection Connection(Registration);
TEST_QUIC_SUCCEEDED(Connection.GetInitStatus());

uint8_t RawBuffer[100];
QUIC_BUFFER Buffer { sizeof(RawBuffer), RawBuffer };

MsQuicStream Stream1(Connection, QUIC_STREAM_OPEN_FLAG_UNIDIRECTIONAL);
TEST_QUIC_SUCCEEDED(Stream1.GetInitStatus());
TEST_QUIC_SUCCEEDED(Stream1.Send(&Buffer, 1, QUIC_SEND_FLAG_START | QUIC_SEND_FLAG_FIN));

MsQuicStream Stream2(Connection, QUIC_STREAM_OPEN_FLAG_UNIDIRECTIONAL, CleanUpManual, StreamAbortConnFlowControl::ClientStreamCallback, &Context);
TEST_QUIC_SUCCEEDED(Stream2.GetInitStatus());
TEST_QUIC_SUCCEEDED(Stream2.Send(&Buffer, 1, QUIC_SEND_FLAG_START | QUIC_SEND_FLAG_FIN));

TEST_QUIC_SUCCEEDED(Connection.StartLocalhost(ClientConfiguration, ServerLocalAddr));
TEST_TRUE(Connection.HandshakeCompleteEvent.WaitTimeout(TestWaitTimeout));
TEST_TRUE(Connection.HandshakeComplete);

TEST_TRUE(Context.ClientStreamShutdownComplete.WaitTimeout(TestWaitTimeout));
}

0 comments on commit 86edb13

Please sign in to comment.