From 639325bfbc194a66385c126f0c1c07f37efc7981 Mon Sep 17 00:00:00 2001 From: Aaron Feickert <66188213+AaronFeickert@users.noreply.github.com> Date: Wed, 23 Aug 2023 18:21:08 -0500 Subject: [PATCH] Optimize tests --- src/ristretto/bulletproofs_plus.rs | 158 +++++++++++++++-------------- 1 file changed, 80 insertions(+), 78 deletions(-) diff --git a/src/ristretto/bulletproofs_plus.rs b/src/ristretto/bulletproofs_plus.rs index 48e6f96b..fe302ff1 100644 --- a/src/ristretto/bulletproofs_plus.rs +++ b/src/ristretto/bulletproofs_plus.rs @@ -586,12 +586,13 @@ mod test { /// 'BulletproofsPlusService' initialization should only succeed when both bit length and aggregation size are a /// power of 2 and when bit_length <= 64 + // Initialize the range proof service, checking that it behaves correctly #[test] fn test_service_init() { for extension_degree in EXTENSION_DEGREE { let factory = ExtendedPedersenCommitmentFactory::new_with_extension_degree(extension_degree).unwrap(); - for bit_length in [1, 2, 4, 128] { - for aggregation_size in [1, 2, 64] { + for bit_length in [1, 2, 4, 5, 128] { + for aggregation_size in [1, 2, 3] { let bullet_proofs_plus_service = BulletproofsPlusService::init(bit_length, aggregation_size, factory.clone()); if bit_length.is_power_of_two() && aggregation_size.is_power_of_two() && bit_length <= 64 { @@ -604,32 +605,44 @@ mod test { } } - /// The 'BulletproofsPlusService' interface 'construct_proof' should only accept Pedersen generators of - /// 'ExtensionDegree::Zero' with 'aggregation_size == 1' and values proportional to the bit length + /// Test non-extended range proof service functionality + /// These proofs are not aggregated and do not use extension or batch verification + /// Using nontrivial aggregation or extension or an invalid value should fail #[test] - fn test_construct_verify_proof_no_recovery() { + fn test_range_proof_service() { let mut rng = rand::thread_rng(); + const BIT_LENGTH: usize = 4; + const AGGREGATION_FACTORS: [usize; 2] = [1, 2]; + for extension_degree in EXTENSION_DEGREE { let factory = ExtendedPedersenCommitmentFactory::new_with_extension_degree(extension_degree).unwrap(); - // bit length and aggregation size are chosen so that 'BulletProofsPlusService::init' will always succeed - for bit_length in [4, 64] { - for aggregation_size in [1, 16] { - let bulletproofs_plus_service = - BulletproofsPlusService::init(bit_length, aggregation_size, factory.clone()).unwrap(); - for value in [0, 1, u64::MAX] { - let key = RistrettoSecretKey(Scalar::random_not_zero(&mut rng)); - let proof = bulletproofs_plus_service.construct_proof(&key, value); - if extension_degree == CommitmentExtensionDegree::DefaultPedersen && - aggregation_size == 1 && - value >> (bit_length - 1) <= 1 - { - assert!(proof.is_ok()); - assert!( - bulletproofs_plus_service.verify(&proof.unwrap(), &factory.commit_value(&key, value)) - ); - } else { - assert!(proof.is_err()); - } + + for aggregation_factor in AGGREGATION_FACTORS { + let bulletproofs_plus_service = + BulletproofsPlusService::init(BIT_LENGTH, aggregation_factor, factory.clone()).unwrap(); + assert_eq!(bulletproofs_plus_service.range(), BIT_LENGTH); + + for value in [0, 1, u64::MAX] { + let key = RistrettoSecretKey(Scalar::random_not_zero(&mut rng)); + let proof = bulletproofs_plus_service.construct_proof(&key, value); + // This should only succeed with trivial aggregation and extension and a valid value + if extension_degree == CommitmentExtensionDegree::DefaultPedersen && + aggregation_factor == 1 && + value >> (BIT_LENGTH - 1) <= 1 + { + // The proof should succeed + let proof = proof.unwrap(); + + // Successful verification + assert!(bulletproofs_plus_service.verify(&proof, &factory.commit_value(&key, value))); + + // Failed verification (due to a bad mask) + assert!(!bulletproofs_plus_service.verify( + &proof, + &factory.commit_value(&RistrettoSecretKey(Scalar::random_not_zero(&mut rng)), value) + )); + } else { + assert!(proof.is_err()); } } } @@ -640,7 +653,7 @@ mod test { #[allow(clippy::too_many_lines)] fn test_construct_verify_extended_proof_with_recovery() { static BIT_LENGTH: [usize; 2] = [2, 64]; - static AGGREGATION_SIZE: [usize; 2] = [2, 64]; + static AGGREGATION_SIZE: [usize; 2] = [1, 2]; let mut rng = rand::thread_rng(); for extension_degree in [ CommitmentExtensionDegree::DefaultPedersen, @@ -781,68 +794,57 @@ mod test { } #[test] - fn test_simple_aggregated_extended_proof() { + // Test correctness of single aggregated proofs of varying extension degree + fn test_single_aggregated_extended_proof() { let mut rng = rand::thread_rng(); - let bit_length = 64; + + const BIT_LENGTH: usize = 4; + const AGGREGATION_FACTOR: usize = 2; for extension_degree in [ CommitmentExtensionDegree::DefaultPedersen, - CommitmentExtensionDegree::AddOneBasePoint, + CommitmentExtensionDegree::AddFiveBasePoints, ] { let factory = ExtendedPedersenCommitmentFactory::new_with_extension_degree(extension_degree).unwrap(); + let bulletproofs_plus_service = + BulletproofsPlusService::init(BIT_LENGTH, AGGREGATION_FACTOR, factory.clone()).unwrap(); + + let (value_min, value_max) = (0u64, (1u64 << BIT_LENGTH) - 1); + + let mut statements = Vec::with_capacity(AGGREGATION_FACTOR); + let mut extended_witnesses = Vec::with_capacity(AGGREGATION_FACTOR); + + // Set up the statements and witnesses + for _ in 0..AGGREGATION_FACTOR { + let value = rng.gen_range(value_min..value_max); + let minimum_value_promise = value / 3; + let secrets = vec![RistrettoSecretKey(Scalar::random_not_zero(&mut rng)); extension_degree as usize]; + let extended_mask = RistrettoExtendedMask::assign(extension_degree, secrets.clone()).unwrap(); + let commitment = factory.commit_value_extended(&secrets, value).unwrap(); + + statements.push(RistrettoStatement { + commitment: commitment.clone(), + minimum_value_promise, + }); + extended_witnesses.push(RistrettoExtendedWitness { + mask: extended_mask.clone(), + value, + minimum_value_promise, + }); + } - for aggregation_size in [2, 4] { - // 0. Batch data - let mut proofs = vec![]; - let mut statements_public = vec![]; - - #[allow(clippy::cast_possible_truncation)] - let (value_min, value_max) = (0u64, ((1u128 << bit_length) - 1) as u64); - - // 1. Prover's service - let bulletproofs_plus_service = - BulletproofsPlusService::init(bit_length, aggregation_size, factory.clone()).unwrap(); - - // 2. Create witness data - let mut statements = vec![]; - let mut extended_witnesses = vec![]; - for _m in 0..aggregation_size { - let value = rng.gen_range(value_min..value_max); - let minimum_value_promise = value / 3; - let secrets = - vec![RistrettoSecretKey(Scalar::random_not_zero(&mut rng)); extension_degree as usize]; - let extended_mask = RistrettoExtendedMask::assign(extension_degree, secrets.clone()).unwrap(); - let commitment = factory.commit_value_extended(&secrets, value).unwrap(); - statements.push(RistrettoStatement { - commitment: commitment.clone(), - minimum_value_promise, - }); - extended_witnesses.push(RistrettoExtendedWitness { - mask: extended_mask.clone(), - value, - minimum_value_promise, - }); - } - - // 3. Generate the statement - statements_public.push(RistrettoAggregatedPublicStatement::init(statements).unwrap()); + // Aggregate the statements + let aggregated_statement = RistrettoAggregatedPublicStatement::init(statements).unwrap(); - // 4. Create the aggregated proof - let seed_nonce = None; // This only has meaning for non-aggregated proofs - let proof = bulletproofs_plus_service.construct_extended_proof(extended_witnesses, seed_nonce); - proofs.push(proof.unwrap()); + // Generate an aggregate proof + let proof = bulletproofs_plus_service + .construct_extended_proof(extended_witnesses, None) + .unwrap(); - // 5. Verifier's service - let bulletproofs_plus_service = - BulletproofsPlusService::init(bit_length, aggregation_size, factory.clone()).unwrap(); - - // 7. Verify the aggregated proof as public entity - let proofs_ref = proofs.iter().collect::>(); - let statements_ref = statements_public.iter().collect::>(); - assert!(bulletproofs_plus_service - .verify_batch(proofs_ref, statements_ref) - .is_ok()); - } + // Verify the proof + assert!(bulletproofs_plus_service + .verify_batch(vec![&proof], vec![&aggregated_statement]) + .is_ok()); } }