Skip to content

Commit

Permalink
Improve FLP testing and add tooling to "test-util"
Browse files Browse the repository at this point in the history
Implement a new struct, `flp::test_util::FlpTest`, for driving generic
tests for FLPs. This tool subsumes most of the functionality of
`flp_validity_test()`. The irregular length tests have been dropped such
that the `valid()` is no longer required to do bounds checking itself.

Remove `flp_validity_test()` and replace each use with `FlpTest`.
  • Loading branch information
cjpatton committed Jan 11, 2024
1 parent b489d33 commit 64f76fa
Show file tree
Hide file tree
Showing 3 changed files with 289 additions and 425 deletions.
224 changes: 219 additions & 5 deletions src/flp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,6 @@ pub enum FlpError {
/// An error happened during noising.
#[error("differential privacy error: {0}")]
DifferentialPrivacy(#[from] crate::dp::DpError),

/// Unit test error.
#[cfg(test)]
#[error("test failed: {0}")]
Test(String),
}

/// A type. Implementations of this trait specify how a particular kind of measurement is encoded
Expand Down Expand Up @@ -752,6 +747,225 @@ pub(crate) fn gadget_poly_len(gadget_degree: usize, wire_poly_len: usize) -> usi
gadget_degree * (wire_poly_len - 1) + 1
}

/// Utilities for testing FLPs.
#[cfg(feature = "test-util")]
pub mod test_utils {
use super::*;
use crate::field::{random_vector, FieldElement, FieldElementWithInteger};

/// Various tests for an FLP.
pub struct FlpTest<'a, T: Type> {
/// The FLP.
pub flp: &'a T,

/// Optional test name.
pub name: Option<&'a str>,

/// The input to use for the tests.
pub input: &'a [T::Field],

/// If set, the expected result of truncating the input.
pub expected_output: Option<&'a [T::Field]>,

/// Whether the input is expected to be valid.
pub expect_valid: bool,
}

impl<T: Type> FlpTest<'_, T> {
/// Construct a test and run it. Expect the input to be valid and compare the truncated
/// output to the provided value.
pub fn expect_valid<const SHARES: usize>(
flp: &T,
input: &[T::Field],
expected_output: &[T::Field],
) {
FlpTest {
flp,
name: None,
input,
expected_output: Some(expected_output),
expect_valid: true,
}
.run::<SHARES>()
}

/// Construct a test and run it. Expect the input to be invalid.
pub fn expect_invalid<const SHARES: usize>(flp: &T, input: &[T::Field]) {
FlpTest {
flp,
name: None,
input,
expect_valid: false,
expected_output: None,
}
.run::<SHARES>()
}

/// Construct a test and run it. Expect the input to be valid.
pub fn expect_valid_no_output<const SHARES: usize>(flp: &T, input: &[T::Field]) {
FlpTest {
flp,
name: None,
input,
expect_valid: true,
expected_output: None,
}
.run::<SHARES>()
}

/// Run the tests.
pub fn run<const SHARES: usize>(&self) {
let name = self.name.unwrap_or("unnamed test");

assert_eq!(
self.input.len(),
self.flp.input_len(),
"{name}: unexpected input length"
);

let mut gadgets = self.flp.gadget();
let joint_rand = random_vector(self.flp.joint_rand_len()).unwrap();
let prove_rand = random_vector(self.flp.prove_rand_len()).unwrap();
let query_rand = random_vector(self.flp.query_rand_len()).unwrap();
assert_eq!(
self.flp.query_rand_len(),
gadgets.len(),
"{name}: unexpected number of gadgets"
);
assert_eq!(
self.flp.joint_rand_len(),
joint_rand.len(),
"{name}: unexpected joint rand length"
);
assert_eq!(
self.flp.prove_rand_len(),
prove_rand.len(),
"{name}: unexpected prove rand length",
);
assert_eq!(
self.flp.query_rand_len(),
query_rand.len(),
"{name}: unexpected query rand length",
);

// Run the validity circuit.
let v = self
.flp
.valid(&mut gadgets, self.input, &joint_rand, 1)
.unwrap();
assert_eq!(
v == T::Field::zero(),
self.expect_valid,
"{name}: unexpected output of valid() returned {v}",
);

// Generate the proof.
let proof = self
.flp
.prove(self.input, &prove_rand, &joint_rand)
.unwrap();
assert_eq!(
proof.len(),
self.flp.proof_len(),
"{name}: unexpected proof length"
);

// Query the proof.
let verifier = self
.flp
.query(self.input, &proof, &query_rand, &joint_rand, 1)
.unwrap();
assert_eq!(
verifier.len(),
self.flp.verifier_len(),
"{name}: unexpected verifier length"
);

// Decide if the input is valid.
let res = self.flp.decide(&verifier).unwrap();
assert_eq!(res, self.expect_valid, "{name}: unexpected decision");

// Run distributed FLP.
let input_shares = split_vector::<_, SHARES>(self.input);
let proof_shares = split_vector::<_, SHARES>(&proof);
let verifier: Vec<T::Field> = (0..SHARES)
.map(|i| {
self.flp
.query(
&input_shares[i],
&proof_shares[i],
&query_rand,
&joint_rand,
SHARES,
)
.unwrap()
})
.reduce(|mut left, right| {
for (x, y) in left.iter_mut().zip(right.iter()) {
*x += *y;
}
left
})
.unwrap();

let res = self.flp.decide(&verifier).unwrap();
assert_eq!(
res, self.expect_valid,
"{name}: unexpected distributed decision"
);

// Try verifying various proof mutants.
for i in 0..std::cmp::min(proof.len(), 10) {
let mut mutated_proof = proof.clone();
mutated_proof[i] *= T::Field::from(
<T::Field as FieldElementWithInteger>::Integer::try_from(23).unwrap(),
);
let verifier = self
.flp
.query(self.input, &mutated_proof, &query_rand, &joint_rand, 1)
.unwrap();
assert!(
!self.flp.decide(&verifier).unwrap(),
"{name}: proof mutant {} deemed valid",
i
);
}

// Try truncating the input.
if let Some(ref expected_output) = self.expected_output {
let output = self.flp.truncate(self.input.to_vec()).unwrap();

assert_eq!(
output.len(),
self.flp.output_len(),
"{name}: unexpected output length of truncate()"
);

assert_eq!(
&output, expected_output,
"{name}: unexpected output of truncate()"
);
}
}
}

fn split_vector<F: FieldElement, const SHARES: usize>(inp: &[F]) -> [Vec<F>; SHARES] {
let mut outp = Vec::with_capacity(SHARES);
outp.push(inp.to_vec());

for _ in 1..SHARES {
let share: Vec<F> =
random_vector(inp.len()).expect("failed to generate a random vector");
for (x, y) in outp[0].iter_mut().zip(&share) {
*x -= *y;
}
outp.push(share);
}

outp.try_into().unwrap()
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading

0 comments on commit 64f76fa

Please sign in to comment.