Skip to content

Commit

Permalink
Implement balance table (#285)
Browse files Browse the repository at this point in the history
* Implement balance table

* Update hamt doc test and port over go actor tests

* Add more tests for fun

* headers

* merge conflict import fixes
  • Loading branch information
austinabell authored Mar 19, 2020
1 parent f2027f0 commit 7671255
Show file tree
Hide file tree
Showing 11 changed files with 356 additions and 7 deletions.
33 changes: 32 additions & 1 deletion ipld/hamt/src/hamt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ where
V: DeserializeOwned,
{
match self.root.get(k, self.store, self.bit_width)? {
Some(v) => Ok(Some(from_ipld(v).map_err(Error::Encoding)?)),
Some(v) => Ok(Some(from_ipld(&v).map_err(Error::Encoding)?)),
None => Ok(None),
}
}
Expand Down Expand Up @@ -201,4 +201,35 @@ where
pub fn is_empty(&self) -> bool {
self.root.is_empty()
}

/// Iterates over each KV in the Hamt and runs a function on the values.
///
/// This function will constrain all values to be of the same type
///
/// # Examples
///
/// ```
/// use ipld_hamt::Hamt;
///
/// let store = db::MemoryDB::default();
///
/// let mut map: Hamt<usize, _> = Hamt::new(&store);
/// map.set(1, 1).unwrap();
/// map.set(4, 2).unwrap();
///
/// let mut total = 0;
/// map.for_each(&mut |_, v: u64| {
/// total += v;
/// Ok(())
/// }).unwrap();
/// assert_eq!(total, 3);
/// ```
#[inline]
pub fn for_each<F, V>(&self, f: &mut F) -> Result<(), String>
where
V: DeserializeOwned,
F: FnMut(&K, V) -> Result<(), String>,
{
self.root.for_each(self.store, f)
}
}
28 changes: 25 additions & 3 deletions ipld/hamt/src/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ use super::hash_bits::HashBits;
use super::pointer::Pointer;
use super::{Error, Hash, HashedKey, KeyValuePair, MAX_ARRAY_WIDTH};
use cid::multihash::Blake2b256;
use forest_encoding::{de::Deserializer, ser::Serializer};
use forest_ipld::Ipld;
use forest_ipld::{from_ipld, Ipld};
use ipld_blockstore::BlockStore;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::borrow::Borrow;
use std::fmt::Debug;

Expand Down Expand Up @@ -129,6 +128,29 @@ where
self.pointers.is_empty()
}

pub(crate) fn for_each<V, S, F>(&self, store: &S, f: &mut F) -> Result<(), String>
where
V: DeserializeOwned,
F: FnMut(&K, V) -> Result<(), String>,
S: BlockStore,
{
for p in &self.pointers {
match p {
Pointer::Link(cid) => match store.get::<Node<K>>(cid)? {
Some(node) => node.for_each(store, f)?,
None => return Err(format!("Node with cid {} not found", cid)),
},
Pointer::Cache(n) => n.for_each(store, f)?,
Pointer::Values(kvs) => {
for kv in kvs {
f(kv.0.borrow(), from_ipld(&kv.1).map_err(Error::Encoding)?)?;
}
}
}
}
Ok(())
}

/// Search for a key.
fn search<Q: ?Sized, S: BlockStore>(
&self,
Expand Down
4 changes: 2 additions & 2 deletions ipld/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ where

/// Convert a `Ipld` structure into a type `T`
/// Currently converts using a byte buffer with serde_cbor
pub fn from_ipld<T>(value: Ipld) -> Result<T, String>
pub fn from_ipld<T>(value: &Ipld) -> Result<T, String>
where
T: DeserializeOwned,
{
// TODO find a way to convert without going through byte buffer
let buf = to_vec(&value).map_err(|e| e.to_string())?;
let buf = to_vec(value).map_err(|e| e.to_string())?;
from_slice(buf.as_slice()).map_err(|e| e.to_string())
}
1 change: 1 addition & 0 deletions node/clock/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl ChainEpochClock {
}
}

// TODO revisit usage of the Sub impls, these will panic (checked sub or floored sub would be safer)
impl Sub for ChainEpoch {
type Output = ChainEpoch;

Expand Down
6 changes: 6 additions & 0 deletions node/db/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ impl fmt::Display for Error {
}
}

impl From<Error> for String {
fn from(e: Error) -> Self {
e.to_string()
}
}

