From 67ff4778c670bf96ebf86fa28575e708b9997765 Mon Sep 17 00:00:00 2001 From: Marc Brinkmann Date: Wed, 28 Feb 2024 21:24:34 +0100 Subject: [PATCH] Prevent races in multi-frame sends causing protocol violations --- src/io.rs | 49 ++++++++++++++++++++++++++++--------------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/src/io.rs b/src/io.rs index 4e13195..746ddf8 100644 --- a/src/io.rs +++ b/src/io.rs @@ -261,8 +261,9 @@ pub struct IoCore { /// The maximum time allowed for a peer to receive an error. error_timeout: Duration, - /// The frame in the process of being sent, which may be partially transferred already. - current_frame: Option, + /// The frame in the process of being sent, which may be partially transferred already. Also + /// indicates if the current frame is the final frame of a message. + current_frame: Option<(OutgoingFrame, bool)>, /// The headers of active current multi-frame transfers. active_multi_frame: [Option
; N], /// Frames waiting to be sent. @@ -543,19 +544,34 @@ where tokio::select! { biased; // We actually like the bias, avoid the randomness overhead. - write_result = write_all_buf_if_some(&mut self.writer, self.current_frame.as_mut()) + write_result = write_all_buf_if_some(&mut self.writer, + self.current_frame.as_mut() + .map(|(ref mut frame, _)| frame)) , if self.current_frame.is_some() => { write_result.map_err(CoreError::WriteFailed)?; // Clear `current_frame` via `Option::take` and examine what was sent. - if let Some(frame_sent) = self.current_frame.take() { + if let Some((frame_sent, was_final)) = self.current_frame.take() { #[cfg(feature = "tracing")] tracing::trace!(frame=%frame_sent, "sent"); - if frame_sent.header().is_error() { + let header_sent = frame_sent.header(); + + // If we finished the active multi frame send, clear it. + if was_final { + let channel_idx = header_sent.channel().get() as usize; + if let Some(ref active_multi_frame) = + self.active_multi_frame[channel_idx] { + if header_sent == *active_multi_frame { + self.active_multi_frame[channel_idx] = None; + } + } + } + + if header_sent.is_error() { // We finished sending an error frame, time to exit. - return Err(CoreError::RemoteProtocolViolation(frame_sent.header())); + return Err(CoreError::RemoteProtocolViolation(header_sent)); } // TODO: We should restrict the dirty-queue processing here a little bit @@ -563,7 +579,7 @@ where // A message has completed sending, process the wait queue in case we have // to start sending a multi-frame message like a response that was delayed // only because of the one-multi-frame-per-channel restriction. - self.process_wait_queue(frame_sent.header().channel())?; + self.process_wait_queue(header_sent.channel())?; } else { #[cfg(feature = "tracing")] tracing::error!("current frame should not disappear"); @@ -798,23 +814,14 @@ where .next_owned(self.juliet.max_frame_size()); // If there are more frames after this one, schedule the remainder. - if let Some(next_frame_iter) = additional_frames { + let is_final = if let Some(next_frame_iter) = additional_frames { self.ready_queue.push_back(next_frame_iter); + false } else { - // No additional frames. Check if sending the next frame finishes a multi-frame message. - let about_to_finish = frame.header(); - if let Some(ref active_multi) = - self.active_multi_frame[about_to_finish.channel().get() as usize] - { - if about_to_finish == *active_multi { - // Once the scheduled frame is processed, we will finished the multi-frame - // transfer, so we can allow for the next multi-frame transfer to be scheduled. - self.active_multi_frame[about_to_finish.channel().get() as usize] = None; - } - } - } + true + }; - self.current_frame = Some(frame); + self.current_frame = Some((frame, is_final)); Ok(()) }