diff --git a/teos/build.rs b/teos/build.rs index aa9031fd..8d0f131c 100644 --- a/teos/build.rs +++ b/teos/build.rs @@ -2,7 +2,7 @@ fn main() -> Result<(), Box> { tonic_build::configure() .extern_path(".common.teos.v2", "::teos-common::protos") .type_attribute(".", "#[derive(serde::Serialize, serde::Deserialize)]") - .field_attribute("user_id", "#[serde(with = \"hex::serde\")]") + .field_attribute("GetUserRequest.user_id", "#[serde(with = \"hex::serde\")]") .field_attribute("tower_id", "#[serde(with = \"hex::serde\")]") .field_attribute( "user_ids", diff --git a/teos/proto/teos/v2/appointment.proto b/teos/proto/teos/v2/appointment.proto index 67c66797..ecb873ee 100644 --- a/teos/proto/teos/v2/appointment.proto +++ b/teos/proto/teos/v2/appointment.proto @@ -4,9 +4,10 @@ package teos.v2; import "common/teos/v2/appointment.proto"; message GetAppointmentsRequest { - // Request the information of appointments with specific locator. + // Request the information of appointments with specific locator and user_id (optional) . bytes locator = 1; + optional bytes user_id = 2; } message GetAppointmentsResponse { diff --git a/teos/src/api/internal.rs b/teos/src/api/internal.rs index 38396bc0..fdfd13dd 100644 --- a/teos/src/api/internal.rs +++ b/teos/src/api/internal.rs @@ -2,7 +2,6 @@ use std::sync::{Arc, Condvar, Mutex}; use tonic::{Code, Request, Response, Status}; use triggered::Trigger; -use crate::extended_appointment::UUID; use crate::protos as msgs; use crate::protos::private_tower_services_server::PrivateTowerServices; use crate::protos::public_tower_services_server::PublicTowerServices; @@ -280,31 +279,44 @@ impl PrivateTowerServices for Arc { .map_or("an unknown address".to_owned(), |a| a.to_string()) ); - let mut matching_appointments = vec![]; - let locator = Locator::from_slice(&request.into_inner().locator).map_err(|_| { + let req_data = request.into_inner(); + let locator = Locator::from_slice(&req_data.locator).map_err(|_| { Status::new( Code::InvalidArgument, "The provided locator does not match the expected format (16-byte hexadecimal string)", ) })?; - for (_, appointment) in self + let user_id = req_data + .user_id + .map(|id| UserId::from_slice(&id)) + .transpose() + .map_err(|_| { + Status::new( + Code::InvalidArgument, + "The Provided user_id does not match expected format (33-byte hex string)", + ) + })?; + + let appointments: Vec = self .watcher - .get_watcher_appointments_with_locator(locator) + .get_watcher_appointments_with_locator(locator, user_id) + .into_values() + .map(|appointment| appointment.inner) + .collect(); + + let mut matching_appointments: Vec = appointments .into_iter() - { - matching_appointments.push(common_msgs::AppointmentData { + .map(|appointment| common_msgs::AppointmentData { appointment_data: Some( - common_msgs::appointment_data::AppointmentData::Appointment( - appointment.inner.into(), - ), + common_msgs::appointment_data::AppointmentData::Appointment(appointment.into()), ), }) - } + .collect(); for (_, tracker) in self .watcher - .get_responder_trackers_with_locator(locator) + .get_responder_trackers_with_locator(locator, user_id) .into_iter() { matching_appointments.push(common_msgs::AppointmentData { @@ -390,7 +402,6 @@ impl PrivateTowerServices for Arc { Some((info, locators)) => Ok(Response::new(msgs::GetUserResponse { available_slots: info.available_slots, subscription_expiry: info.subscription_expiry, - // TODO: Should make `get_appointments` queryable using the (user_id, locator) pair for consistency. appointments: locators .into_iter() .map(|locator| locator.to_vec()) @@ -511,7 +522,10 @@ mod tests_private_api { let locator = Locator::new(get_random_tx().txid()).to_vec(); let response = internal_api - .get_appointments(Request::new(msgs::GetAppointmentsRequest { locator })) + .get_appointments(Request::new(msgs::GetAppointmentsRequest { + locator, + user_id: None, + })) .await .unwrap() .into_inner(); @@ -548,6 +562,7 @@ mod tests_private_api { let response = internal_api .get_appointments(Request::new(msgs::GetAppointmentsRequest { locator: locator.to_vec(), + user_id: None, })) .await .unwrap() @@ -599,6 +614,7 @@ mod tests_private_api { let response = internal_api .get_appointments(Request::new(msgs::GetAppointmentsRequest { locator: locator.to_vec(), + user_id: None, })) .await .unwrap() @@ -747,7 +763,10 @@ mod tests_private_api { assert_eq!(response.available_slots, SLOTS - 1); assert_eq!(response.subscription_expiry, START_HEIGHT as u32 + DURATION); - assert_eq!(response.appointments, Vec::from([appointment.inner.locator.to_vec()])); + assert_eq!( + response.appointments, + Vec::from([appointment.inner.locator.to_vec()]) + ); } #[tokio::test] diff --git a/teos/src/cli.rs b/teos/src/cli.rs index 3ef1d9ff..2e710c2b 100644 --- a/teos/src/cli.rs +++ b/teos/src/cli.rs @@ -75,20 +75,29 @@ async fn main() { println!("{}", pretty_json(&appointments.into_inner()).unwrap()); } Command::GetAppointments(appointments_data) => { - match Locator::from_hex(&appointments_data.locator) { - Ok(locator) => { - match client - .get_appointments(Request::new(msgs::GetAppointmentsRequest { - locator: locator.to_vec(), - })) - .await - { - Ok(appointments) => { - println!("{}", pretty_json(&appointments.into_inner()).unwrap()) + match appointments_data + .user_id + .map(|id| UserId::from_str(id.as_str())) + .map(|id| id.map(|user_id| user_id.to_vec())) + .transpose() + { + Ok(user_id) => match Locator::from_hex(&appointments_data.locator) { + Ok(locator) => { + match client + .get_appointments(Request::new(msgs::GetAppointmentsRequest { + locator: locator.to_vec(), + user_id, + })) + .await + { + Ok(appointments) => { + println!("{}", pretty_json(&appointments.into_inner()).unwrap()) + } + Err(status) => handle_error(status.message()), } - Err(status) => handle_error(status.message()), } - } + Err(e) => handle_error(e), + }, Err(e) => handle_error(e), }; } diff --git a/teos/src/cli_config.rs b/teos/src/cli_config.rs index ba085b8d..8c23c467 100644 --- a/teos/src/cli_config.rs +++ b/teos/src/cli_config.rs @@ -31,6 +31,8 @@ pub struct GetUserData { pub struct GetAppointmentsData { /// The locator of the appointments (16-byte hexadecimal string). pub locator: String, + /// The user identifier (33-byte compressed public key). + pub user_id: Option, } /// Holds all the command line options and commands. diff --git a/teos/src/dbm.rs b/teos/src/dbm.rs index ca6dabff..9b8f6c22 100644 --- a/teos/src/dbm.rs +++ b/teos/src/dbm.rs @@ -330,23 +330,30 @@ impl DBM { /// matching this locator. If no locator is given, all the appointments in the database would be returned. pub(crate) fn load_appointments( &self, - locator: Option, + extra_params: Option<(Locator, Option)>, ) -> HashMap { let mut appointments = HashMap::new(); let mut sql = "SELECT a.UUID, a.locator, a.encrypted_blob, a.to_self_delay, a.user_signature, a.start_block, a.user_id FROM appointments as a LEFT JOIN trackers as t ON a.UUID=t.UUID WHERE t.UUID IS NULL".to_string(); - // If a locator was passed, filter based on it. - if locator.is_some() { - sql.push_str(" AND a.locator=(?)"); - } + + // If a locator and an optional user_id were passed, filter based on it. + if let Some((_, user_id)) = extra_params { + sql.push_str(" AND a.locator=(?1)"); + if user_id.is_some() { + sql.push_str(" AND a.user_id=(?2)"); + } + }; + let mut stmt = self.connection.prepare(&sql).unwrap(); - let mut rows = if let Some(locator) = locator { - stmt.query([locator.to_vec()]).unwrap() - } else { - stmt.query([]).unwrap() + let mut rows = match extra_params { + Some((locator, None)) => stmt.query([locator.to_vec()]).unwrap(), + Some((locator, Some(user_id))) => { + stmt.query([locator.to_vec(), user_id.to_vec()]).unwrap() + } + _ => stmt.query([]).unwrap(), }; while let Ok(Some(row)) = rows.next() { @@ -596,23 +603,30 @@ impl DBM { /// matching this locator. If no locator is given, all the trackers in the database would be returned. pub(crate) fn load_trackers( &self, - locator: Option, + extra_params: Option<(Locator, Option)>, ) -> HashMap { let mut trackers = HashMap::new(); let mut sql = "SELECT t.UUID, t.dispute_tx, t.penalty_tx, t.height, t.confirmed, a.user_id FROM trackers as t INNER JOIN appointments as a ON t.UUID=a.UUID" .to_string(); - // If a locator was passed, filter based on it. - if locator.is_some() { - sql.push_str(" WHERE a.locator=(?)"); + + // If a locator and an optional user_id were passed, filter based on it. + if let Some((_, user_id)) = extra_params { + sql.push_str(" AND a.locator=(?1)"); + if user_id.is_some() { + sql.push_str(" AND a.user_id=(?2)"); + } } + let mut stmt = self.connection.prepare(&sql).unwrap(); - let mut rows = if let Some(locator) = locator { - stmt.query([locator.to_vec()]).unwrap() - } else { - stmt.query([]).unwrap() + let mut rows = match extra_params { + Some((locator, None)) => stmt.query([locator.to_vec()]).unwrap(), + Some((locator, Some(user_id))) => { + stmt.query([locator.to_vec(), user_id.to_vec()]).unwrap() + } + _ => stmt.query([]).unwrap(), }; while let Ok(Some(row)) = rows.next() { @@ -1157,7 +1171,7 @@ mod tests { } // Validate that no other appointments than the ones with our locator are returned. - assert_eq!(dbm.load_appointments(Some(locator)), appointments); + assert_eq!(dbm.load_appointments(Some((locator, None))), appointments); // If an appointment has an associated tracker, it should not be loaded since it is seen // as a triggered appointment @@ -1175,7 +1189,52 @@ mod tests { dbm.store_tracker(uuid, &tracker).unwrap(); // We should get all the appointments matching our locator back except from the triggered one - assert_eq!(dbm.load_appointments(Some(locator)), appointments); + assert_eq!(dbm.load_appointments(Some((locator, None))), appointments); + } + + #[test] + fn test_load_appointments_with_locator_and_user_id() { + let dbm = DBM::in_memory().unwrap(); + + // create two appointment maps for two userId + let mut user_id1_appointments = HashMap::new(); + let mut user_id2_appointments = HashMap::new(); + + let dispute_tx = get_random_tx(); + let dispute_txid = dispute_tx.txid(); + let locator = Locator::new(dispute_txid); + + // generate two user ids + let user_id1 = get_random_user_id(); + let user_id2 = get_random_user_id(); + + let user = UserInfo::new(AVAILABLE_SLOTS, SUBSCRIPTION_START, SUBSCRIPTION_EXPIRY); + dbm.store_user(user_id1, &user).unwrap(); + dbm.store_user(user_id2, &user).unwrap(); + + let (uuid, appointment) = + generate_dummy_appointment_with_user(user_id1, Some(&dispute_txid)); + dbm.store_appointment(uuid, &appointment).unwrap(); + // Store the appointment for the first user_id made using our dispute tx. + user_id1_appointments.insert(uuid, appointment); + + let (uuid, appointment) = + generate_dummy_appointment_with_user(user_id2, Some(&dispute_txid)); + dbm.store_appointment(uuid, &appointment).unwrap(); + // Store the appointment for the second user_id made using our dispute tx. + user_id2_appointments.insert(uuid, appointment); + + // Validate that the first user_id appointment map matches the fetched appointments. + assert_eq!( + dbm.load_appointments(Some((locator, Some(user_id1))),), + user_id1_appointments + ); + + // Validate that the second user_id appointment map matches the fetched appointments. + assert_eq!( + dbm.load_appointments(Some((locator, Some(user_id2)))), + user_id2_appointments + ); } #[test] @@ -1550,7 +1609,57 @@ mod tests { dbm.store_tracker(uuid, &tracker).unwrap(); } - assert_eq!(dbm.load_trackers(Some(locator)), trackers); + assert_eq!(dbm.load_trackers(Some((locator, None))), trackers); + } + + #[test] + fn test_load_trackers_with_locator_and_user_id() { + let dbm = DBM::in_memory().unwrap(); + + // create two tracker maps for two userId + let mut user_id1_trackers = HashMap::new(); + let mut user_id2_trackers = HashMap::new(); + + let dispute_tx = get_random_tx(); + let dispute_txid = dispute_tx.txid(); + let locator = Locator::new(dispute_txid); + let status = ConfirmationStatus::InMempoolSince(42); + + // generate two user ids + let user_id1 = get_random_user_id(); + let user_id2 = get_random_user_id(); + + let user = UserInfo::new(AVAILABLE_SLOTS, SUBSCRIPTION_START, SUBSCRIPTION_EXPIRY); + dbm.store_user(user_id1, &user).unwrap(); + dbm.store_user(user_id2, &user).unwrap(); + + // Store the tracker for the first user_id created with a new appointment . + let (uuid, appointment) = + generate_dummy_appointment_with_user(user_id1, Some(&dispute_txid)); + let tracker = get_random_tracker(user_id1, status); + dbm.store_appointment(uuid, &appointment).unwrap(); + dbm.store_tracker(uuid, &tracker).unwrap(); + user_id1_trackers.insert(uuid, tracker); + + // Store the tracker for the second user_id created with a new appointment . + let (uuid, appointment) = + generate_dummy_appointment_with_user(user_id2, Some(&dispute_txid)); + let tracker = get_random_tracker(user_id2, status); + dbm.store_appointment(uuid, &appointment).unwrap(); + dbm.store_tracker(uuid, &tracker).unwrap(); + user_id2_trackers.insert(uuid, tracker); + + // Validate that the first user_id tracker map matches the fetched trackers. + assert_eq!( + dbm.load_trackers(Some((locator, Some(user_id1)))), + user_id1_trackers + ); + + // Validate that the second user_id tracker map matches the fetched trackers. + assert_eq!( + dbm.load_trackers(Some((locator, Some(user_id2)))), + user_id2_trackers + ); } #[test] diff --git a/teos/src/watcher.rs b/teos/src/watcher.rs index 90606fcf..3df048f0 100644 --- a/teos/src/watcher.rs +++ b/teos/src/watcher.rs @@ -423,12 +423,16 @@ impl Watcher { self.dbm.lock().unwrap().load_appointments(None) } - /// Gets all the appointments matching a specific locator from the [Watcher] (from the database). + /// Gets all the appointments matching a specific locator and an optional user id from the [Watcher] (from the database). pub(crate) fn get_watcher_appointments_with_locator( &self, locator: Locator, + user_id: Option, ) -> HashMap { - self.dbm.lock().unwrap().load_appointments(Some(locator)) + self.dbm + .lock() + .unwrap() + .load_appointments(Some((locator, user_id))) } /// Gets all the trackers stored in the [Responder] (from the database). @@ -436,12 +440,16 @@ impl Watcher { self.dbm.lock().unwrap().load_trackers(None) } - /// Gets all the trackers matching s specific locator from the [Responder] (from the database). + /// Gets all the trackers matching a specific locator and an optional user id from the [Responder] (from the database). pub(crate) fn get_responder_trackers_with_locator( &self, locator: Locator, + user_id: Option, ) -> HashMap { - self.dbm.lock().unwrap().load_trackers(Some(locator)) + self.dbm + .lock() + .unwrap() + .load_trackers(Some((locator, user_id))) } /// Gets the list of all registered user ids.