diff --git a/mm-server/src/server.rs b/mm-server/src/server.rs index b9bb9a6..866c577 100644 --- a/mm-server/src/server.rs +++ b/mm-server/src/server.rs @@ -17,6 +17,7 @@ use std::net::SocketAddr; use std::sync::Arc; use tracing::debug_span; use tracing::trace; +use tracing::trace_span; use tracing::warn; use tracing::{debug, error}; @@ -66,7 +67,8 @@ pub struct ClientConnection { conn: quiche::Connection, timer: mio_timerfd::TimerFd, timeout_token: mio::Token, - partial: HashMap, + partial_reads: HashMap, + partial_writes: HashMap, in_flight: HashMap, dgram_recv: Receiver, dgram_send: WakingSender, @@ -264,6 +266,9 @@ impl Server { // Demux packets from in-flight requests and datagrams from attachments. for client in self.clients.values_mut() { + let conn_span = trace_span!("conn_write", conn_id = ?client.conn_id); + let _guard = conn_span.enter(); + if client.conn.is_draining() { continue; } @@ -302,7 +307,14 @@ impl Server { continue; } + if !client.flush_partial_write(sid)? { + continue; + } + loop { + let span = trace_span!("stream_write", sid); + let _guard = span.enter(); + match client .in_flight .get(&sid) @@ -311,7 +323,10 @@ impl Server { .try_recv() { Ok(msg) => { - client.write_message(sid, msg, false, &mut self.scratch)?; + if !client.write_message(sid, msg, false, &mut self.scratch)? { + // No more write capacity at the moment. + break; + } } Err(TryRecvError::Disconnected) => { client.conn.stream_send(sid, &[], true)?; @@ -439,7 +454,8 @@ impl Server { timer, timeout_token, in_flight: streams, - partial: HashMap::new(), + partial_reads: HashMap::new(), + partial_writes: HashMap::new(), dgram_recv, dgram_send, }; @@ -546,7 +562,7 @@ impl Server { // Clean up partial data for closed streams. client - .partial + .partial_reads .retain(|sid, _| !client.conn.stream_finished(*sid)); Ok(()) @@ -591,7 +607,7 @@ impl ClientConnection { ) -> anyhow::Result<(Vec, bool)> { // Start with partial data from the previous call to read_messages. scratch.truncate(0); - if let Some(partial) = self.partial.remove(&sid) { + if let Some(partial) = self.partial_reads.remove(&sid) { scratch.unsplit(partial); } @@ -629,7 +645,8 @@ impl ClientConnection { sid, n ); - self.partial.insert(sid, buf); + + self.partial_reads.insert(sid, buf); break; } Err(e) => return Err(e.into()), @@ -649,26 +666,73 @@ impl ClientConnection { Ok((messages, stream_fin)) } - /// Send a message on a stream. + /// Send a message on a stream. Returns Ok(false) if the stream is full. fn write_message( &mut self, sid: u64, msg: protocol::MessageType, fin: bool, scratch: &mut BytesMut, - ) -> anyhow::Result<()> { + ) -> anyhow::Result { scratch.resize(protocol::MAX_MESSAGE_SIZE, 0); - let len = protocol::encode_message(&msg, scratch).unwrap(); + let len = protocol::encode_message(&msg, scratch)?; - trace!( - conn_id = ?self.conn_id, - stream_id = sid, - len, - "sending {}", msg - ); + trace!(len, "sending {}", msg); - self.conn.stream_send(sid, &scratch[..len], fin)?; - Ok(()) + match self.conn.stream_send(sid, &scratch[..len], fin) { + Ok(n) if n != len => { + // Partial write. + assert!(n < len); + trace!(n, "partial write"); + + let partial = scratch.split_to(len).split_off(n).freeze(); + let old = self.partial_writes.insert(sid, partial); + assert_eq!(None, old); + + Ok(false) + } + Err(quiche::Error::Done) => { + trace!("stream blocked"); + + let data = scratch.split_to(len).freeze(); + let old = self.partial_writes.insert(sid, data); + assert_eq!(None, old); + + Ok(false) + } + v => { + assert_eq!(len, v?); + Ok(true) + } + } + } + + /// Flushes previous partial writes. + fn flush_partial_write(&mut self, sid: u64) -> anyhow::Result { + use std::collections::hash_map::Entry; + + match self.partial_writes.entry(sid) { + Entry::Vacant(_) => Ok(true), + Entry::Occupied(mut entry) => { + let partial = entry.get().clone(); + trace!(len = partial.len(), "flushing previous partial"); + + match self.conn.stream_send(sid, &partial, false) { + Ok(n) if n != entry.get().len() => { + // Partial write. + entry.get_mut().advance(n); + trace!(len = entry.get().len(), "remaining partial"); + Ok(false) + } + Ok(_) => { + entry.remove(); + Ok(true) + } + Err(quiche::Error::Done) => Ok(false), + Err(e) => Err(anyhow!(e)), + } + } + } } /// Send a message as a datagram.