impl From<rocksdb::Error> for Error {
fn from(e: rocksdb::Error) -> Error {
Error::Database(String::from(e))
Expand Down
3 changes: 3 additions & 0 deletions vm/actor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@ ipld_blockstore = { path = "../../ipld/blockstore" }
ipld_hamt = { path = "../../ipld/hamt" }
forest_ipld = { path = "../../ipld" }
message = { package = "forest_message", path = "../message" }

[dev-dependencies]
db = { path = "../../node/db" }
2 changes: 2 additions & 0 deletions vm/actor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
extern crate lazy_static;

mod builtin;
mod util;

pub use self::builtin::*;
pub use self::util::*;
pub use vm::{ActorID, ActorState, Serialized};

use ipld_blockstore::BlockStore;
Expand Down
132 changes: 132 additions & 0 deletions vm/actor/src/util/balance_table.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Copyright 2020 ChainSafe Systems
// SPDX-License-Identifier: Apache-2.0, MIT

use crate::HAMT_BIT_WIDTH;
use address::Address;
use cid::Cid;
use ipld_blockstore::BlockStore;
use ipld_hamt::{Error, Hamt};
use num_traits::CheckedSub;
use vm::TokenAmount;

/// Balance table which handles getting and updating token balances specifically
pub struct BalanceTable<'a, BS>(Hamt<'a, String, BS>);
impl<'a, BS> BalanceTable<'a, BS>
where
BS: BlockStore,
{
/// Initializes a new empty balance table
pub fn new_empty(bs: &'a BS) -> Self {
Self(Hamt::new_with_bit_width(bs, HAMT_BIT_WIDTH))
}

/// Initializes a balance table from a root Cid
pub fn from_root(bs: &'a BS, cid: &Cid) -> Result<Self, Error> {
Ok(Self(Hamt::load_with_bit_width(cid, bs, HAMT_BIT_WIDTH)?))
}

/// Retrieve root from balance table
#[inline]
pub fn root(&mut self) -> Result<Cid, Error> {
self.0.flush()
}

/// Gets token amount for given address in balance table
#[inline]
pub fn get(&self, key: &Address) -> Result<TokenAmount, String> {
Ok(self
.0
.get(&key.hash_key())?
// TODO investigate whether it's worth it to cache root to give better error details
.ok_or("no key {} in map root")?)
}

/// Checks if a balance for an address exists
#[inline]
pub fn has(&self, key: &Address) -> Result<bool, Error> {
match self.0.get::<_, TokenAmount>(&key.hash_key())? {
Some(_) => Ok(true),
None => Ok(false),
}
}

/// Sets the balance for the address, overwriting previous value
#[inline]
pub fn set(&mut self, key: &Address, value: TokenAmount) -> Result<(), Error> {
self.0.set(key.hash_key(), value)
}

/// Adds token amount to previously initialized account.
pub fn add(&mut self, key: &Address, value: TokenAmount) -> Result<(), String> {
let prev = self.get(key)?;
Ok(self.0.set(key.hash_key(), prev + value)?)
}

/// Adds an amount to a balance. Creates entry if not exists
pub fn add_create(&mut self, key: &Address, value: TokenAmount) -> Result<(), String> {
let new_val = match self.0.get::<_, TokenAmount>(&key.hash_key())? {
Some(v) => v + value,
None => value,
};
Ok(self.0.set(key.hash_key(), new_val)?)
}

/// Subtracts up to the specified amount from a balance, without reducing the balance
/// below some minimum.
/// Returns the amount subtracted (always positive or zero).
pub fn subtract_with_minimum(
&mut self,
key: &Address,
req: &TokenAmount,
floor: &TokenAmount,
) -> Result<TokenAmount, String> {
let prev = self.get(key)?;
let res = prev.checked_sub(req).unwrap_or_else(|| TokenAmount::new(0));
let new_val: TokenAmount = std::cmp::max(&res, floor).clone();

if prev > new_val {
// Subtraction needed, set new value and return change
self.0.set(key.hash_key(), new_val.clone())?;
Ok(prev - new_val)
} else {
// New value is same as previous, no change needed
Ok(TokenAmount::default())
}
}

/// Subtracts value from a balance, and errors if full amount was not substracted.
pub fn must_subtract(&mut self, key: &Address, req: &TokenAmount) -> Result<(), String> {
let sub_amt = self.subtract_with_minimum(key, req, &TokenAmount::new(0))?;
if &sub_amt != req {
return Err(format!(
"Couldn't subtract value from address {} (req: {}, available: {})",
key, req, sub_amt
));
}

Ok(())
}

/// Removes an entry from the table, returning the prior value. The entry must have been previously initialized.
pub fn remove(&mut self, key: &Address) -> Result<TokenAmount, String> {
// Ensure entry exists and get previous value
let prev = self.get(key)?;

// Remove entry from table
self.0.delete(&key.hash_key())?;

Ok(prev)
}

/// Returns total balance held by this balance table
pub fn total(&self) -> Result<TokenAmount, String> {
let mut total = TokenAmount::default();

self.0.for_each(&mut |_, v| {
total += v;
Ok(())
})?;

Ok(total)
}
}
6 changes: 6 additions & 0 deletions vm/actor/src/util/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// Copyright 2020 ChainSafe Systems
// SPDX-License-Identifier: Apache-2.0, MIT

mod balance_table;

pub use self::balance_table::BalanceTable;
109 changes: 109 additions & 0 deletions vm/actor/tests/balance_table_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright 2020 ChainSafe Systems
// SPDX-License-Identifier: Apache-2.0, MIT

use actor::BalanceTable;
use address::Address;
use vm::TokenAmount;

// Ported test from specs-actors
#[test]
fn add_create() {
let addr = Address::new_id(100).unwrap();
let store = db::MemoryDB::default();
let mut bt = BalanceTable::new_empty(&store);

assert_eq!(bt.has(&addr), Ok(false));

bt.add_create(&addr, TokenAmount::new(10)).unwrap();
assert_eq!(bt.get(&addr), Ok(TokenAmount::new(10)));

bt.add_create(&addr, TokenAmount::new(20)).unwrap();
assert_eq!(bt.get(&addr), Ok(TokenAmount::new(30)));
}

// Ported test from specs-actors
#[test]
fn total() {
let addr1 = Address::new_id(100).unwrap();
let addr2 = Address::new_id(101).unwrap();
let store = db::MemoryDB::default();
let mut bt = BalanceTable::new_empty(&store);

assert_eq!(bt.total(), Ok(TokenAmount::new(0)));

struct TotalTestCase<'a> {
amount: u64,
addr: &'a Address,
total: u64,
}
let test_vectors = [
TotalTestCase {
amount: 10,
addr: &addr1,
total: 10,
},
TotalTestCase {
amount: 20,
addr: &addr1,
total: 30,
},
TotalTestCase {
amount: 40,
addr: &addr2,
total: 70,
},
TotalTestCase {
amount: 50,
addr: &addr2,
total: 120,
},
];

for t in test_vectors.iter() {
bt.add_create(t.addr, TokenAmount::new(t.amount)).unwrap();

assert_eq!(bt.total(), Ok(TokenAmount::new(t.total)));
}
}

