Skip to content

Commit

Permalink
[zk-token-sdk] Add range proof generation error types (solana-labs#34065
Browse files Browse the repository at this point in the history
)

* replace assert statements with `VectorLengthMismatch` error variant

* add a condition to check that the bit lengths are in the correct range

* replace assert statements with `GeneratorLengthMismatch`

* remove unchecked arithmetic

* add `InnerProductLengthMismatch` error

* fix typo

* add a clarifying comment on unwrap safety

* fix typo
  • Loading branch information
samkim-crypto authored Nov 21, 2023
1 parent ecc067f commit ded278f
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl BatchedRangeProofU128Data {
.try_fold(0_usize, |acc, &x| acc.checked_add(x))
.ok_or(ProofGenerationError::IllegalAmountBitLength)?;

// `u64::BITS` is 128, which fits in a single byte and should not overflow to `usize` for
// `u128::BITS` is 128, which fits in a single byte and should not overflow to `usize` for
// an overwhelming number of platforms. However, to be extra cautious, use `try_from` and
// `unwrap` here. A simple case `u128::BITS as usize` can silently overflow.
let expected_bit_length = usize::try_from(u128::BITS).unwrap();
Expand Down
10 changes: 10 additions & 0 deletions zk-token-sdk/src/range_proof/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ use {crate::errors::TranscriptError, thiserror::Error};
pub enum RangeProofGenerationError {
#[error("maximum generator length exceeded")]
MaximumGeneratorLengthExceeded,
#[error("amounts, commitments, openings, or bit lengths vectors have different lengths")]
VectorLengthMismatch,
#[error("invalid bit size")]
InvalidBitSize,
#[error("insufficient generators for the proof")]
GeneratorLengthMismatch,
#[error("inner product length mismatch")]
InnerProductLengthMismatch,
}

#[derive(Error, Clone, Debug, Eq, PartialEq)]
Expand All @@ -25,6 +33,8 @@ pub enum RangeProofVerificationError {
InvalidGeneratorsLength,
#[error("maximum generator length exceeded")]
MaximumGeneratorLengthExceeded,
#[error("commitments and bit lengths vectors have different lengths")]
VectorLengthMismatch,
}

#[derive(Error, Clone, Debug, Eq, PartialEq)]
Expand Down
77 changes: 50 additions & 27 deletions zk-token-sdk/src/range_proof/inner_product.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use {
crate::{
range_proof::{errors::RangeProofVerificationError, util},
range_proof::{
errors::{RangeProofGenerationError, RangeProofVerificationError},
util,
},
transcript::TranscriptProtocol,
},
core::iter,
Expand Down Expand Up @@ -45,7 +48,7 @@ impl InnerProductProof {
mut a_vec: Vec<Scalar>,
mut b_vec: Vec<Scalar>,
transcript: &mut Transcript,
) -> Self {
) -> Result<Self, RangeProofGenerationError> {
// Create slices G, H, a, b backed by their respective
// vectors. This lets us reslice as we compress the lengths
// of the vectors in the main loop below.
Expand All @@ -57,15 +60,20 @@ impl InnerProductProof {
let mut n = G.len();

// All of the input vectors must have the same length.
assert_eq!(G.len(), n);
assert_eq!(H.len(), n);
assert_eq!(a.len(), n);
assert_eq!(b.len(), n);
assert_eq!(G_factors.len(), n);
assert_eq!(H_factors.len(), n);
if G.len() != n
|| H.len() != n
|| a.len() != n
|| b.len() != n
|| G_factors.len() != n
|| H_factors.len() != n
{
return Err(RangeProofGenerationError::GeneratorLengthMismatch);
}

// All of the input vectors must have a length that is a power of two.
assert!(n.is_power_of_two());
if !n.is_power_of_two() {
return Err(RangeProofGenerationError::InvalidBitSize);
}

transcript.innerproduct_domain_separator(n as u64);

Expand All @@ -76,18 +84,21 @@ impl InnerProductProof {
// If it's the first iteration, unroll the Hprime = H*y_inv scalar mults
// into multiscalar muls, for performance.
if n != 1 {
n /= 2;
n = n.checked_div(2).unwrap();
let (a_L, a_R) = a.split_at_mut(n);
let (b_L, b_R) = b.split_at_mut(n);
let (G_L, G_R) = G.split_at_mut(n);
let (H_L, H_R) = H.split_at_mut(n);

let c_L = util::inner_product(a_L, b_R);
let c_R = util::inner_product(a_R, b_L);
let c_L = util::inner_product(a_L, b_R)
.ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?;
let c_R = util::inner_product(a_R, b_L)
.ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?;

let L = RistrettoPoint::multiscalar_mul(
a_L.iter()
.zip(G_factors[n..2 * n].iter())
// `n` was previously divided in half and therefore, it cannot overflow.
.zip(G_factors[n..n.checked_mul(2).unwrap()].iter())
.map(|(a_L_i, g)| a_L_i * g)
.chain(
b_R.iter()
Expand All @@ -105,7 +116,7 @@ impl InnerProductProof {
.map(|(a_R_i, g)| a_R_i * g)
.chain(
b_L.iter()
.zip(H_factors[n..2 * n].iter())
.zip(H_factors[n..n.checked_mul(2).unwrap()].iter())
.map(|(b_L_i, h)| b_L_i * h),
)
.chain(iter::once(c_R)),
Expand All @@ -126,11 +137,17 @@ impl InnerProductProof {
a_L[i] = a_L[i] * u + u_inv * a_R[i];
b_L[i] = b_L[i] * u_inv + u * b_R[i];
G_L[i] = RistrettoPoint::multiscalar_mul(
&[u_inv * G_factors[i], u * G_factors[n + i]],
&[
u_inv * G_factors[i],
u * G_factors[n.checked_add(i).unwrap()],
],
&[G_L[i], G_R[i]],
);
H_L[i] = RistrettoPoint::multiscalar_mul(
&[u * H_factors[i], u_inv * H_factors[n + i]],
&[
u * H_factors[i],
u_inv * H_factors[n.checked_add(i).unwrap()],
],
&[H_L[i], H_R[i]],
)
}
Expand All @@ -142,14 +159,16 @@ impl InnerProductProof {
}

while n != 1 {
n /= 2;
n = n.checked_div(2).unwrap();
let (a_L, a_R) = a.split_at_mut(n);
let (b_L, b_R) = b.split_at_mut(n);
let (G_L, G_R) = G.split_at_mut(n);
let (H_L, H_R) = H.split_at_mut(n);

let c_L = util::inner_product(a_L, b_R);
let c_R = util::inner_product(a_R, b_L);
let c_L = util::inner_product(a_L, b_R)
.ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?;
let c_R = util::inner_product(a_R, b_L)
.ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?;

let L = RistrettoPoint::multiscalar_mul(
a_L.iter().chain(b_R.iter()).chain(iter::once(&c_L)),
Expand Down Expand Up @@ -185,12 +204,12 @@ impl InnerProductProof {
H = H_L;
}

InnerProductProof {
Ok(InnerProductProof {
L_vec,
R_vec,
a: a[0],
b: b[0],
}
})
}

/// Computes three vectors of verification scalars \\([u\_{i}^{2}]\\), \\([u\_{i}^{-2}]\\) and
Expand All @@ -210,7 +229,7 @@ impl InnerProductProof {
// and this check prevents overflow in 1<<lg_n below.
return Err(RangeProofVerificationError::InvalidBitSize);
}
if n != (1 << lg_n) {
if n != (1_usize.checked_shl(lg_n as u32).unwrap()) {
return Err(RangeProofVerificationError::InvalidBitSize);
}

Expand Down Expand Up @@ -244,11 +263,14 @@ impl InnerProductProof {
let mut s = Vec::with_capacity(n);
s.push(allinv);
for i in 1..n {
let lg_i = (32 - 1 - (i as u32).leading_zeros()) as usize;
let k = 1 << lg_i;
let lg_i = 31_u32.checked_sub((i as u32).leading_zeros()).unwrap() as usize;
let k = 1_usize.checked_shl(lg_i as u32).unwrap();
// The challenges are stored in "creation order" as [u_k,...,u_1],
// so u_{lg(i)+1} = is indexed by (lg_n-1) - lg_i
let u_lg_i_sq = challenges_sq[(lg_n - 1) - lg_i];
let u_lg_i_sq = challenges_sq[lg_n
.checked_sub(1)
.and_then(|x| x.checked_sub(lg_i))
.unwrap()];
s.push(s[i - k] * u_lg_i_sq);
}

Expand Down Expand Up @@ -418,7 +440,7 @@ mod tests {

let a: Vec<_> = (0..n).map(|_| Scalar::random(&mut OsRng)).collect();
let b: Vec<_> = (0..n).map(|_| Scalar::random(&mut OsRng)).collect();
let c = util::inner_product(&a, &b);
let c = util::inner_product(&a, &b).unwrap();

let G_factors: Vec<Scalar> = iter::repeat(Scalar::one()).take(n).collect();

Expand Down Expand Up @@ -451,7 +473,8 @@ mod tests {
a.clone(),
b.clone(),
&mut prover_transcript,
);
)
.unwrap();

assert!(proof
.verify(
Expand Down
40 changes: 31 additions & 9 deletions zk-token-sdk/src/range_proof/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,23 @@ impl RangeProof {
) -> Result<Self, RangeProofGenerationError> {
// amounts, bit-lengths, openings must be same length vectors
let m = amounts.len();
assert_eq!(bit_lengths.len(), m);
assert_eq!(openings.len(), m);
if bit_lengths.len() != m || openings.len() != m {
return Err(RangeProofGenerationError::VectorLengthMismatch);
}

// each bit length must be greater than 0 for the proof to make sense
if bit_lengths
.iter()
.any(|bit_length| *bit_length == 0 || *bit_length > u64::BITS as usize)
{
return Err(RangeProofGenerationError::InvalidBitSize);
}

// total vector dimension to compute the ultimate inner product proof for
let nm: usize = bit_lengths.iter().sum();
assert!(nm.is_power_of_two());
if !nm.is_power_of_two() {
return Err(RangeProofGenerationError::VectorLengthMismatch);
}

let bp_gens = BulletproofGens::new(nm)
.map_err(|_| RangeProofGenerationError::MaximumGeneratorLengthExceeded)?;
Expand All @@ -93,7 +104,10 @@ impl RangeProof {
for (amount_i, n_i) in amounts.iter().zip(bit_lengths.iter()) {
for j in 0..(*n_i) {
let (G_ij, H_ij) = gens_iter.next().unwrap();
let v_ij = Choice::from(((amount_i >> j) & 1) as u8);

// `j` is guaranteed to be at most `u64::BITS` (a 6-bit number) and therefore,
// casting is lossless and right shift can be safely unwrapped
let v_ij = Choice::from((amount_i.checked_shr(j as u32).unwrap() & 1) as u8);
let mut point = -H_ij;
point.conditional_assign(G_ij, v_ij);
A += point;
Expand Down Expand Up @@ -138,7 +152,9 @@ impl RangeProof {
let mut exp_2 = Scalar::one();

for j in 0..(*n_i) {
let a_L_j = Scalar::from((amount_i >> j) & 1);
// `j` is guaranteed to be at most `u64::BITS` (a 6-bit number) and therefore,
// casting is lossless and right shift can be safely unwrapped
let a_L_j = Scalar::from(amount_i.checked_shr(j as u32).unwrap() & 1);
let a_R_j = a_L_j - Scalar::one();

l_poly.0[i] = a_L_j - z;
Expand All @@ -148,13 +164,17 @@ impl RangeProof {

exp_y *= y;
exp_2 = exp_2 + exp_2;
i += 1;

// `i` is capped by the sum of vectors in `bit_lengths`
i = i.checked_add(1).unwrap();
}
exp_z *= z;
}

// define t(x) = <l(x), r(x)> = t_0 + t_1*x + t_2*x
let t_poly = l_poly.inner_product(&r_poly);
let t_poly = l_poly
.inner_product(&r_poly)
.ok_or(RangeProofGenerationError::InnerProductLengthMismatch)?;

// generate Pedersen commitment for the coefficients t_1 and t_2
let (T_1, t_1_blinding) = Pedersen::new(t_poly.1);
Expand Down Expand Up @@ -216,7 +236,7 @@ impl RangeProof {
l_vec,
r_vec,
transcript,
);
)?;

Ok(RangeProof {
A,
Expand All @@ -238,7 +258,9 @@ impl RangeProof {
transcript: &mut Transcript,
) -> Result<(), RangeProofVerificationError> {
// commitments and bit-lengths must be same length vectors
assert_eq!(comms.len(), bit_lengths.len());
if comms.len() != bit_lengths.len() {
return Err(RangeProofVerificationError::VectorLengthMismatch);
}

let m = bit_lengths.len();
let nm: usize = bit_lengths.iter().sum();
Expand Down
18 changes: 9 additions & 9 deletions zk-token-sdk/src/range_proof/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@ impl VecPoly1 {
VecPoly1(vec![Scalar::zero(); n], vec![Scalar::zero(); n])
}

pub fn inner_product(&self, rhs: &VecPoly1) -> Poly2 {
pub fn inner_product(&self, rhs: &VecPoly1) -> Option<Poly2> {
// Uses Karatsuba's method
let l = self;
let r = rhs;

let t0 = inner_product(&l.0, &r.0);
let t2 = inner_product(&l.1, &r.1);
let t0 = inner_product(&l.0, &r.0)?;
let t2 = inner_product(&l.1, &r.1)?;

let l0_plus_l1 = add_vec(&l.0, &l.1);
let r0_plus_r1 = add_vec(&r.0, &r.1);

let t1 = inner_product(&l0_plus_l1, &r0_plus_r1) - t0 - t2;
let t1 = inner_product(&l0_plus_l1, &r0_plus_r1)? - t0 - t2;

Poly2(t0, t1, t2)
Some(Poly2(t0, t1, t2))
}

pub fn eval(&self, x: Scalar) -> Vec<Scalar> {
Expand Down Expand Up @@ -98,16 +98,16 @@ pub fn read32(data: &[u8]) -> [u8; 32] {
/// \\[
/// {\langle {\mathbf{a}}, {\mathbf{b}} \rangle} = \sum\_{i=0}^{n-1} a\_i \cdot b\_i.
/// \\]
/// Panics if the lengths of \\(\mathbf{a}\\) and \\(\mathbf{b}\\) are not equal.
pub fn inner_product(a: &[Scalar], b: &[Scalar]) -> Scalar {
/// Errors if the lengths of \\(\mathbf{a}\\) and \\(\mathbf{b}\\) are not equal.
pub fn inner_product(a: &[Scalar], b: &[Scalar]) -> Option<Scalar> {
let mut out = Scalar::zero();
if a.len() != b.len() {
panic!("inner_product(a,b): lengths of vectors do not match");
return None;
}
for i in 0..a.len() {
out += a[i] * b[i];
}
out
Some(out)
}

/// Takes the sum of all the powers of `x`, up to `n`
Expand Down

0 comments on commit ded278f

Please sign in to comment.