From ded278fb57e0b204ef67e408aca59ff7767d2b72 Mon Sep 17 00:00:00 2001 From: samkim-crypto Date: Wed, 22 Nov 2023 06:50:24 +0900 Subject: [PATCH] [zk-token-sdk] Add range proof generation error types (#34065) * replace assert statements with `VectorLengthMismatch` error variant * add a condition to check that the bit lengths are in the correct range * replace assert statements with `GeneratorLengthMismatch` * remove unchecked arithmetic * add `InnerProductLengthMismatch` error * fix typo * add a clarifying comment on unwrap safety * fix typo --- .../batched_range_proof_u128.rs | 2 +- zk-token-sdk/src/range_proof/errors.rs | 10 +++ zk-token-sdk/src/range_proof/inner_product.rs | 77 ++++++++++++------- zk-token-sdk/src/range_proof/mod.rs | 40 +++++++--- zk-token-sdk/src/range_proof/util.rs | 18 ++--- 5 files changed, 101 insertions(+), 46 deletions(-) diff --git a/zk-token-sdk/src/instruction/batched_range_proof/batched_range_proof_u128.rs b/zk-token-sdk/src/instruction/batched_range_proof/batched_range_proof_u128.rs index 916245c2f31611..4036be9a94c940 100644 --- a/zk-token-sdk/src/instruction/batched_range_proof/batched_range_proof_u128.rs +++ b/zk-token-sdk/src/instruction/batched_range_proof/batched_range_proof_u128.rs @@ -47,7 +47,7 @@ impl BatchedRangeProofU128Data { .try_fold(0_usize, |acc, &x| acc.checked_add(x)) .ok_or(ProofGenerationError::IllegalAmountBitLength)?; - // `u64::BITS` is 128, which fits in a single byte and should not overflow to `usize` for + // `u128::BITS` is 128, which fits in a single byte and should not overflow to `usize` for // an overwhelming number of platforms. However, to be extra cautious, use `try_from` and // `unwrap` here. A simple case `u128::BITS as usize` can silently overflow. let expected_bit_length = usize::try_from(u128::BITS).unwrap(); diff --git a/zk-token-sdk/src/range_proof/errors.rs b/zk-token-sdk/src/range_proof/errors.rs index 25ae0ed8764692..f0c872f7aa3494 100644 --- a/zk-token-sdk/src/range_proof/errors.rs +++ b/zk-token-sdk/src/range_proof/errors.rs @@ -5,6 +5,14 @@ use {crate::errors::TranscriptError, thiserror::Error}; pub enum RangeProofGenerationError { #[error("maximum generator length exceeded")] MaximumGeneratorLengthExceeded, + #[error("amounts, commitments, openings, or bit lengths vectors have different lengths")] + VectorLengthMismatch, + #[error("invalid bit size")] + InvalidBitSize, + #[error("insufficient generators for the proof")] + GeneratorLengthMismatch, + #[error("inner product length mismatch")] + InnerProductLengthMismatch, } #[derive(Error, Clone, Debug, Eq, PartialEq)] @@ -25,6 +33,8 @@ pub enum RangeProofVerificationError { InvalidGeneratorsLength, #[error("maximum generator length exceeded")] MaximumGeneratorLengthExceeded, + #[error("commitments and bit lengths vectors have different lengths")] + VectorLengthMismatch, } #[derive(Error, Clone, Debug, Eq, PartialEq)] diff --git a/zk-token-sdk/src/range_proof/inner_product.rs b/zk-token-sdk/src/range_proof/inner_product.rs index baecef78d7b076..44e8e0674a3d6a 100644 --- a/zk-token-sdk/src/range_proof/inner_product.rs +++ b/zk-token-sdk/src/range_proof/inner_product.rs @@ -1,6 +1,9 @@ use { crate::{ - range_proof::{errors::RangeProofVerificationError, util}, + range_proof::{ + errors::{RangeProofGenerationError, RangeProofVerificationError}, + util, + }, transcript::TranscriptProtocol, }, core::iter, @@ -45,7 +48,7 @@ impl InnerProductProof { mut a_vec: Vec, mut b_vec: Vec, transcript: &mut Transcript, - ) -> Self { + ) -> Result { // Create slices G, H, a, b backed by their respective // vectors. This lets us reslice as we compress the lengths // of the vectors in the main loop below. @@ -57,15 +60,20 @@ impl InnerProductProof { let mut n = G.len(); // All of the input vectors must have the same length. - assert_eq!(G.len(), n); - assert_eq!(H.len(), n); - assert_eq!(a.len(), n); - assert_eq!(b.len(), n); - assert_eq!(G_factors.len(), n); - assert_eq!(H_factors.len(), n); + if G.len() != n + || H.len() != n + || a.len() != n + || b.len() != n + || G_factors.len() != n + || H_factors.len() != n + { + return Err(RangeProofGenerationError::GeneratorLengthMismatch); + } // All of the input vectors must have a length that is a power of two. - assert!(n.is_power_of_two()); + if !n.is_power_of_two() { + return Err(RangeProofGenerationError::InvalidBitSize); + } transcript.innerproduct_domain_separator(n as u64); @@ -76,18 +84,21 @@ impl InnerProductProof { // If it's the first iteration, unroll the Hprime = H*y_inv scalar mults // into multiscalar muls, for performance. if n != 1 { - n /= 2; + n = n.checked_div(2).unwrap(); let (a_L, a_R) = a.split_at_mut(n); let (b_L, b_R) = b.split_at_mut(n); let (G_L, G_R) = G.split_at_mut(n); let (H_L, H_R) = H.split_at_mut(n); - let c_L = util::inner_product(a_L, b_R); - let c_R = util::inner_product(a_R, b_L); + let c_L = util::inner_product(a_L, b_R) + .ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?; + let c_R = util::inner_product(a_R, b_L) + .ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?; let L = RistrettoPoint::multiscalar_mul( a_L.iter() - .zip(G_factors[n..2 * n].iter()) + // `n` was previously divided in half and therefore, it cannot overflow. + .zip(G_factors[n..n.checked_mul(2).unwrap()].iter()) .map(|(a_L_i, g)| a_L_i * g) .chain( b_R.iter() @@ -105,7 +116,7 @@ impl InnerProductProof { .map(|(a_R_i, g)| a_R_i * g) .chain( b_L.iter() - .zip(H_factors[n..2 * n].iter()) + .zip(H_factors[n..n.checked_mul(2).unwrap()].iter()) .map(|(b_L_i, h)| b_L_i * h), ) .chain(iter::once(c_R)), @@ -126,11 +137,17 @@ impl InnerProductProof { a_L[i] = a_L[i] * u + u_inv * a_R[i]; b_L[i] = b_L[i] * u_inv + u * b_R[i]; G_L[i] = RistrettoPoint::multiscalar_mul( - &[u_inv * G_factors[i], u * G_factors[n + i]], + &[ + u_inv * G_factors[i], + u * G_factors[n.checked_add(i).unwrap()], + ], &[G_L[i], G_R[i]], ); H_L[i] = RistrettoPoint::multiscalar_mul( - &[u * H_factors[i], u_inv * H_factors[n + i]], + &[ + u * H_factors[i], + u_inv * H_factors[n.checked_add(i).unwrap()], + ], &[H_L[i], H_R[i]], ) } @@ -142,14 +159,16 @@ impl InnerProductProof { } while n != 1 { - n /= 2; + n = n.checked_div(2).unwrap(); let (a_L, a_R) = a.split_at_mut(n); let (b_L, b_R) = b.split_at_mut(n); let (G_L, G_R) = G.split_at_mut(n); let (H_L, H_R) = H.split_at_mut(n); - let c_L = util::inner_product(a_L, b_R); - let c_R = util::inner_product(a_R, b_L); + let c_L = util::inner_product(a_L, b_R) + .ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?; + let c_R = util::inner_product(a_R, b_L) + .ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?; let L = RistrettoPoint::multiscalar_mul( a_L.iter().chain(b_R.iter()).chain(iter::once(&c_L)), @@ -185,12 +204,12 @@ impl InnerProductProof { H = H_L; } - InnerProductProof { + Ok(InnerProductProof { L_vec, R_vec, a: a[0], b: b[0], - } + }) } /// Computes three vectors of verification scalars \\([u\_{i}^{2}]\\), \\([u\_{i}^{-2}]\\) and @@ -210,7 +229,7 @@ impl InnerProductProof { // and this check prevents overflow in 1< = (0..n).map(|_| Scalar::random(&mut OsRng)).collect(); let b: Vec<_> = (0..n).map(|_| Scalar::random(&mut OsRng)).collect(); - let c = util::inner_product(&a, &b); + let c = util::inner_product(&a, &b).unwrap(); let G_factors: Vec = iter::repeat(Scalar::one()).take(n).collect(); @@ -451,7 +473,8 @@ mod tests { a.clone(), b.clone(), &mut prover_transcript, - ); + ) + .unwrap(); assert!(proof .verify( diff --git a/zk-token-sdk/src/range_proof/mod.rs b/zk-token-sdk/src/range_proof/mod.rs index 86754dbf61b73b..6658c350495473 100644 --- a/zk-token-sdk/src/range_proof/mod.rs +++ b/zk-token-sdk/src/range_proof/mod.rs @@ -75,12 +75,23 @@ impl RangeProof { ) -> Result { // amounts, bit-lengths, openings must be same length vectors let m = amounts.len(); - assert_eq!(bit_lengths.len(), m); - assert_eq!(openings.len(), m); + if bit_lengths.len() != m || openings.len() != m { + return Err(RangeProofGenerationError::VectorLengthMismatch); + } + + // each bit length must be greater than 0 for the proof to make sense + if bit_lengths + .iter() + .any(|bit_length| *bit_length == 0 || *bit_length > u64::BITS as usize) + { + return Err(RangeProofGenerationError::InvalidBitSize); + } // total vector dimension to compute the ultimate inner product proof for let nm: usize = bit_lengths.iter().sum(); - assert!(nm.is_power_of_two()); + if !nm.is_power_of_two() { + return Err(RangeProofGenerationError::VectorLengthMismatch); + } let bp_gens = BulletproofGens::new(nm) .map_err(|_| RangeProofGenerationError::MaximumGeneratorLengthExceeded)?; @@ -93,7 +104,10 @@ impl RangeProof { for (amount_i, n_i) in amounts.iter().zip(bit_lengths.iter()) { for j in 0..(*n_i) { let (G_ij, H_ij) = gens_iter.next().unwrap(); - let v_ij = Choice::from(((amount_i >> j) & 1) as u8); + + // `j` is guaranteed to be at most `u64::BITS` (a 6-bit number) and therefore, + // casting is lossless and right shift can be safely unwrapped + let v_ij = Choice::from((amount_i.checked_shr(j as u32).unwrap() & 1) as u8); let mut point = -H_ij; point.conditional_assign(G_ij, v_ij); A += point; @@ -138,7 +152,9 @@ impl RangeProof { let mut exp_2 = Scalar::one(); for j in 0..(*n_i) { - let a_L_j = Scalar::from((amount_i >> j) & 1); + // `j` is guaranteed to be at most `u64::BITS` (a 6-bit number) and therefore, + // casting is lossless and right shift can be safely unwrapped + let a_L_j = Scalar::from(amount_i.checked_shr(j as u32).unwrap() & 1); let a_R_j = a_L_j - Scalar::one(); l_poly.0[i] = a_L_j - z; @@ -148,13 +164,17 @@ impl RangeProof { exp_y *= y; exp_2 = exp_2 + exp_2; - i += 1; + + // `i` is capped by the sum of vectors in `bit_lengths` + i = i.checked_add(1).unwrap(); } exp_z *= z; } // define t(x) = = t_0 + t_1*x + t_2*x - let t_poly = l_poly.inner_product(&r_poly); + let t_poly = l_poly + .inner_product(&r_poly) + .ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?; // generate Pedersen commitment for the coefficients t_1 and t_2 let (T_1, t_1_blinding) = Pedersen::new(t_poly.1); @@ -216,7 +236,7 @@ impl RangeProof { l_vec, r_vec, transcript, - ); + )?; Ok(RangeProof { A, @@ -238,7 +258,9 @@ impl RangeProof { transcript: &mut Transcript, ) -> Result<(), RangeProofVerificationError> { // commitments and bit-lengths must be same length vectors - assert_eq!(comms.len(), bit_lengths.len()); + if comms.len() != bit_lengths.len() { + return Err(RangeProofVerificationError::VectorLengthMismatch); + } let m = bit_lengths.len(); let nm: usize = bit_lengths.iter().sum(); diff --git a/zk-token-sdk/src/range_proof/util.rs b/zk-token-sdk/src/range_proof/util.rs index c551abd8f3a15c..4a76543d475bc0 100644 --- a/zk-token-sdk/src/range_proof/util.rs +++ b/zk-token-sdk/src/range_proof/util.rs @@ -11,20 +11,20 @@ impl VecPoly1 { VecPoly1(vec![Scalar::zero(); n], vec![Scalar::zero(); n]) } - pub fn inner_product(&self, rhs: &VecPoly1) -> Poly2 { + pub fn inner_product(&self, rhs: &VecPoly1) -> Option { // Uses Karatsuba's method let l = self; let r = rhs; - let t0 = inner_product(&l.0, &r.0); - let t2 = inner_product(&l.1, &r.1); + let t0 = inner_product(&l.0, &r.0)?; + let t2 = inner_product(&l.1, &r.1)?; let l0_plus_l1 = add_vec(&l.0, &l.1); let r0_plus_r1 = add_vec(&r.0, &r.1); - let t1 = inner_product(&l0_plus_l1, &r0_plus_r1) - t0 - t2; + let t1 = inner_product(&l0_plus_l1, &r0_plus_r1)? - t0 - t2; - Poly2(t0, t1, t2) + Some(Poly2(t0, t1, t2)) } pub fn eval(&self, x: Scalar) -> Vec { @@ -98,16 +98,16 @@ pub fn read32(data: &[u8]) -> [u8; 32] { /// \\[ /// {\langle {\mathbf{a}}, {\mathbf{b}} \rangle} = \sum\_{i=0}^{n-1} a\_i \cdot b\_i. /// \\] -/// Panics if the lengths of \\(\mathbf{a}\\) and \\(\mathbf{b}\\) are not equal. -pub fn inner_product(a: &[Scalar], b: &[Scalar]) -> Scalar { +/// Errors if the lengths of \\(\mathbf{a}\\) and \\(\mathbf{b}\\) are not equal. +pub fn inner_product(a: &[Scalar], b: &[Scalar]) -> Option { let mut out = Scalar::zero(); if a.len() != b.len() { - panic!("inner_product(a,b): lengths of vectors do not match"); + return None; } for i in 0..a.len() { out += a[i] * b[i]; } - out + Some(out) } /// Takes the sum of all the powers of `x`, up to `n`