#[test]
fn balance_subtracts() {
let addr = Address::new_id(100).unwrap();
let store = db::MemoryDB::default();
let mut bt = BalanceTable::new_empty(&store);

bt.set(&addr, TokenAmount::new(80)).unwrap();
assert_eq!(bt.get(&addr), Ok(TokenAmount::new(80)));
// Test subtracting past minimum only subtracts correct amount
assert_eq!(
bt.subtract_with_minimum(&addr, &TokenAmount::new(20), &TokenAmount::new(70)),
Ok(TokenAmount::new(10))
);
assert_eq!(bt.get(&addr), Ok(TokenAmount::new(70)));

// Test subtracting to limit
assert_eq!(
bt.subtract_with_minimum(&addr, &TokenAmount::new(10), &TokenAmount::new(60)),
Ok(TokenAmount::new(10))
);
assert_eq!(bt.get(&addr), Ok(TokenAmount::new(60)));

// Test must subtract success
bt.must_subtract(&addr, &TokenAmount::new(10)).unwrap();
assert_eq!(bt.get(&addr), Ok(TokenAmount::new(50)));

// Test subtracting more than available
assert!(bt.must_subtract(&addr, &TokenAmount::new(100)).is_err());
}

#[test]
fn remove() {
let addr = Address::new_id(100).unwrap();
let store = db::MemoryDB::default();
let mut bt = BalanceTable::new_empty(&store);

bt.set(&addr, TokenAmount::new(1)).unwrap();
assert_eq!(bt.get(&addr), Ok(TokenAmount::new(1)));
bt.remove(&addr).unwrap();
assert!(bt.get(&addr).is_err());
}
Loading

0 comments on commit 7671255

Please sign in to comment.