Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce API to safely initialize Packets #3533

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion ci/test-miri.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ source ci/_
source ci/rust-version.sh nightly

# miri is very slow; so only run very few of selective tests!
_ cargo "+${rust_nightly}" miri test -p solana-packet -- test_packet_buffer_writer
_ cargo "+${rust_nightly}" miri test -p solana-program -- hash:: account_info::

_ cargo "+${rust_nightly}" miri test -p solana-unified-scheduler-logic

# run intentionally-#[ignored] ub triggering tests for each to make sure they fail
Expand Down
20 changes: 7 additions & 13 deletions core/src/shred_fetch_stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use {
clock::{Slot, DEFAULT_MS_PER_SLOT},
epoch_schedule::EpochSchedule,
genesis_config::ClusterType,
packet::{Meta, PACKET_DATA_SIZE},
packet::{Packet, PACKET_DATA_SIZE},
pubkey::Pubkey,
},
solana_streamer::streamer::{self, PacketBatchReceiver, StreamerReceiveStats},
Expand Down Expand Up @@ -350,29 +350,23 @@ pub(crate) fn receive_quic_datagrams(
};
let mut packet_batch =
PacketBatch::new_with_recycler(&recycler, PACKETS_PER_BATCH, "receive_quic_datagrams");
unsafe {
packet_batch.set_len(PACKETS_PER_BATCH);
};
Comment on lines -353 to -355

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is removing one unsafe but then adding two new ones.
Why is the new code better than the old one?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://doc.rust-lang.org/std/vec/struct.Vec.html#method.set_len

Safety

  • new_len must be less than or equal to capacity().
  • The elements at old_len..new_len must be initialized.

We are not initializing the data first so we are violating the second bullet

