From 9dc68d1046c15aad48391d5e6cee6a9005a15a7c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sof=C3=ADa=20Celi?= Date: Fri, 15 Dec 2023 19:32:47 +0000 Subject: [PATCH] Some refactors from @fogti --- Cargo.toml | 2 +- src/api.rs | 7 ++++--- src/db.rs | 10 +++------- src/errors.rs | 24 ++++++++---------------- src/utils.rs | 12 +++++++++--- 5 files changed, 25 insertions(+), 30 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 169d42b..b55594f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "frodo-pir" version = "0.0.1" authors = ["Alex Davidson ", "gpestana ", "SofĂ­a Celi "] -edition = "2018" +edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/api.rs b/src/api.rs index 29062c6..9b8bc45 100644 --- a/src/api.rs +++ b/src/api.rs @@ -65,14 +65,15 @@ impl Shard { // Produces a serialized response (base64-encoded) to a serialized // client query: c' = b' * DB pub fn respond(&self, q: &Query) -> ResultBoxedError> { + let q = q.as_slice(); let resp = Response( (0..self.db.get_matrix_width_self()) - .map(|i| self.db.vec_mult(q.as_slice(), i)) + .map(|i| self.db.vec_mult(q, i)) .collect(), ); - let se = bincode::serialize(&resp); + let ser = bincode::serialize(&resp); - Ok(se?) + Ok(ser?) } /// Returns the database diff --git a/src/db.rs b/src/db.rs index 1e60c56..7234c55 100644 --- a/src/db.rs +++ b/src/db.rs @@ -53,11 +53,7 @@ impl Database { } pub fn vec_mult(&self, row: &[u32], col_idx: usize) -> u32 { - let mut acc = 0u32; - for (i, entry) in row.iter().enumerate() { - acc = acc.wrapping_add(entry.wrapping_mul(self.entries[col_idx][i])); - } - acc + vec_mult_u32_u32(row, &self.entries[col_idx]).unwrap() } pub fn write_to_file(&self, path: &str) -> ResultBoxedError<()> { @@ -73,7 +69,7 @@ impl Database { /// Returns the ith DB entry as a base64-encoded string pub fn get_db_entry(&self, i: usize) -> String { base64_from_u32_slice( - &swap_matrix_fmt(&self.entries)[i], + &get_matrix_second_at(&self.entries, i), self.plaintext_bits, self.elem_size, ) @@ -152,7 +148,7 @@ impl BaseParams { ) -> Vec> { let lhs = swap_matrix_fmt(&generate_lwe_matrix_from_seed(public_seed, dim, m)); - (0..Database::get_matrix_width(db.elem_size, db.plaintext_bits)) + (0..db.get_matrix_width_self()) .map(|i| { let mut col = Vec::with_capacity(m); for r in &lhs { diff --git a/src/errors.rs b/src/errors.rs index cc7b06b..740f9c0 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -13,10 +13,8 @@ pub struct ErrorUnexpectedInputSize { } impl ErrorUnexpectedInputSize { - pub fn new(msg: &str) -> Self { - Self { - details: msg.to_string(), - } + pub fn new(details: String) -> Self { + Self { details } } } @@ -35,7 +33,8 @@ impl Error for ErrorUnexpectedInputSize { // ErrorQueryParamsReused blocks attempts to reuse query parameters that // were used already. #[derive(Debug)] -pub struct ErrorQueryParamsReused {} +pub struct ErrorQueryParamsReused; + impl Display for ErrorQueryParamsReused { fn fmt(&self, f: &mut Formatter) -> FmtResult { write!( @@ -44,22 +43,15 @@ impl Display for ErrorQueryParamsReused { ) } } -impl Error for ErrorQueryParamsReused { - fn description(&self) -> &str { - "" - } -} +impl Error for ErrorQueryParamsReused {} // ErrorOverflownAdd blocks attempts to overflown addition. #[derive(Debug)] -pub struct ErrorOverflownAdd {} +pub struct ErrorOverflownAdd; + impl Display for ErrorOverflownAdd { fn fmt(&self, f: &mut Formatter) -> FmtResult { write!(f, "Attempted to overflow addition") } } -impl Error for ErrorOverflownAdd { - fn description(&self) -> &str { - "" - } -} +impl Error for ErrorOverflownAdd {} diff --git a/src/utils.rs b/src/utils.rs index b133177..df52874 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -44,6 +44,12 @@ pub mod matrices { swapped_row } + /// Takes a matrix and returns the [*][i] elements + /// equivalent to `swap_matrix_fmt(xys)[i]`, but much faster + pub fn get_matrix_second_at(matrix: &[Vec], secidx: usize) -> Vec { + matrix.iter().map(|y| y[secidx]).collect() + } + /// Generates an LWE matrix from a public seed /// This corresponds to the generation of `A` in the paper. pub fn generate_lwe_matrix_from_seed( @@ -68,7 +74,7 @@ pub mod matrices { if row.len() != col.len() { //panic!("row_len: {}, col_len: {}", row.len(), col.len()); - return Err(Box::new(ErrorUnexpectedInputSize::new(&format!( + return Err(Box::new(ErrorUnexpectedInputSize::new(format!( "row_len: {}, col_len:{},", row.len(), col.len(), @@ -182,7 +188,7 @@ pub mod format { let u32_len = std::mem::size_of::(); let byte_len = bytes.len(); if byte_len > u32_len { - return Err(ErrorUnexpectedInputSize::new(&format!( + return Err(ErrorUnexpectedInputSize::new(format!( "bytes are too long to parse as u16, length: {}", byte_len ))); @@ -199,7 +205,7 @@ pub mod format { let sized_vec: [u8; 4] = match bytes.try_into() { Ok(b) => b, Err(e) => { - return Err(ErrorUnexpectedInputSize::new(&format!( + return Err(ErrorUnexpectedInputSize::new(format!( "Unexpected vector size: {:?}", e, )))