diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index 1d68880f2c..9036d2180b 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -1679,7 +1679,11 @@ impl Connection { self.paths.make_permanent(path, None, cid); Ok(()) } else if let Some(primary) = self.paths.primary() { - if primary.borrow().remote_cid().is_empty() { + if primary + .borrow() + .remote_cid() + .map_or(true, |id| id.is_empty()) + { self.paths .make_permanent(path, None, ConnectionIdEntry::empty_remote()); Ok(()) @@ -1908,7 +1912,7 @@ impl Connection { // a packet on a new path, we avoid sending (and the privacy risk) rather // than reuse a connection ID. let res = if path.borrow().is_temporary() { - assert!(!cfg!(test), "attempting to close with a temporary path"); + debug_assert!(!cfg!(test), "attempting to close with a temporary path"); Err(Error::InternalError) } else { self.output_path(&path, now, &Some(details)) @@ -1932,17 +1936,16 @@ impl Connection { ) -> (PacketType, PacketBuilder) { let pt = PacketType::from(cspace); let mut builder = if pt == PacketType::Short { - qdebug!("Building Short dcid {}", path.remote_cid()); + qdebug!("Building Short dcid {:?}", path.remote_cid()); PacketBuilder::short(encoder, tx.key_phase(), path.remote_cid()) } else { qdebug!( - "Building {:?} dcid {} scid {}", + "Building {:?} dcid {:?} scid {:?}", pt, path.remote_cid(), path.local_cid(), ); - - PacketBuilder::long(encoder, pt, version, path.remote_cid(), path.local_cid()) + PacketBuilder::long(encoder, pt, version, &path.remote_cid(), &path.local_cid()) }; if builder.remaining() > 0 { builder.scramble(grease_quic_bit); diff --git a/neqo-transport/src/connection/tests/datagram.rs b/neqo-transport/src/connection/tests/datagram.rs index 6d02419fcd..ec2795a232 100644 --- a/neqo-transport/src/connection/tests/datagram.rs +++ b/neqo-transport/src/connection/tests/datagram.rs @@ -599,7 +599,7 @@ fn datagram_fill() { let path = p.borrow(); // Minimum overhead is connection ID length, 1 byte short header, 1 byte packet number, // 1 byte for the DATAGRAM frame type, and 16 bytes for the AEAD. - path.plpmtu() - path.remote_cid().len() - 19 + path.plpmtu() - path.remote_cid().unwrap().len() - 19 }; assert!(space >= 64); // Unlikely, but this test depends on the datagram being this large. diff --git a/neqo-transport/src/connection/tests/idle.rs b/neqo-transport/src/connection/tests/idle.rs index 336648f776..55d2ac8f16 100644 --- a/neqo-transport/src/connection/tests/idle.rs +++ b/neqo-transport/src/connection/tests/idle.rs @@ -287,7 +287,7 @@ fn idle_caching() { let mut client = default_client(); let mut server = default_server(); let start = now(); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); // Perform the first round trip, but drop the Initial from the server. // The client then caches the Handshake packet. diff --git a/neqo-transport/src/fc.rs b/neqo-transport/src/fc.rs index 37bb3daf57..acc4d6582d 100644 --- a/neqo-transport/src/fc.rs +++ b/neqo-transport/src/fc.rs @@ -810,7 +810,7 @@ mod test { fc[StreamType::BiDi].add_retired(1); fc[StreamType::BiDi].send_flowc_update(); // consume the frame - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut tokens = Vec::new(); fc[StreamType::BiDi].write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); assert_eq!(tokens.len(), 1); diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index 339800d700..6bfc40f22e 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -149,15 +149,19 @@ impl PacketBuilder { /// /// If, after calling this method, `remaining()` returns 0, then call `abort()` to get /// the encoder back. - pub fn short(mut encoder: Encoder, key_phase: bool, dcid: impl AsRef<[u8]>) -> Self { + pub fn short(mut encoder: Encoder, key_phase: bool, dcid: Option>) -> Self { let mut limit = Self::infer_limit(&encoder); let header_start = encoder.len(); // Check that there is enough space for the header. // 5 = 1 (first byte) + 4 (packet number) - if limit > encoder.len() && 5 + dcid.as_ref().len() < limit - encoder.len() { + if limit > encoder.len() + && 5 + dcid.as_ref().map_or(0, |d| d.as_ref().len()) < limit - encoder.len() + { encoder .encode_byte(PACKET_BIT_SHORT | PACKET_BIT_FIXED_QUIC | (u8::from(key_phase) << 2)); - encoder.encode(dcid.as_ref()); + if let Some(dcid) = dcid { + encoder.encode(dcid.as_ref()); + } } else { limit = 0; } @@ -185,20 +189,23 @@ impl PacketBuilder { mut encoder: Encoder, pt: PacketType, version: Version, - dcid: impl AsRef<[u8]>, - scid: impl AsRef<[u8]>, + dcid: &Option>, + scid: &Option>, ) -> Self { let mut limit = Self::infer_limit(&encoder); let header_start = encoder.len(); // Check that there is enough space for the header. // 11 = 1 (first byte) + 4 (version) + 2 (dcid+scid length) + 4 (packet number) if limit > encoder.len() - && 11 + dcid.as_ref().len() + scid.as_ref().len() < limit - encoder.len() + && 11 + + dcid.as_ref().map_or(0, |d| d.as_ref().len()) + + scid.as_ref().map_or(0, |d| d.as_ref().len()) + < limit - encoder.len() { encoder.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC | pt.to_byte(version) << 4); encoder.encode_uint(4, version.wire_version()); - encoder.encode_vec(1, dcid.as_ref()); - encoder.encode_vec(1, scid.as_ref()); + encoder.encode_vec(1, dcid.as_ref().map_or(&[], AsRef::as_ref)); + encoder.encode_vec(1, scid.as_ref().map_or(&[], AsRef::as_ref)); } else { limit = 0; } @@ -994,8 +1001,8 @@ mod tests { Encoder::new(), PacketType::Initial, Version::default(), - ConnectionId::from(&[][..]), - ConnectionId::from(SERVER_CID), + &None::<&[u8]>, + &Some(ConnectionId::from(SERVER_CID)), ); builder.initial_token(&[]); builder.pn(1, 2); @@ -1058,7 +1065,7 @@ mod tests { fn build_short() { fixture_init(); let mut builder = - PacketBuilder::short(Encoder::new(), true, ConnectionId::from(SERVER_CID)); + PacketBuilder::short(Encoder::new(), true, Some(ConnectionId::from(SERVER_CID))); builder.pn(0, 1); builder.encode(SAMPLE_SHORT_PAYLOAD); // Enough payload for sampling. let packet = builder @@ -1073,7 +1080,7 @@ mod tests { let mut firsts = Vec::new(); for _ in 0..64 { let mut builder = - PacketBuilder::short(Encoder::new(), true, ConnectionId::from(SERVER_CID)); + PacketBuilder::short(Encoder::new(), true, Some(ConnectionId::from(SERVER_CID))); builder.scramble(true); builder.pn(0, 1); firsts.push(builder.as_ref()[0]); @@ -1136,8 +1143,8 @@ mod tests { Encoder::new(), PacketType::Handshake, Version::default(), - ConnectionId::from(SERVER_CID), - ConnectionId::from(CLIENT_CID), + &Some(ConnectionId::from(SERVER_CID)), + &Some(ConnectionId::from(CLIENT_CID)), ); builder.pn(0, 1); builder.encode(&[0; 3]); @@ -1145,7 +1152,8 @@ mod tests { assert_eq!(encoder.len(), 45); let first = encoder.clone(); - let mut builder = PacketBuilder::short(encoder, false, ConnectionId::from(SERVER_CID)); + let mut builder = + PacketBuilder::short(encoder, false, Some(ConnectionId::from(SERVER_CID))); builder.pn(1, 3); builder.encode(&[0]); // Minimal size (packet number is big enough). let encoder = builder.build(&mut prot).expect("build"); @@ -1170,8 +1178,8 @@ mod tests { Encoder::new(), PacketType::Handshake, Version::default(), - ConnectionId::from(&[][..]), - ConnectionId::from(&[][..]), + &None::<&[u8]>, + &None::<&[u8]>, ); builder.pn(0, 1); builder.encode(&[1, 2, 3]); @@ -1189,8 +1197,8 @@ mod tests { Encoder::new(), PacketType::Handshake, Version::default(), - ConnectionId::from(&[][..]), - ConnectionId::from(&[][..]), + &None::<&[u8]>, + &None::<&[u8]>, ); builder.pn(0, 1); builder.scramble(true); @@ -1210,8 +1218,8 @@ mod tests { Encoder::new(), PacketType::Initial, Version::default(), - ConnectionId::from(&[][..]), - ConnectionId::from(SERVER_CID), + &None::<&[u8]>, + &Some(ConnectionId::from(SERVER_CID)), ); assert_ne!(builder.remaining(), 0); builder.initial_token(&[]); @@ -1229,7 +1237,7 @@ mod tests { let mut builder = PacketBuilder::short( Encoder::with_capacity(100), true, - ConnectionId::from(SERVER_CID), + Some(ConnectionId::from(SERVER_CID)), ); builder.pn(0, 1); // Pad, but not up to the full capacity. Leave enough space for the @@ -1244,8 +1252,8 @@ mod tests { encoder, PacketType::Initial, Version::default(), - ConnectionId::from(SERVER_CID), - ConnectionId::from(SERVER_CID), + &Some(ConnectionId::from(SERVER_CID)), + &Some(ConnectionId::from(SERVER_CID)), ); assert_eq!(builder.remaining(), 0); assert_eq!(builder.abort(), encoder_copy); diff --git a/neqo-transport/src/path.rs b/neqo-transport/src/path.rs index ec56ebdb81..9523a392b6 100644 --- a/neqo-transport/src/path.rs +++ b/neqo-transport/src/path.rs @@ -660,8 +660,8 @@ impl Path { /// Get the first local connection ID. /// Only do this for the primary path during the handshake. - pub fn local_cid(&self) -> &ConnectionId { - self.local_cid.as_ref().unwrap() + pub const fn local_cid(&self) -> Option<&ConnectionId> { + self.local_cid.as_ref() } /// Set the remote connection ID based on the peer's choice. @@ -674,8 +674,10 @@ impl Path { } /// Access the remote connection ID. - pub fn remote_cid(&self) -> &ConnectionId { - self.remote_cid.as_ref().unwrap().connection_id() + pub fn remote_cid(&self) -> Option<&ConnectionId> { + self.remote_cid + .as_ref() + .map(super::cid::ConnectionIdEntry::connection_id) } /// Set the stateless reset token for the connection ID that is currently in use. diff --git a/neqo-transport/src/pmtud.rs b/neqo-transport/src/pmtud.rs index 5ee59e3dbf..9eec6b0eda 100644 --- a/neqo-transport/src/pmtud.rs +++ b/neqo-transport/src/pmtud.rs @@ -383,7 +383,7 @@ mod tests { let stats_before = stats.clone(); // Fake a packet number, so the builder logic works. - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let pn = prot.next_pn(); builder.pn(pn, 4); builder.set_initial_limit(&SendProfile::new_limited(pmtud.plpmtu()), 16, pmtud); diff --git a/neqo-transport/src/qlog.rs b/neqo-transport/src/qlog.rs index 29f17bf6b9..fa127212f0 100644 --- a/neqo-transport/src/qlog.rs +++ b/neqo-transport/src/qlog.rs @@ -104,8 +104,8 @@ fn connection_started(qlog: &NeqoQlog, path: &PathRef) { protocol: Some("QUIC".into()), src_port: p.local_address().port().into(), dst_port: p.remote_address().port().into(), - src_cid: Some(format!("{}", p.local_cid())), - dst_cid: Some(format!("{}", p.remote_cid())), + src_cid: p.local_cid().map(ToString::to_string), + dst_cid: p.remote_cid().map(ToString::to_string), }); Some(ev_data) diff --git a/neqo-transport/src/recv_stream.rs b/neqo-transport/src/recv_stream.rs index c022f5fbd0..bb4f923114 100644 --- a/neqo-transport/src/recv_stream.rs +++ b/neqo-transport/src/recv_stream.rs @@ -1483,7 +1483,7 @@ mod tests { assert!(s.has_frames_to_write()); // consume it - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut token = Vec::new(); s.write_frame(&mut builder, &mut token, &mut FrameStats::default()); @@ -1597,7 +1597,7 @@ mod tests { s.read(&mut buf).unwrap(); assert!(session_fc.borrow().frame_needed()); // consume it - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut token = Vec::new(); session_fc .borrow_mut() @@ -1618,7 +1618,7 @@ mod tests { s.read(&mut buf).unwrap(); assert!(session_fc.borrow().frame_needed()); // consume it - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut token = Vec::new(); session_fc .borrow_mut() @@ -1866,7 +1866,7 @@ mod tests { assert!(s.fc().unwrap().frame_needed()); // Write the fc update frame - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut token = Vec::new(); let mut stats = FrameStats::default(); fc.borrow_mut() diff --git a/neqo-transport/src/send_stream.rs b/neqo-transport/src/send_stream.rs index a6e42cfdaf..3f0002da13 100644 --- a/neqo-transport/src/send_stream.rs +++ b/neqo-transport/src/send_stream.rs @@ -2596,7 +2596,7 @@ mod tests { ss.insert(StreamId::from(0), s); let mut tokens = Vec::new(); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); // Write a small frame: no fin. let written = builder.len(); @@ -2684,7 +2684,7 @@ mod tests { ss.insert(StreamId::from(0), s); let mut tokens = Vec::new(); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); ss.write_frames( TransmissionPriority::default(), &mut builder, @@ -2762,7 +2762,7 @@ mod tests { assert_eq!(s.next_bytes(false), Some((0, &b"ab"[..]))); // This doesn't report blocking yet. - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut tokens = Vec::new(); let mut stats = FrameStats::default(); s.write_blocked_frame( @@ -2815,7 +2815,7 @@ mod tests { assert_eq!(s.send_atomic(b"abc").unwrap(), 0); // Assert that STREAM_DATA_BLOCKED is sent. - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut tokens = Vec::new(); let mut stats = FrameStats::default(); s.write_blocked_frame( @@ -2902,7 +2902,7 @@ mod tests { s.mark_as_lost(len_u64, 0, true); // No frame should be sent here. - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut tokens = Vec::new(); let mut stats = FrameStats::default(); s.write_stream_frame( @@ -2962,7 +2962,7 @@ mod tests { s.close(); } - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let header_len = builder.len(); builder.set_limit(header_len + space); @@ -3063,7 +3063,7 @@ mod tests { s.send(data).unwrap(); s.close(); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let header_len = builder.len(); // Add 2 for the frame type and stream ID, then add the extra. builder.set_limit(header_len + data.len() + 2 + extra); diff --git a/neqo-transport/src/tracking.rs b/neqo-transport/src/tracking.rs index 90bbd0b54a..b7ab8bac50 100644 --- a/neqo-transport/src/tracking.rs +++ b/neqo-transport/src/tracking.rs @@ -797,7 +797,7 @@ mod tests { } fn write_frame_at(rp: &mut RecvdPackets, now: Instant) { - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut stats = FrameStats::default(); let mut tokens = Vec::new(); rp.write_frame(now, RTT, &mut builder, &mut tokens, &mut stats); @@ -952,7 +952,7 @@ mod tests { #[test] fn drop_spaces() { let mut tracker = AckTracker::default(); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); tracker .get_mut(PacketNumberSpace::Initial) .unwrap() @@ -1017,7 +1017,7 @@ mod tests { .ack_time(now().checked_sub(Duration::from_millis(1)).unwrap()) .is_some()); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); builder.set_limit(10); let mut stats = FrameStats::default(); @@ -1048,7 +1048,7 @@ mod tests { .ack_time(now().checked_sub(Duration::from_millis(1)).unwrap()) .is_some()); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); // The code pessimistically assumes that each range needs 16 bytes to express. // So this won't be enough for a second range. builder.set_limit(RecvdPackets::USEFUL_ACK_LEN + 8);