Skip to content

Commit

Permalink
feat: Add overwrites to VMPoolState
Browse files Browse the repository at this point in the history
Add capabilities and block_lasting_overwrites as attributes as well
Implement first version of get_overwrites, get_token_overwrites, get_balance_overwrites

--- don't change below this line ---
ENG-3757 Took 1 hour 19 minutes
  • Loading branch information
dianacarvalho1 committed Oct 29, 2024
1 parent 1e25aa1 commit 20aba87
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/protocol/vm/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::protocol::vm::errors::ProtosimError;
use ethers::abi::Uint;

#[allow(dead_code)]
#[derive(Eq, PartialEq, Hash, Debug)]
#[derive(Eq, PartialEq, Hash, Debug, Clone)]
pub enum Capability {
SellSide = 1,
BuySide = 2,
Expand Down
122 changes: 109 additions & 13 deletions src/protocol/vm/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::{
};

use chrono::Utc;
use ethabi::Hash;
use ethers::{
abi::{decode, Address as EthAddress, ParamType},
prelude::U256,
Expand All @@ -31,12 +32,18 @@ use itertools::Itertools;
use revm::{
precompile::{Address, Bytes},
primitives::{
alloy_primitives::Keccak256, keccak256, AccountInfo, Bytecode, B256, KECCAK_EMPTY,
alloy_primitives::Keccak256, keccak256, AccountInfo, Bytecode, B256, KECCAK_EMPTY, U256 as rU256
},
DatabaseRef,
};
use std::{collections::HashMap, fmt::Debug, sync::Arc};
use std::cmp::max;
use std::collections::HashSet;
use tokio::sync::RwLock;
use crate::protocol::vm::erc20_overwrite_factory::{ERC20OverwriteFactory, Overwrites};
use crate::protocol::vm::models::Capability;
use crate::protocol::vm::utils::SlotHash;


#[derive(Clone)]
pub struct VMPoolState<D: DatabaseRef + EngineDatabaseInterface + Clone> {
Expand All @@ -51,9 +58,14 @@ pub struct VMPoolState<D: DatabaseRef + EngineDatabaseInterface + Clone> {
/// The contract address for where protocol balances are stored (i.e. a vault contract).
/// If given, balances will be overwritten here instead of on the pool contract during
/// simulations
pub balance_owner: Option<H160>, // TODO: implement this in ENG-3758
pub balance_owner: Option<H160>,
/// Spot prices of the pool by token pair
pub spot_prices: HashMap<(ERC20Token, ERC20Token), f64>,
/// The supported capabilities of this pool
pub capabilities: HashSet<Capability>,
/// Storage overwrites that will be applied to all simulations. They will be cleared
// when ``clear_all_cache`` is called, i.e. usually at each block. Hence, the name.
pub block_lasting_overwrites: HashMap<H160, Overwrites>,
/// The address to bytecode map of all stateless contracts used by the protocol
/// for simulations. If the bytecode is None, an RPC call is done to get the code from our node
pub stateless_contracts: HashMap<String, Option<Vec<u8>>>,
Expand All @@ -73,6 +85,8 @@ impl VMPoolState<PreCachedDB> {
balances: HashMap<H160, U256>,
spot_prices: HashMap<(ERC20Token, ERC20Token), f64>,
adapter_contract_path: String,
capabilities: HashSet<Capability>,
block_lasting_overwrites: HashMap<H160, Overwrites>,
stateless_contracts: HashMap<String, Option<Vec<u8>>>,
trace: bool,
) -> Result<Self, ProtosimError> {
Expand All @@ -84,6 +98,8 @@ impl VMPoolState<PreCachedDB> {
balances,
balance_owner: None,
spot_prices,
capabilities,
block_lasting_overwrites,
stateless_contracts,
trace,
engine: None,
Expand Down Expand Up @@ -276,7 +292,7 @@ impl VMPoolState<PreCachedDB> {
// Manually unpack the inner vector
if let [t0, t1] = &tokens_pair[..] {
let sell_amount_limit = self
.get_sell_amount_limit(t0.address, t1.address)
.get_sell_amount_limit((*t0).clone(), (*t1).clone())
.await;
println!("Sell amount limit: {}", sell_amount_limit);
let price_result = self
Expand All @@ -289,11 +305,11 @@ impl VMPoolState<PreCachedDB> {
t1.address,
vec![sell_amount_limit],
self.block.number,
None, // TODO: in 3758 add overwrites here
self.block_lasting_overwrites,
)
.await?;
println!("Price result: {:?}", price_result);
// TODO: handle scaled price here
// TODO: handle scaled price here when we have capabilities
self.spot_prices.insert(
((*t0).clone(), (*t1).clone()),
price_result
Expand All @@ -306,23 +322,99 @@ impl VMPoolState<PreCachedDB> {
Ok(())
}

async fn get_sell_amount_limit(&self, sell_token: EthAddress, buy_token: EthAddress) -> U256 {
async fn get_sell_amount_limit(&self, sell_token: ERC20Token, buy_token: ERC20Token) -> U256 {
let binding = self
.adapter_contract
.clone()
.expect("Adapter contract not set");
let limits = binding
.get_limits(
self.id.clone()[2..].to_string(),
sell_token,
buy_token,
sell_token.address,
buy_token.address,
self.block.number,
None, // TODO: in 3758 add overwrites here
)
.await;
Some(self.get_overwrites(&sell_token, &buy_token, Some(U256::from_big_endian(&(*MAX_BALANCE / rU256::from(100)).to_be_bytes()))).await),
).await;

let sell_amount_limit = limits.expect("Expected a (u64, u64)").0;
sell_amount_limit
}

pub async fn get_overwrites(
&self,
sell_token: &ERC20Token,
buy_token: &ERC20Token,
max_amount: Option<U256>,
) -> HashMap<H160, Overwrites> {
let token_overwrites = self.get_token_overwrites(sell_token, buy_token, max_amount).await;

// Merge `block_lasting_overwrites` with `token_overwrites`
let mut overwrites =self.block_lasting_overwrites.clone();
// TODO: is this merge enough?? See here for python version protosim_py.python.protosim_py.evm.pool_state._merge
for (address, inner_map) in token_overwrites {
overwrites.entry(address)
.or_insert_with(HashMap::new)
.extend(inner_map)
}
overwrites
}

async fn get_token_overwrites(
&self,
sell_token: &ERC20Token,
buy_token: &ERC20Token,
max_amount: Option<U256>,
) -> HashMap<H160, Overwrites> {
let mut res: Vec<HashMap<H160, Overwrites>> = Vec::new();
if !self.capabilities.contains(&Capability::TokenBalanceIndependent) {
res.push(self.get_balance_overwrites());
}
let max_amount = if max_amount.is_none() {
self.get_sell_amount_limit((*sell_token).clone(), (*buy_token).clone()).await
} else {
max_amount.expect("Failed to get max amount")
};

let mut overwrites = ERC20OverwriteFactory::new(sell_token.address.clone(),(SlotHash::from_low_u64_be(0), SlotHash::from_low_u64_be(1)));

overwrites.set_balance(max_amount, H160::from_slice(&*EXTERNAL_ACCOUNT.0));

// Set allowance for ADAPTER_ADDRESS to max_amount
overwrites.set_allowance(max_amount, H160::from_slice(&*EXTERNAL_ACCOUNT.0), H160::from_slice(&*ADAPTER_ADDRESS.0));

res.push(overwrites.get_protosim_overwrites());

// Merge all overwrites into a single HashMap
// TODO: is this merge enough?? See here for python version protosim_py.python.protosim_py.evm.pool_state._merge
res.into_iter().fold(HashMap::new(), |mut acc, overwrite| {
for (address, inner_map) in overwrite {
acc.entry(address)
.or_insert_with(HashMap::new)
.extend(inner_map);
}
acc
})
}

fn get_balance_overwrites(&self) -> HashMap<H160, Overwrites>{
let mut balance_overwrites: HashMap<H160, Overwrites> = HashMap::new();
let address = self.balance_owner.unwrap_or(self.id.parse().expect("Pool ID is not an address"));

for token in &self.tokens {
let mut overwrites = ERC20OverwriteFactory::new(token.address.clone(), (SlotHash::from_low_u64_be(0), SlotHash::from_low_u64_be(1)));
overwrites.set_balance(
self.balances
.get(&token.address)
.cloned()
.unwrap_or_default(),

address,
);
balance_overwrites.extend(overwrites.get_protosim_overwrites());
}

balance_overwrites
}
}

#[cfg(test)]
Expand Down Expand Up @@ -379,6 +471,8 @@ mod tests {
HashMap::new(),
HashMap::new(),
"src/protocol/vm/assets/BalancerV2SwapAdapter.evm.runtime".to_string(),
HashSet::new(),
HashMap::new(),
HashMap::new(),
true,
)
Expand Down Expand Up @@ -471,6 +565,8 @@ mod tests {
]),
HashMap::new(),
"src/protocol/vm/assets/BalancerV2SwapAdapter.evm.runtime".to_string(),
HashSet::new(),
HashMap::new(),
HashMap::new(),
false,
)
Expand Down Expand Up @@ -520,12 +616,12 @@ mod tests {
async fn test_get_sell_amount_limit() {
let pool_state = setup_pool_state().await;
let dai_limit = pool_state
.get_sell_amount_limit(pool_state.tokens[0].address, pool_state.tokens[1].address)
.get_sell_amount_limit(pool_state.tokens[0].clone(), pool_state.tokens[1].clone())
.await;
assert_eq!(dai_limit, U256::from_dec_str("100279494253364362835").unwrap());

let bal_limit = pool_state
.get_sell_amount_limit(pool_state.tokens[1].address, pool_state.tokens[0].address)
.get_sell_amount_limit(pool_state.tokens[1].clone(), pool_state.tokens[0].clone())
.await;
assert_eq!(bal_limit, U256::from_dec_str("13997408640689987484").unwrap());
}
Expand Down

0 comments on commit 20aba87

Please sign in to comment.