Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stake-pool: Add tolerance for stake accounts at minimum #3839

Merged
merged 8 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 13 additions & 26 deletions stake-pool/program/src/big_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,14 @@ impl<'data> BigVec<'data> {
}

/// Find matching data in the array
pub fn find<T: Pack>(&self, data: &[u8], predicate: fn(&[u8], &[u8]) -> bool) -> Option<&T> {
pub fn find<T: Pack, F: Fn(&[u8]) -> bool>(&self, predicate: F) -> Option<&T> {
let len = self.len() as usize;
let mut current = 0;
let mut current_index = VEC_SIZE_BYTES;
while current != len {
let end_index = current_index + T::LEN;
let current_slice = &self.data[current_index..end_index];
if predicate(current_slice, data) {
if predicate(current_slice) {
return Some(unsafe { &*(current_slice.as_ptr() as *const T) });
}
current_index = end_index;
Expand All @@ -165,18 +165,14 @@ impl<'data> BigVec<'data> {
}

/// Find matching data in the array
pub fn find_mut<T: Pack>(
&mut self,
data: &[u8],
predicate: fn(&[u8], &[u8]) -> bool,
) -> Option<&mut T> {
pub fn find_mut<T: Pack, F: Fn(&[u8]) -> bool>(&mut self, predicate: F) -> Option<&mut T> {
let len = self.len() as usize;
let mut current = 0;
let mut current_index = VEC_SIZE_BYTES;
while current != len {
let end_index = current_index + T::LEN;
let current_slice = &self.data[current_index..end_index];
if predicate(current_slice, data) {
if predicate(current_slice) {
return Some(unsafe { &mut *(current_slice.as_ptr() as *mut T) });
}
current_index = end_index;
Expand Down Expand Up @@ -242,10 +238,7 @@ impl<'data, 'vec, T: Pack + 'data> Iterator for IterMut<'data, 'vec, T> {

#[cfg(test)]
mod tests {
use {
super::*,
solana_program::{program_memory::sol_memcmp, program_pack::Sealed},
};
use {super::*, solana_program::program_pack::Sealed};

#[derive(Debug, PartialEq)]
struct TestStruct {
Expand Down Expand Up @@ -317,11 +310,11 @@ mod tests {
check_big_vec_eq(&v, &[2, 4]);
}

fn find_predicate(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
fn find_predicate(a: &[u8], b: u64) -> bool {
if a.len() != 8 {
false
} else {
sol_memcmp(a, b, a.len()) == 0
u64::try_from_slice(&a[0..8]).unwrap() == b
}
}

Expand All @@ -330,32 +323,26 @@ mod tests {
let mut data = [0u8; 4 + 8 * 4];
let v = from_slice(&mut data, &[1, 2, 3, 4]);
assert_eq!(
v.find::<TestStruct>(&1u64.to_le_bytes(), find_predicate),
v.find::<TestStruct, _>(|x| find_predicate(x, 1)),
Some(&TestStruct::new(1))
);
assert_eq!(
v.find::<TestStruct>(&4u64.to_le_bytes(), find_predicate),
v.find::<TestStruct, _>(|x| find_predicate(x, 4)),
Some(&TestStruct::new(4))
);
assert_eq!(
v.find::<TestStruct>(&5u64.to_le_bytes(), find_predicate),
None
);
assert_eq!(v.find::<TestStruct, _>(|x| find_predicate(x, 5)), None);
}

#[test]
fn find_mut() {
let mut data = [0u8; 4 + 8 * 4];
let mut v = from_slice(&mut data, &[1, 2, 3, 4]);
let mut test_struct = v
.find_mut::<TestStruct>(&1u64.to_le_bytes(), find_predicate)
.find_mut::<TestStruct, _>(|x| find_predicate(x, 1))
.unwrap();
test_struct.value = 0;
check_big_vec_eq(&v, &[0, 2, 3, 4]);
assert_eq!(
v.find_mut::<TestStruct>(&5u64.to_le_bytes(), find_predicate),
None
);
assert_eq!(v.find_mut::<TestStruct, _>(|x| find_predicate(x, 5)), None);
}

#[test]
Expand Down
116 changes: 72 additions & 44 deletions stake-pool/program/src/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,10 +844,9 @@ impl Processor {
if header.max_validators == validator_list.len() {
return Err(ProgramError::AccountDataTooSmall);
}
let maybe_validator_stake_info = validator_list.find::<ValidatorStakeInfo>(
validator_vote_info.key.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
);
let maybe_validator_stake_info = validator_list.find::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, validator_vote_info.key)
});
if maybe_validator_stake_info.is_some() {
return Err(StakePoolError::ValidatorAlreadyAdded.into());
}
Expand Down Expand Up @@ -994,10 +993,9 @@ impl Processor {

let (meta, stake) = get_stake_state(stake_account_info)?;
let vote_account_address = stake.delegation.voter_pubkey;
let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
);
let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &vote_account_address)
});
if maybe_validator_stake_info.is_none() {
msg!(
"Vote account {} not found in stake pool",
Expand Down Expand Up @@ -1154,10 +1152,9 @@ impl Processor {
let (meta, stake) = get_stake_state(validator_stake_account_info)?;
let vote_account_address = stake.delegation.voter_pubkey;

let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
);
let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &vote_account_address)
});
if maybe_validator_stake_info.is_none() {
msg!(
"Vote account {} not found in stake pool",
Expand Down Expand Up @@ -1316,10 +1313,9 @@ impl Processor {

let vote_account_address = validator_vote_account_info.key;

let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
);
let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, vote_account_address)
});
if maybe_validator_stake_info.is_none() {
msg!(
"Vote account {} not found in stake pool",
Expand Down Expand Up @@ -1481,10 +1477,9 @@ impl Processor {
}

if let Some(vote_account_address) = vote_account_address {
let maybe_validator_stake_info = validator_list.find::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
);
let maybe_validator_stake_info = validator_list.find::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &vote_account_address)
});
match maybe_validator_stake_info {
Some(vsi) => {
if vsi.status != StakeStatus::Active {
Expand Down Expand Up @@ -2031,10 +2026,9 @@ impl Processor {
}

let mut validator_stake_info = validator_list
.find_mut::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
)
.find_mut::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &vote_account_address)
})
.ok_or(StakePoolError::ValidatorNotFound)?;
check_validator_stake_address(
program_id,
Expand Down Expand Up @@ -2428,7 +2422,7 @@ impl Processor {
.checked_sub(pool_tokens_fee)
.ok_or(StakePoolError::CalculationFailure)?;

let withdraw_lamports = stake_pool
let mut withdraw_lamports = stake_pool
.calc_lamports_withdraw_amount(pool_tokens_burnt)
.ok_or(StakePoolError::CalculationFailure)?;

Expand All @@ -2442,17 +2436,27 @@ impl Processor {
let meta = stake_state.meta().ok_or(StakePoolError::WrongStakeState)?;
let required_lamports = minimum_stake_lamports(&meta, stake_minimum_delegation);

let lamports_per_pool_token = stake_pool
.get_lamports_per_pool_token()
.ok_or(StakePoolError::CalculationFailure)?;
let minimum_lamports_with_tolerance =
required_lamports.saturating_add(lamports_per_pool_token);

let has_active_stake = validator_list
.find::<ValidatorStakeInfo>(
&required_lamports.to_le_bytes(),
ValidatorStakeInfo::active_lamports_not_equal,
)
.find::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::active_lamports_greater_than(
x,
&minimum_lamports_with_tolerance,
)
})
.is_some();
let has_transient_stake = validator_list
.find::<ValidatorStakeInfo>(
&0u64.to_le_bytes(),
ValidatorStakeInfo::transient_lamports_not_equal,
)
.find::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::transient_lamports_greater_than(
x,
&minimum_lamports_with_tolerance,
)
})
.is_some();

let validator_list_item_info = if *stake_split_from.key == stake_pool.reserve_stake {
Expand All @@ -2478,10 +2482,9 @@ impl Processor {
stake_pool.preferred_withdraw_validator_vote_address
{
let preferred_validator_info = validator_list
.find::<ValidatorStakeInfo>(
preferred_withdraw_validator.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
)
.find::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &preferred_withdraw_validator)
})
.ok_or(StakePoolError::ValidatorNotFound)?;
let available_lamports = preferred_validator_info
.active_stake_lamports
Expand All @@ -2493,10 +2496,9 @@ impl Processor {
}

let validator_stake_info = validator_list
.find_mut::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
)
.find_mut::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &vote_account_address)
})
.ok_or(StakePoolError::ValidatorNotFound)?;

