From a2c6fbe15585c8c9961760e4a0935d29ffae601c Mon Sep 17 00:00:00 2001 From: Mathew Date: Fri, 7 Jun 2024 00:22:52 +1000 Subject: [PATCH] Fix infinite loop on blocked stream --- src/quic/session.cc | 61 +++++++++++++++++++++++++++++++++++---------- src/quic/stream.cc | 10 ++++++++ src/quic/stream.h | 4 +++ 3 files changed, 62 insertions(+), 13 deletions(-) diff --git a/src/quic/session.cc b/src/quic/session.cc index 3e5978a65aa360..92a3a5789123f1 100644 --- a/src/quic/session.cc +++ b/src/quic/session.cc @@ -2209,8 +2209,11 @@ void Session::StreamDataBlocked(stream_id id) { IncrementStat(&SessionStats::block_count); BaseObjectPtr stream = FindStream(id); - if (stream) + if (stream) { stream->OnBlocked(); + } else { + Debug(this, "Stream %" PRId64 " not found to block", id); + } } void Session::IncrementConnectionCloseAttempts() { @@ -3268,26 +3271,34 @@ bool Session::Application::SendPendingData() { uint8_t* pos = nullptr; size_t packets_sent = 0; int err; + std::vector blocking; + bool ret = false; + BaseObjectPtr stream; Debug(session(), "Start sending pending data"); for (;;) { - ssize_t ndatalen; StreamData stream_data; + ssize_t ndatalen = 0; err = GetStreamData(&stream_data); if (err < 0) { session()->set_last_error(kQuicInternalError); - return false; + goto end; } // If the packet was sent previously, then packet will have been reset. if (!pos) { packet = CreateStreamDataPacket(); + if (!packet) { + Debug(session(), "Failed to create packet for stream data"); + session()->set_last_error(kQuicInternalError); + goto end; + } pos = packet->data(); } // If stream_data.id is -1, then we're not serializing any data for any // specific stream. We still need to process QUIC session packets tho. - if (stream_data.id > -1) { + if (stream_data.id >= 0) { Debug(session(), "Serializing packets for stream id %" PRId64, stream_data.id); packet->AddRetained(stream_data.stream->GetOutboundSource()); @@ -3322,10 +3333,11 @@ bool Session::Application::SendPendingData() { // CONNECTION_CLOSE since even those require a // packet number. session()->Close(Session::SessionCloseFlags::SILENT); - return false; + goto end; case NGTCP2_ERR_STREAM_DATA_BLOCKED: - Debug(session(), "Stream %lld blocked", stream_data.id); + Debug(session(), "Stream %lld blocked session data left %lld", stream_data.id, session()->max_data_left()); session()->StreamDataBlocked(stream_data.id); + blocking.push_back(stream_data.id); if (session()->max_data_left() == 0) { if (stream_data.id >= 0) { Debug(session(), "Resuming %llu after block", stream_data.id); @@ -3344,6 +3356,7 @@ bool Session::Application::SendPendingData() { CHECK_LE(ndatalen, 0); continue; case NGTCP2_ERR_STREAM_NOT_FOUND: + Debug(session(), "Stream %lld no found", stream_data.id); continue; case NGTCP2_ERR_WRITE_MORE: CHECK_GT(ndatalen, 0); @@ -3354,7 +3367,7 @@ bool Session::Application::SendPendingData() { if(nwrite != 0){ // -ve response i.e error packet.reset(); session()->set_last_error(kQuicInternalError); - return false; + goto end; } // 0 bytes in this sending operation @@ -3368,9 +3381,10 @@ bool Session::Application::SendPendingData() { Debug(session(), "Congestion limited, but %" PRIu64 " bytes pending", packet->length()); if (!session()->SendPacket(std::move(packet), path)) - return false; + goto end; } - return true; + ret = true; + goto end; } pos += nwrite; @@ -3384,7 +3398,7 @@ bool Session::Application::SendPendingData() { Debug(session(), "Sending %" PRIu64 " bytes in serialized packet", nwrite); if (!session()->SendPacket(std::move(packet), path)) { Debug(session(), "-- Failed to send packet"); - return false; + goto end; } pos = nullptr; if (++packets_sent == kMaxPackets) { @@ -3392,8 +3406,19 @@ bool Session::Application::SendPendingData() { break; } Debug(session(), "-- Looping"); + } // end for + + + ret = true; +end: + for(stream_id id : blocking) { + stream = session()->FindStream(id); + if (stream) { + stream->Unblock(); + } } - return true; + + return ret; } void Session::Application::StreamClose( @@ -3576,13 +3601,24 @@ bool DefaultApplication::ReceiveStreamData( } int DefaultApplication::GetStreamData(StreamData* stream_data) { + if (stream_queue_.IsEmpty()) { + stream_data->id = -1; + return 0; + } + Stream* stream = stream_queue_.PopFront(); - stream_data->stream.reset(stream); if (stream == nullptr) { stream_data->id = -1; return 0; } CHECK(!stream->is_destroyed()); + stream_data->stream.reset(stream); + + if(stream->IsBlocked()){ + stream_data->id = -1; + return 0; + } + stream_data->id = stream->id(); auto next = [&]( int status, @@ -3600,7 +3636,6 @@ int DefaultApplication::GetStreamData(StreamData* stream_data) { stream_data->count = count; if (count > 0) { - stream->Schedule(&stream_queue_); stream_data->remaining = get_length(data, count); } else { diff --git a/src/quic/stream.cc b/src/quic/stream.cc index 2b6e1cfa391c39..754ca650d7314c 100644 --- a/src/quic/stream.cc +++ b/src/quic/stream.cc @@ -269,6 +269,8 @@ void Stream::OnBlocked() { HandleScope handle_scope(env()->isolate()); Context::Scope context_scope(env()->context()); + blocked = true; + BaseObjectPtr ptr(this); USE(state->stream_blocked_callback()->Call( env()->context(), @@ -276,6 +278,14 @@ void Stream::OnBlocked() { 0, nullptr)); } +bool Stream::IsBlocked(){ + return blocked; +} + +void Stream::Unblock() { + blocked = false; +} + void Stream::OnReset(error_code app_error_code) { BindingState* state = BindingState::Get(env()); HandleScope scope(env()->isolate()); diff --git a/src/quic/stream.h b/src/quic/stream.h index 82ebf6e0763be0..8cf1f361b1b9ce 100644 --- a/src/quic/stream.h +++ b/src/quic/stream.h @@ -193,6 +193,9 @@ class Stream final : public AsyncWrap, void BeginHeaders(HeadersKind kind); void OnBlocked(); + bool IsBlocked(); + void Unblock(); + void OnReset(error_code app_error_code); void Commit(size_t ammount); @@ -295,6 +298,7 @@ class Stream final : public AsyncWrap, BaseObjectPtr session_; AliasedStruct state_; stream_id id_; + bool blocked = false; bool destroyed_ = false; bool destroying_ = false;