diff --git a/quinn-udp/src/fallback.rs b/quinn-udp/src/fallback.rs index e3855c6b8..14556142a 100644 --- a/quinn-udp/src/fallback.rs +++ b/quinn-udp/src/fallback.rs @@ -77,8 +77,10 @@ impl UdpSocket { debug_assert!(!bufs.is_empty()); let mut buf = ReadBuf::new(&mut bufs[0]); let addr = ready!(self.io.poll_recv_from(cx, &mut buf))?; + let len = buf.filled().len(); meta[0] = RecvMeta { - len: buf.filled().len(), + len, + stride: len, addr, ecn: None, dst_ip: None, @@ -95,6 +97,7 @@ impl UdpSocket { pub fn udp_state() -> super::UdpState { super::UdpState { max_gso_segments: std::sync::atomic::AtomicUsize::new(1), + gro_segments: 1, } } diff --git a/quinn-udp/src/lib.rs b/quinn-udp/src/lib.rs index fd232b023..05071102a 100644 --- a/quinn-udp/src/lib.rs +++ b/quinn-udp/src/lib.rs @@ -37,6 +37,7 @@ pub const BATCH_SIZE: usize = imp::BATCH_SIZE; #[derive(Debug)] pub struct UdpState { max_gso_segments: AtomicUsize, + gro_segments: usize, } impl UdpState { @@ -53,6 +54,15 @@ impl UdpState { pub fn max_gso_segments(&self) -> usize { self.max_gso_segments.load(Ordering::Relaxed) } + + /// The number of segments to read when GRO is enabled. Used as a factor to + /// compute the receive buffer size. + /// + /// Returns 1 if the platform doesn't support GRO. + #[inline] + pub fn gro_segments(&self) -> usize { + self.gro_segments + } } impl Default for UdpState { @@ -65,6 +75,7 @@ impl Default for UdpState { pub struct RecvMeta { pub addr: SocketAddr, pub len: usize, + pub stride: usize, pub ecn: Option, /// The destination IP address which was encoded in this datagram pub dst_ip: Option, @@ -76,6 +87,7 @@ impl Default for RecvMeta { Self { addr: SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0), len: 0, + stride: 0, ecn: None, dst_ip: None, } diff --git a/quinn-udp/src/unix.rs b/quinn-udp/src/unix.rs index a3746ac1f..1244dce8d 100644 --- a/quinn-udp/src/unix.rs +++ b/quinn-udp/src/unix.rs @@ -52,6 +52,7 @@ pub struct UdpSocket { impl UdpSocket { pub fn from_std(socket: std::net::UdpSocket) -> io::Result { socket.set_nonblocking(true)?; + init(&socket)?; let now = Instant::now(); Ok(UdpSocket { @@ -134,6 +135,18 @@ fn init(io: &std::net::UdpSocket) -> io::Result<()> { } #[cfg(target_os = "linux")] { + // opportunistically try to enable GRO. See gro::gro_segments(). + let on: libc::c_int = 1; + unsafe { + libc::setsockopt( + io.as_raw_fd(), + libc::SOL_UDP, + libc::UDP_GRO, + &on as *const _ as _, + mem::size_of_val(&on) as _, + ) + }; + if addr.is_ipv4() { let rc = unsafe { libc::setsockopt( @@ -407,6 +420,7 @@ fn recv( pub fn udp_state() -> UdpState { UdpState { max_gso_segments: AtomicUsize::new(gso::max_gso_segments()), + gro_segments: gro::gro_segments(), } } @@ -500,6 +514,7 @@ fn decode_recv( let name = unsafe { name.assume_init() }; let mut ecn_bits = 0; let mut dst_ip = None; + let mut stride = len; let cmsg_iter = unsafe { cmsg::Iter::new(hdr) }; for cmsg in cmsg_iter { @@ -527,6 +542,10 @@ fn decode_recv( let pktinfo = cmsg::decode::(cmsg); dst_ip = Some(IpAddr::V6(ptr::read(&pktinfo.ipi6_addr as *const _ as _))); }, + #[cfg(target_os = "linux")] + (libc::SOL_UDP, libc::UDP_GRO) => unsafe { + stride = cmsg::decode::(cmsg) as usize; + }, _ => {} } } @@ -539,6 +558,7 @@ fn decode_recv( RecvMeta { len, + stride, addr, ecn: EcnCodepoint::from_bits(ecn_bits), dst_ip, @@ -602,3 +622,39 @@ mod gso { panic!("Setting a segment size is not supported on current platform"); } } + +#[cfg(target_os = "linux")] +mod gro { + use super::*; + + pub fn gro_segments() -> usize { + let socket = match std::net::UdpSocket::bind("[::]:0") { + Ok(socket) => socket, + Err(_) => return 1, + }; + + let on: libc::c_int = 1; + let rc = unsafe { + libc::setsockopt( + socket.as_raw_fd(), + libc::SOL_UDP, + libc::UDP_GRO, + &on as *const _ as _, + mem::size_of_val(&on) as _, + ) + }; + + if rc != -1 { + 10 + } else { + 1 + } + } +} + +#[cfg(not(target_os = "linux"))] +mod gro { + pub fn gro_segments() -> usize { + 1 + } +} diff --git a/quinn/src/endpoint.rs b/quinn/src/endpoint.rs index 698b706e7..98d817376 100644 --- a/quinn/src/endpoint.rs +++ b/quinn/src/endpoint.rs @@ -12,7 +12,7 @@ use std::{ time::Instant, }; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use proto::{ self as proto, ClientConfig, ConnectError, ConnectionHandle, DatagramEvent, ServerConfig, }; @@ -346,27 +346,32 @@ impl EndpointInner { Poll::Ready(Ok(msgs)) => { self.recv_limiter.record_work(msgs); for (meta, buf) in metas.iter().zip(iovs.iter()).take(msgs) { - let data = buf[0..meta.len].into(); - match self - .inner - .handle(now, meta.addr, meta.dst_ip, meta.ecn, data) - { - Some((handle, DatagramEvent::NewConnection(conn))) => { - let conn = - self.connections - .insert(handle, conn, self.udp_state.clone()); - self.incoming.push_back(conn); + let mut data: BytesMut = buf[0..meta.len].into(); + while !data.is_empty() { + let buf = data.split_to(meta.stride.min(data.len())); + match self + .inner + .handle(now, meta.addr, meta.dst_ip, meta.ecn, buf) + { + Some((handle, DatagramEvent::NewConnection(conn))) => { + let conn = self.connections.insert( + handle, + conn, + self.udp_state.clone(), + ); + self.incoming.push_back(conn); + } + Some((handle, DatagramEvent::ConnectionEvent(event))) => { + // Ignoring errors from dropped connections that haven't yet been cleaned up + let _ = self + .connections + .senders + .get_mut(&handle) + .unwrap() + .send(ConnectionEvent::Proto(event)); + } + None => {} } - Some((handle, DatagramEvent::ConnectionEvent(event))) => { - // Ignoring errors from dropped connections that haven't yet been cleaned up - let _ = self - .connections - .senders - .get_mut(&handle) - .unwrap() - .send(ConnectionEvent::Proto(event)); - } - None => {} } } } @@ -565,12 +570,17 @@ pub(crate) struct EndpointRef(Arc>); impl EndpointRef { pub(crate) fn new(socket: UdpSocket, inner: proto::Endpoint, ipv6: bool) -> Self { - let recv_buf = - vec![0; inner.config().get_max_udp_payload_size().min(64 * 1024) as usize * BATCH_SIZE]; + let udp_state = Arc::new(UdpState::new()); + let recv_buf = vec![ + 0; + inner.config().get_max_udp_payload_size().min(64 * 1024) as usize + * udp_state.gro_segments() + * BATCH_SIZE + ]; let (sender, events) = mpsc::unbounded_channel(); Self(Arc::new(Mutex::new(EndpointInner { socket, - udp_state: Arc::new(UdpState::new()), + udp_state, inner, ipv6, events,