-
Notifications
You must be signed in to change notification settings - Fork 256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce API to safely initialize Packets #3533
base: master
Are you sure you want to change the base?
Changes from all commits
140909f
ad3760f
792f154
29814b2
111a86b
adda120
bf90368
c904a9b
0856264
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,6 @@ use { | |
solana_runtime_transaction::runtime_transaction::RuntimeTransaction, | ||
solana_sdk::{ | ||
hash::Hash, | ||
packet::Meta, | ||
transaction::{ | ||
Result, SanitizedTransaction, Transaction, TransactionError, | ||
TransactionVerificationMode, VersionedTransaction, | ||
|
@@ -548,26 +547,22 @@ fn start_verify_transactions_gpu( | |
num_transactions, | ||
"entry-sig-verify", | ||
); | ||
// We use set_len here instead of resize(num_txs, Packet::default()), to save | ||
// memory bandwidth and avoid writing a large amount of data that will be overwritten | ||
// soon afterwards. As well, Packet::default() actually leaves the packet data | ||
// uninitialized, so the initialization would simply write junk into | ||
// the vector anyway. | ||
unsafe { | ||
packet_batch.set_len(num_transactions); | ||
} | ||
|
||
let uninitialized_packets = packet_batch.spare_capacity_mut().iter_mut(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do not,
|
||
let transaction_iter = transaction_chunk | ||
.iter() | ||
.map(|tx| tx.to_versioned_transaction()); | ||
|
||
let res = packet_batch | ||
.iter_mut() | ||
.zip(transaction_iter) | ||
.all(|(packet, tx)| { | ||
*packet.meta_mut() = Meta::default(); | ||
Packet::populate_packet(packet, None, &tx).is_ok() | ||
}); | ||
if res { | ||
let all_packets_initialized = | ||
uninitialized_packets | ||
.zip(transaction_iter) | ||
.all(|(uninit_packet, tx)| { | ||
Packet::init_packet_from_data(uninit_packet, &tx, None).is_ok() | ||
}); | ||
|
||
if all_packets_initialized { | ||
// SAFETY: All packets have been successfully initialized | ||
unsafe { packet_batch.set_len(num_transactions) }; | ||
Ok(packet_batch) | ||
} else { | ||
Err(TransactionError::SanitizeFailure) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,10 @@ use { | |
bitflags::bitflags, | ||
std::{ | ||
fmt, | ||
io::{self, Write}, | ||
mem, | ||
net::{IpAddr, Ipv4Addr, SocketAddr}, | ||
ptr, | ||
slice::SliceIndex, | ||
}, | ||
}; | ||
|
@@ -158,11 +161,76 @@ impl Packet { | |
&mut self.meta | ||
} | ||
|
||
#[cfg(feature = "bincode")] | ||
/// Initializes a std::mem::MaybeUninit<Packet> such that the Packet can | ||
/// be safely extracted via methods such as MaybeUninit::assume_init() | ||
pub fn init_packet_from_data<T: serde::Serialize>( | ||
packet: &mut mem::MaybeUninit<Packet>, | ||
data: &T, | ||
addr: Option<&SocketAddr>, | ||
) -> Result<()> { | ||
let mut writer = PacketWriter::new_from_uninit_packet(packet); | ||
bincode::serialize_into(&mut writer, data)?; | ||
|
||
let serialized_size = writer.position(); | ||
let (ip, port) = if let Some(addr) = addr { | ||
(addr.ip(), addr.port()) | ||
} else { | ||
(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) | ||
}; | ||
Self::init_packet_meta( | ||
packet, | ||
Meta { | ||
size: serialized_size, | ||
addr: ip, | ||
port, | ||
flags: PacketFlags::empty(), | ||
}, | ||
); | ||
|
||
Ok(()) | ||
} | ||
|
||
pub fn init_packet_from_bytes( | ||
packet: &mut mem::MaybeUninit<Packet>, | ||
bytes: &[u8], | ||
addr: Option<&SocketAddr>, | ||
) -> io::Result<()> { | ||
let mut writer = PacketWriter::new_from_uninit_packet(packet); | ||
let num_bytes_written = writer.write(bytes)?; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should probably use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yeah, that's fair and also simplifies things on the caller side (ie no longer need the |
||
debug_assert_eq!(bytes.len(), num_bytes_written); | ||
|
||
let size = writer.position(); | ||
let (ip, port) = if let Some(addr) = addr { | ||
(addr.ip(), addr.port()) | ||
} else { | ||
(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) | ||
}; | ||
Self::init_packet_meta( | ||
packet, | ||
Meta { | ||
size, | ||
addr: ip, | ||
port, | ||
flags: PacketFlags::empty(), | ||
}, | ||
); | ||
|
||
Ok(()) | ||
} | ||
|
||
fn init_packet_meta(packet: &mut mem::MaybeUninit<Packet>, meta: Meta) { | ||
// SAFETY: Access the field by pointer as creating a reference to | ||
// and/or within the uninitialized Packet is undefined behavior | ||
unsafe { ptr::addr_of_mut!((*packet.as_mut_ptr()).meta).write(meta) }; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wasn't there a new syntax for getting ptr of a field introduced in 1.82? Iirc all this concern came from the upgrade to that version There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Oh nice, didn't know about this; thanks for mentioning and will read up on it
Yep, validator was panicking with 1.82. The panic was addressed with #3325, so we could hypothetically go back to 1.82 to take advantage of the new syntax ( |
||
} | ||
|
||
#[cfg(feature = "bincode")] | ||
pub fn from_data<T: serde::Serialize>(dest: Option<&SocketAddr>, data: T) -> Result<Self> { | ||
let mut packet = Self::default(); | ||
Self::populate_packet(&mut packet, dest, &data)?; | ||
Ok(packet) | ||
let mut packet = mem::MaybeUninit::uninit(); | ||
Self::init_packet_from_data(&mut packet, &data, dest)?; | ||
// SAFETY: init_packet_from_data() just initialized the packet | ||
unsafe { Ok(packet.assume_init()) } | ||
} | ||
|
||
#[cfg(feature = "bincode")] | ||
|
@@ -224,6 +292,61 @@ impl PartialEq for Packet { | |
} | ||
} | ||
|
||
/// A custom implementation of io::Write to facilitate safe (non-UB) | ||
/// initialization of a MaybeUninit<Packet> | ||
struct PacketWriter { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not simple wrapper around MaybeUninit packet? We know the capacity of the buffer, and can determine remaining bytes from the current length and the fixed capacity There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The main motivation for writing this wrapper was to have something that implements If you drill down into bincode, |
||
// A pointer to the current write position | ||
position: *mut u8, | ||
// The number of remaining bytes that can be written to | ||
spare_capacity: usize, | ||
} | ||
|
||
impl PacketWriter { | ||
fn new_from_uninit_packet(packet: &mut mem::MaybeUninit<Packet>) -> Self { | ||
// SAFETY: Access the field by pointer as creating a reference to | ||
// and/or within the uninitialized Packet is undefined behavior | ||
let position = unsafe { ptr::addr_of_mut!((*packet.as_mut_ptr()).buffer) as *mut u8 }; | ||
let spare_capacity = PACKET_DATA_SIZE; | ||
|
||
Self { | ||
position, | ||
spare_capacity, | ||
} | ||
} | ||
|
||
/// The offset of the write pointer within the buffer, which is also the | ||
/// number of bytes that have been written | ||
fn position(&self) -> usize { | ||
PACKET_DATA_SIZE.saturating_sub(self.spare_capacity) | ||
} | ||
} | ||
|
||
impl io::Write for PacketWriter { | ||
#[inline] | ||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { | ||
if buf.len() > self.spare_capacity { | ||
return Err(io::Error::from(io::ErrorKind::WriteZero)); | ||
} | ||
|
||
// SAFETY: We previously verifed that buf.len() <= self.spare_capacity | ||
// so this write will not push us past the end of the buffer. Likewise, | ||
// we can update self.spare_capacity without fear of overflow | ||
unsafe { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. all these new instances of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that we should be very stingy with our use of |
||
ptr::copy_nonoverlapping(buf.as_ptr(), self.position, buf.len()); | ||
// Update position and spare_capacity for the next call to write() | ||
self.position = self.position.add(buf.len()); | ||
self.spare_capacity = self.spare_capacity.saturating_sub(buf.len()); | ||
} | ||
|
||
Ok(buf.len()) | ||
} | ||
|
||
#[inline] | ||
fn flush(&mut self) -> io::Result<()> { | ||
Ok(()) | ||
} | ||
} | ||
|
||
impl Meta { | ||
pub fn socket_addr(&self) -> SocketAddr { | ||
SocketAddr::new(self.addr, self.port) | ||
|
@@ -239,6 +362,11 @@ impl Meta { | |
.set(PacketFlags::FROM_STAKED_NODE, from_staked_node); | ||
} | ||
|
||
#[inline] | ||
pub fn set_flags(&mut self, flags: PacketFlags) { | ||
self.flags = flags; | ||
} | ||
|
||
#[inline] | ||
pub fn discard(&self) -> bool { | ||
self.flags.contains(PacketFlags::DISCARD) | ||
|
@@ -309,7 +437,7 @@ impl Default for Meta { | |
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
use {super::*, std::io::Write}; | ||
|
||
#[test] | ||
fn test_deserialize_slice() { | ||
|
@@ -349,4 +477,35 @@ mod tests { | |
Err("the size limit has been reached".to_string()), | ||
); | ||
} | ||
|
||
#[test] | ||
fn test_packet_buffer_writer() { | ||
let mut packet = mem::MaybeUninit::<Packet>::uninit(); | ||
let mut writer = PacketWriter::new_from_uninit_packet(&mut packet); | ||
let total_capacity = writer.spare_capacity; | ||
assert_eq!(total_capacity, PACKET_DATA_SIZE); | ||
let payload: [u8; PACKET_DATA_SIZE] = std::array::from_fn(|i| i as u8 % 255); | ||
|
||
// Write 1200 bytes (1200 total) | ||
let num_to_write = PACKET_DATA_SIZE - 32; | ||
assert_eq!( | ||
num_to_write, | ||
writer.write(&payload[..num_to_write]).unwrap() | ||
); | ||
assert_eq!(num_to_write, writer.position()); | ||
// Write 28 bytes (1228 total) | ||
assert_eq!(28, writer.write(&payload[1200..1200 + 28]).unwrap()); | ||
assert_eq!(1200 + 28, writer.position()); | ||
// Attempt to write 5 bytes (1233 total) which exceeds buffer capacity | ||
assert!(writer | ||
.write(&payload[1200 + 28 - 1..PACKET_DATA_SIZE]) | ||
.is_err()); | ||
// writer.position() remains unchanged | ||
assert_eq!(1200 + 28, writer.position()); | ||
// Write 4 bytes (1232 total) to fill buffer | ||
assert_eq!(4, writer.write(&payload[1200 + 28..]).unwrap()); | ||
assert_eq!(PACKET_DATA_SIZE, writer.position()); | ||
// Writing any amount of bytes will fail on the already full buffer | ||
assert!(writer.write(&[0]).is_err()); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is removing one
unsafe
but then adding two new ones.Why is the new code better than the old one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://doc.rust-lang.org/std/vec/struct.Vec.html#method.set_len
We are not initializing the data first so we are violating the second bullet