let withdraw_source = if has_active_stake {
Expand Down Expand Up @@ -2548,11 +2550,37 @@ impl Processor {
}
}
StakeWithdrawSource::ValidatorRemoval => {
if withdraw_lamports != stake_split_from.lamports() {
msg!("Cannot withdraw a whole account worth {} lamports, must withdraw exactly {} lamports worth of pool tokens",
withdraw_lamports, stake_split_from.lamports());
let split_from_lamports = stake_split_from.lamports();
// The upper bound for reasonable tolerance is twice the lamports per
// pool token because we have two sources of rounding. The first happens
// when reducing the stake account to as close to the minimum as possible,
// and the second happens on this withdrawal.
//
// For example, if the pool token is extremely valuable, it might only
// be possible to reduce the stake account to a minimum of
// `stake_rent + minimum_delegation + lamports_per_pool_token - 1`.
//
// After that, the minimum amount of pool tokens to get to this amount
// may actually be worth
// `stake_rent + minimum_delegation + lamports_per_pool_token * 2 - 2`.
// We give an extra grace on this check of two lamports, which should be
// reasonable. At worst, it just means that a withdrawer is losing out
// on two lamports.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am curious, I think there is something I am missing here. We are comparing against the actual lamports in the stake account here, so a tolerance of +lamports_per_token should always be enough.

For example, let's assume we only reduce the stake account to minimum_stake_lamports + (lamports_per_token - 1). But then split_from_lamports = minimum_stake_lamports +(lamports_per_token - 1) already, and we can always find a token burn amount so that we are within +lamports_per_token of that. No need to double, as far as I understand.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I just wrote the test poorly during my first pass. No need for the limit, thanks for taking the time to notice this!

let upper_bound = split_from_lamports
.saturating_add(lamports_per_pool_token.saturating_mul(2));
if withdraw_lamports < split_from_lamports || withdraw_lamports > upper_bound {
msg!(
"Cannot withdraw a whole account worth {} lamports, \
must withdraw at least {} lamports worth of pool tokens \
with a margin of {} lamports",
withdraw_lamports,
split_from_lamports,
lamports_per_pool_token
);
return Err(StakePoolError::StakeLamportsNotEqualToMinimum.into());
}
// truncate the lamports down to the amount in the account
withdraw_lamports = split_from_lamports;
}
}
Some((validator_stake_info, withdraw_source))
Expand Down
29 changes: 19 additions & 10 deletions stake-pool/program/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,15 @@ impl StakePool {
}
}

