Skip to content

Commit

Permalink
fix: Check whether CIDs are empty (#2034)
Browse files Browse the repository at this point in the history
* fix: Check whether CIDs are empty

WIP

Fixes #1429

* Update neqo-transport/src/path.rs

Co-authored-by: Martin Thomson <mt@lowentropy.net>
Signed-off-by: Lars Eggert <lars@eggert.org>

* Suggestion from @martinthomson

* Update neqo-transport/src/qlog.rs

Co-authored-by: Max Inden <mail@max-inden.de>
Signed-off-by: Lars Eggert <lars@eggert.org>

* Suggestion from @mxinden

@mxinden, is `take()` the way to go here?

* Log error

* Fix test

* Simplify test

---------

Signed-off-by: Lars Eggert <lars@eggert.org>
Co-authored-by: Martin Thomson <mt@lowentropy.net>
Co-authored-by: Max Inden <mail@max-inden.de>
  • Loading branch information
3 people authored Aug 9, 2024
1 parent 78e5a5e commit 477a09a
Show file tree
Hide file tree
Showing 12 changed files with 85 additions and 55 deletions.
13 changes: 8 additions & 5 deletions neqo-transport/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1684,7 +1684,11 @@ impl Connection {
self.paths.make_permanent(path, None, cid);
Ok(())
} else if let Some(primary) = self.paths.primary() {
if primary.borrow().remote_cid().is_empty() {
if primary
.borrow()
.remote_cid()
.map_or(true, |id| id.is_empty())
{
self.paths
.make_permanent(path, None, ConnectionIdEntry::empty_remote());
Ok(())
Expand Down Expand Up @@ -1913,7 +1917,7 @@ impl Connection {
// a packet on a new path, we avoid sending (and the privacy risk) rather
// than reuse a connection ID.
let res = if path.borrow().is_temporary() {
assert!(!cfg!(test), "attempting to close with a temporary path");
qerror!([self], "Attempting to close with a temporary path");
Err(Error::InternalError)
} else {
self.output_path(&path, now, &Some(details))
Expand All @@ -1937,16 +1941,15 @@ impl Connection {
) -> (PacketType, PacketBuilder) {
let pt = PacketType::from(cspace);
let mut builder = if pt == PacketType::Short {
qdebug!("Building Short dcid {}", path.remote_cid());
qdebug!("Building Short dcid {:?}", path.remote_cid());
PacketBuilder::short(encoder, tx.key_phase(), path.remote_cid())
} else {
qdebug!(
"Building {:?} dcid {} scid {}",
"Building {:?} dcid {:?} scid {:?}",
pt,
path.remote_cid(),
path.local_cid(),
);

PacketBuilder::long(encoder, pt, version, path.remote_cid(), path.local_cid())
};
if builder.remaining() > 0 {
Expand Down
2 changes: 1 addition & 1 deletion neqo-transport/src/connection/tests/datagram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ fn datagram_fill() {
let path = p.borrow();
// Minimum overhead is connection ID length, 1 byte short header, 1 byte packet number,
// 1 byte for the DATAGRAM frame type, and 16 bytes for the AEAD.
path.plpmtu() - path.remote_cid().len() - 19
path.plpmtu() - path.remote_cid().unwrap().len() - 19
};
assert!(space >= 64); // Unlikely, but this test depends on the datagram being this large.

Expand Down
2 changes: 1 addition & 1 deletion neqo-transport/src/connection/tests/idle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ fn idle_caching() {
let mut client = default_client();
let mut server = default_server();
let start = now();
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);

// Perform the first round trip, but drop the Initial from the server.
// The client then caches the Handshake packet.
Expand Down
19 changes: 18 additions & 1 deletion neqo-transport/src/connection/tests/migration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,6 @@ impl crate::connection::test_internal::FrameWriter for GarbageWriter {
/// Test the case that we run out of connection ID and receive an invalid frame
/// from a new path.
#[test]
#[should_panic(expected = "attempting to close with a temporary path")]
fn error_on_new_path_with_no_connection_id() {
let mut client = default_client();
let mut server = default_server();
Expand All @@ -967,5 +966,23 @@ fn error_on_new_path_with_no_connection_id() {

// See issue #1697. We had a crash when the client had a temporary path and
// process_output is called.
let closing_frames = client.stats().frame_tx.connection_close;
mem::drop(client.process_output(now()));
assert!(matches!(
client.state(),
State::Closing {
error: CloseReason::Transport(Error::UnknownFrameType),
..
}
));
// Wait until the connection is closed.
let mut now = now();
now += client.process(None, now).callback();
_ = client.process_output(now);
// No closing frames should be sent, and the connection should be closed.
assert_eq!(client.stats().frame_tx.connection_close, closing_frames);
assert!(matches!(
client.state(),
State::Closed(CloseReason::Transport(Error::UnknownFrameType))
));
}
2 changes: 1 addition & 1 deletion neqo-transport/src/fc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ mod test {
fc[StreamType::BiDi].add_retired(1);
fc[StreamType::BiDi].send_flowc_update();
// consume the frame
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let mut tokens = Vec::new();
fc[StreamType::BiDi].write_frames(&mut builder, &mut tokens, &mut FrameStats::default());
assert_eq!(tokens.len(), 1);
Expand Down
56 changes: 32 additions & 24 deletions neqo-transport/src/packet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,19 @@ impl PacketBuilder {
///
/// If, after calling this method, `remaining()` returns 0, then call `abort()` to get
/// the encoder back.
pub fn short(mut encoder: Encoder, key_phase: bool, dcid: impl AsRef<[u8]>) -> Self {
pub fn short(mut encoder: Encoder, key_phase: bool, dcid: Option<impl AsRef<[u8]>>) -> Self {
let mut limit = Self::infer_limit(&encoder);
let header_start = encoder.len();
// Check that there is enough space for the header.
// 5 = 1 (first byte) + 4 (packet number)
if limit > encoder.len() && 5 + dcid.as_ref().len() < limit - encoder.len() {
if limit > encoder.len()
&& 5 + dcid.as_ref().map_or(0, |d| d.as_ref().len()) < limit - encoder.len()
{
encoder
.encode_byte(PACKET_BIT_SHORT | PACKET_BIT_FIXED_QUIC | (u8::from(key_phase) << 2));
encoder.encode(dcid.as_ref());
if let Some(dcid) = dcid {
encoder.encode(dcid.as_ref());
}
} else {
limit = 0;
}
Expand Down Expand Up @@ -185,20 +189,23 @@ impl PacketBuilder {
mut encoder: Encoder,
pt: PacketType,
version: Version,
dcid: impl AsRef<[u8]>,
scid: impl AsRef<[u8]>,
mut dcid: Option<impl AsRef<[u8]>>,
mut scid: Option<impl AsRef<[u8]>>,
) -> Self {
let mut limit = Self::infer_limit(&encoder);
let header_start = encoder.len();
// Check that there is enough space for the header.
// 11 = 1 (first byte) + 4 (version) + 2 (dcid+scid length) + 4 (packet number)
if limit > encoder.len()
&& 11 + dcid.as_ref().len() + scid.as_ref().len() < limit - encoder.len()
&& 11
+ dcid.as_ref().map_or(0, |d| d.as_ref().len())
+ scid.as_ref().map_or(0, |d| d.as_ref().len())
< limit - encoder.len()
{
encoder.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC | pt.to_byte(version) << 4);
encoder.encode_uint(4, version.wire_version());
encoder.encode_vec(1, dcid.as_ref());
encoder.encode_vec(1, scid.as_ref());
encoder.encode_vec(1, dcid.take().as_ref().map_or(&[], AsRef::as_ref));
encoder.encode_vec(1, scid.take().as_ref().map_or(&[], AsRef::as_ref));
} else {
limit = 0;
}
Expand Down Expand Up @@ -994,8 +1001,8 @@ mod tests {
Encoder::new(),
PacketType::Initial,
Version::default(),
ConnectionId::from(&[][..]),
ConnectionId::from(SERVER_CID),
None::<&[u8]>,
Some(ConnectionId::from(SERVER_CID)),
);
builder.initial_token(&[]);
builder.pn(1, 2);
Expand Down Expand Up @@ -1058,7 +1065,7 @@ mod tests {
fn build_short() {
fixture_init();
let mut builder =
PacketBuilder::short(Encoder::new(), true, ConnectionId::from(SERVER_CID));
PacketBuilder::short(Encoder::new(), true, Some(ConnectionId::from(SERVER_CID)));
builder.pn(0, 1);
builder.encode(SAMPLE_SHORT_PAYLOAD); // Enough payload for sampling.
let packet = builder
Expand All @@ -1073,7 +1080,7 @@ mod tests {
let mut firsts = Vec::new();
for _ in 0..64 {
let mut builder =
PacketBuilder::short(Encoder::new(), true, ConnectionId::from(SERVER_CID));
PacketBuilder::short(Encoder::new(), true, Some(ConnectionId::from(SERVER_CID)));
builder.scramble(true);
builder.pn(0, 1);
firsts.push(builder.as_ref()[0]);
Expand Down Expand Up @@ -1136,16 +1143,17 @@ mod tests {
Encoder::new(),
PacketType::Handshake,
Version::default(),
ConnectionId::from(SERVER_CID),
ConnectionId::from(CLIENT_CID),
Some(ConnectionId::from(SERVER_CID)),
Some(ConnectionId::from(CLIENT_CID)),
);
builder.pn(0, 1);
builder.encode(&[0; 3]);
let encoder = builder.build(&mut prot).expect("build");
assert_eq!(encoder.len(), 45);
let first = encoder.clone();

let mut builder = PacketBuilder::short(encoder, false, ConnectionId::from(SERVER_CID));
let mut builder =
PacketBuilder::short(encoder, false, Some(ConnectionId::from(SERVER_CID)));
builder.pn(1, 3);
builder.encode(&[0]); // Minimal size (packet number is big enough).
let encoder = builder.build(&mut prot).expect("build");
Expand All @@ -1170,8 +1178,8 @@ mod tests {
Encoder::new(),
PacketType::Handshake,
Version::default(),
ConnectionId::from(&[][..]),
ConnectionId::from(&[][..]),
None::<&[u8]>,
None::<&[u8]>,
);
builder.pn(0, 1);
builder.encode(&[1, 2, 3]);
Expand All @@ -1189,8 +1197,8 @@ mod tests {
Encoder::new(),
PacketType::Handshake,
Version::default(),
ConnectionId::from(&[][..]),
ConnectionId::from(&[][..]),
None::<&[u8]>,
None::<&[u8]>,
);
builder.pn(0, 1);
builder.scramble(true);
Expand All @@ -1210,8 +1218,8 @@ mod tests {
Encoder::new(),
PacketType::Initial,
Version::default(),
ConnectionId::from(&[][..]),
ConnectionId::from(SERVER_CID),
None::<&[u8]>,
Some(ConnectionId::from(SERVER_CID)),
);
assert_ne!(builder.remaining(), 0);
builder.initial_token(&[]);
Expand All @@ -1229,7 +1237,7 @@ mod tests {
let mut builder = PacketBuilder::short(
Encoder::with_capacity(100),
true,
ConnectionId::from(SERVER_CID),
Some(ConnectionId::from(SERVER_CID)),
);
builder.pn(0, 1);
// Pad, but not up to the full capacity. Leave enough space for the
Expand All @@ -1244,8 +1252,8 @@ mod tests {
encoder,
PacketType::Initial,
Version::default(),
ConnectionId::from(SERVER_CID),
ConnectionId::from(SERVER_CID),
Some(ConnectionId::from(SERVER_CID)),
Some(ConnectionId::from(SERVER_CID)),
);
assert_eq!(builder.remaining(), 0);
assert_eq!(builder.abort(), encoder_copy);
Expand Down
10 changes: 6 additions & 4 deletions neqo-transport/src/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,8 @@ impl Path {

/// Get the first local connection ID.
/// Only do this for the primary path during the handshake.
pub fn local_cid(&self) -> &ConnectionId {
self.local_cid.as_ref().unwrap()
pub const fn local_cid(&self) -> Option<&ConnectionId> {
self.local_cid.as_ref()
}

/// Set the remote connection ID based on the peer's choice.
Expand All @@ -674,8 +674,10 @@ impl Path {
}

/// Access the remote connection ID.
pub fn remote_cid(&self) -> &ConnectionId {
self.remote_cid.as_ref().unwrap().connection_id()
pub fn remote_cid(&self) -> Option<&ConnectionId> {
self.remote_cid
.as_ref()
.map(super::cid::ConnectionIdEntry::connection_id)
}

/// Set the stateless reset token for the connection ID that is currently in use.
Expand Down
2 changes: 1 addition & 1 deletion neqo-transport/src/pmtud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ mod tests {
let stats_before = stats.clone();

// Fake a packet number, so the builder logic works.
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let pn = prot.next_pn();
builder.pn(pn, 4);
builder.set_initial_limit(&SendProfile::new_limited(pmtud.plpmtu()), 16, pmtud);
Expand Down
4 changes: 2 additions & 2 deletions neqo-transport/src/qlog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ fn connection_started(qlog: &NeqoQlog, path: &PathRef) {
protocol: Some("QUIC".into()),
src_port: p.local_address().port().into(),
dst_port: p.remote_address().port().into(),
src_cid: Some(format!("{}", p.local_cid())),
dst_cid: Some(format!("{}", p.remote_cid())),
src_cid: p.local_cid().map(ToString::to_string),
dst_cid: p.remote_cid().map(ToString::to_string),
});

Some(ev_data)
Expand Down
8 changes: 4 additions & 4 deletions neqo-transport/src/recv_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,7 @@ mod tests {
assert!(s.has_frames_to_write());

// consume it
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let mut token = Vec::new();
s.write_frame(&mut builder, &mut token, &mut FrameStats::default());

Expand Down Expand Up @@ -1597,7 +1597,7 @@ mod tests {
s.read(&mut buf).unwrap();
assert!(session_fc.borrow().frame_needed());
// consume it
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let mut token = Vec::new();
session_fc
.borrow_mut()
Expand All @@ -1618,7 +1618,7 @@ mod tests {
s.read(&mut buf).unwrap();
assert!(session_fc.borrow().frame_needed());
// consume it
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let mut token = Vec::new();
session_fc
.borrow_mut()
Expand Down Expand Up @@ -1866,7 +1866,7 @@ mod tests {
assert!(s.fc().unwrap().frame_needed());

// Write the fc update frame
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let mut token = Vec::new();
let mut stats = FrameStats::default();
fc.borrow_mut()
Expand Down
14 changes: 7 additions & 7 deletions neqo-transport/src/send_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2596,7 +2596,7 @@ mod tests {
ss.insert(StreamId::from(0), s);

let mut tokens = Vec::new();
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);

// Write a small frame: no fin.
let written = builder.len();
Expand Down Expand Up @@ -2684,7 +2684,7 @@ mod tests {
ss.insert(StreamId::from(0), s);

let mut tokens = Vec::new();
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
ss.write_frames(
TransmissionPriority::default(),
&mut builder,
Expand Down Expand Up @@ -2762,7 +2762,7 @@ mod tests {
assert_eq!(s.next_bytes(false), Some((0, &b"ab"[..])));

// This doesn't report blocking yet.
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let mut tokens = Vec::new();
let mut stats = FrameStats::default();
s.write_blocked_frame(
Expand Down Expand Up @@ -2815,7 +2815,7 @@ mod tests {
assert_eq!(s.send_atomic(b"abc").unwrap(), 0);

// Assert that STREAM_DATA_BLOCKED is sent.
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let mut tokens = Vec::new();
let mut stats = FrameStats::default();
s.write_blocked_frame(
Expand Down Expand Up @@ -2902,7 +2902,7 @@ mod tests {
s.mark_as_lost(len_u64, 0, true);

// No frame should be sent here.
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let mut tokens = Vec::new();
let mut stats = FrameStats::default();
s.write_stream_frame(
Expand Down Expand Up @@ -2962,7 +2962,7 @@ mod tests {
s.close();
}

let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let header_len = builder.len();
builder.set_limit(header_len + space);

Expand Down Expand Up @@ -3063,7 +3063,7 @@ mod tests {
s.send(data).unwrap();
s.close();

let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let header_len = builder.len();
// Add 2 for the frame type and stream ID, then add the extra.
builder.set_limit(header_len + data.len() + 2 + extra);
Expand Down
Loading

0 comments on commit 477a09a

Please sign in to comment.