From 64f76fad9cc938a74d5a5853eb842e42c60052d0 Mon Sep 17 00:00:00 2001 From: Christopher Patton Date: Wed, 10 Jan 2024 18:13:57 -0800 Subject: [PATCH] Improve FLP testing and add tooling to "test-util" Implement a new struct, `flp::test_util::FlpTest`, for driving generic tests for FLPs. This tool subsumes most of the functionality of `flp_validity_test()`. The irregular length tests have been dropped such that the `valid()` is no longer required to do bounds checking itself. Remove `flp_validity_test()` and replace each use with `FlpTest`. --- src/flp.rs | 224 +++++++++++++++++- src/flp/types.rs | 418 +++------------------------------ src/flp/types/fixedpoint_l2.rs | 72 +++--- 3 files changed, 289 insertions(+), 425 deletions(-) diff --git a/src/flp.rs b/src/flp.rs index f3abc28f9..f2dc85a86 100644 --- a/src/flp.rs +++ b/src/flp.rs @@ -111,11 +111,6 @@ pub enum FlpError { /// An error happened during noising. #[error("differential privacy error: {0}")] DifferentialPrivacy(#[from] crate::dp::DpError), - - /// Unit test error. - #[cfg(test)] - #[error("test failed: {0}")] - Test(String), } /// A type. Implementations of this trait specify how a particular kind of measurement is encoded @@ -752,6 +747,225 @@ pub(crate) fn gadget_poly_len(gadget_degree: usize, wire_poly_len: usize) -> usi gadget_degree * (wire_poly_len - 1) + 1 } +/// Utilities for testing FLPs. +#[cfg(feature = "test-util")] +pub mod test_utils { + use super::*; + use crate::field::{random_vector, FieldElement, FieldElementWithInteger}; + + /// Various tests for an FLP. + pub struct FlpTest<'a, T: Type> { + /// The FLP. + pub flp: &'a T, + + /// Optional test name. + pub name: Option<&'a str>, + + /// The input to use for the tests. + pub input: &'a [T::Field], + + /// If set, the expected result of truncating the input. + pub expected_output: Option<&'a [T::Field]>, + + /// Whether the input is expected to be valid. + pub expect_valid: bool, + } + + impl FlpTest<'_, T> { + /// Construct a test and run it. Expect the input to be valid and compare the truncated + /// output to the provided value. + pub fn expect_valid( + flp: &T, + input: &[T::Field], + expected_output: &[T::Field], + ) { + FlpTest { + flp, + name: None, + input, + expected_output: Some(expected_output), + expect_valid: true, + } + .run::() + } + + /// Construct a test and run it. Expect the input to be invalid. + pub fn expect_invalid(flp: &T, input: &[T::Field]) { + FlpTest { + flp, + name: None, + input, + expect_valid: false, + expected_output: None, + } + .run::() + } + + /// Construct a test and run it. Expect the input to be valid. + pub fn expect_valid_no_output(flp: &T, input: &[T::Field]) { + FlpTest { + flp, + name: None, + input, + expect_valid: true, + expected_output: None, + } + .run::() + } + + /// Run the tests. + pub fn run(&self) { + let name = self.name.unwrap_or("unnamed test"); + + assert_eq!( + self.input.len(), + self.flp.input_len(), + "{name}: unexpected input length" + ); + + let mut gadgets = self.flp.gadget(); + let joint_rand = random_vector(self.flp.joint_rand_len()).unwrap(); + let prove_rand = random_vector(self.flp.prove_rand_len()).unwrap(); + let query_rand = random_vector(self.flp.query_rand_len()).unwrap(); + assert_eq!( + self.flp.query_rand_len(), + gadgets.len(), + "{name}: unexpected number of gadgets" + ); + assert_eq!( + self.flp.joint_rand_len(), + joint_rand.len(), + "{name}: unexpected joint rand length" + ); + assert_eq!( + self.flp.prove_rand_len(), + prove_rand.len(), + "{name}: unexpected prove rand length", + ); + assert_eq!( + self.flp.query_rand_len(), + query_rand.len(), + "{name}: unexpected query rand length", + ); + + // Run the validity circuit. + let v = self + .flp + .valid(&mut gadgets, self.input, &joint_rand, 1) + .unwrap(); + assert_eq!( + v == T::Field::zero(), + self.expect_valid, + "{name}: unexpected output of valid() returned {v}", + ); + + // Generate the proof. + let proof = self + .flp + .prove(self.input, &prove_rand, &joint_rand) + .unwrap(); + assert_eq!( + proof.len(), + self.flp.proof_len(), + "{name}: unexpected proof length" + ); + + // Query the proof. + let verifier = self + .flp + .query(self.input, &proof, &query_rand, &joint_rand, 1) + .unwrap(); + assert_eq!( + verifier.len(), + self.flp.verifier_len(), + "{name}: unexpected verifier length" + ); + + // Decide if the input is valid. + let res = self.flp.decide(&verifier).unwrap(); + assert_eq!(res, self.expect_valid, "{name}: unexpected decision"); + + // Run distributed FLP. + let input_shares = split_vector::<_, SHARES>(self.input); + let proof_shares = split_vector::<_, SHARES>(&proof); + let verifier: Vec = (0..SHARES) + .map(|i| { + self.flp + .query( + &input_shares[i], + &proof_shares[i], + &query_rand, + &joint_rand, + SHARES, + ) + .unwrap() + }) + .reduce(|mut left, right| { + for (x, y) in left.iter_mut().zip(right.iter()) { + *x += *y; + } + left + }) + .unwrap(); + + let res = self.flp.decide(&verifier).unwrap(); + assert_eq!( + res, self.expect_valid, + "{name}: unexpected distributed decision" + ); + + // Try verifying various proof mutants. + for i in 0..std::cmp::min(proof.len(), 10) { + let mut mutated_proof = proof.clone(); + mutated_proof[i] *= T::Field::from( + ::Integer::try_from(23).unwrap(), + ); + let verifier = self + .flp + .query(self.input, &mutated_proof, &query_rand, &joint_rand, 1) + .unwrap(); + assert!( + !self.flp.decide(&verifier).unwrap(), + "{name}: proof mutant {} deemed valid", + i + ); + } + + // Try truncating the input. + if let Some(ref expected_output) = self.expected_output { + let output = self.flp.truncate(self.input.to_vec()).unwrap(); + + assert_eq!( + output.len(), + self.flp.output_len(), + "{name}: unexpected output length of truncate()" + ); + + assert_eq!( + &output, expected_output, + "{name}: unexpected output of truncate()" + ); + } + } + } + + fn split_vector(inp: &[F]) -> [Vec; SHARES] { + let mut outp = Vec::with_capacity(SHARES); + outp.push(inp.to_vec()); + + for _ in 1..SHARES { + let share: Vec = + random_vector(inp.len()).expect("failed to generate a random vector"); + for (x, y) in outp[0].iter_mut().zip(&share) { + *x -= *y; + } + outp.push(share); + } + + outp.try_into().unwrap() + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/flp/types.rs b/src/flp/types.rs index 922e0aef0..bca88a36c 100644 --- a/src/flp/types.rs +++ b/src/flp/types.rs @@ -775,7 +775,7 @@ mod tests { use crate::flp::gadgets::ParallelSum; #[cfg(feature = "multithreaded")] use crate::flp::gadgets::ParallelSumMultithreaded; - use crate::flp::types::test_utils::{flp_validity_test, ValidityTestCase}; + use crate::flp::test_utils::FlpTest; use std::cmp; #[test] @@ -798,39 +798,11 @@ mod tests { ); // Test FLP on valid input. - flp_validity_test( - &count, - &count.encode_measurement(&true).unwrap(), - &ValidityTestCase:: { - expect_valid: true, - expected_output: Some(vec![one]), - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &count, - &count.encode_measurement(&false).unwrap(), - &ValidityTestCase:: { - expect_valid: true, - expected_output: Some(vec![zero]), - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_valid::<3>(&count, &count.encode_measurement(&true).unwrap(), &[one]); + FlpTest::expect_valid::<3>(&count, &count.encode_measurement(&false).unwrap(), &[zero]); // Test FLP on invalid input. - flp_validity_test( - &count, - &[TestField::from(1337)], - &ValidityTestCase:: { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_invalid::<3>(&count, &[TestField::from(1337)]); // Try running the validity circuit on an input that's too short. count.valid(&mut count.gadget(), &[], &[], 1).unwrap_err(); @@ -857,72 +829,22 @@ mod tests { ); // Test FLP on valid input. - flp_validity_test( + FlpTest::expect_valid::<3>( &sum, &sum.encode_measurement(&1337).unwrap(), - &ValidityTestCase { - expect_valid: true, - expected_output: Some(vec![TestField::from(1337)]), - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &Sum::new(0).unwrap(), - &[], - &ValidityTestCase:: { - expect_valid: true, - expected_output: Some(vec![zero]), - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &Sum::new(2).unwrap(), - &[one, zero], - &ValidityTestCase { - expect_valid: true, - expected_output: Some(vec![one]), - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( + &[TestField::from(1337)], + ); + FlpTest::expect_valid::<3>(&Sum::new(0).unwrap(), &[], &[zero]); + FlpTest::expect_valid::<3>(&Sum::new(2).unwrap(), &[one, zero], &[one]); + FlpTest::expect_valid::<3>( &Sum::new(9).unwrap(), &[one, zero, one, one, zero, one, one, one, zero], - &ValidityTestCase:: { - expect_valid: true, - expected_output: Some(vec![TestField::from(237)]), - num_shares: 3, - }, - ) - .unwrap(); + &[TestField::from(237)], + ); // Test FLP on invalid input. - flp_validity_test( - &Sum::new(3).unwrap(), - &[one, nine, zero], - &ValidityTestCase:: { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &Sum::new(5).unwrap(), - &[zero, zero, zero, zero, nine], - &ValidityTestCase:: { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_invalid::<3>(&Sum::new(3).unwrap(), &[one, nine, zero]); + FlpTest::expect_invalid::<3>(&Sum::new(5).unwrap(), &[zero, zero, zero, zero, nine]); } #[test] @@ -992,83 +914,29 @@ mod tests { ); // Test valid inputs. - flp_validity_test( + FlpTest::expect_valid::<3>( &hist, &hist.encode_measurement(&0).unwrap(), - &ValidityTestCase:: { - expect_valid: true, - expected_output: Some(vec![one, zero, zero]), - num_shares: 3, - }, - ) - .unwrap(); + &[one, zero, zero], + ); - flp_validity_test( + FlpTest::expect_valid::<3>( &hist, &hist.encode_measurement(&1).unwrap(), - &ValidityTestCase:: { - expect_valid: true, - expected_output: Some(vec![zero, one, zero]), - num_shares: 3, - }, - ) - .unwrap(); + &[zero, one, zero], + ); - flp_validity_test( + FlpTest::expect_valid::<3>( &hist, &hist.encode_measurement(&2).unwrap(), - &ValidityTestCase:: { - expect_valid: true, - expected_output: Some(vec![zero, zero, one]), - num_shares: 3, - }, - ) - .unwrap(); + &[zero, zero, one], + ); // Test invalid inputs. - flp_validity_test( - &hist, - &[zero, zero, nine], - &ValidityTestCase:: { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &hist, - &[zero, one, one], - &ValidityTestCase:: { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &hist, - &[one, one, one], - &ValidityTestCase:: { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); - - flp_validity_test( - &hist, - &[zero, zero, zero], - &ValidityTestCase:: { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_invalid::<3>(&hist, &[zero, zero, nine]); + FlpTest::expect_invalid::<3>(&hist, &[zero, one, one]); + FlpTest::expect_invalid::<3>(&hist, &[one, one, one]); + FlpTest::expect_invalid::<3>(&hist, &[zero, zero, zero]); } #[test] @@ -1096,72 +964,38 @@ mod tests { for len in 1..10 { let chunk_length = cmp::max((len as f64).sqrt() as usize, 1); let sum_vec = f(1, len, chunk_length).unwrap(); - flp_validity_test( + FlpTest::expect_valid_no_output::<3>( &sum_vec, &sum_vec.encode_measurement(&vec![1; len]).unwrap(), - &ValidityTestCase:: { - expect_valid: true, - expected_output: Some(vec![one; len]), - num_shares: 3, - }, - ) - .unwrap(); + ); } let len = 100; let sum_vec = f(1, len, 10).unwrap(); - flp_validity_test( + FlpTest::expect_valid::<3>( &sum_vec, &sum_vec.encode_measurement(&vec![1; len]).unwrap(), - &ValidityTestCase:: { - expect_valid: true, - expected_output: Some(vec![one; len]), - num_shares: 3, - }, - ) - .unwrap(); + &vec![one; len], + ); let len = 23; let sum_vec = f(4, len, 4).unwrap(); - flp_validity_test( + FlpTest::expect_valid::<3>( &sum_vec, &sum_vec.encode_measurement(&vec![9; len]).unwrap(), - &ValidityTestCase:: { - expect_valid: true, - expected_output: Some(vec![nine; len]), - num_shares: 3, - }, - ) - .unwrap(); + &vec![nine; len], + ); // Test on invalid inputs. for len in 1..10 { let chunk_length = cmp::max((len as f64).sqrt() as usize, 1); let sum_vec = f(1, len, chunk_length).unwrap(); - flp_validity_test( - &sum_vec, - &vec![nine; len], - &ValidityTestCase:: { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_invalid::<3>(&sum_vec, &vec![nine; len]); } let len = 23; let sum_vec = f(2, len, 4).unwrap(); - flp_validity_test( - &sum_vec, - &vec![nine; 2 * len], - &ValidityTestCase:: { - expect_valid: false, - expected_output: None, - num_shares: 3, - }, - ) - .unwrap(); + FlpTest::expect_invalid::<3>(&sum_vec, &vec![nine; 2 * len]); // Round trip let want = vec![1; len]; @@ -1224,184 +1058,6 @@ mod tests { } } -#[cfg(test)] -mod test_utils { - use super::*; - use crate::field::{random_vector, split_vector, FieldElement}; - - pub(crate) struct ValidityTestCase { - pub(crate) expect_valid: bool, - pub(crate) expected_output: Option>, - // Number of shares to split input and proofs into in `flp_test`. - pub(crate) num_shares: usize, - } - - pub(crate) fn flp_validity_test( - typ: &T, - input: &[T::Field], - t: &ValidityTestCase, - ) -> Result<(), FlpError> { - let mut gadgets = typ.gadget(); - - if input.len() != typ.input_len() { - return Err(FlpError::Test(format!( - "unexpected input length: got {}; want {}", - input.len(), - typ.input_len() - ))); - } - - if typ.query_rand_len() != gadgets.len() { - return Err(FlpError::Test(format!( - "query rand length: got {}; want {}", - typ.query_rand_len(), - gadgets.len() - ))); - } - - let joint_rand = random_vector(typ.joint_rand_len()).unwrap(); - let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); - let query_rand = random_vector(typ.query_rand_len()).unwrap(); - - // Run the validity circuit. - let v = typ.valid(&mut gadgets, input, &joint_rand, 1)?; - if v != T::Field::zero() && t.expect_valid { - return Err(FlpError::Test(format!( - "expected valid input: valid() returned {v}" - ))); - } - if v == T::Field::zero() && !t.expect_valid { - return Err(FlpError::Test(format!( - "expected invalid input: valid() returned {v}" - ))); - } - - // Generate the proof. - let proof = typ.prove(input, &prove_rand, &joint_rand)?; - if proof.len() != typ.proof_len() { - return Err(FlpError::Test(format!( - "unexpected proof length: got {}; want {}", - proof.len(), - typ.proof_len() - ))); - } - - // Query the proof. - let verifier = typ.query(input, &proof, &query_rand, &joint_rand, 1)?; - if verifier.len() != typ.verifier_len() { - return Err(FlpError::Test(format!( - "unexpected verifier length: got {}; want {}", - verifier.len(), - typ.verifier_len() - ))); - } - - // Decide if the input is valid. - let res = typ.decide(&verifier)?; - if res != t.expect_valid { - return Err(FlpError::Test(format!( - "decision is {}; want {}", - res, t.expect_valid, - ))); - } - - // Run distributed FLP. - let input_shares: Vec> = split_vector(input, t.num_shares) - .unwrap() - .into_iter() - .collect(); - - let proof_shares: Vec> = split_vector(&proof, t.num_shares) - .unwrap() - .into_iter() - .collect(); - - let verifier: Vec = (0..t.num_shares) - .map(|i| { - typ.query( - &input_shares[i], - &proof_shares[i], - &query_rand, - &joint_rand, - t.num_shares, - ) - .unwrap() - }) - .reduce(|mut left, right| { - for (x, y) in left.iter_mut().zip(right.iter()) { - *x += *y; - } - left - }) - .unwrap(); - - let res = typ.decide(&verifier)?; - if res != t.expect_valid { - return Err(FlpError::Test(format!( - "distributed decision is {}; want {}", - res, t.expect_valid, - ))); - } - - // Try verifying various proof mutants. - for i in 0..proof.len() { - let mut mutated_proof = proof.clone(); - mutated_proof[i] += T::Field::one(); - let verifier = typ.query(input, &mutated_proof, &query_rand, &joint_rand, 1)?; - if typ.decide(&verifier)? { - return Err(FlpError::Test(format!( - "decision for proof mutant {} is {}; want {}", - i, true, false, - ))); - } - } - - // Try verifying a proof that is too short. - let mut mutated_proof = proof.clone(); - mutated_proof.truncate(gadgets[0].arity() - 1); - if typ - .query(input, &mutated_proof, &query_rand, &joint_rand, 1) - .is_ok() - { - return Err(FlpError::Test( - "query on short proof succeeded; want failure".to_string(), - )); - } - - // Try verifying a proof that is too long. - let mut mutated_proof = proof; - mutated_proof.extend_from_slice(&[T::Field::one(); 17]); - if typ - .query(input, &mutated_proof, &query_rand, &joint_rand, 1) - .is_ok() - { - return Err(FlpError::Test( - "query on long proof succeeded; want failure".to_string(), - )); - } - - if let Some(ref want) = t.expected_output { - let got = typ.truncate(input.to_vec())?; - - if got.len() != typ.output_len() { - return Err(FlpError::Test(format!( - "unexpected output length: got {}; want {}", - got.len(), - typ.output_len() - ))); - } - - if &got != want { - return Err(FlpError::Test(format!( - "unexpected output: got {got:?}; want {want:?}" - ))); - } - } - - Ok(()) - } -} - #[cfg(feature = "experimental")] #[cfg_attr(docsrs, doc(cfg(feature = "experimental")))] pub mod fixedpoint_l2; diff --git a/src/flp/types/fixedpoint_l2.rs b/src/flp/types/fixedpoint_l2.rs index 22828a428..8766c035b 100644 --- a/src/flp/types/fixedpoint_l2.rs +++ b/src/flp/types/fixedpoint_l2.rs @@ -685,7 +685,7 @@ mod tests { use crate::dp::{Rational, ZCdpBudget}; use crate::field::{random_vector, Field128, FieldElement}; use crate::flp::gadgets::ParallelSum; - use crate::flp::types::test_utils::{flp_validity_test, ValidityTestCase}; + use crate::flp::test_utils::FlpTest; use crate::vdaf::xof::SeedStreamTurboShake128; use fixed::types::extra::{U127, U14, U63}; use fixed::{FixedI128, FixedI16, FixedI64}; @@ -792,52 +792,46 @@ mod tests { let mut input: Vec = vsum.encode_measurement(&fp_vec).unwrap(); assert_eq!(input[0], Field128::zero()); input[0] = one; // it was zero - flp_validity_test( - &vsum, - &input, - &ValidityTestCase:: { - expect_valid: false, - expected_output: Some(vec![ - Field128::from(enc_vec[0] + 1), // = enc(0.25) + 2^0 - Field128::from(enc_vec[1]), - Field128::from(enc_vec[2]), - ]), - num_shares: 3, - }, - ) - .unwrap(); + FlpTest { + name: None, + flp: &vsum, + input: &input, + expected_output: Some(&[ + Field128::from(enc_vec[0] + 1), // = enc(0.25) + 2^0 + Field128::from(enc_vec[1]), + Field128::from(enc_vec[2]), + ]), + expect_valid: false, + } + .run::<3>(); // encoding contains entries that are not zero or one let mut input2: Vec = vsum.encode_measurement(&fp_vec).unwrap(); input2[0] = one + one; - flp_validity_test( - &vsum, - &input2, - &ValidityTestCase:: { - expect_valid: false, - expected_output: Some(vec![ - Field128::from(enc_vec[0] + 2), // = enc(0.25) + 2*2^0 - Field128::from(enc_vec[1]), - Field128::from(enc_vec[2]), - ]), - num_shares: 3, - }, - ) - .unwrap(); + FlpTest { + name: None, + flp: &vsum, + input: &input2, + expected_output: Some(&[ + Field128::from(enc_vec[0] + 2), // = enc(0.25) + 2*2^0 + Field128::from(enc_vec[1]), + Field128::from(enc_vec[2]), + ]), + expect_valid: false, + } + .run::<3>(); // norm is too big // 2^n - 1, the field element encoded by the all-1 vector let one_enc = Field128::from(((2_u128) << (n - 1)) - 1); - flp_validity_test( - &vsum, - &vec![one; 3 * n + 2 * n - 2], // all vector entries and the norm are all-1-vectors - &ValidityTestCase:: { - expect_valid: false, - expected_output: Some(vec![one_enc; 3]), - num_shares: 3, - }, - ) - .unwrap(); + FlpTest { + name: None, + flp: &vsum, + input: &vec![one; 3 * n + 2 * n - 2], // all vector entries and the norm are all-1-vectors + expected_output: Some(&[one_enc; 3]), + expect_valid: false, + } + .run::<3>(); // invalid submission length, should be 3n + (2*n - 2) for a // 3-element n-bit vector. 3*n bits for 3 entries, (2*n-2) for norm.