/// Get the current value of pool tokens, rounded up
#[inline]
pub fn get_lamports_per_pool_token(&self) -> Option<u64> {
self.total_lamports
.checked_add(self.pool_token_supply)?
.checked_sub(1)?
.checked_div(self.pool_token_supply)
}

/// Checks that the withdraw or deposit authority is valid
fn check_program_derived_authority(
authority_address: &Pubkey,
Expand Down Expand Up @@ -660,24 +669,24 @@ impl ValidatorStakeInfo {

/// Performs a very cheap comparison, for checking if this validator stake
/// info matches the vote account address
pub fn memcmp_pubkey(data: &[u8], vote_address_bytes: &[u8]) -> bool {
pub fn memcmp_pubkey(data: &[u8], vote_address: &Pubkey) -> bool {
sol_memcmp(
&data[41..41 + PUBKEY_BYTES],
vote_address_bytes,
vote_address.as_ref(),
PUBKEY_BYTES,
) == 0
}

/// Performs a very cheap comparison, for checking if this validator stake
/// info does not have active lamports equal to the given bytes
pub fn active_lamports_not_equal(data: &[u8], lamports_le_bytes: &[u8]) -> bool {
sol_memcmp(&data[0..8], lamports_le_bytes, 8) != 0
/// Performs a comparison, used to check if this validator stake
/// info has more active lamports than some limit
pub fn active_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
u64::try_from_slice(&data[0..8]).unwrap() > *lamports
}

/// Performs a very cheap comparison, for checking if this validator stake
/// info does not have lamports equal to the given bytes
pub fn transient_lamports_not_equal(data: &[u8], lamports_le_bytes: &[u8]) -> bool {
sol_memcmp(&data[8..16], lamports_le_bytes, 8) != 0
/// Performs a comparison, used to check if this validator stake
/// info has more transient lamports than some limit
pub fn transient_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
u64::try_from_slice(&data[8..16]).unwrap() > *lamports
}

/// Check that the validator stake info is valid
Expand Down
2 changes: 1 addition & 1 deletion stake-pool/program/tests/huge_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use {
},
};

const HUGE_POOL_SIZE: u32 = 2_000;
const HUGE_POOL_SIZE: u32 = 3_300;
const STAKE_AMOUNT: u64 = 200_000_000_000;

async fn setup(
Expand Down
Loading