Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve FLP testing and add tooling to "test-util" #900

Merged
merged 2 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
//! [`FftFriendlyFieldElement`], and have an associated element called the "generator" that
//! generates a multiplicative subgroup of order `2^n` for some `n`.

#[cfg(feature = "crypto-dependencies")]
use crate::prng::{Prng, PrngError};
use crate::{
codec::{CodecError, Decode, Encode},
Expand Down Expand Up @@ -828,7 +827,7 @@ pub(crate) fn merge_vector<F: FieldElement>(
}

/// Outputs an additive secret sharing of the input.
#[cfg(all(feature = "crypto-dependencies", test))]
#[cfg(test)]
pub(crate) fn split_vector<F: FieldElement>(
inp: &[F],
num_shares: usize,
Expand All @@ -852,8 +851,6 @@ pub(crate) fn split_vector<F: FieldElement>(
}

/// Generate a vector of uniformly distributed random field elements.
#[cfg(feature = "crypto-dependencies")]
#[cfg_attr(docsrs, doc(cfg(feature = "crypto-dependencies")))]
pub fn random_vector<F: FieldElement>(len: usize) -> Result<Vec<F>, PrngError> {
Ok(Prng::new()?.take(len).collect())
}
Expand Down
224 changes: 219 additions & 5 deletions src/flp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,6 @@ pub enum FlpError {
/// An error happened during noising.
#[error("differential privacy error: {0}")]
DifferentialPrivacy(#[from] crate::dp::DpError),

/// Unit test error.
#[cfg(test)]
#[error("test failed: {0}")]
Test(String),
cjpatton marked this conversation as resolved.
Show resolved Hide resolved
}

/// A type. Implementations of this trait specify how a particular kind of measurement is encoded
Expand Down Expand Up @@ -752,6 +747,225 @@ pub(crate) fn gadget_poly_len(gadget_degree: usize, wire_poly_len: usize) -> usi
gadget_degree * (wire_poly_len - 1) + 1
}

/// Utilities for testing FLPs.
#[cfg(feature = "test-util")]
pub mod test_utils {
tgeoghegan marked this conversation as resolved.
Show resolved Hide resolved
use super::*;
use crate::field::{random_vector, FieldElement, FieldElementWithInteger};

/// Various tests for an FLP.
pub struct FlpTest<'a, T: Type> {
/// The FLP.
pub flp: &'a T,

/// Optional test name.
pub name: Option<&'a str>,

/// The input to use for the tests.
pub input: &'a [T::Field],

/// If set, the expected result of truncating the input.
pub expected_output: Option<&'a [T::Field]>,

/// Whether the input is expected to be valid.
pub expect_valid: bool,
}

impl<T: Type> FlpTest<'_, T> {
/// Construct a test and run it. Expect the input to be valid and compare the truncated
/// output to the provided value.
pub fn expect_valid<const SHARES: usize>(
flp: &T,
input: &[T::Field],
expected_output: &[T::Field],
) {
FlpTest {
flp,
name: None,
input,
expected_output: Some(expected_output),
expect_valid: true,
}
.run::<SHARES>()
}

/// Construct a test and run it. Expect the input to be invalid.
pub fn expect_invalid<const SHARES: usize>(flp: &T, input: &[T::Field]) {
FlpTest {
flp,
name: None,
input,
expect_valid: false,
expected_output: None,
}
.run::<SHARES>()
}

/// Construct a test and run it. Expect the input to be valid.
pub fn expect_valid_no_output<const SHARES: usize>(flp: &T, input: &[T::Field]) {
FlpTest {
flp,
name: None,
input,
expect_valid: true,
expected_output: None,
}
.run::<SHARES>()
}

/// Run the tests.
pub fn run<const SHARES: usize>(&self) {
let name = self.name.unwrap_or("unnamed test");

assert_eq!(
self.input.len(),
self.flp.input_len(),
"{name}: unexpected input length"
);

let mut gadgets = self.flp.gadget();
let joint_rand = random_vector(self.flp.joint_rand_len()).unwrap();
let prove_rand = random_vector(self.flp.prove_rand_len()).unwrap();
let query_rand = random_vector(self.flp.query_rand_len()).unwrap();
assert_eq!(
self.flp.query_rand_len(),
gadgets.len(),
"{name}: unexpected number of gadgets"
);
assert_eq!(
self.flp.joint_rand_len(),
joint_rand.len(),
"{name}: unexpected joint rand length"
);
assert_eq!(
self.flp.prove_rand_len(),
prove_rand.len(),
"{name}: unexpected prove rand length",
);
assert_eq!(
self.flp.query_rand_len(),
query_rand.len(),
"{name}: unexpected query rand length",
);

// Run the validity circuit.
let v = self
.flp
.valid(&mut gadgets, self.input, &joint_rand, 1)
.unwrap();
assert_eq!(
v == T::Field::zero(),
self.expect_valid,
"{name}: unexpected output of valid() returned {v}",
);

// Generate the proof.
let proof = self
.flp
.prove(self.input, &prove_rand, &joint_rand)
.unwrap();
assert_eq!(
proof.len(),
self.flp.proof_len(),
"{name}: unexpected proof length"
);

// Query the proof.
let verifier = self
.flp
.query(self.input, &proof, &query_rand, &joint_rand, 1)
.unwrap();
assert_eq!(
verifier.len(),
self.flp.verifier_len(),
"{name}: unexpected verifier length"
);

// Decide if the input is valid.
let res = self.flp.decide(&verifier).unwrap();
assert_eq!(res, self.expect_valid, "{name}: unexpected decision");

// Run distributed FLP.
let input_shares = split_vector::<_, SHARES>(self.input);
let proof_shares = split_vector::<_, SHARES>(&proof);
let verifier: Vec<T::Field> = (0..SHARES)
.map(|i| {
self.flp
.query(
&input_shares[i],
&proof_shares[i],
&query_rand,
&joint_rand,
SHARES,
)
.unwrap()
})
.reduce(|mut left, right| {
for (x, y) in left.iter_mut().zip(right.iter()) {
*x += *y;
}
left
})
.unwrap();

let res = self.flp.decide(&verifier).unwrap();
assert_eq!(
res, self.expect_valid,
"{name}: unexpected distributed decision"
);

// Try verifying various proof mutants.
for i in 0..std::cmp::min(proof.len(), 10) {
let mut mutated_proof = proof.clone();
mutated_proof[i] *= T::Field::from(
<T::Field as FieldElementWithInteger>::Integer::try_from(23).unwrap(),
);
let verifier = self
.flp
.query(self.input, &mutated_proof, &query_rand, &joint_rand, 1)
.unwrap();
assert!(
!self.flp.decide(&verifier).unwrap(),
"{name}: proof mutant {} deemed valid",
i
);
}

// Try truncating the input.
if let Some(ref expected_output) = self.expected_output {
let output = self.flp.truncate(self.input.to_vec()).unwrap();

assert_eq!(
output.len(),
self.flp.output_len(),
"{name}: unexpected output length of truncate()"
);

assert_eq!(
&output, expected_output,
"{name}: unexpected output of truncate()"
);
}
}
}

fn split_vector<F: FieldElement, const SHARES: usize>(inp: &[F]) -> [Vec<F>; SHARES] {
let mut outp = Vec::with_capacity(SHARES);
outp.push(inp.to_vec());

for _ in 1..SHARES {
let share: Vec<F> =
random_vector(inp.len()).expect("failed to generate a random vector");
for (x, y) in outp[0].iter_mut().zip(&share) {
*x -= *y;
}
outp.push(share);
}

outp.try_into().unwrap()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading