diff --git a/accounts-db/src/tiered_storage/byte_block.rs b/accounts-db/src/tiered_storage/byte_block.rs index e0fa8b4b136b3b..8fe2b6417181fc 100644 --- a/accounts-db/src/tiered_storage/byte_block.rs +++ b/accounts-db/src/tiered_storage/byte_block.rs @@ -53,11 +53,31 @@ impl ByteBlockWriter { self.len } + /// Write plain ol' data to the internal buffer of the ByteBlockWriter instance + /// + /// Prefer this over `write_type()`, as it prevents some undefined behavior. + pub fn write_pod(&mut self, value: &T) -> IoResult { + // SAFETY: Since T is NoUninit, it does not contain any uninitialized bytes. + unsafe { self.write_type(value) } + } + /// Write the specified typed instance to the internal buffer of /// the ByteBlockWriter instance. - pub fn write_type(&mut self, value: &T) -> std::io::Result { + /// + /// Prefer `write_pod()` when possible, because `write_type()` may cause + /// undefined behavior if `value` contains uninitialized bytes. + /// + /// # Safety + /// + /// Caller must ensure casting T to bytes is safe. + /// Refer to the Safety sections in std::slice::from_raw_parts() + /// and bytemuck's Pod and NoUninit for more information. + pub unsafe fn write_type(&mut self, value: &T) -> IoResult { let size = mem::size_of::(); let ptr = value as *const _ as *const u8; + // SAFETY: The caller ensures that `value` contains no uninitialized bytes, + // we ensure the size is safe by querying T directly, + // and Rust ensures all values are at least byte-aligned. let slice = unsafe { std::slice::from_raw_parts(ptr, size) }; self.write(slice)?; Ok(size) @@ -73,10 +93,10 @@ impl ByteBlockWriter { ) -> std::io::Result { let mut size = 0; if let Some(rent_epoch) = opt_fields.rent_epoch { - size += self.write_type(&rent_epoch)?; + size += self.write_pod(&rent_epoch)?; } if let Some(hash) = opt_fields.account_hash { - size += self.write_type(&hash)?; + size += self.write_pod(&hash)?; } debug_assert_eq!(size, opt_fields.size()); @@ -112,18 +132,40 @@ impl ByteBlockWriter { /// The util struct for reading byte blocks. pub struct ByteBlockReader; +/// Reads the raw part of the input byte_block, at the specified offset, as type T. +/// +/// Returns None if `offset` + size_of::() exceeds the size of the input byte_block. +/// +/// Type T must be plain ol' data to ensure no undefined behavior. +pub fn read_pod(byte_block: &[u8], offset: usize) -> Option<&T> { + // SAFETY: Since T is AnyBitPattern, it is safe to cast bytes to T. + unsafe { read_type(byte_block, offset) } +} + /// Reads the raw part of the input byte_block at the specified offset /// as type T. /// /// If `offset` + size_of::() exceeds the size of the input byte_block, /// then None will be returned. -pub fn read_type(byte_block: &[u8], offset: usize) -> Option<&T> { +/// +/// Prefer `read_pod()` when possible, because `read_type()` may cause +/// undefined behavior. +/// +/// # Safety +/// +/// Caller must ensure casting bytes to T is safe. +/// Refer to the Safety sections in std::slice::from_raw_parts() +/// and bytemuck's Pod and AnyBitPattern for more information. +pub unsafe fn read_type(byte_block: &[u8], offset: usize) -> Option<&T> { let (next, overflow) = offset.overflowing_add(std::mem::size_of::()); if overflow || next > byte_block.len() { return None; } let ptr = byte_block[offset..].as_ptr() as *const T; debug_assert!(ptr as usize % std::mem::align_of::() == 0); + // SAFETY: The caller ensures it is safe to cast bytes to T, + // we ensure the size is safe by querying T directly, + // and we just checked above to ensure the ptr is aligned for T. Some(unsafe { &*ptr }) } @@ -169,7 +211,7 @@ mod tests { let mut writer = ByteBlockWriter::new(format); let value: u32 = 42; - writer.write_type(&value).unwrap(); + writer.write_pod(&value).unwrap(); assert_eq!(writer.raw_len(), mem::size_of::()); let buffer = writer.finish().unwrap(); @@ -231,12 +273,14 @@ mod tests { let test_data3 = [33u8; 300]; // Write the above meta and data in an interleaving way. - writer.write_type(&test_metas[0]).unwrap(); - writer.write_type(&test_data1).unwrap(); - writer.write_type(&test_metas[1]).unwrap(); - writer.write_type(&test_data2).unwrap(); - writer.write_type(&test_metas[2]).unwrap(); - writer.write_type(&test_data3).unwrap(); + unsafe { + writer.write_type(&test_metas[0]).unwrap(); + writer.write_type(&test_data1).unwrap(); + writer.write_type(&test_metas[1]).unwrap(); + writer.write_type(&test_data2).unwrap(); + writer.write_type(&test_metas[2]).unwrap(); + writer.write_type(&test_data3).unwrap(); + } assert_eq!( writer.raw_len(), mem::size_of::() * 3 @@ -346,13 +390,13 @@ mod tests { let mut offset = 0; for opt_fields in &opt_fields_vec { if let Some(expected_rent_epoch) = opt_fields.rent_epoch { - let rent_epoch = read_type::(&decoded_buffer, offset).unwrap(); + let rent_epoch = read_pod::(&decoded_buffer, offset).unwrap(); assert_eq!(*rent_epoch, expected_rent_epoch); verified_count += 1; offset += std::mem::size_of::(); } if let Some(expected_hash) = opt_fields.account_hash { - let hash = read_type::(&decoded_buffer, offset).unwrap(); + let hash = read_pod::(&decoded_buffer, offset).unwrap(); assert_eq!(hash, &expected_hash); verified_count += 1; offset += std::mem::size_of::(); diff --git a/accounts-db/src/tiered_storage/file.rs b/accounts-db/src/tiered_storage/file.rs index d2227fe2fa4870..f909e287b721f5 100644 --- a/accounts-db/src/tiered_storage/file.rs +++ b/accounts-db/src/tiered_storage/file.rs @@ -1,8 +1,11 @@ -use std::{ - fs::{File, OpenOptions}, - io::{Read, Seek, SeekFrom, Write}, - mem, - path::Path, +use { + bytemuck::{AnyBitPattern, NoUninit}, + std::{ + fs::{File, OpenOptions}, + io::{Read, Result as IoResult, Seek, SeekFrom, Write}, + mem, + path::Path, + }, }; #[derive(Debug)] @@ -33,14 +36,53 @@ impl TieredStorageFile { )) } - pub fn write_type(&self, value: &T) -> Result { + /// Writes `value` to the file. + /// + /// `value` must be plain ol' data. + pub fn write_pod(&self, value: &T) -> IoResult { + // SAFETY: Since T is NoUninit, it does not contain any uninitialized bytes. + unsafe { self.write_type(value) } + } + + /// Writes `value` to the file. + /// + /// Prefer `write_pod` when possible, because `write_value` may cause + /// undefined behavior if `value` contains uninitialized bytes. + /// + /// # Safety + /// + /// Caller must ensure casting T to bytes is safe. + /// Refer to the Safety sections in std::slice::from_raw_parts() + /// and bytemuck's Pod and NoUninit for more information. + pub unsafe fn write_type(&self, value: &T) -> IoResult { let ptr = value as *const _ as *const u8; let bytes = unsafe { std::slice::from_raw_parts(ptr, mem::size_of::()) }; self.write_bytes(bytes) } - pub fn read_type(&self, value: &mut T) -> Result<(), std::io::Error> { + /// Reads a value of type `T` from the file. + /// + /// Type T must be plain ol' data. + pub fn read_pod(&self, value: &mut T) -> IoResult<()> { + // SAFETY: Since T is AnyBitPattern, it is safe to cast bytes to T. + unsafe { self.read_type(value) } + } + + /// Reads a value of type `T` from the file. + /// + /// Prefer `read_pod()` when possible, because `read_type()` may cause + /// undefined behavior. + /// + /// # Safety + /// + /// Caller must ensure casting bytes to T is safe. + /// Refer to the Safety sections in std::slice::from_raw_parts() + /// and bytemuck's Pod and AnyBitPattern for more information. + pub unsafe fn read_type(&self, value: &mut T) -> IoResult<()> { let ptr = value as *mut _ as *mut u8; + // SAFETY: The caller ensures it is safe to cast bytes to T, + // we ensure the size is safe by querying T directly, + // and Rust ensures ptr is aligned. let bytes = unsafe { std::slice::from_raw_parts_mut(ptr, mem::size_of::()) }; self.read_bytes(bytes) } diff --git a/accounts-db/src/tiered_storage/footer.rs b/accounts-db/src/tiered_storage/footer.rs index 1dc82ebae0cbe5..f3b8fba4d20a57 100644 --- a/accounts-db/src/tiered_storage/footer.rs +++ b/accounts-db/src/tiered_storage/footer.rs @@ -1,7 +1,10 @@ use { crate::tiered_storage::{ - error::TieredStorageError, file::TieredStorageFile, index::IndexBlockFormat, - mmap_utils::get_type, TieredStorageResult as TsResult, + error::TieredStorageError, + file::TieredStorageFile, + index::IndexBlockFormat, + mmap_utils::{get_pod, get_type}, + TieredStorageResult, }, bytemuck::{Pod, Zeroable}, memmap2::Mmap, @@ -200,17 +203,25 @@ impl TieredStorageFooter { Self::new_from_footer_block(&file) } - pub fn write_footer_block(&self, file: &TieredStorageFile) -> TsResult<()> { - file.write_type(self)?; - file.write_type(&TieredStorageMagicNumber::default())?; + pub fn write_footer_block(&self, file: &TieredStorageFile) -> TieredStorageResult<()> { + // SAFETY: The footer does not contain any uninitialized bytes. + unsafe { file.write_type(self)? }; + file.write_pod(&TieredStorageMagicNumber::default())?; Ok(()) } - pub fn new_from_footer_block(file: &TieredStorageFile) -> TsResult { - let mut footer_size: u64 = 0; + pub fn new_from_footer_block(file: &TieredStorageFile) -> TieredStorageResult { file.seek_from_end(-(FOOTER_TAIL_SIZE as i64))?; - file.read_type(&mut footer_size)?; + + let mut footer_version: u64 = 0; + file.read_pod(&mut footer_version)?; + if footer_version != FOOTER_FORMAT_VERSION { + return Err(TieredStorageError::InvalidFooterVersion(footer_version)); + } + + let mut footer_size: u64 = 0; + file.read_pod(&mut footer_size)?; if footer_size != FOOTER_SIZE as u64 { return Err(TieredStorageError::InvalidFooterSize( footer_size, @@ -218,14 +229,8 @@ impl TieredStorageFooter { )); } - let mut footer_version: u64 = 0; - file.read_type(&mut footer_version)?; - if footer_version != FOOTER_FORMAT_VERSION { - return Err(TieredStorageError::InvalidFooterVersion(footer_version)); - } - let mut magic_number = TieredStorageMagicNumber::zeroed(); - file.read_type(&mut magic_number)?; + file.read_pod(&mut magic_number)?; if magic_number != TieredStorageMagicNumber::default() { return Err(TieredStorageError::MagicNumberMismatch( TieredStorageMagicNumber::default().0, @@ -235,7 +240,9 @@ impl TieredStorageFooter { let mut footer = Self::default(); file.seek_from_end(-(footer_size as i64))?; - file.read_type(&mut footer)?; + // SAFETY: We sanitize the footer to ensure all the bytes are + // actually safe to interpret as a TieredStorageFooter. + unsafe { file.read_type(&mut footer)? }; Self::sanitize(&footer)?; Ok(footer) @@ -243,7 +250,13 @@ impl TieredStorageFooter { pub fn new_from_mmap(mmap: &Mmap) -> TieredStorageResult<&TieredStorageFooter> { let offset = mmap.len().saturating_sub(FOOTER_TAIL_SIZE); - let (&footer_size, offset) = get_type::(mmap, offset)?; + + let (footer_version, offset) = get_pod::(mmap, offset)?; + if *footer_version != FOOTER_FORMAT_VERSION { + return Err(TieredStorageError::InvalidFooterVersion(*footer_version)); + } + + let (&footer_size, offset) = get_pod::(mmap, offset)?; if footer_size != FOOTER_SIZE as u64 { return Err(TieredStorageError::InvalidFooterSize( footer_size, @@ -251,12 +264,7 @@ impl TieredStorageFooter { )); } - let (footer_version, offset) = get_type::(mmap, offset)?; - if *footer_version != FOOTER_FORMAT_VERSION { - return Err(TieredStorageError::InvalidFooterVersion(*footer_version)); - } - - let (magic_number, _offset) = get_type::(mmap, offset)?; + let (magic_number, _offset) = get_pod::(mmap, offset)?; if *magic_number != TieredStorageMagicNumber::default() { return Err(TieredStorageError::MagicNumberMismatch( TieredStorageMagicNumber::default().0, @@ -265,7 +273,9 @@ impl TieredStorageFooter { } let footer_offset = mmap.len().saturating_sub(footer_size as usize); - let (footer, _offset) = get_type::(mmap, footer_offset)?; + // SAFETY: We sanitize the footer to ensure all the bytes are + // actually safe to interpret as a TieredStorageFooter. + let (footer, _offset) = unsafe { get_type::(mmap, footer_offset)? }; Self::sanitize(footer)?; Ok(footer) diff --git a/accounts-db/src/tiered_storage/hot.rs b/accounts-db/src/tiered_storage/hot.rs index 62e25a56eb22d4..8d3172799b4b7b 100644 --- a/accounts-db/src/tiered_storage/hot.rs +++ b/accounts-db/src/tiered_storage/hot.rs @@ -11,7 +11,7 @@ use { }, index::{AccountOffset, IndexBlockFormat, IndexOffset}, meta::{AccountMetaFlags, AccountMetaOptionalFields, TieredAccountMeta}, - mmap_utils::get_type, + mmap_utils::get_pod, owners::{OwnerOffset, OwnersBlock}, TieredStorageError, TieredStorageFormat, TieredStorageResult, }, @@ -195,7 +195,7 @@ impl TieredAccountMeta for HotAccountMeta { .then(|| { let offset = self.optional_fields_offset(account_block) + AccountMetaOptionalFields::rent_epoch_offset(self.flags()); - byte_block::read_type::(account_block, offset).copied() + byte_block::read_pod::(account_block, offset).copied() }) .flatten() } @@ -208,7 +208,7 @@ impl TieredAccountMeta for HotAccountMeta { .then(|| { let offset = self.optional_fields_offset(account_block) + AccountMetaOptionalFields::account_hash_offset(self.flags()); - byte_block::read_type::(account_block, offset) + byte_block::read_pod::(account_block, offset) }) .flatten() } @@ -450,13 +450,16 @@ pub mod tests { .with_flags(&flags); let mut writer = ByteBlockWriter::new(AccountBlockFormat::AlignedRaw); - writer.write_type(&expected_meta).unwrap(); - writer.write_type(&account_data).unwrap(); - writer.write_type(&padding).unwrap(); + writer.write_pod(&expected_meta).unwrap(); + // SAFETY: These values are POD, so they are safe to write. + unsafe { + writer.write_type(&account_data).unwrap(); + writer.write_type(&padding).unwrap(); + } writer.write_optional_fields(&optional_fields).unwrap(); let buffer = writer.finish().unwrap(); - let meta = byte_block::read_type::(&buffer, 0).unwrap(); + let meta = byte_block::read_pod::(&buffer, 0).unwrap(); assert_eq!(expected_meta, *meta); assert!(meta.flags().has_rent_epoch()); assert!(meta.flags().has_account_hash()); @@ -546,7 +549,7 @@ pub mod tests { .iter() .map(|meta| { let prev_offset = current_offset; - current_offset += file.write_type(meta).unwrap(); + current_offset += file.write_pod(meta).unwrap(); HotAccountOffset::new(prev_offset).unwrap() }) .collect(); diff --git a/accounts-db/src/tiered_storage/index.rs b/accounts-db/src/tiered_storage/index.rs index 5752d1230c2697..c04d026cdb2a74 100644 --- a/accounts-db/src/tiered_storage/index.rs +++ b/accounts-db/src/tiered_storage/index.rs @@ -1,6 +1,6 @@ use { crate::tiered_storage::{ - file::TieredStorageFile, footer::TieredStorageFooter, mmap_utils::get_type, + file::TieredStorageFile, footer::TieredStorageFooter, mmap_utils::get_pod, TieredStorageResult, }, memmap2::Mmap, @@ -58,10 +58,10 @@ impl IndexBlockFormat { Self::AddressAndBlockOffsetOnly => { let mut bytes_written = 0; for index_entry in index_entries { - bytes_written += file.write_type(index_entry.address)?; + bytes_written += file.write_pod(index_entry.address)?; } for index_entry in index_entries { - bytes_written += file.write_type(&index_entry.offset)?; + bytes_written += file.write_pod(&index_entry.offset)?; } Ok(bytes_written) } @@ -107,7 +107,7 @@ impl IndexBlockFormat { let offset = footer.index_block_offset as usize + std::mem::size_of::() * footer.account_entry_count as usize + std::mem::size_of::() * index_offset.0 as usize; - let (account_offset, _) = get_type::(mmap, offset)?; + let (account_offset, _) = get_pod::(mmap, offset)?; Ok(*account_offset) } diff --git a/accounts-db/src/tiered_storage/mmap_utils.rs b/accounts-db/src/tiered_storage/mmap_utils.rs index a1e70a1e617949..56513473bdbc98 100644 --- a/accounts-db/src/tiered_storage/mmap_utils.rs +++ b/accounts-db/src/tiered_storage/mmap_utils.rs @@ -4,10 +4,30 @@ use { memmap2::Mmap, }; -pub fn get_type(map: &Mmap, offset: usize) -> std::io::Result<(&T, usize)> { - let (data, next) = get_slice(map, offset, std::mem::size_of::())?; +/// Borrows a value of type `T` from `mmap` +/// +/// Type T must be plain ol' data to ensure no undefined behavior. +pub fn get_pod(mmap: &Mmap, offset: usize) -> IoResult<(&T, usize)> { + // SAFETY: Since T is AnyBitPattern, it is safe to cast bytes to T. + unsafe { get_type::(mmap, offset) } +} + +/// Borrows a value of type `T` from `mmap` +/// +/// Prefer `get_pod()` when possible, because `get_type()` may cause undefined behavior. +/// +/// # Safety +/// +/// Caller must ensure casting bytes to T is safe. +/// Refer to the Safety sections in std::slice::from_raw_parts() +/// and bytemuck's Pod and AnyBitPattern for more information. +pub unsafe fn get_type(mmap: &Mmap, offset: usize) -> IoResult<(&T, usize)> { + let (data, next) = get_slice(mmap, offset, std::mem::size_of::())?; let ptr = data.as_ptr() as *const T; debug_assert!(ptr as usize % std::mem::align_of::() == 0); + // SAFETY: The caller ensures it is safe to cast bytes to T, + // we ensure the size is safe by querying T directly, + // and we just checked above to ensure the ptr is aligned for T. Ok((unsafe { &*ptr }, next)) } @@ -33,5 +53,7 @@ pub fn get_slice(map: &Mmap, offset: usize, size: usize) -> std::io::Result<(&[u let next = u64_align!(next); let ptr = data.as_ptr(); + // SAFETY: The Mmap ensures the bytes are safe the read, and we just checked + // to ensure we don't read past the end of the internal buffer. Ok((unsafe { std::slice::from_raw_parts(ptr, size) }, next)) } diff --git a/accounts-db/src/tiered_storage/owners.rs b/accounts-db/src/tiered_storage/owners.rs index 7cd548e3a00c8d..41e1f8a6715a3f 100644 --- a/accounts-db/src/tiered_storage/owners.rs +++ b/accounts-db/src/tiered_storage/owners.rs @@ -1,6 +1,6 @@ use { crate::tiered_storage::{ - file::TieredStorageFile, footer::TieredStorageFooter, mmap_utils::get_type, + file::TieredStorageFile, footer::TieredStorageFooter, mmap_utils::get_pod, TieredStorageResult, }, memmap2::Mmap, @@ -32,7 +32,7 @@ impl OwnersBlock { ) -> TieredStorageResult { let mut bytes_written = 0; for address in addresses { - bytes_written += file.write_type(address)?; + bytes_written += file.write_pod(address)?; } Ok(bytes_written) @@ -47,7 +47,7 @@ impl OwnersBlock { ) -> TieredStorageResult<&'a Pubkey> { let offset = footer.owners_block_offset as usize + (std::mem::size_of::() * owner_offset.0 as usize); - let (pubkey, _) = get_type::(mmap, offset)?; + let (pubkey, _) = get_pod::(mmap, offset)?; Ok(pubkey) }