Skip to content

Commit

Permalink
fix(server): handle full send queues more gracefully
Browse files Browse the repository at this point in the history
With large bursts in stream data, we sometimes need to wait before sending
everything. This handles storing and flushing partial sends.
  • Loading branch information
colinmarc committed Apr 30, 2024
1 parent e1dc976 commit face877
Showing 1 changed file with 81 additions and 17 deletions.
98 changes: 81 additions & 17 deletions mm-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use std::net::SocketAddr;
use std::sync::Arc;
use tracing::debug_span;
use tracing::trace;
use tracing::trace_span;
use tracing::warn;
use tracing::{debug, error};

Expand Down Expand Up @@ -66,7 +67,8 @@ pub struct ClientConnection {
conn: quiche::Connection,
timer: mio_timerfd::TimerFd,
timeout_token: mio::Token,
partial: HashMap<u64, BytesMut>,
partial_reads: HashMap<u64, BytesMut>,
partial_writes: HashMap<u64, Bytes>,
in_flight: HashMap<u64, StreamWorker>,
dgram_recv: Receiver<protocol::MessageType>,
dgram_send: WakingSender<protocol::MessageType>,
Expand Down Expand Up @@ -264,6 +266,9 @@ impl Server {

// Demux packets from in-flight requests and datagrams from attachments.
for client in self.clients.values_mut() {
let conn_span = trace_span!("conn_write", conn_id = ?client.conn_id);
let _guard = conn_span.enter();

if client.conn.is_draining() {
continue;
}
Expand Down Expand Up @@ -302,7 +307,14 @@ impl Server {
continue;
}

if !client.flush_partial_write(sid)? {
continue;
}

loop {
let span = trace_span!("stream_write", sid);
let _guard = span.enter();

match client
.in_flight
.get(&sid)
Expand All @@ -311,7 +323,10 @@ impl Server {
.try_recv()
{
Ok(msg) => {
client.write_message(sid, msg, false, &mut self.scratch)?;
if !client.write_message(sid, msg, false, &mut self.scratch)? {
// No more write capacity at the moment.
break;
}
}
Err(TryRecvError::Disconnected) => {
client.conn.stream_send(sid, &[], true)?;
Expand Down Expand Up @@ -439,7 +454,8 @@ impl Server {
timer,
timeout_token,
in_flight: streams,
partial: HashMap::new(),
partial_reads: HashMap::new(),
partial_writes: HashMap::new(),
dgram_recv,
dgram_send,
};
Expand Down Expand Up @@ -546,7 +562,7 @@ impl Server {

// Clean up partial data for closed streams.
client
.partial
.partial_reads
.retain(|sid, _| !client.conn.stream_finished(*sid));

Ok(())
Expand Down Expand Up @@ -591,7 +607,7 @@ impl ClientConnection {
) -> anyhow::Result<(Vec<protocol::MessageType>, bool)> {
// Start with partial data from the previous call to read_messages.
scratch.truncate(0);
if let Some(partial) = self.partial.remove(&sid) {
if let Some(partial) = self.partial_reads.remove(&sid) {
scratch.unsplit(partial);
}

Expand Down Expand Up @@ -629,7 +645,8 @@ impl ClientConnection {
sid,
n
);
self.partial.insert(sid, buf);

self.partial_reads.insert(sid, buf);
break;
}
Err(e) => return Err(e.into()),
Expand All @@ -649,26 +666,73 @@ impl ClientConnection {
Ok((messages, stream_fin))
}

/// Send a message on a stream.
/// Send a message on a stream. Returns Ok(false) if the stream is full.
fn write_message(
&mut self,
sid: u64,
msg: protocol::MessageType,
fin: bool,
scratch: &mut BytesMut,
) -> anyhow::Result<()> {
) -> anyhow::Result<bool> {
scratch.resize(protocol::MAX_MESSAGE_SIZE, 0);
let len = protocol::encode_message(&msg, scratch).unwrap();
let len = protocol::encode_message(&msg, scratch)?;

trace!(
conn_id = ?self.conn_id,
stream_id = sid,
len,
"sending {}", msg
);
trace!(len, "sending {}", msg);

self.conn.stream_send(sid, &scratch[..len], fin)?;
Ok(())
match self.conn.stream_send(sid, &scratch[..len], fin) {
Ok(n) if n != len => {
// Partial write.
assert!(n < len);
trace!(n, "partial write");

let partial = scratch.split_to(len).split_off(n).freeze();
let old = self.partial_writes.insert(sid, partial);
assert_eq!(None, old);

Ok(false)
}
Err(quiche::Error::Done) => {
trace!("stream blocked");

let data = scratch.split_to(len).freeze();
let old = self.partial_writes.insert(sid, data);
assert_eq!(None, old);

Ok(false)
}
v => {
assert_eq!(len, v?);
Ok(true)
}
}
}

/// Flushes previous partial writes.
fn flush_partial_write(&mut self, sid: u64) -> anyhow::Result<bool> {
use std::collections::hash_map::Entry;

match self.partial_writes.entry(sid) {
Entry::Vacant(_) => Ok(true),
Entry::Occupied(mut entry) => {
let partial = entry.get().clone();
trace!(len = partial.len(), "flushing previous partial");

match self.conn.stream_send(sid, &partial, false) {
Ok(n) if n != entry.get().len() => {
// Partial write.
entry.get_mut().advance(n);
trace!(len = entry.get().len(), "remaining partial");
Ok(false)
}
Ok(_) => {
entry.remove();
Ok(true)
}
Err(quiche::Error::Done) => Ok(false),
Err(e) => Err(anyhow!(e)),
}
}
}
}

/// Send a message as a datagram.
Expand Down

0 comments on commit face877

Please sign in to comment.