Skip to content

Commit

Permalink
stake-pool: Ceiling all fee calculations (HAL-01) (#6153)
Browse files Browse the repository at this point in the history
stake-pool: Ceiling all fee calculations
  • Loading branch information
joncinque authored Jan 24, 2024
1 parent 0d6832e commit a17fffe
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 24 deletions.
10 changes: 7 additions & 3 deletions stake-pool/program/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -942,9 +942,13 @@ impl Fee {
if self.denominator == 0 {
return Some(0);
}
(amt as u128)
.checked_mul(self.numerator as u128)?
.checked_div(self.denominator as u128)
let numerator = (amt as u128).checked_mul(self.numerator as u128)?;
// ceiling the calculation by adding (denominator - 1) to the numerator
let denominator = self.denominator as u128;
numerator
.checked_add(denominator)?
.checked_sub(1)?
.checked_div(denominator)
}

/// Withdrawal fees have some additional restrictions,
Expand Down
11 changes: 7 additions & 4 deletions stake-pool/program/tests/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -897,15 +897,17 @@ impl StakePoolAccounts {
}

pub fn calculate_fee(&self, amount: u64) -> u64 {
amount * self.epoch_fee.numerator / self.epoch_fee.denominator
(amount * self.epoch_fee.numerator + self.epoch_fee.denominator - 1)
/ self.epoch_fee.denominator
}

pub fn calculate_withdrawal_fee(&self, pool_tokens: u64) -> u64 {
pool_tokens * self.withdrawal_fee.numerator / self.withdrawal_fee.denominator
(pool_tokens * self.withdrawal_fee.numerator + self.withdrawal_fee.denominator - 1)
/ self.withdrawal_fee.denominator
}

pub fn calculate_inverse_withdrawal_fee(&self, pool_tokens: u64) -> u64 {
pool_tokens * self.withdrawal_fee.denominator
(pool_tokens * self.withdrawal_fee.denominator + self.withdrawal_fee.denominator - 1)
/ (self.withdrawal_fee.denominator - self.withdrawal_fee.numerator)
}

Expand All @@ -914,7 +916,8 @@ impl StakePoolAccounts {
}

pub fn calculate_sol_deposit_fee(&self, pool_tokens: u64) -> u64 {
pool_tokens * self.sol_deposit_fee.numerator / self.sol_deposit_fee.denominator
(pool_tokens * self.sol_deposit_fee.numerator + self.sol_deposit_fee.denominator - 1)
/ self.sol_deposit_fee.denominator
}

pub fn calculate_sol_referral_fee(&self, deposit_fee_collected: u64) -> u64 {
Expand Down
4 changes: 2 additions & 2 deletions stake-pool/program/tests/withdraw_sol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,15 @@ async fn fail_overdraw_reserve() {
.await;
assert!(error.is_none(), "{:?}", error);

// try to withdraw one lamport, will overdraw
// try to withdraw one lamport after fees, will overdraw
let error = stake_pool_accounts
.withdraw_sol(
&mut context.banks_client,
&context.payer,
&context.last_blockhash,
&user,
&pool_token_account,
1,
2,
None,
)
.await
Expand Down
18 changes: 3 additions & 15 deletions stake-pool/program/tests/withdraw_with_fee.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ mod helpers;
use {
bincode::deserialize,
helpers::*,
solana_program::{borsh0_10::try_from_slice_unchecked, pubkey::Pubkey, stake},
solana_program::{pubkey::Pubkey, stake},
solana_program_test::*,
solana_sdk::signature::{Keypair, Signer},
spl_stake_pool::{minimum_stake_lamports, state},
spl_stake_pool::minimum_stake_lamports,
};

#[tokio::test]
Expand Down Expand Up @@ -183,20 +183,8 @@ async fn success_empty_out_stake_with_fee() {
.await;
let lamports_to_withdraw =
validator_stake_account.lamports - minimum_stake_lamports(&meta, stake_minimum_delegation);
let stake_pool_account = get_account(
&mut context.banks_client,
&stake_pool_accounts.stake_pool.pubkey(),
)
.await;
let stake_pool =
try_from_slice_unchecked::<state::StakePool>(stake_pool_account.data.as_slice()).unwrap();
let fee = stake_pool.stake_withdrawal_fee;
let inverse_fee = state::Fee {
numerator: fee.denominator - fee.numerator,
denominator: fee.denominator,
};
let pool_tokens_to_withdraw =
lamports_to_withdraw * inverse_fee.denominator / inverse_fee.numerator;
stake_pool_accounts.calculate_inverse_withdrawal_fee(lamports_to_withdraw);

let last_blockhash = context
.banks_client
Expand Down

0 comments on commit a17fffe

Please sign in to comment.