Skip to content

Commit

Permalink
added user_id as an option parameter to getappointments cli command
Browse files Browse the repository at this point in the history
Signed-off-by: aruokhai <joshuaaruokhaitech@gmail.com>

removed todo

Signed-off-by: aruokhai <joshuaaruokhaitech@gmail.com>
  • Loading branch information
aruokhai committed Jan 8, 2024
1 parent 818f756 commit 332f3c4
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 53 deletions.
2 changes: 1 addition & 1 deletion teos/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::configure()
.extern_path(".common.teos.v2", "::teos-common::protos")
.type_attribute(".", "#[derive(serde::Serialize, serde::Deserialize)]")
.field_attribute("user_id", "#[serde(with = \"hex::serde\")]")
.field_attribute("GetUserRequest.user_id", "#[serde(with = \"hex::serde\")]")
.field_attribute("tower_id", "#[serde(with = \"hex::serde\")]")
.field_attribute(
"user_ids",
Expand Down
3 changes: 2 additions & 1 deletion teos/proto/teos/v2/appointment.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ package teos.v2;
import "common/teos/v2/appointment.proto";

message GetAppointmentsRequest {
// Request the information of appointments with specific locator.
// Request the information of appointments with specific locator and user_id (optional) .

bytes locator = 1;
optional bytes user_id = 2;
}

message GetAppointmentsResponse {
Expand Down
49 changes: 34 additions & 15 deletions teos/src/api/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::sync::{Arc, Condvar, Mutex};
use tonic::{Code, Request, Response, Status};
use triggered::Trigger;

use crate::extended_appointment::UUID;
use crate::protos as msgs;
use crate::protos::private_tower_services_server::PrivateTowerServices;
use crate::protos::public_tower_services_server::PublicTowerServices;
Expand Down Expand Up @@ -280,31 +279,44 @@ impl PrivateTowerServices for Arc<InternalAPI> {
.map_or("an unknown address".to_owned(), |a| a.to_string())
);

let mut matching_appointments = vec![];
let locator = Locator::from_slice(&request.into_inner().locator).map_err(|_| {
let req_data = request.into_inner();
let locator = Locator::from_slice(&req_data.locator).map_err(|_| {
Status::new(
Code::InvalidArgument,
"The provided locator does not match the expected format (16-byte hexadecimal string)",
)
})?;

for (_, appointment) in self
let user_id = req_data
.user_id
.map(|id| UserId::from_slice(&id))
.transpose()
.map_err(|_| {
Status::new(
Code::InvalidArgument,
"The Provided user_id does not match expected format (33-byte hex string)",
)
})?;

let appointments: Vec<Appointment> = self
.watcher
.get_watcher_appointments_with_locator(locator)
.get_watcher_appointments_with_locator(locator, user_id)
.into_values()
.map(|appointment| appointment.inner)
.collect();

let mut matching_appointments: Vec<common_msgs::AppointmentData> = appointments
.into_iter()
{
matching_appointments.push(common_msgs::AppointmentData {
.map(|appointment| common_msgs::AppointmentData {
appointment_data: Some(
common_msgs::appointment_data::AppointmentData::Appointment(
appointment.inner.into(),
),
common_msgs::appointment_data::AppointmentData::Appointment(appointment.into()),
),
})
}
.collect();

for (_, tracker) in self
.watcher
.get_responder_trackers_with_locator(locator)
.get_responder_trackers_with_locator(locator, user_id)
.into_iter()
{
matching_appointments.push(common_msgs::AppointmentData {
Expand Down Expand Up @@ -390,7 +402,6 @@ impl PrivateTowerServices for Arc<InternalAPI> {
Some((info, locators)) => Ok(Response::new(msgs::GetUserResponse {
available_slots: info.available_slots,
subscription_expiry: info.subscription_expiry,
// TODO: Should make `get_appointments` queryable using the (user_id, locator) pair for consistency.
appointments: locators
.into_iter()
.map(|locator| locator.to_vec())
Expand Down Expand Up @@ -511,7 +522,10 @@ mod tests_private_api {

let locator = Locator::new(get_random_tx().txid()).to_vec();
let response = internal_api
.get_appointments(Request::new(msgs::GetAppointmentsRequest { locator }))
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator,
user_id: None,
}))
.await
.unwrap()
.into_inner();
Expand Down Expand Up @@ -548,6 +562,7 @@ mod tests_private_api {
let response = internal_api
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator: locator.to_vec(),
user_id: None,
}))
.await
.unwrap()
Expand Down Expand Up @@ -599,6 +614,7 @@ mod tests_private_api {
let response = internal_api
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator: locator.to_vec(),
user_id: None,
}))
.await
.unwrap()
Expand Down Expand Up @@ -747,7 +763,10 @@ mod tests_private_api {

assert_eq!(response.available_slots, SLOTS - 1);
assert_eq!(response.subscription_expiry, START_HEIGHT as u32 + DURATION);
assert_eq!(response.appointments, Vec::from([appointment.inner.locator.to_vec()]));
assert_eq!(
response.appointments,
Vec::from([appointment.inner.locator.to_vec()])
);
}

#[tokio::test]
Expand Down
33 changes: 21 additions & 12 deletions teos/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,29 @@ async fn main() {
println!("{}", pretty_json(&appointments.into_inner()).unwrap());
}
Command::GetAppointments(appointments_data) => {
match Locator::from_hex(&appointments_data.locator) {
Ok(locator) => {
match client
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator: locator.to_vec(),
}))
.await
{
Ok(appointments) => {
println!("{}", pretty_json(&appointments.into_inner()).unwrap())
match appointments_data
.user_id
.map(|id| UserId::from_str(id.as_str()))
.map(|id| id.map(|user_id| user_id.to_vec()))
.transpose()
{
Ok(user_id) => match Locator::from_hex(&appointments_data.locator) {
Ok(locator) => {
match client
.get_appointments(Request::new(msgs::GetAppointmentsRequest {
locator: locator.to_vec(),
user_id,
}))
.await
{
Ok(appointments) => {
println!("{}", pretty_json(&appointments.into_inner()).unwrap())
}
Err(status) => handle_error(status.message()),
}
Err(status) => handle_error(status.message()),
}
}
Err(e) => handle_error(e),
},
Err(e) => handle_error(e),
};
}
Expand Down
2 changes: 2 additions & 0 deletions teos/src/cli_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ pub struct GetUserData {
pub struct GetAppointmentsData {
/// The locator of the appointments (16-byte hexadecimal string).
pub locator: String,
/// The user identifier (33-byte compressed public key).
pub user_id: Option<String>,
}

/// Holds all the command line options and commands.
Expand Down
149 changes: 129 additions & 20 deletions teos/src/dbm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,23 +330,30 @@ impl DBM {
/// matching this locator. If no locator is given, all the appointments in the database would be returned.
pub(crate) fn load_appointments(
&self,
locator: Option<Locator>,
extra_params: Option<(Locator, Option<UserId>)>,
) -> HashMap<UUID, ExtendedAppointment> {
let mut appointments = HashMap::new();

let mut sql =
"SELECT a.UUID, a.locator, a.encrypted_blob, a.to_self_delay, a.user_signature, a.start_block, a.user_id
FROM appointments as a LEFT JOIN trackers as t ON a.UUID=t.UUID WHERE t.UUID IS NULL".to_string();
// If a locator was passed, filter based on it.
if locator.is_some() {
sql.push_str(" AND a.locator=(?)");
}

// If a locator and an optional user_id were passed, filter based on it.
if let Some((_, user_id)) = extra_params {
sql.push_str(" AND a.locator=(?1)");
if user_id.is_some() {
sql.push_str(" AND a.user_id=(?2)");
}
};

let mut stmt = self.connection.prepare(&sql).unwrap();

let mut rows = if let Some(locator) = locator {
stmt.query([locator.to_vec()]).unwrap()
} else {
stmt.query([]).unwrap()
let mut rows = match extra_params {
Some((locator, None)) => stmt.query([locator.to_vec()]).unwrap(),
Some((locator, Some(user_id))) => {
stmt.query([locator.to_vec(), user_id.to_vec()]).unwrap()
}
_ => stmt.query([]).unwrap(),
};

while let Ok(Some(row)) = rows.next() {
Expand Down Expand Up @@ -596,23 +603,30 @@ impl DBM {
/// matching this locator. If no locator is given, all the trackers in the database would be returned.
pub(crate) fn load_trackers(
&self,
locator: Option<Locator>,
extra_params: Option<(Locator, Option<UserId>)>,
) -> HashMap<UUID, TransactionTracker> {
let mut trackers = HashMap::new();

let mut sql = "SELECT t.UUID, t.dispute_tx, t.penalty_tx, t.height, t.confirmed, a.user_id
FROM trackers as t INNER JOIN appointments as a ON t.UUID=a.UUID"
.to_string();
// If a locator was passed, filter based on it.
if locator.is_some() {
sql.push_str(" WHERE a.locator=(?)");

// If a locator and an optional user_id were passed, filter based on it.
if let Some((_, user_id)) = extra_params {
sql.push_str(" AND a.locator=(?1)");
if user_id.is_some() {
sql.push_str(" AND a.user_id=(?2)");
}
}

let mut stmt = self.connection.prepare(&sql).unwrap();

let mut rows = if let Some(locator) = locator {
stmt.query([locator.to_vec()]).unwrap()
} else {
stmt.query([]).unwrap()
let mut rows = match extra_params {
Some((locator, None)) => stmt.query([locator.to_vec()]).unwrap(),
Some((locator, Some(user_id))) => {
stmt.query([locator.to_vec(), user_id.to_vec()]).unwrap()
}
_ => stmt.query([]).unwrap(),
};

while let Ok(Some(row)) = rows.next() {
Expand Down Expand Up @@ -1157,7 +1171,7 @@ mod tests {
}

// Validate that no other appointments than the ones with our locator are returned.
assert_eq!(dbm.load_appointments(Some(locator)), appointments);
assert_eq!(dbm.load_appointments(Some((locator, None))), appointments);

// If an appointment has an associated tracker, it should not be loaded since it is seen
// as a triggered appointment
Expand All @@ -1175,7 +1189,52 @@ mod tests {
dbm.store_tracker(uuid, &tracker).unwrap();

// We should get all the appointments matching our locator back except from the triggered one
assert_eq!(dbm.load_appointments(Some(locator)), appointments);
assert_eq!(dbm.load_appointments(Some((locator, None))), appointments);
}

#[test]
fn test_load_appointments_with_locator_and_user_id() {
let dbm = DBM::in_memory().unwrap();

// create two appointment maps for two userId
let mut user_id1_appointments = HashMap::new();
let mut user_id2_appointments = HashMap::new();

let dispute_tx = get_random_tx();
let dispute_txid = dispute_tx.txid();
let locator = Locator::new(dispute_txid);

// generate two user ids
let user_id1 = get_random_user_id();
let user_id2 = get_random_user_id();

let user = UserInfo::new(AVAILABLE_SLOTS, SUBSCRIPTION_START, SUBSCRIPTION_EXPIRY);
dbm.store_user(user_id1, &user).unwrap();
dbm.store_user(user_id2, &user).unwrap();

let (uuid, appointment) =
generate_dummy_appointment_with_user(user_id1, Some(&dispute_txid));
dbm.store_appointment(uuid, &appointment).unwrap();
// Store the appointment for the first user_id made using our dispute tx.
user_id1_appointments.insert(uuid, appointment);

let (uuid, appointment) =
generate_dummy_appointment_with_user(user_id2, Some(&dispute_txid));
dbm.store_appointment(uuid, &appointment).unwrap();
// Store the appointment for the second user_id made using our dispute tx.
user_id2_appointments.insert(uuid, appointment);

// Validate that the first user_id appointment map matches the fetched appointments.
assert_eq!(
dbm.load_appointments(Some((locator, Some(user_id1))),),
user_id1_appointments
);

// Validate that the second user_id appointment map matches the fetched appointments.
assert_eq!(
dbm.load_appointments(Some((locator, Some(user_id2)))),
user_id2_appointments
);
}

#[test]
Expand Down Expand Up @@ -1550,7 +1609,57 @@ mod tests {
dbm.store_tracker(uuid, &tracker).unwrap();
}

assert_eq!(dbm.load_trackers(Some(locator)), trackers);
assert_eq!(dbm.load_trackers(Some((locator, None))), trackers);
}

#[test]
fn test_load_trackers_with_locator_and_user_id() {
let dbm = DBM::in_memory().unwrap();

// create two tracker maps for two userId
let mut user_id1_trackers = HashMap::new();
let mut user_id2_trackers = HashMap::new();

let dispute_tx = get_random_tx();
let dispute_txid = dispute_tx.txid();
let locator = Locator::new(dispute_txid);
let status = ConfirmationStatus::InMempoolSince(42);

// generate two user ids
let user_id1 = get_random_user_id();
let user_id2 = get_random_user_id();

let user = UserInfo::new(AVAILABLE_SLOTS, SUBSCRIPTION_START, SUBSCRIPTION_EXPIRY);
dbm.store_user(user_id1, &user).unwrap();
dbm.store_user(user_id2, &user).unwrap();

// Store the tracker for the first user_id created with a new appointment .
let (uuid, appointment) =
generate_dummy_appointment_with_user(user_id1, Some(&dispute_txid));
let tracker = get_random_tracker(user_id1, status);
dbm.store_appointment(uuid, &appointment).unwrap();
dbm.store_tracker(uuid, &tracker).unwrap();
user_id1_trackers.insert(uuid, tracker);

// Store the tracker for the second user_id created with a new appointment .
let (uuid, appointment) =
generate_dummy_appointment_with_user(user_id2, Some(&dispute_txid));
let tracker = get_random_tracker(user_id2, status);
dbm.store_appointment(uuid, &appointment).unwrap();
dbm.store_tracker(uuid, &tracker).unwrap();
user_id2_trackers.insert(uuid, tracker);

// Validate that the first user_id tracker map matches the fetched trackers.
assert_eq!(
dbm.load_trackers(Some((locator, Some(user_id1)))),
user_id1_trackers
);

// Validate that the second user_id tracker map matches the fetched trackers.
assert_eq!(
dbm.load_trackers(Some((locator, Some(user_id2)))),
user_id2_trackers
);
}

#[test]
Expand Down
Loading

0 comments on commit 332f3c4

Please sign in to comment.