Skip to content

Commit

Permalink
Some refactors from @fogti
Browse files Browse the repository at this point in the history
  • Loading branch information
claucece committed Dec 15, 2023
1 parent 872de92 commit 9dc68d1
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 30 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "frodo-pir"
version = "0.0.1"
authors = ["Alex Davidson <coela@alxdavids.xyz>", "gpestana <g6pestana@gmail.com>", "Sofía Celi <cherenkov@riseup.net>"]
edition = "2018"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand Down
7 changes: 4 additions & 3 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,15 @@ impl Shard {
// Produces a serialized response (base64-encoded) to a serialized
// client query: c' = b' * DB
pub fn respond(&self, q: &Query) -> ResultBoxedError<Vec<u8>> {
let q = q.as_slice();
let resp = Response(
(0..self.db.get_matrix_width_self())
.map(|i| self.db.vec_mult(q.as_slice(), i))
.map(|i| self.db.vec_mult(q, i))
.collect(),
);
let se = bincode::serialize(&resp);
let ser = bincode::serialize(&resp);

Ok(se?)
Ok(ser?)
}

/// Returns the database
Expand Down
10 changes: 3 additions & 7 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,7 @@ impl Database {
}

pub fn vec_mult(&self, row: &[u32], col_idx: usize) -> u32 {
let mut acc = 0u32;
for (i, entry) in row.iter().enumerate() {
acc = acc.wrapping_add(entry.wrapping_mul(self.entries[col_idx][i]));
}
acc
vec_mult_u32_u32(row, &self.entries[col_idx]).unwrap()
}

pub fn write_to_file(&self, path: &str) -> ResultBoxedError<()> {
Expand All @@ -73,7 +69,7 @@ impl Database {
/// Returns the ith DB entry as a base64-encoded string
pub fn get_db_entry(&self, i: usize) -> String {
base64_from_u32_slice(
&swap_matrix_fmt(&self.entries)[i],
&get_matrix_second_at(&self.entries, i),
self.plaintext_bits,
self.elem_size,
)
Expand Down Expand Up @@ -152,7 +148,7 @@ impl BaseParams {
) -> Vec<Vec<u32>> {
let lhs =
swap_matrix_fmt(&generate_lwe_matrix_from_seed(public_seed, dim, m));
(0..Database::get_matrix_width(db.elem_size, db.plaintext_bits))
(0..db.get_matrix_width_self())
.map(|i| {
let mut col = Vec::with_capacity(m);
for r in &lhs {
Expand Down
24 changes: 8 additions & 16 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@ pub struct ErrorUnexpectedInputSize {
}

impl ErrorUnexpectedInputSize {
pub fn new(msg: &str) -> Self {
Self {
details: msg.to_string(),
}
pub fn new(details: String) -> Self {
Self { details }
}
}

Expand All @@ -35,7 +33,8 @@ impl Error for ErrorUnexpectedInputSize {
// ErrorQueryParamsReused blocks attempts to reuse query parameters that
// were used already.
#[derive(Debug)]
pub struct ErrorQueryParamsReused {}
pub struct ErrorQueryParamsReused;

impl Display for ErrorQueryParamsReused {
fn fmt(&self, f: &mut Formatter) -> FmtResult {
write!(
Expand All @@ -44,22 +43,15 @@ impl Display for ErrorQueryParamsReused {
)
}
}
impl Error for ErrorQueryParamsReused {
fn description(&self) -> &str {
""
}
}
impl Error for ErrorQueryParamsReused {}

// ErrorOverflownAdd blocks attempts to overflown addition.
#[derive(Debug)]
pub struct ErrorOverflownAdd {}
pub struct ErrorOverflownAdd;

impl Display for ErrorOverflownAdd {
fn fmt(&self, f: &mut Formatter) -> FmtResult {
write!(f, "Attempted to overflow addition")
}
}
impl Error for ErrorOverflownAdd {
fn description(&self) -> &str {
""
}
}
impl Error for ErrorOverflownAdd {}
12 changes: 9 additions & 3 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ pub mod matrices {
swapped_row
}

/// Takes a matrix and returns the [*][i] elements
/// equivalent to `swap_matrix_fmt(xys)[i]`, but much faster
pub fn get_matrix_second_at(matrix: &[Vec<u32>], secidx: usize) -> Vec<u32> {
matrix.iter().map(|y| y[secidx]).collect()
}

/// Generates an LWE matrix from a public seed
/// This corresponds to the generation of `A` in the paper.
pub fn generate_lwe_matrix_from_seed(
Expand All @@ -68,7 +74,7 @@ pub mod matrices {
if row.len() != col.len() {
//panic!("row_len: {}, col_len: {}", row.len(), col.len());

return Err(Box::new(ErrorUnexpectedInputSize::new(&format!(
return Err(Box::new(ErrorUnexpectedInputSize::new(format!(
"row_len: {}, col_len:{},",
row.len(),
col.len(),
Expand Down Expand Up @@ -182,7 +188,7 @@ pub mod format {
let u32_len = std::mem::size_of::<u32>();
let byte_len = bytes.len();
if byte_len > u32_len {
return Err(ErrorUnexpectedInputSize::new(&format!(
return Err(ErrorUnexpectedInputSize::new(format!(
"bytes are too long to parse as u16, length: {}",
byte_len
)));
Expand All @@ -199,7 +205,7 @@ pub mod format {
let sized_vec: [u8; 4] = match bytes.try_into() {
Ok(b) => b,
Err(e) => {
return Err(ErrorUnexpectedInputSize::new(&format!(
return Err(ErrorUnexpectedInputSize::new(format!(
"Unexpected vector size: {:?}",
e,
)))
Expand Down

0 comments on commit 9dc68d1

Please sign in to comment.