Skip to content

Commit

Permalink
stake-pool: Remove unsafe pointer casts via Pod types (#5185)
Browse files Browse the repository at this point in the history
* stake-pool: Force BigVec to work with Pod types

* Remove all unsafe through an enum wrapper struct

* Also fix the CLI
  • Loading branch information
joncinque authored Sep 14, 2023
1 parent 908ea3f commit 38212ea
Show file tree
Hide file tree
Showing 12 changed files with 209 additions and 204 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions stake-pool/cli/src/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use {
solana_cli_output::{QuietDisplay, VerboseDisplay},
solana_sdk::native_token::Sol,
solana_sdk::{pubkey::Pubkey, stake::state::Lockup},
spl_stake_pool::state::{Fee, StakePool, StakeStatus, ValidatorList, ValidatorStakeInfo},
spl_stake_pool::state::{
Fee, PodStakeStatus, StakePool, StakeStatus, ValidatorList, ValidatorStakeInfo,
},
std::fmt::{Display, Formatter, Result, Write},
};

Expand Down Expand Up @@ -384,8 +386,9 @@ impl From<ValidatorStakeInfo> for CliStakePoolValidator {
}
}

impl From<StakeStatus> for CliStakePoolValidatorStakeStatus {
fn from(s: StakeStatus) -> CliStakePoolValidatorStakeStatus {
impl From<PodStakeStatus> for CliStakePoolValidatorStakeStatus {
fn from(s: PodStakeStatus) -> CliStakePoolValidatorStakeStatus {
let s = StakeStatus::try_from(s).unwrap();
match s {
StakeStatus::Active => CliStakePoolValidatorStakeStatus::Active,
StakeStatus::DeactivatingTransient => {
Expand Down
1 change: 1 addition & 0 deletions stake-pool/program/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ test-sbf = []
[dependencies]
arrayref = "0.3.7"
borsh = "0.10"
bytemuck = "1.13"
num-derive = "0.4"
num-traits = "0.2"
num_enum = "0.7.0"
Expand Down
175 changes: 52 additions & 123 deletions stake-pool/program/src/big_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
use {
arrayref::array_ref,
borsh::BorshDeserialize,
solana_program::{
program_error::ProgramError, program_memory::sol_memmove, program_pack::Pack,
},
std::marker::PhantomData,
bytemuck::Pod,
solana_program::{program_error::ProgramError, program_memory::sol_memmove},
std::mem,
};

/// Contains easy to use utilities for a big vector of Borsh-compatible types,
Expand All @@ -32,7 +31,7 @@ impl<'data> BigVec<'data> {
}

/// Retain all elements that match the provided function, discard all others
pub fn retain<T: Pack, F: Fn(&[u8]) -> bool>(
pub fn retain<T: Pod, F: Fn(&[u8]) -> bool>(
&mut self,
predicate: F,
) -> Result<(), ProgramError> {
Expand All @@ -42,12 +41,12 @@ impl<'data> BigVec<'data> {

let data_start_index = VEC_SIZE_BYTES;
let data_end_index =
data_start_index.saturating_add((vec_len as usize).saturating_mul(T::LEN));
for start_index in (data_start_index..data_end_index).step_by(T::LEN) {
let end_index = start_index + T::LEN;
data_start_index.saturating_add((vec_len as usize).saturating_mul(mem::size_of::<T>()));
for start_index in (data_start_index..data_end_index).step_by(mem::size_of::<T>()) {
let end_index = start_index + mem::size_of::<T>();
let slice = &self.data[start_index..end_index];
if !predicate(slice) {
let gap = removals_found * T::LEN;
let gap = removals_found * mem::size_of::<T>();
if removals_found > 0 {
// In case the compute budget is ever bumped up, allowing us
// to use this safe code instead:
Expand All @@ -68,7 +67,7 @@ impl<'data> BigVec<'data> {

// final memmove
if removals_found > 0 {
let gap = removals_found * T::LEN;
let gap = removals_found * mem::size_of::<T>();
// In case the compute budget is ever bumped up, allowing us
// to use this safe code instead:
//self.data.copy_within(dst_start_index + gap..data_end_index, dst_start_index);
Expand All @@ -88,11 +87,11 @@ impl<'data> BigVec<'data> {
}

/// Extracts a slice of the data types
pub fn deserialize_mut_slice<T: Pack>(
pub fn deserialize_mut_slice<T: Pod>(
&mut self,
skip: usize,
len: usize,
) -> Result<Vec<&'data mut T>, ProgramError> {
) -> Result<&mut [T], ProgramError> {
let vec_len = self.len();
let last_item_index = skip
.checked_add(len)
Expand All @@ -101,66 +100,60 @@ impl<'data> BigVec<'data> {
return Err(ProgramError::AccountDataTooSmall);
}

let start_index = VEC_SIZE_BYTES.saturating_add(skip.saturating_mul(T::LEN));
let end_index = start_index.saturating_add(len.saturating_mul(T::LEN));
let mut deserialized = vec![];
for slice in self.data[start_index..end_index].chunks_exact_mut(T::LEN) {
deserialized.push(unsafe { &mut *(slice.as_ptr() as *mut T) });
let start_index = VEC_SIZE_BYTES.saturating_add(skip.saturating_mul(mem::size_of::<T>()));
let end_index = start_index.saturating_add(len.saturating_mul(mem::size_of::<T>()));
bytemuck::try_cast_slice_mut(&mut self.data[start_index..end_index])
.map_err(|_| ProgramError::InvalidAccountData)
}

/// Extracts a slice of the data types
pub fn deserialize_slice<T: Pod>(&self, skip: usize, len: usize) -> Result<&[T], ProgramError> {
let vec_len = self.len();
let last_item_index = skip
.checked_add(len)
.ok_or(ProgramError::AccountDataTooSmall)?;
if last_item_index > vec_len as usize {
return Err(ProgramError::AccountDataTooSmall);
}
Ok(deserialized)

let start_index = VEC_SIZE_BYTES.saturating_add(skip.saturating_mul(mem::size_of::<T>()));
let end_index = start_index.saturating_add(len.saturating_mul(mem::size_of::<T>()));
bytemuck::try_cast_slice(&self.data[start_index..end_index])
.map_err(|_| ProgramError::InvalidAccountData)
}

/// Add new element to the end
pub fn push<T: Pack>(&mut self, element: T) -> Result<(), ProgramError> {
pub fn push<T: Pod>(&mut self, element: T) -> Result<(), ProgramError> {
let mut vec_len_ref = &mut self.data[0..VEC_SIZE_BYTES];
let mut vec_len = u32::try_from_slice(vec_len_ref)?;

let start_index = VEC_SIZE_BYTES + vec_len as usize * T::LEN;
let end_index = start_index + T::LEN;
let start_index = VEC_SIZE_BYTES + vec_len as usize * mem::size_of::<T>();
let end_index = start_index + mem::size_of::<T>();

vec_len += 1;
borsh::to_writer(&mut vec_len_ref, &vec_len)?;

if self.data.len() < end_index {
return Err(ProgramError::AccountDataTooSmall);
}
let element_ref = &mut self.data[start_index..start_index + T::LEN];
element.pack_into_slice(element_ref);
let element_ref = bytemuck::try_from_bytes_mut(
&mut self.data[start_index..start_index + mem::size_of::<T>()],
)
.map_err(|_| ProgramError::InvalidAccountData)?;
*element_ref = element;
Ok(())
}

/// Get an iterator for the type provided
pub fn iter<'vec, T: Pack>(&'vec self) -> Iter<'data, 'vec, T> {
Iter {
len: self.len() as usize,
current: 0,
current_index: VEC_SIZE_BYTES,
inner: self,
phantom: PhantomData,
}
}

/// Get a mutable iterator for the type provided
pub fn iter_mut<'vec, T: Pack>(&'vec mut self) -> IterMut<'data, 'vec, T> {
IterMut {
len: self.len() as usize,
current: 0,
current_index: VEC_SIZE_BYTES,
inner: self,
phantom: PhantomData,
}
}

/// Find matching data in the array
pub fn find<T: Pack, F: Fn(&[u8]) -> bool>(&self, predicate: F) -> Option<&T> {
pub fn find<T: Pod, F: Fn(&[u8]) -> bool>(&self, predicate: F) -> Option<&T> {
let len = self.len() as usize;
let mut current = 0;
let mut current_index = VEC_SIZE_BYTES;
while current != len {
let end_index = current_index + T::LEN;
let end_index = current_index + mem::size_of::<T>();
let current_slice = &self.data[current_index..end_index];
if predicate(current_slice) {
return Some(unsafe { &*(current_slice.as_ptr() as *const T) });
return Some(bytemuck::from_bytes(current_slice));
}
current_index = end_index;
current += 1;
Expand All @@ -169,15 +162,17 @@ impl<'data> BigVec<'data> {
}

/// Find matching data in the array
pub fn find_mut<T: Pack, F: Fn(&[u8]) -> bool>(&mut self, predicate: F) -> Option<&mut T> {
pub fn find_mut<T: Pod, F: Fn(&[u8]) -> bool>(&mut self, predicate: F) -> Option<&mut T> {
let len = self.len() as usize;
let mut current = 0;
let mut current_index = VEC_SIZE_BYTES;
while current != len {
let end_index = current_index + T::LEN;
let end_index = current_index + mem::size_of::<T>();
let current_slice = &self.data[current_index..end_index];
if predicate(current_slice) {
return Some(unsafe { &mut *(current_slice.as_ptr() as *mut T) });
return Some(bytemuck::from_bytes_mut(
&mut self.data[current_index..end_index],
));
}
current_index = end_index;
current += 1;
Expand All @@ -186,84 +181,16 @@ impl<'data> BigVec<'data> {
}
}

/// Iterator wrapper over a BigVec
pub struct Iter<'data, 'vec, T> {
len: usize,
current: usize,
current_index: usize,
inner: &'vec BigVec<'data>,
phantom: PhantomData<T>,
}

impl<'data, 'vec, T: Pack + 'data> Iterator for Iter<'data, 'vec, T> {
type Item = &'data T;

fn next(&mut self) -> Option<Self::Item> {
if self.current == self.len {
None
} else {
let end_index = self.current_index + T::LEN;
let value = Some(unsafe {
&*(self.inner.data[self.current_index..end_index].as_ptr() as *const T)
});
self.current += 1;
self.current_index = end_index;
value
}
}
}

/// Iterator wrapper over a BigVec
pub struct IterMut<'data, 'vec, T> {
len: usize,
current: usize,
current_index: usize,
inner: &'vec mut BigVec<'data>,
phantom: PhantomData<T>,
}

impl<'data, 'vec, T: Pack + 'data> Iterator for IterMut<'data, 'vec, T> {
type Item = &'data mut T;

fn next(&mut self) -> Option<Self::Item> {
if self.current == self.len {
None
} else {
let end_index = self.current_index + T::LEN;
let value = Some(unsafe {
&mut *(self.inner.data[self.current_index..end_index].as_ptr() as *mut T)
});
self.current += 1;
self.current_index = end_index;
value
}
}
}

#[cfg(test)]
mod tests {
use {super::*, solana_program::program_pack::Sealed};
use {super::*, bytemuck::Zeroable};

#[derive(Debug, PartialEq)]
#[repr(C)]
#[derive(Debug, Copy, Clone, PartialEq, Pod, Zeroable)]
struct TestStruct {
value: [u8; 8],
}

impl Sealed for TestStruct {}

impl Pack for TestStruct {
const LEN: usize = 8;
fn pack_into_slice(&self, data: &mut [u8]) {
let mut data = data;
borsh::to_writer(&mut data, &self.value).unwrap();
}
fn unpack_from_slice(src: &[u8]) -> Result<Self, ProgramError> {
Ok(TestStruct {
value: src.try_into().unwrap(),
})
}
}

impl TestStruct {
fn new(value: u8) -> Self {
let value = [value, 0, 0, 0, 0, 0, 0, 0];
Expand All @@ -281,7 +208,9 @@ mod tests {

fn check_big_vec_eq(big_vec: &BigVec, slice: &[u8]) {
assert!(big_vec
.iter::<TestStruct>()
.deserialize_slice::<TestStruct>(0, big_vec.len() as usize)
.unwrap()
.iter()
.map(|x| &x.value[0])
.zip(slice.iter())
.all(|(a, b)| a == b));
Expand Down
Loading

0 comments on commit 38212ea

Please sign in to comment.