From 477a09ab2f22e4c9cbf031388726405b5355c0b8 Mon Sep 17 00:00:00 2001 From: Lars Eggert Date: Fri, 9 Aug 2024 14:54:07 +0300 Subject: [PATCH] fix: Check whether CIDs are empty (#2034) * fix: Check whether CIDs are empty WIP Fixes #1429 * Update neqo-transport/src/path.rs Co-authored-by: Martin Thomson Signed-off-by: Lars Eggert * Suggestion from @martinthomson * Update neqo-transport/src/qlog.rs Co-authored-by: Max Inden Signed-off-by: Lars Eggert * Suggestion from @mxinden @mxinden, is `take()` the way to go here? * Log error * Fix test * Simplify test --------- Signed-off-by: Lars Eggert Co-authored-by: Martin Thomson Co-authored-by: Max Inden --- neqo-transport/src/connection/mod.rs | 13 +++-- .../src/connection/tests/datagram.rs | 2 +- neqo-transport/src/connection/tests/idle.rs | 2 +- .../src/connection/tests/migration.rs | 19 ++++++- neqo-transport/src/fc.rs | 2 +- neqo-transport/src/packet/mod.rs | 56 +++++++++++-------- neqo-transport/src/path.rs | 10 ++-- neqo-transport/src/pmtud.rs | 2 +- neqo-transport/src/qlog.rs | 4 +- neqo-transport/src/recv_stream.rs | 8 +-- neqo-transport/src/send_stream.rs | 14 ++--- neqo-transport/src/tracking.rs | 8 +-- 12 files changed, 85 insertions(+), 55 deletions(-) diff --git a/neqo-transport/src/connection/mod.rs b/neqo-transport/src/connection/mod.rs index 76d570270d..276d407f17 100644 --- a/neqo-transport/src/connection/mod.rs +++ b/neqo-transport/src/connection/mod.rs @@ -1684,7 +1684,11 @@ impl Connection { self.paths.make_permanent(path, None, cid); Ok(()) } else if let Some(primary) = self.paths.primary() { - if primary.borrow().remote_cid().is_empty() { + if primary + .borrow() + .remote_cid() + .map_or(true, |id| id.is_empty()) + { self.paths .make_permanent(path, None, ConnectionIdEntry::empty_remote()); Ok(()) @@ -1913,7 +1917,7 @@ impl Connection { // a packet on a new path, we avoid sending (and the privacy risk) rather // than reuse a connection ID. let res = if path.borrow().is_temporary() { - assert!(!cfg!(test), "attempting to close with a temporary path"); + qerror!([self], "Attempting to close with a temporary path"); Err(Error::InternalError) } else { self.output_path(&path, now, &Some(details)) @@ -1937,16 +1941,15 @@ impl Connection { ) -> (PacketType, PacketBuilder) { let pt = PacketType::from(cspace); let mut builder = if pt == PacketType::Short { - qdebug!("Building Short dcid {}", path.remote_cid()); + qdebug!("Building Short dcid {:?}", path.remote_cid()); PacketBuilder::short(encoder, tx.key_phase(), path.remote_cid()) } else { qdebug!( - "Building {:?} dcid {} scid {}", + "Building {:?} dcid {:?} scid {:?}", pt, path.remote_cid(), path.local_cid(), ); - PacketBuilder::long(encoder, pt, version, path.remote_cid(), path.local_cid()) }; if builder.remaining() > 0 { diff --git a/neqo-transport/src/connection/tests/datagram.rs b/neqo-transport/src/connection/tests/datagram.rs index 6d02419fcd..ec2795a232 100644 --- a/neqo-transport/src/connection/tests/datagram.rs +++ b/neqo-transport/src/connection/tests/datagram.rs @@ -599,7 +599,7 @@ fn datagram_fill() { let path = p.borrow(); // Minimum overhead is connection ID length, 1 byte short header, 1 byte packet number, // 1 byte for the DATAGRAM frame type, and 16 bytes for the AEAD. - path.plpmtu() - path.remote_cid().len() - 19 + path.plpmtu() - path.remote_cid().unwrap().len() - 19 }; assert!(space >= 64); // Unlikely, but this test depends on the datagram being this large. diff --git a/neqo-transport/src/connection/tests/idle.rs b/neqo-transport/src/connection/tests/idle.rs index 336648f776..55d2ac8f16 100644 --- a/neqo-transport/src/connection/tests/idle.rs +++ b/neqo-transport/src/connection/tests/idle.rs @@ -287,7 +287,7 @@ fn idle_caching() { let mut client = default_client(); let mut server = default_server(); let start = now(); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); // Perform the first round trip, but drop the Initial from the server. // The client then caches the Handshake packet. diff --git a/neqo-transport/src/connection/tests/migration.rs b/neqo-transport/src/connection/tests/migration.rs index e786e1e348..64c025f98b 100644 --- a/neqo-transport/src/connection/tests/migration.rs +++ b/neqo-transport/src/connection/tests/migration.rs @@ -946,7 +946,6 @@ impl crate::connection::test_internal::FrameWriter for GarbageWriter { /// Test the case that we run out of connection ID and receive an invalid frame /// from a new path. #[test] -#[should_panic(expected = "attempting to close with a temporary path")] fn error_on_new_path_with_no_connection_id() { let mut client = default_client(); let mut server = default_server(); @@ -967,5 +966,23 @@ fn error_on_new_path_with_no_connection_id() { // See issue #1697. We had a crash when the client had a temporary path and // process_output is called. + let closing_frames = client.stats().frame_tx.connection_close; mem::drop(client.process_output(now())); + assert!(matches!( + client.state(), + State::Closing { + error: CloseReason::Transport(Error::UnknownFrameType), + .. + } + )); + // Wait until the connection is closed. + let mut now = now(); + now += client.process(None, now).callback(); + _ = client.process_output(now); + // No closing frames should be sent, and the connection should be closed. + assert_eq!(client.stats().frame_tx.connection_close, closing_frames); + assert!(matches!( + client.state(), + State::Closed(CloseReason::Transport(Error::UnknownFrameType)) + )); } diff --git a/neqo-transport/src/fc.rs b/neqo-transport/src/fc.rs index 37bb3daf57..acc4d6582d 100644 --- a/neqo-transport/src/fc.rs +++ b/neqo-transport/src/fc.rs @@ -810,7 +810,7 @@ mod test { fc[StreamType::BiDi].add_retired(1); fc[StreamType::BiDi].send_flowc_update(); // consume the frame - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut tokens = Vec::new(); fc[StreamType::BiDi].write_frames(&mut builder, &mut tokens, &mut FrameStats::default()); assert_eq!(tokens.len(), 1); diff --git a/neqo-transport/src/packet/mod.rs b/neqo-transport/src/packet/mod.rs index 339800d700..09a4e19d26 100644 --- a/neqo-transport/src/packet/mod.rs +++ b/neqo-transport/src/packet/mod.rs @@ -149,15 +149,19 @@ impl PacketBuilder { /// /// If, after calling this method, `remaining()` returns 0, then call `abort()` to get /// the encoder back. - pub fn short(mut encoder: Encoder, key_phase: bool, dcid: impl AsRef<[u8]>) -> Self { + pub fn short(mut encoder: Encoder, key_phase: bool, dcid: Option>) -> Self { let mut limit = Self::infer_limit(&encoder); let header_start = encoder.len(); // Check that there is enough space for the header. // 5 = 1 (first byte) + 4 (packet number) - if limit > encoder.len() && 5 + dcid.as_ref().len() < limit - encoder.len() { + if limit > encoder.len() + && 5 + dcid.as_ref().map_or(0, |d| d.as_ref().len()) < limit - encoder.len() + { encoder .encode_byte(PACKET_BIT_SHORT | PACKET_BIT_FIXED_QUIC | (u8::from(key_phase) << 2)); - encoder.encode(dcid.as_ref()); + if let Some(dcid) = dcid { + encoder.encode(dcid.as_ref()); + } } else { limit = 0; } @@ -185,20 +189,23 @@ impl PacketBuilder { mut encoder: Encoder, pt: PacketType, version: Version, - dcid: impl AsRef<[u8]>, - scid: impl AsRef<[u8]>, + mut dcid: Option>, + mut scid: Option>, ) -> Self { let mut limit = Self::infer_limit(&encoder); let header_start = encoder.len(); // Check that there is enough space for the header. // 11 = 1 (first byte) + 4 (version) + 2 (dcid+scid length) + 4 (packet number) if limit > encoder.len() - && 11 + dcid.as_ref().len() + scid.as_ref().len() < limit - encoder.len() + && 11 + + dcid.as_ref().map_or(0, |d| d.as_ref().len()) + + scid.as_ref().map_or(0, |d| d.as_ref().len()) + < limit - encoder.len() { encoder.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC | pt.to_byte(version) << 4); encoder.encode_uint(4, version.wire_version()); - encoder.encode_vec(1, dcid.as_ref()); - encoder.encode_vec(1, scid.as_ref()); + encoder.encode_vec(1, dcid.take().as_ref().map_or(&[], AsRef::as_ref)); + encoder.encode_vec(1, scid.take().as_ref().map_or(&[], AsRef::as_ref)); } else { limit = 0; } @@ -994,8 +1001,8 @@ mod tests { Encoder::new(), PacketType::Initial, Version::default(), - ConnectionId::from(&[][..]), - ConnectionId::from(SERVER_CID), + None::<&[u8]>, + Some(ConnectionId::from(SERVER_CID)), ); builder.initial_token(&[]); builder.pn(1, 2); @@ -1058,7 +1065,7 @@ mod tests { fn build_short() { fixture_init(); let mut builder = - PacketBuilder::short(Encoder::new(), true, ConnectionId::from(SERVER_CID)); + PacketBuilder::short(Encoder::new(), true, Some(ConnectionId::from(SERVER_CID))); builder.pn(0, 1); builder.encode(SAMPLE_SHORT_PAYLOAD); // Enough payload for sampling. let packet = builder @@ -1073,7 +1080,7 @@ mod tests { let mut firsts = Vec::new(); for _ in 0..64 { let mut builder = - PacketBuilder::short(Encoder::new(), true, ConnectionId::from(SERVER_CID)); + PacketBuilder::short(Encoder::new(), true, Some(ConnectionId::from(SERVER_CID))); builder.scramble(true); builder.pn(0, 1); firsts.push(builder.as_ref()[0]); @@ -1136,8 +1143,8 @@ mod tests { Encoder::new(), PacketType::Handshake, Version::default(), - ConnectionId::from(SERVER_CID), - ConnectionId::from(CLIENT_CID), + Some(ConnectionId::from(SERVER_CID)), + Some(ConnectionId::from(CLIENT_CID)), ); builder.pn(0, 1); builder.encode(&[0; 3]); @@ -1145,7 +1152,8 @@ mod tests { assert_eq!(encoder.len(), 45); let first = encoder.clone(); - let mut builder = PacketBuilder::short(encoder, false, ConnectionId::from(SERVER_CID)); + let mut builder = + PacketBuilder::short(encoder, false, Some(ConnectionId::from(SERVER_CID))); builder.pn(1, 3); builder.encode(&[0]); // Minimal size (packet number is big enough). let encoder = builder.build(&mut prot).expect("build"); @@ -1170,8 +1178,8 @@ mod tests { Encoder::new(), PacketType::Handshake, Version::default(), - ConnectionId::from(&[][..]), - ConnectionId::from(&[][..]), + None::<&[u8]>, + None::<&[u8]>, ); builder.pn(0, 1); builder.encode(&[1, 2, 3]); @@ -1189,8 +1197,8 @@ mod tests { Encoder::new(), PacketType::Handshake, Version::default(), - ConnectionId::from(&[][..]), - ConnectionId::from(&[][..]), + None::<&[u8]>, + None::<&[u8]>, ); builder.pn(0, 1); builder.scramble(true); @@ -1210,8 +1218,8 @@ mod tests { Encoder::new(), PacketType::Initial, Version::default(), - ConnectionId::from(&[][..]), - ConnectionId::from(SERVER_CID), + None::<&[u8]>, + Some(ConnectionId::from(SERVER_CID)), ); assert_ne!(builder.remaining(), 0); builder.initial_token(&[]); @@ -1229,7 +1237,7 @@ mod tests { let mut builder = PacketBuilder::short( Encoder::with_capacity(100), true, - ConnectionId::from(SERVER_CID), + Some(ConnectionId::from(SERVER_CID)), ); builder.pn(0, 1); // Pad, but not up to the full capacity. Leave enough space for the @@ -1244,8 +1252,8 @@ mod tests { encoder, PacketType::Initial, Version::default(), - ConnectionId::from(SERVER_CID), - ConnectionId::from(SERVER_CID), + Some(ConnectionId::from(SERVER_CID)), + Some(ConnectionId::from(SERVER_CID)), ); assert_eq!(builder.remaining(), 0); assert_eq!(builder.abort(), encoder_copy); diff --git a/neqo-transport/src/path.rs b/neqo-transport/src/path.rs index ec56ebdb81..9523a392b6 100644 --- a/neqo-transport/src/path.rs +++ b/neqo-transport/src/path.rs @@ -660,8 +660,8 @@ impl Path { /// Get the first local connection ID. /// Only do this for the primary path during the handshake. - pub fn local_cid(&self) -> &ConnectionId { - self.local_cid.as_ref().unwrap() + pub const fn local_cid(&self) -> Option<&ConnectionId> { + self.local_cid.as_ref() } /// Set the remote connection ID based on the peer's choice. @@ -674,8 +674,10 @@ impl Path { } /// Access the remote connection ID. - pub fn remote_cid(&self) -> &ConnectionId { - self.remote_cid.as_ref().unwrap().connection_id() + pub fn remote_cid(&self) -> Option<&ConnectionId> { + self.remote_cid + .as_ref() + .map(super::cid::ConnectionIdEntry::connection_id) } /// Set the stateless reset token for the connection ID that is currently in use. diff --git a/neqo-transport/src/pmtud.rs b/neqo-transport/src/pmtud.rs index 5ee59e3dbf..9eec6b0eda 100644 --- a/neqo-transport/src/pmtud.rs +++ b/neqo-transport/src/pmtud.rs @@ -383,7 +383,7 @@ mod tests { let stats_before = stats.clone(); // Fake a packet number, so the builder logic works. - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let pn = prot.next_pn(); builder.pn(pn, 4); builder.set_initial_limit(&SendProfile::new_limited(pmtud.plpmtu()), 16, pmtud); diff --git a/neqo-transport/src/qlog.rs b/neqo-transport/src/qlog.rs index 29f17bf6b9..fa127212f0 100644 --- a/neqo-transport/src/qlog.rs +++ b/neqo-transport/src/qlog.rs @@ -104,8 +104,8 @@ fn connection_started(qlog: &NeqoQlog, path: &PathRef) { protocol: Some("QUIC".into()), src_port: p.local_address().port().into(), dst_port: p.remote_address().port().into(), - src_cid: Some(format!("{}", p.local_cid())), - dst_cid: Some(format!("{}", p.remote_cid())), + src_cid: p.local_cid().map(ToString::to_string), + dst_cid: p.remote_cid().map(ToString::to_string), }); Some(ev_data) diff --git a/neqo-transport/src/recv_stream.rs b/neqo-transport/src/recv_stream.rs index c4c716f676..7b46a386bc 100644 --- a/neqo-transport/src/recv_stream.rs +++ b/neqo-transport/src/recv_stream.rs @@ -1483,7 +1483,7 @@ mod tests { assert!(s.has_frames_to_write()); // consume it - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut token = Vec::new(); s.write_frame(&mut builder, &mut token, &mut FrameStats::default()); @@ -1597,7 +1597,7 @@ mod tests { s.read(&mut buf).unwrap(); assert!(session_fc.borrow().frame_needed()); // consume it - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut token = Vec::new(); session_fc .borrow_mut() @@ -1618,7 +1618,7 @@ mod tests { s.read(&mut buf).unwrap(); assert!(session_fc.borrow().frame_needed()); // consume it - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut token = Vec::new(); session_fc .borrow_mut() @@ -1866,7 +1866,7 @@ mod tests { assert!(s.fc().unwrap().frame_needed()); // Write the fc update frame - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut token = Vec::new(); let mut stats = FrameStats::default(); fc.borrow_mut() diff --git a/neqo-transport/src/send_stream.rs b/neqo-transport/src/send_stream.rs index a6e42cfdaf..3f0002da13 100644 --- a/neqo-transport/src/send_stream.rs +++ b/neqo-transport/src/send_stream.rs @@ -2596,7 +2596,7 @@ mod tests { ss.insert(StreamId::from(0), s); let mut tokens = Vec::new(); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); // Write a small frame: no fin. let written = builder.len(); @@ -2684,7 +2684,7 @@ mod tests { ss.insert(StreamId::from(0), s); let mut tokens = Vec::new(); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); ss.write_frames( TransmissionPriority::default(), &mut builder, @@ -2762,7 +2762,7 @@ mod tests { assert_eq!(s.next_bytes(false), Some((0, &b"ab"[..]))); // This doesn't report blocking yet. - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut tokens = Vec::new(); let mut stats = FrameStats::default(); s.write_blocked_frame( @@ -2815,7 +2815,7 @@ mod tests { assert_eq!(s.send_atomic(b"abc").unwrap(), 0); // Assert that STREAM_DATA_BLOCKED is sent. - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut tokens = Vec::new(); let mut stats = FrameStats::default(); s.write_blocked_frame( @@ -2902,7 +2902,7 @@ mod tests { s.mark_as_lost(len_u64, 0, true); // No frame should be sent here. - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut tokens = Vec::new(); let mut stats = FrameStats::default(); s.write_stream_frame( @@ -2962,7 +2962,7 @@ mod tests { s.close(); } - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let header_len = builder.len(); builder.set_limit(header_len + space); @@ -3063,7 +3063,7 @@ mod tests { s.send(data).unwrap(); s.close(); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let header_len = builder.len(); // Add 2 for the frame type and stream ID, then add the extra. builder.set_limit(header_len + data.len() + 2 + extra); diff --git a/neqo-transport/src/tracking.rs b/neqo-transport/src/tracking.rs index 90bbd0b54a..b7ab8bac50 100644 --- a/neqo-transport/src/tracking.rs +++ b/neqo-transport/src/tracking.rs @@ -797,7 +797,7 @@ mod tests { } fn write_frame_at(rp: &mut RecvdPackets, now: Instant) { - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); let mut stats = FrameStats::default(); let mut tokens = Vec::new(); rp.write_frame(now, RTT, &mut builder, &mut tokens, &mut stats); @@ -952,7 +952,7 @@ mod tests { #[test] fn drop_spaces() { let mut tracker = AckTracker::default(); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); tracker .get_mut(PacketNumberSpace::Initial) .unwrap() @@ -1017,7 +1017,7 @@ mod tests { .ack_time(now().checked_sub(Duration::from_millis(1)).unwrap()) .is_some()); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); builder.set_limit(10); let mut stats = FrameStats::default(); @@ -1048,7 +1048,7 @@ mod tests { .ack_time(now().checked_sub(Duration::from_millis(1)).unwrap()) .is_some()); - let mut builder = PacketBuilder::short(Encoder::new(), false, []); + let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>); // The code pessimistically assumes that each range needs 16 bytes to express. // So this won't be enough for a second range. builder.set_limit(RecvdPackets::USEFUL_ACK_LEN + 8);