Skip to content

Commit

Permalink
Simplify ml evaluation testing (Nova Forward port, easy) (microsoft#185)
Browse files Browse the repository at this point in the history
* chore: refactor test imports

* refactor: simplify and reformulate evaluation testing
  • Loading branch information
huitseeker authored Dec 18, 2023
1 parent bb103dd commit 55d83ea
Showing 1 changed file with 27 additions and 86 deletions.
113 changes: 27 additions & 86 deletions src/spartan/polys/multilinear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ mod tests {
use crate::provider::{self, bn256_grumpkin::bn256, secp_secq::secp256k1};

use super::*;
use pasta_curves::Fp;
use rand_chacha::ChaCha20Rng;
use rand_core::SeedableRng;

Expand Down Expand Up @@ -252,12 +251,12 @@ mod tests {

#[test]
fn test_multilinear_polynomial() {
test_multilinear_polynomial_with::<Fp>();
test_multilinear_polynomial_with::<pasta_curves::Fp>();
}

#[test]
fn test_sparse_polynomial() {
test_sparse_polynomial_with::<Fp>();
test_sparse_polynomial_with::<pasta_curves::Fp>();
}

fn test_mlp_add_with<F: PrimeField>() {
Expand All @@ -271,7 +270,7 @@ mod tests {

#[test]
fn test_mlp_add() {
test_mlp_add_with::<Fp>();
test_mlp_add_with::<pasta_curves::Fp>();
test_mlp_add_with::<bn256::Scalar>();
test_mlp_add_with::<secp256k1::Scalar>();
}
Expand Down Expand Up @@ -304,111 +303,53 @@ mod tests {

#[test]
fn test_evaluation() {
test_evaluation_with::<Fp>();
test_evaluation_with::<pasta_curves::Fp>();
test_evaluation_with::<provider::bn256_grumpkin::bn256::Scalar>();
test_evaluation_with::<provider::secp_secq::secp256k1::Scalar>();
}

/// This evaluates a multilinear polynomial at a partial point in the evaluation domain,
/// which forces us to model how we pass coordinates to the evaluation function precisely.
fn partial_eval<F: PrimeField>(
/// This binds the variables of a multilinear polynomial to a provided sequence
/// of values.
///
/// Assuming `bind_poly_var_top` defines the "top" variable of the polynomial,
/// this aims to test whether variables should be provided to the
/// `evaluate` function in topmost-first (big endian) of topmost-last (lower endian)
/// order.
fn bind_sequence<F: PrimeField>(
poly: &MultilinearPolynomial<F>,
point: &[F],
values: &[F],
) -> MultilinearPolynomial<F> {
// Get size of partial evaluation point u = (u_0,...,u_{m-1})
let m = point.len();

// Assert that the size of the polynomial being evaluated is a power of 2 greater than (1 << m)
// Assert that the size of the polynomial being evaluated is a power of 2 greater than (1 << values.len())
assert!(poly.Z.len().is_power_of_two());
assert!(poly.Z.len() >= 1 << m);
let n = poly.Z.len().trailing_zeros() as usize;

// Partial evaluation is done in m rounds l = 0,...,m-1.

// Temporary buffer of half the size of the polynomial
let mut n_l = 1 << (n - 1);
let mut tmp = vec![F::ZERO; n_l];

let prev = &poly.Z;
// Evaluate variable X_{n-1} at u_{m-1}
let u_l = point[m - 1];
for i in 0..n_l {
tmp[i] = prev[i] + u_l * (prev[i + n_l] - prev[i]);
}
assert!(poly.Z.len() >= 1 << values.len());

// Evaluate m-1 variables X_{n-l-1}, ..., X_{n-2} at m-1 remaining values u_0,...,u_{m-2})
for l in 1..m {
n_l = 1 << (n - l - 1);
let u_l = point[m - l - 1];
for i in 0..n_l {
tmp[i] = tmp[i] + u_l * (tmp[i + n_l] - tmp[i]);
}
let mut tmp = poly.clone();
for v in values.iter() {
tmp.bind_poly_var_top(v);
}
tmp.truncate(1 << (poly.num_vars - m));

MultilinearPolynomial::new(tmp)
}

fn partial_evaluate_mle_with<F: PrimeField>() {
// Initialize a random polynomial
let n = 5;
let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
let poly = MultilinearPolynomial::random(n, &mut rng);

// Define a random multivariate evaluation point u = (u_0, u_1, u_2, u_3, u_4)
let u_0 = F::random(&mut rng);
let u_1 = F::random(&mut rng);
let u_2 = F::random(&mut rng);
let u_3 = F::random(&mut rng);
let u_4 = F::random(&mut rng);
let u_challenge = [u_4, u_3, u_2, u_1, u_0];

// Directly computing v = p(u_0,...,u_4) and comparing it with the result of
// first computing the partial evaluation in the last 3 variables
// g(X_0,X_1) = p(X_0,X_1,u_2,u_3,u_4), then v = g(u_0,u_1)

// Compute v = p(u_0,...,u_4)
let v_expected = poly.evaluate(&u_challenge[..]);

// Compute g(X_0,X_1) = p(X_0,X_1,u_2,u_3,u_4), then v = g(u_0,u_1)
let u_part_1 = [u_1, u_0]; // note the endianness difference
let u_part_2 = [u_2, u_3, u_4];

// Note how we start with part 2, and continue with part 1
let partial_evaluated_poly = partial_eval(&poly, &u_part_2);
let v_result = partial_evaluated_poly.evaluate(&u_part_1);

assert_eq!(v_result, v_expected);
}

#[test]
fn test_partial_evaluate_mle() {
partial_evaluate_mle_with::<Fp>();
partial_evaluate_mle_with::<bn256::Scalar>();
partial_evaluate_mle_with::<secp256k1::Scalar>();
tmp
}

fn partial_and_evaluate_with<F: PrimeField>() {
for _i in 0..50 {
fn bind_and_evaluate_with<F: PrimeField>() {
for i in 0..50 {
// Initialize a random polynomial
let n = 7;
let mut rng = ChaCha20Rng::from_seed([0u8; 32]);
let mut rng = ChaCha20Rng::from_seed([i as u8; 32]);
let poly = MultilinearPolynomial::random(n, &mut rng);

// draw a random point
let pt: Vec<_> = std::iter::from_fn(|| Some(F::random(&mut rng)))
.take(n)
.collect();
// this shows the order in which coordinates are evaluated
let rev_pt: Vec<_> = pt.iter().cloned().rev().collect();
assert_eq!(poly.evaluate(&pt), partial_eval(&poly, &rev_pt).Z[0])
assert_eq!(poly.evaluate(&pt), bind_sequence(&poly, &pt).Z[0])
}
}

#[test]
fn test_partial_and_evaluate() {
partial_and_evaluate_with::<Fp>();
partial_and_evaluate_with::<bn256::Scalar>();
partial_and_evaluate_with::<secp256k1::Scalar>();
fn test_bind_and_evaluate() {
bind_and_evaluate_with::<pasta_curves::Fp>();
bind_and_evaluate_with::<bn256::Scalar>();
bind_and_evaluate_with::<secp256k1::Scalar>();
}
}

0 comments on commit 55d83ea

Please sign in to comment.