Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Check whether CIDs are empty #2034

Merged
merged 9 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions neqo-transport/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1684,7 +1684,11 @@ impl Connection {
self.paths.make_permanent(path, None, cid);
Ok(())
} else if let Some(primary) = self.paths.primary() {
if primary.borrow().remote_cid().is_empty() {
if primary
.borrow()
.remote_cid()
.map_or(true, |id| id.is_empty())
{
self.paths
.make_permanent(path, None, ConnectionIdEntry::empty_remote());
Ok(())
Expand Down Expand Up @@ -1913,7 +1917,7 @@ impl Connection {
// a packet on a new path, we avoid sending (and the privacy risk) rather
// than reuse a connection ID.
let res = if path.borrow().is_temporary() {
assert!(!cfg!(test), "attempting to close with a temporary path");
qerror!([self], "Attempting to close with a temporary path");
Err(Error::InternalError)
} else {
self.output_path(&path, now, &Some(details))
Expand All @@ -1937,16 +1941,15 @@ impl Connection {
) -> (PacketType, PacketBuilder) {
let pt = PacketType::from(cspace);
let mut builder = if pt == PacketType::Short {
qdebug!("Building Short dcid {}", path.remote_cid());
qdebug!("Building Short dcid {:?}", path.remote_cid());
PacketBuilder::short(encoder, tx.key_phase(), path.remote_cid())
} else {
qdebug!(
"Building {:?} dcid {} scid {}",
"Building {:?} dcid {:?} scid {:?}",
pt,
path.remote_cid(),
path.local_cid(),
);

PacketBuilder::long(encoder, pt, version, path.remote_cid(), path.local_cid())
};
if builder.remaining() > 0 {
Expand Down
2 changes: 1 addition & 1 deletion neqo-transport/src/connection/tests/datagram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ fn datagram_fill() {
let path = p.borrow();
// Minimum overhead is connection ID length, 1 byte short header, 1 byte packet number,
// 1 byte for the DATAGRAM frame type, and 16 bytes for the AEAD.
path.plpmtu() - path.remote_cid().len() - 19
path.plpmtu() - path.remote_cid().unwrap().len() - 19
};
assert!(space >= 64); // Unlikely, but this test depends on the datagram being this large.

Expand Down
2 changes: 1 addition & 1 deletion neqo-transport/src/connection/tests/idle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ fn idle_caching() {
let mut client = default_client();
let mut server = default_server();
let start = now();
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);

// Perform the first round trip, but drop the Initial from the server.
// The client then caches the Handshake packet.
Expand Down
19 changes: 18 additions & 1 deletion neqo-transport/src/connection/tests/migration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,6 @@ impl crate::connection::test_internal::FrameWriter for GarbageWriter {
/// Test the case that we run out of connection ID and receive an invalid frame
/// from a new path.
#[test]
#[should_panic(expected = "attempting to close with a temporary path")]
fn error_on_new_path_with_no_connection_id() {
let mut client = default_client();
let mut server = default_server();
Expand All @@ -967,5 +966,23 @@ fn error_on_new_path_with_no_connection_id() {

// See issue #1697. We had a crash when the client had a temporary path and
// process_output is called.
let closing_frames = client.stats().frame_tx.connection_close;
mem::drop(client.process_output(now()));
assert!(matches!(
client.state(),
State::Closing {
error: CloseReason::Transport(Error::UnknownFrameType),
..
}
));
// Wait until the connection is closed.
let mut now = now();
now += client.process(None, now).callback();
_ = client.process_output(now);
// No closing frames should be sent, and the connection should be closed.
assert_eq!(client.stats().frame_tx.connection_close, closing_frames);
assert!(matches!(
client.state(),
State::Closed(CloseReason::Transport(Error::UnknownFrameType))
));
}
2 changes: 1 addition & 1 deletion neqo-transport/src/fc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ mod test {
fc[StreamType::BiDi].add_retired(1);
fc[StreamType::BiDi].send_flowc_update();
// consume the frame
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let mut tokens = Vec::new();
fc[StreamType::BiDi].write_frames(&mut builder, &mut tokens, &mut FrameStats::default());
assert_eq!(tokens.len(), 1);
Expand Down
56 changes: 32 additions & 24 deletions neqo-transport/src/packet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,19 @@ impl PacketBuilder {
///
/// If, after calling this method, `remaining()` returns 0, then call `abort()` to get
/// the encoder back.
pub fn short(mut encoder: Encoder, key_phase: bool, dcid: impl AsRef<[u8]>) -> Self {
pub fn short(mut encoder: Encoder, key_phase: bool, dcid: Option<impl AsRef<[u8]>>) -> Self {
let mut limit = Self::infer_limit(&encoder);
let header_start = encoder.len();
// Check that there is enough space for the header.
// 5 = 1 (first byte) + 4 (packet number)
if limit > encoder.len() && 5 + dcid.as_ref().len() < limit - encoder.len() {
if limit > encoder.len()
&& 5 + dcid.as_ref().map_or(0, |d| d.as_ref().len()) < limit - encoder.len()
{
encoder
.encode_byte(PACKET_BIT_SHORT | PACKET_BIT_FIXED_QUIC | (u8::from(key_phase) << 2));
encoder.encode(dcid.as_ref());
if let Some(dcid) = dcid {
encoder.encode(dcid.as_ref());
}
} else {
limit = 0;
}
Expand Down Expand Up @@ -185,20 +189,23 @@ impl PacketBuilder {
mut encoder: Encoder,
pt: PacketType,
version: Version,
dcid: impl AsRef<[u8]>,
scid: impl AsRef<[u8]>,
mut dcid: Option<impl AsRef<[u8]>>,
mut scid: Option<impl AsRef<[u8]>>,
) -> Self {
let mut limit = Self::infer_limit(&encoder);
let header_start = encoder.len();
// Check that there is enough space for the header.
// 11 = 1 (first byte) + 4 (version) + 2 (dcid+scid length) + 4 (packet number)
if limit > encoder.len()
&& 11 + dcid.as_ref().len() + scid.as_ref().len() < limit - encoder.len()
&& 11
+ dcid.as_ref().map_or(0, |d| d.as_ref().len())
+ scid.as_ref().map_or(0, |d| d.as_ref().len())
< limit - encoder.len()
{
encoder.encode_byte(PACKET_BIT_LONG | PACKET_BIT_FIXED_QUIC | pt.to_byte(version) << 4);
encoder.encode_uint(4, version.wire_version());
encoder.encode_vec(1, dcid.as_ref());
encoder.encode_vec(1, scid.as_ref());
encoder.encode_vec(1, dcid.take().as_ref().map_or(&[], AsRef::as_ref));
encoder.encode_vec(1, scid.take().as_ref().map_or(&[], AsRef::as_ref));
} else {
limit = 0;
}
Expand Down Expand Up @@ -994,8 +1001,8 @@ mod tests {
Encoder::new(),
PacketType::Initial,
Version::default(),
ConnectionId::from(&[][..]),
ConnectionId::from(SERVER_CID),
None::<&[u8]>,
Some(ConnectionId::from(SERVER_CID)),
);
builder.initial_token(&[]);
builder.pn(1, 2);
Expand Down Expand Up @@ -1058,7 +1065,7 @@ mod tests {
fn build_short() {
fixture_init();
let mut builder =
PacketBuilder::short(Encoder::new(), true, ConnectionId::from(SERVER_CID));
PacketBuilder::short(Encoder::new(), true, Some(ConnectionId::from(SERVER_CID)));
builder.pn(0, 1);
builder.encode(SAMPLE_SHORT_PAYLOAD); // Enough payload for sampling.
let packet = builder
Expand All @@ -1073,7 +1080,7 @@ mod tests {
let mut firsts = Vec::new();
for _ in 0..64 {
let mut builder =
PacketBuilder::short(Encoder::new(), true, ConnectionId::from(SERVER_CID));
PacketBuilder::short(Encoder::new(), true, Some(ConnectionId::from(SERVER_CID)));
builder.scramble(true);
builder.pn(0, 1);
firsts.push(builder.as_ref()[0]);
Expand Down Expand Up @@ -1136,16 +1143,17 @@ mod tests {
Encoder::new(),
PacketType::Handshake,
Version::default(),
ConnectionId::from(SERVER_CID),
ConnectionId::from(CLIENT_CID),
Some(ConnectionId::from(SERVER_CID)),
Some(ConnectionId::from(CLIENT_CID)),
);
builder.pn(0, 1);
builder.encode(&[0; 3]);
let encoder = builder.build(&mut prot).expect("build");
assert_eq!(encoder.len(), 45);
let first = encoder.clone();

let mut builder = PacketBuilder::short(encoder, false, ConnectionId::from(SERVER_CID));
let mut builder =
PacketBuilder::short(encoder, false, Some(ConnectionId::from(SERVER_CID)));
builder.pn(1, 3);
builder.encode(&[0]); // Minimal size (packet number is big enough).
let encoder = builder.build(&mut prot).expect("build");
Expand All @@ -1170,8 +1178,8 @@ mod tests {
Encoder::new(),
PacketType::Handshake,
Version::default(),
ConnectionId::from(&[][..]),
ConnectionId::from(&[][..]),
None::<&[u8]>,
None::<&[u8]>,
);
builder.pn(0, 1);
builder.encode(&[1, 2, 3]);
Expand All @@ -1189,8 +1197,8 @@ mod tests {
Encoder::new(),
PacketType::Handshake,
Version::default(),
ConnectionId::from(&[][..]),
ConnectionId::from(&[][..]),
None::<&[u8]>,
None::<&[u8]>,
);
builder.pn(0, 1);
builder.scramble(true);
Expand All @@ -1210,8 +1218,8 @@ mod tests {
Encoder::new(),
PacketType::Initial,
Version::default(),
ConnectionId::from(&[][..]),
ConnectionId::from(SERVER_CID),
None::<&[u8]>,
Some(ConnectionId::from(SERVER_CID)),
);
assert_ne!(builder.remaining(), 0);
builder.initial_token(&[]);
Expand All @@ -1229,7 +1237,7 @@ mod tests {
let mut builder = PacketBuilder::short(
Encoder::with_capacity(100),
true,
ConnectionId::from(SERVER_CID),
Some(ConnectionId::from(SERVER_CID)),
);
builder.pn(0, 1);
// Pad, but not up to the full capacity. Leave enough space for the
Expand All @@ -1244,8 +1252,8 @@ mod tests {
encoder,
PacketType::Initial,
Version::default(),
ConnectionId::from(SERVER_CID),
ConnectionId::from(SERVER_CID),
Some(ConnectionId::from(SERVER_CID)),
Some(ConnectionId::from(SERVER_CID)),
);
assert_eq!(builder.remaining(), 0);
assert_eq!(builder.abort(), encoder_copy);
Expand Down
10 changes: 6 additions & 4 deletions neqo-transport/src/path.rs
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,8 @@ impl Path {

/// Get the first local connection ID.
/// Only do this for the primary path during the handshake.
pub fn local_cid(&self) -> &ConnectionId {
self.local_cid.as_ref().unwrap()
pub const fn local_cid(&self) -> Option<&ConnectionId> {
self.local_cid.as_ref()
}

/// Set the remote connection ID based on the peer's choice.
Expand All @@ -674,8 +674,10 @@ impl Path {
}

/// Access the remote connection ID.
pub fn remote_cid(&self) -> &ConnectionId {
self.remote_cid.as_ref().unwrap().connection_id()
pub fn remote_cid(&self) -> Option<&ConnectionId> {
self.remote_cid
.as_ref()
.map(super::cid::ConnectionIdEntry::connection_id)
}

/// Set the stateless reset token for the connection ID that is currently in use.
Expand Down
2 changes: 1 addition & 1 deletion neqo-transport/src/pmtud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ mod tests {
let stats_before = stats.clone();

// Fake a packet number, so the builder logic works.
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let pn = prot.next_pn();
builder.pn(pn, 4);
builder.set_initial_limit(&SendProfile::new_limited(pmtud.plpmtu()), 16, pmtud);
Expand Down
4 changes: 2 additions & 2 deletions neqo-transport/src/qlog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ fn connection_started(qlog: &NeqoQlog, path: &PathRef) {
protocol: Some("QUIC".into()),
src_port: p.local_address().port().into(),
dst_port: p.remote_address().port().into(),
src_cid: Some(format!("{}", p.local_cid())),
dst_cid: Some(format!("{}", p.remote_cid())),
src_cid: p.local_cid().map(ToString::to_string),
dst_cid: p.remote_cid().map(ToString::to_string),
});

Some(ev_data)
Expand Down
8 changes: 4 additions & 4 deletions neqo-transport/src/recv_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,7 @@ mod tests {
assert!(s.has_frames_to_write());

// consume it
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let mut token = Vec::new();
s.write_frame(&mut builder, &mut token, &mut FrameStats::default());

Expand Down Expand Up @@ -1597,7 +1597,7 @@ mod tests {
s.read(&mut buf).unwrap();
assert!(session_fc.borrow().frame_needed());
// consume it
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let mut token = Vec::new();
session_fc
.borrow_mut()
Expand All @@ -1618,7 +1618,7 @@ mod tests {
s.read(&mut buf).unwrap();
assert!(session_fc.borrow().frame_needed());
// consume it
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let mut token = Vec::new();
session_fc
.borrow_mut()
Expand Down Expand Up @@ -1866,7 +1866,7 @@ mod tests {
assert!(s.fc().unwrap().frame_needed());

// Write the fc update frame
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let mut token = Vec::new();
let mut stats = FrameStats::default();
fc.borrow_mut()
Expand Down
14 changes: 7 additions & 7 deletions neqo-transport/src/send_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2596,7 +2596,7 @@ mod tests {
ss.insert(StreamId::from(0), s);

let mut tokens = Vec::new();
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);

// Write a small frame: no fin.
let written = builder.len();
Expand Down Expand Up @@ -2684,7 +2684,7 @@ mod tests {
ss.insert(StreamId::from(0), s);

let mut tokens = Vec::new();
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
ss.write_frames(
TransmissionPriority::default(),
&mut builder,
Expand Down Expand Up @@ -2762,7 +2762,7 @@ mod tests {
assert_eq!(s.next_bytes(false), Some((0, &b"ab"[..])));

// This doesn't report blocking yet.
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let mut tokens = Vec::new();
let mut stats = FrameStats::default();
s.write_blocked_frame(
Expand Down Expand Up @@ -2815,7 +2815,7 @@ mod tests {
assert_eq!(s.send_atomic(b"abc").unwrap(), 0);

// Assert that STREAM_DATA_BLOCKED is sent.
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let mut tokens = Vec::new();
let mut stats = FrameStats::default();
s.write_blocked_frame(
Expand Down Expand Up @@ -2902,7 +2902,7 @@ mod tests {
s.mark_as_lost(len_u64, 0, true);

// No frame should be sent here.
let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let mut tokens = Vec::new();
let mut stats = FrameStats::default();
s.write_stream_frame(
Expand Down Expand Up @@ -2962,7 +2962,7 @@ mod tests {
s.close();
}

let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let header_len = builder.len();
builder.set_limit(header_len + space);

Expand Down Expand Up @@ -3063,7 +3063,7 @@ mod tests {
s.send(data).unwrap();
s.close();

let mut builder = PacketBuilder::short(Encoder::new(), false, []);
let mut builder = PacketBuilder::short(Encoder::new(), false, None::<&[u8]>);
let header_len = builder.len();
// Add 2 for the frame type and stream ID, then add the extra.
builder.set_limit(header_len + data.len() + 2 + extra);
Expand Down
Loading
Loading