Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion rln-prover/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,5 @@ RUST_LOG=debug cargo run -p prover_cli -- --ip 127.0.0.1 --metrics-ip 127.0.0.1
### Unit tests

* cargo test
* cargo test --features anvil
* cargo test --features anvil

24 changes: 23 additions & 1 deletion rln-prover/prover/src/grpc_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub mod prover_proto {
pub(crate) const FILE_DESCRIPTOR_SET: &[u8] =
tonic::include_file_descriptor_set!("prover_descriptor");
}
use crate::user_db_2::UserDb2;
use crate::user_db_2::{UserDb2, UserTierInfo2};
use crate::user_db_types::RateLimit;
use prover_proto::{
GetUserTierInfoReply,
Expand Down Expand Up @@ -492,6 +492,28 @@ impl From<UserTierInfo> for UserTierInfoResult {
}
}

/// UserTierInfo2 to UserTierInfoResult (Grpc message) conversion
impl From<UserTierInfo2> for UserTierInfoResult {
fn from(tier_info: UserTierInfo2) -> Self {
let mut res = UserTierInfoResult {
current_epoch: tier_info.current_epoch.into(),
// current_epoch_slice: tier_info.current_epoch_slice.into(),
current_epoch_slice: 0,
tx_count: tier_info.epoch_tx_count,
tier: None,
};

if tier_info.tier_name.is_some() && tier_info.tier_limit.is_some() {
res.tier = Some(Tier {
name: tier_info.tier_name.unwrap().into(),
quota: tier_info.tier_limit.unwrap().into(),
})
}

res
}
}