let deadline = Instant::now() + PACKET_COALESCE_DURATION;
let entries = std::iter::once(entry).chain(
std::iter::repeat_with(|| quic_datagrams_receiver.recv_deadline(deadline).ok())
.while_some(),
);
let size = entries
.filter(|(_, _, bytes)| bytes.len() <= PACKET_DATA_SIZE)
.zip(packet_batch.iter_mut())
.zip(packet_batch.spare_capacity_mut().iter_mut())
.map(|((_pubkey, addr, bytes), packet)| {
*packet.meta_mut() = Meta {
size: bytes.len(),
addr: addr.ip(),
port: addr.port(),
flags,
};
packet.buffer_mut()[..bytes.len()].copy_from_slice(&bytes);
Packet::init_packet_from_bytes(packet, &bytes, Some(&addr)).unwrap();
// SAFETY: Packet::init_packet_from_bytes() just initialized the packet
unsafe { packet.assume_init_mut().meta_mut().set_flags(flags) };
})
.count();
if size > 0 {
packet_batch.truncate(size);
// SAFETY: By now, size packets have been initialized
unsafe { packet_batch.set_len(size) };
if sender.send(packet_batch).is_err() {
return; // The receiver end of the channel is disconnected.
}
Expand Down
29 changes: 12 additions & 17 deletions entry/src/entry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use {
solana_runtime_transaction::runtime_transaction::RuntimeTransaction,
solana_sdk::{
hash::Hash,
packet::Meta,
transaction::{
Result, SanitizedTransaction, Transaction, TransactionError,
TransactionVerificationMode, VersionedTransaction,
Expand Down Expand Up @@ -548,26 +547,22 @@ fn start_verify_transactions_gpu(
num_transactions,
"entry-sig-verify",
);
// We use set_len here instead of resize(num_txs, Packet::default()), to save
// memory bandwidth and avoid writing a large amount of data that will be overwritten
// soon afterwards. As well, Packet::default() actually leaves the packet data
// uninitialized, so the initialization would simply write junk into
// the vector anyway.
unsafe {
packet_batch.set_len(num_transactions);
}

let uninitialized_packets = packet_batch.spare_capacity_mut().iter_mut();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't we need assume_init somewhere below?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not, set_len() does the work for us at the end:

  • assume_init() yields a T from a MaybeUninit<T>
  • We are starting with a Vec<Packet> (the type under the hood for PacketBatch)
  • packet_batch.spare_capacity_mut() allows us to access elements at index i where vec_length <= i < vec_capacity
  • We initialize the elements in place, so calling set_len() is saying "these are valid elements of the Vec now and can be accessed normally; also drop them normally when dropping the Vec`

let transaction_iter = transaction_chunk
.iter()
.map(|tx| tx.to_versioned_transaction());

let res = packet_batch
.iter_mut()
.zip(transaction_iter)
.all(|(packet, tx)| {
*packet.meta_mut() = Meta::default();
Packet::populate_packet(packet, None, &tx).is_ok()
});
if res {
let all_packets_initialized =
uninitialized_packets
.zip(transaction_iter)
.all(|(uninit_packet, tx)| {
Packet::init_packet_from_data(uninit_packet, &tx, None).is_ok()
});

if all_packets_initialized {
// SAFETY: All packets have been successfully initialized
unsafe { packet_batch.set_len(num_transactions) };
Ok(packet_batch)
} else {
Err(TransactionError::SanitizeFailure)
Expand Down
5 changes: 5 additions & 0 deletions perf/src/cuda_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use {
rayon::prelude::*,
serde::{Deserialize, Serialize},
std::{
mem::MaybeUninit,
ops::{Index, IndexMut},
os::raw::c_int,
slice::{Iter, IterMut, SliceIndex},
Expand Down Expand Up @@ -129,6 +130,10 @@ impl<T: Clone + Default + Sized> PinnedVec<T> {
self.x.iter_mut()
}

pub fn spare_capacity_mut(&mut self) -> &mut [MaybeUninit<T>] {
self.x.spare_capacity_mut()
}

pub fn capacity(&self) -> usize {
self.x.capacity()
}
Expand Down
5 changes: 5 additions & 0 deletions perf/src/packet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use {
serde::{de::DeserializeOwned, Deserialize, Serialize},
std::{
io::Read,
mem::MaybeUninit,
net::SocketAddr,
ops::{Index, IndexMut},
slice::{Iter, IterMut, SliceIndex},
Expand Down Expand Up @@ -152,6 +153,10 @@ impl PacketBatch {
self.packets.iter_mut()
}

pub fn spare_capacity_mut(&mut self) -> &mut [MaybeUninit<Packet>] {
self.packets.spare_capacity_mut()
}

/// See Vector::set_len() for more details
///
/// # Safety
Expand Down
167 changes: 163 additions & 4 deletions sdk/packet/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ use {
bitflags::bitflags,
std::{
fmt,
io::{self, Write},
mem,
net::{IpAddr, Ipv4Addr, SocketAddr},
ptr,
slice::SliceIndex,
},
};
Expand Down Expand Up @@ -158,11 +161,76 @@ impl Packet {
&mut self.meta
}

#[cfg(feature = "bincode")]
/// Initializes a std::mem::MaybeUninit<Packet> such that the Packet can
/// be safely extracted via methods such as MaybeUninit::assume_init()
pub fn init_packet_from_data<T: serde::Serialize>(
packet: &mut mem::MaybeUninit<Packet>,
data: &T,
addr: Option<&SocketAddr>,
) -> Result<()> {
let mut writer = PacketWriter::new_from_uninit_packet(packet);
bincode::serialize_into(&mut writer, data)?;

let serialized_size = writer.position();
let (ip, port) = if let Some(addr) = addr {
(addr.ip(), addr.port())
} else {
(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
};
Self::init_packet_meta(
packet,
Meta {
size: serialized_size,
addr: ip,
port,
flags: PacketFlags::empty(),
},
);

Ok(())
}

pub fn init_packet_from_bytes(
packet: &mut mem::MaybeUninit<Packet>,
bytes: &[u8],
addr: Option<&SocketAddr>,
) -> io::Result<()> {
let mut writer = PacketWriter::new_from_uninit_packet(packet);
let num_bytes_written = writer.write(bytes)?;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably use Write::write_all.
Write::write is not meant to write the entire buffer.
And there is no need to rely on the implementation details of PacketWriter.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And there is no need to rely on the implementation details of PacketWriter.

Yeah, that's fair and also simplifies things on the caller side (ie no longer need the debug_assert); will make this change

debug_assert_eq!(bytes.len(), num_bytes_written);

let size = writer.position();
let (ip, port) = if let Some(addr) = addr {
(addr.ip(), addr.port())
} else {
(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0)
};
Self::init_packet_meta(
packet,
Meta {
size,
addr: ip,
port,
flags: PacketFlags::empty(),
},
);

Ok(())
}

fn init_packet_meta(packet: &mut mem::MaybeUninit<Packet>, meta: Meta) {
// SAFETY: Access the field by pointer as creating a reference to
// and/or within the uninitialized Packet is undefined behavior
unsafe { ptr::addr_of_mut!((*packet.as_mut_ptr()).meta).write(meta) };
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't there a new syntax for getting ptr of a field introduced in 1.82?

Iirc all this concern came from the upgrade to that version

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't there a new syntax for getting ptr of a field introduced in 1.82?

Oh nice, didn't know about this; thanks for mentioning and will read up on it

irc all this concern came from the upgrade to that version

Yep, validator was panicking with 1.82. The panic was addressed with #3325, so we could hypothetically go back to 1.82 to take advantage of the new syntax (&raw) in 1.82

}

#[cfg(feature = "bincode")]
pub fn from_data<T: serde::Serialize>(dest: Option<&SocketAddr>, data: T) -> Result<Self> {
let mut packet = Self::default();
Self::populate_packet(&mut packet, dest, &data)?;
Ok(packet)
let mut packet = mem::MaybeUninit::uninit();
Self::init_packet_from_data(&mut packet, &data, dest)?;
// SAFETY: init_packet_from_data() just initialized the packet
unsafe { Ok(packet.assume_init()) }
}

#[cfg(feature = "bincode")]
Expand Down Expand Up @@ -224,6 +292,61 @@ impl PartialEq for Packet {
}
}

/// A custom implementation of io::Write to facilitate safe (non-UB)
/// initialization of a MaybeUninit<Packet>
struct PacketWriter {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not simple wrapper around MaybeUninit packet?

We know the capacity of the buffer, and can determine remaining bytes from the current length and the fixed capacity

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We know the capacity of the buffer, and can determine remaining bytes from the current length and the fixed capacity

The main motivation for writing this wrapper was to have something that implements std::io::Write that we could pass to bincode::serialize_into().

If you drill down into bincode, writer.write() might get called repeatedly for one invocation of bincode::serialize_into(). Thus, we need to track how many bytes we have written after each call to write; we don't have the ability to update packet.meta (which may not have been initialized yet) as we go.

// A pointer to the current write position
position: *mut u8,
// The number of remaining bytes that can be written to
spare_capacity: usize,
}

impl PacketWriter {
fn new_from_uninit_packet(packet: &mut mem::MaybeUninit<Packet>) -> Self {
// SAFETY: Access the field by pointer as creating a reference to
// and/or within the uninitialized Packet is undefined behavior
let position = unsafe { ptr::addr_of_mut!((*packet.as_mut_ptr()).buffer) as *mut u8 };
let spare_capacity = PACKET_DATA_SIZE;

Self {
position,
spare_capacity,
}
}

/// The offset of the write pointer within the buffer, which is also the
/// number of bytes that have been written
fn position(&self) -> usize {
PACKET_DATA_SIZE.saturating_sub(self.spare_capacity)
}
}

impl io::Write for PacketWriter {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if buf.len() > self.spare_capacity {
return Err(io::Error::from(io::ErrorKind::WriteZero));
}

// SAFETY: We previously verifed that buf.len() <= self.spare_capacity
// so this write will not push us past the end of the buffer. Likewise,
// we can update self.spare_capacity without fear of overflow
unsafe {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all these new instances of unsafe are not ideal.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that we should be very stingy with our use of unsafe. However, our current code has the potential for UB which is even less ideal than unsafe's

ptr::copy_nonoverlapping(buf.as_ptr(), self.position, buf.len());
// Update position and spare_capacity for the next call to write()
self.position = self.position.add(buf.len());
self.spare_capacity = self.spare_capacity.saturating_sub(buf.len());
}

Ok(buf.len())
}

#[inline]
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}

impl Meta {
pub fn socket_addr(&self) -> SocketAddr {
SocketAddr::new(self.addr, self.port)
Expand All @@ -239,6 +362,11 @@ impl Meta {
.set(PacketFlags::FROM_STAKED_NODE, from_staked_node);
}

#[inline]
pub fn set_flags(&mut self, flags: PacketFlags) {
self.flags = flags;
}

#[inline]
pub fn discard(&self) -> bool {
self.flags.contains(PacketFlags::DISCARD)
Expand Down Expand Up @@ -309,7 +437,7 @@ impl Default for Meta {

#[cfg(test)]
mod tests {
use super::*;
use {super::*, std::io::Write};

#[test]
fn test_deserialize_slice() {
Expand Down Expand Up @@ -349,4 +477,35 @@ mod tests {
Err("the size limit has been reached".to_string()),
);
}

#[test]
fn test_packet_buffer_writer() {
let mut packet = mem::MaybeUninit::<Packet>::uninit();
let mut writer = PacketWriter::new_from_uninit_packet(&mut packet);
let total_capacity = writer.spare_capacity;
assert_eq!(total_capacity, PACKET_DATA_SIZE);
let payload: [u8; PACKET_DATA_SIZE] = std::array::from_fn(|i| i as u8 % 255);

// Write 1200 bytes (1200 total)
let num_to_write = PACKET_DATA_SIZE - 32;
assert_eq!(
num_to_write,
writer.write(&payload[..num_to_write]).unwrap()
);
assert_eq!(num_to_write, writer.position());
// Write 28 bytes (1228 total)
assert_eq!(28, writer.write(&payload[1200..1200 + 28]).unwrap());
assert_eq!(1200 + 28, writer.position());
// Attempt to write 5 bytes (1233 total) which exceeds buffer capacity
assert!(writer
.write(&payload[1200 + 28 - 1..PACKET_DATA_SIZE])
.is_err());
// writer.position() remains unchanged
assert_eq!(1200 + 28, writer.position());
// Write 4 bytes (1232 total) to fill buffer
assert_eq!(4, writer.write(&payload[1200 + 28..]).unwrap());
assert_eq!(PACKET_DATA_SIZE, writer.position());
// Writing any amount of bytes will fail on the already full buffer
assert!(writer.write(&[0]).is_err());
}
}
Loading