/// UserTierInfoError to UserTierInfoError (Grpc message) conversion
impl<E> From<crate::user_db_error::UserTierInfoError<E>> for UserTierInfoError
where
Expand Down
2 changes: 1 addition & 1 deletion rln-prover/prover/src/user_db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ mod tests {
user_db.register(addr).unwrap();

let (ec, ecs) = user_db.get_tx_counter(&addr).unwrap();
assert_eq!(ec, 0u64.into());
assert_eq!(ec, EpochCounter::from(0));
assert_eq!(ecs, EpochSliceCounter::from(0u64));

let ecs_2 = user_db.incr_tx_counter(&addr, Some(42)).unwrap();
Expand Down
103 changes: 63 additions & 40 deletions rln-prover/prover/src/user_db_2.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::fmt::Formatter;
use std::sync::Arc;
// third-party
use alloy::primitives::Address;
use alloy::primitives::{Address, U256};
use ark_bn254::Fr;
use parking_lot::RwLock;
use tokio::sync::RwLock as TokioRwLock;
Expand All @@ -15,7 +15,7 @@ use sea_orm::{
};
// internal
use crate::epoch_service::{Epoch, EpochSlice};
use crate::tier::{TierLimit, TierLimits, TierMatch};
use crate::tier::{TierLimit, TierLimits, TierMatch, TierName};
use crate::user_db::UserTierInfo;
use crate::user_db_error::{
GetMerkleTreeProofError2, RegisterError2, SetTierLimitsError2, TxCounterError2,
Expand All @@ -36,6 +36,16 @@ const TIER_LIMITS_NEXT_KEY: &str = "NEXT";

type ProverMerkleTree = MerkleTree<MemoryDb, ProverPoseidonHash, PersistentDb, MerkleTreeError>;

#[derive(Debug, PartialEq)]
pub struct UserTierInfo2 {
pub(crate) current_epoch: Epoch,
pub(crate) current_epoch_slice: EpochSlice,
pub(crate) epoch_tx_count: u64,
pub(crate) karma_amount: U256,
pub(crate) tier_name: Option<TierName>,
pub(crate) tier_limit: Option<TierLimit>,
}

#[derive(Clone, Debug)]
pub struct UserDb2Config {
pub(crate) tree_count: u64,
Expand Down Expand Up @@ -202,9 +212,9 @@ impl UserDb2 {
&self,
address: &Address,
incr_value: Option<i64>,
) -> Result<EpochSliceCounter, DbErr> {
) -> Result<EpochCounter, DbErr> {
let incr_value = incr_value.unwrap_or(1);
let (epoch, epoch_slice) = *self.epoch_store.read();
let (epoch, _epoch_slice) = *self.epoch_store.read();

let txn = self.db.begin().await?;

Expand All @@ -218,32 +228,24 @@ impl UserDb2 {

// unwrap safe: res_active.epoch/epoch_slice cannot be null
let model_epoch = res_active.epoch.clone().unwrap();
let model_epoch_slice = res_active.epoch_slice.clone().unwrap();
let model_epoch_counter = res_active.epoch_counter.clone().unwrap();
let model_epoch_slice_counter = res_active.epoch_slice_counter.clone().unwrap();
// let model_epoch_slice = res_active.epoch_slice.clone().unwrap();
// let model_epoch_slice_counter = res_active.epoch_slice_counter.clone().unwrap();

if model_epoch == 0 && model_epoch_slice == 0 {
if model_epoch == 0 {
res_active.epoch = Set(epoch.into());
res_active.epoch_slice = Set(epoch_slice.into());
res_active.epoch_counter = Set(incr_value);
res_active.epoch_slice_counter = Set(incr_value);
// res_active.epoch_slice = Set(epoch_slice.into());
// res_active.epoch_slice_counter = Set(incr_value);
} else if epoch != Epoch::from(model_epoch) {
// New epoch
res_active.epoch = Set(epoch.into());
res_active.epoch_slice = Set(0);
res_active.epoch_counter = Set(incr_value);
res_active.epoch_slice_counter = Set(incr_value);
} else if epoch_slice != EpochSlice::from(model_epoch_slice) {
// New epoch slice
res_active.epoch = Set(epoch.into());
res_active.epoch_slice = Set(epoch_slice.into());
res_active.epoch_counter = Set(model_epoch_counter.saturating_add(incr_value));
res_active.epoch_slice_counter = Set(incr_value);
// res_active.epoch_slice = Set(0);
// res_active.epoch_slice_counter = Set(incr_value);
} else {
// Same epoch & epoch slice
// Same epoch
res_active.epoch_counter = Set(model_epoch_counter.saturating_add(incr_value));
res_active.epoch_slice_counter =
Set(model_epoch_slice_counter.saturating_add(incr_value));
}

// res_active.update(&txn).await?;
Expand All @@ -253,9 +255,9 @@ impl UserDb2 {
let new_tx_counter = tx_counter::ActiveModel {
address: Set(address.to_string()),
epoch: Set(epoch.into()),
epoch_slice: Set(epoch_slice.into()),
epoch_counter: Set(incr_value),
epoch_slice_counter: Set(incr_value),
// epoch_slice: Set(epoch_slice.into()),
// epoch_slice_counter: Set(incr_value),
..Default::default()
};

Expand All @@ -267,13 +269,13 @@ impl UserDb2 {

txn.commit().await?;
// FIXME: no 'as'
Ok((new_tx_counter.epoch_slice_counter as u64).into())
Ok((new_tx_counter.epoch_counter as u64).into())
}

pub(crate) async fn get_tx_counter(
&self,
address: &Address,
) -> Result<(EpochCounter, EpochSliceCounter), TxCounterError2> {
) -> Result<EpochCounter, TxCounterError2> {
let res = tx_counter::Entity::find()
.filter(tx_counter::Column::Address.eq(address.to_string()))
.one(&self.db)
Expand All @@ -285,6 +287,26 @@ impl UserDb2 {
}
}

fn counters_from_key(&self, model: tx_counter::Model) -> EpochCounter {
let (epoch, _epoch_slice) = *self.epoch_store.read();
let cmp = (model.epoch == i64::from(epoch));

match cmp {
true => {
// EpochCounter stored in DB == epoch store
// We query for an epoch and this is what is stored in the Db
(model.epoch_counter as u64).into()
}
false => {
// EpochCounter.epoch (stored in DB) != epoch_store.epoch
// We query for an epoch after what is stored in Db
// This can happen if no Tx has updated the epoch counter (yet)
EpochCounter::from(0)
}
}
}

/*
fn counters_from_key(&self, model: tx_counter::Model) -> (EpochCounter, EpochSliceCounter) {
let (epoch, epoch_slice) = *self.epoch_store.read();
let cmp = (
Expand Down Expand Up @@ -327,6 +349,7 @@ impl UserDb2 {
}
}
}
*/

// user register & delete (with app logic)

Expand Down Expand Up @@ -509,12 +532,12 @@ impl UserDb2 {
&self,
address: &Address,
incr_value: Option<i64>,
) -> Result<EpochSliceCounter, TxCounterError2> {
) -> Result<EpochCounter, TxCounterError2> {
let has_user = self.has_user(address).await?;

if has_user {
let epoch_slice_counter = self.incr_tx_counter(address, incr_value).await?;
Ok(epoch_slice_counter)
let epoch_counter = self.incr_tx_counter(address, incr_value).await?;
Ok(epoch_counter)
} else {
Err(TxCounterError2::NotRegistered(*address))
}
Expand All @@ -535,7 +558,7 @@ impl UserDb2 {
&self,
address: &Address,
karma_sc: &KSC,
) -> Result<UserTierInfo, UserTierInfoError2<E>> {
) -> Result<UserTierInfo2, UserTierInfoError2<E>> {
let has_user = self
.has_user(address)
.await
Expand All @@ -551,18 +574,18 @@ impl UserDb2 {
.map_err(|e| UserTierInfoError2::Contract(e))?;

// TODO
let (epoch_tx_count, epoch_slice_tx_count) = self.get_tx_counter(address).await?;
let epoch_tx_count = self.get_tx_counter(address).await?;
// TODO: avoid db query the tier limits (keep it in memory)
let tier_limits = self.get_tier_limits().await?;
let tier_match = tier_limits.get_tier_by_karma(&karma_amount);

let user_tier_info = {
let (current_epoch, current_epoch_slice) = *self.epoch_store.read();
let mut t = UserTierInfo {
let mut t = UserTierInfo2 {
current_epoch,
current_epoch_slice,
epoch_tx_count: epoch_tx_count.into(),
epoch_slice_tx_count: epoch_slice_tx_count.into(),
// epoch_slice_tx_count: epoch_slice_tx_count.into(),
karma_amount,
tier_name: None,
tier_limit: None,
Expand Down Expand Up @@ -709,21 +732,21 @@ mod tests {
assert!(user_db.get_user_identity(&addr).await.is_some());
assert_eq!(
user_db.get_tx_counter(&addr).await.unwrap(),
(0.into(), 0.into())
EpochCounter::from(0)
);

assert!(user_db.get_user_identity(&ADDR_1).await.is_none());
user_db.register_user(ADDR_1).await.unwrap();
assert!(user_db.get_user_identity(&ADDR_1).await.is_some());
assert_eq!(
user_db.get_tx_counter(&addr).await.unwrap(),
(0.into(), 0.into())
EpochCounter::from(0)
);

user_db.incr_tx_counter(&addr, Some(42)).await.unwrap();
assert_eq!(
user_db.get_tx_counter(&addr).await.unwrap(),
(42.into(), 42.into())
EpochCounter::from(42)
);
}

Expand Down Expand Up @@ -753,13 +776,13 @@ mod tests {

user_db.register_user(addr).await.unwrap();

let (ec, ecs) = user_db.get_tx_counter(&addr).await.unwrap();
assert_eq!(ec, 0u64.into());
assert_eq!(ecs, EpochSliceCounter::from(0u64));
let ec = user_db.get_tx_counter(&addr).await.unwrap();
assert_eq!(ec, EpochCounter::from(0));
// assert_eq!(ecs, EpochSliceCounter::from(0u64));

let ecs_2 = user_db.incr_tx_counter(&addr, Some(42)).await.unwrap();
// TODO
assert_eq!(ecs_2, EpochSliceCounter::from(42));
assert_eq!(ecs_2, EpochCounter::from(42));
}

#[tokio::test]
Expand Down Expand Up @@ -803,14 +826,14 @@ mod tests {
// Now update user tx counter
assert_eq!(
user_db.on_new_tx(&addr, None).await,
Ok(EpochSliceCounter::from(1))
Ok(EpochCounter::from(1))
);
let tier_info = user_db
.user_tier_info(&addr, &MockKarmaSc {})
.await
.unwrap();
assert_eq!(tier_info.epoch_tx_count, 1);
assert_eq!(tier_info.epoch_slice_tx_count, 1);
// assert_eq!(tier_info.epoch_slice_tx_count, 1);
}

#[tokio::test]
Expand Down
26 changes: 13 additions & 13 deletions rln-prover/prover/src/user_db_2_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,45 +93,45 @@ mod tests {

assert_eq!(
user_db.get_tx_counter(&ADDR_1).await,
Ok((EpochCounter::from(0), EpochSliceCounter::from(0)))
Ok(EpochCounter::from(0))
);
assert_eq!(
user_db.get_tx_counter(&ADDR_2).await,
Ok((EpochCounter::from(0), EpochSliceCounter::from(0)))
Ok(EpochCounter::from(0))
);

// Now update user tx counter
assert_eq!(
user_db.on_new_tx(&ADDR_1, None).await,
Ok(EpochSliceCounter::from(1))
Ok(EpochCounter::from(1))
);
assert_eq!(
user_db.on_new_tx(&ADDR_1, None).await,
Ok(EpochSliceCounter::from(2))
Ok(EpochCounter::from(2))
);
assert_eq!(
user_db.on_new_tx(&ADDR_1, Some(2)).await,
Ok(EpochSliceCounter::from(4))
Ok(EpochCounter::from(4))
);

assert_eq!(
user_db.on_new_tx(&ADDR_2, None).await,
Ok(EpochSliceCounter::from(1))
Ok(EpochCounter::from(1))
);

assert_eq!(
user_db.on_new_tx(&ADDR_2, None).await,
Ok(EpochSliceCounter::from(2))
Ok(EpochCounter::from(2))
);

assert_eq!(
user_db.get_tx_counter(&ADDR_1).await,
Ok((EpochCounter::from(4), EpochSliceCounter::from(4)))
Ok(EpochCounter::from(4))
);

assert_eq!(
user_db.get_tx_counter(&ADDR_2).await,
Ok((EpochCounter::from(2), EpochSliceCounter::from(2)))
Ok(EpochCounter::from(2))
);
}

Expand Down Expand Up @@ -179,11 +179,11 @@ mod tests {

assert_eq!(
user_db.on_new_tx(&ADDR_1, Some(2)).await,
Ok(EpochSliceCounter::from(2))
Ok(EpochCounter::from(2))
);
assert_eq!(
user_db.on_new_tx(&ADDR_2, Some(1000)).await,
Ok(EpochSliceCounter::from(1000))
Ok(EpochCounter::from(1000))
);

db_conn.close().await.unwrap();
Expand Down Expand Up @@ -213,11 +213,11 @@ mod tests {
assert!(user_db.has_user(&ADDR_2).await.unwrap());
assert_eq!(
user_db.get_tx_counter(&ADDR_1).await.unwrap(),
(2.into(), 2.into())
EpochCounter::from(2)
);
assert_eq!(
user_db.get_tx_counter(&ADDR_2).await.unwrap(),
(1000.into(), 1000.into())
EpochCounter::from(1000)
);

let user_model = user_db.get_user(&ADDR_1).await.unwrap().unwrap();
Expand Down
Loading
Loading