From c2a363a5a1a3c91567d5bd6cf166c8aafa6a7531 Mon Sep 17 00:00:00 2001 From: Joey Kraut Date: Mon, 19 Aug 2024 15:28:08 -0700 Subject: [PATCH] online-phase: algebra: scalar: Optimize (de)serialization This has large effects throughout any stack built on top of the `Scalar` type, so we use the `ark_serialize` `CanonicalSerialize` and `CanonicalDeserialize` traits. --- online-phase/Cargo.toml | 5 ++ online-phase/benches/scalar_serialization.rs | 61 ++++++++++++++++++++ online-phase/src/algebra/scalar/scalar.rs | 22 +++++-- 3 files changed, 84 insertions(+), 4 deletions(-) create mode 100644 online-phase/benches/scalar_serialization.rs diff --git a/online-phase/Cargo.toml b/online-phase/Cargo.toml index a3ccc11..66f391b 100644 --- a/online-phase/Cargo.toml +++ b/online-phase/Cargo.toml @@ -31,6 +31,11 @@ path = "integration/main.rs" harness = false required-features = ["test_helpers"] +[[bench]] +name = "scalar_serialization" +harness = false +required-features = ["benchmarks", "test_helpers"] + [[bench]] name = "batch_ops" harness = false diff --git a/online-phase/benches/scalar_serialization.rs b/online-phase/benches/scalar_serialization.rs new file mode 100644 index 0000000..3ce1fa2 --- /dev/null +++ b/online-phase/benches/scalar_serialization.rs @@ -0,0 +1,61 @@ +use std::time::{Duration, Instant}; + +use ark_mpc::{algebra::Scalar, test_helpers::TestCurve}; +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use rand::thread_rng; + +/// Benchmark the serialization of scalars +fn bench_scalar_serialization(c: &mut Criterion) { + let mut rng = thread_rng(); + let mut group = c.benchmark_group("scalar_serialization"); + group.throughput(Throughput::Elements(1)); + + group.bench_function("scalar_serialization", |b| { + b.iter_custom(|n_iters| { + let mut total_time = Duration::from_secs(0); + for _ in 0..n_iters { + let scalar = Scalar::::random(&mut rng); + + let start = Instant::now(); + let bytes = serde_json::to_value(scalar).unwrap(); + total_time += start.elapsed(); + + black_box(bytes); + } + total_time + }) + }); +} + +/// Benchmark the deserialization of scalars +fn bench_scalar_deserialization(c: &mut Criterion) { + let mut rng = thread_rng(); + let mut group = c.benchmark_group("scalar_deserialization"); + group.throughput(Throughput::Elements(1)); + + group.bench_function("scalar_deserialization", |b| { + b.iter_custom(|n_iters| { + let mut total_time = Duration::from_secs(0); + for _ in 0..n_iters { + let scalar = Scalar::::random(&mut rng); + let serialized = serde_json::to_value(scalar).unwrap(); + + // Time deserialization only + let start = Instant::now(); + let deserialized: Scalar = serde_json::from_value(serialized).unwrap(); + total_time += start.elapsed(); + + black_box(deserialized); + } + + total_time + }) + }); +} + +criterion_group! { + name = scalar_ops; + config = Criterion::default(); + targets = bench_scalar_serialization, bench_scalar_deserialization +} +criterion_main!(scalar_ops); diff --git a/online-phase/src/algebra/scalar/scalar.rs b/online-phase/src/algebra/scalar/scalar.rs index 502226c..4912a1a 100644 --- a/online-phase/src/algebra/scalar/scalar.rs +++ b/online-phase/src/algebra/scalar/scalar.rs @@ -13,6 +13,7 @@ use std::{ use ark_ec::CurveGroup; use ark_ff::{batch_inversion, FftField, Field, One, PrimeField, Zero}; use ark_poly::EvaluationDomain; +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; use ark_std::UniformRand; use itertools::Itertools; use num_bigint::BigUint; @@ -168,16 +169,18 @@ impl Display for Scalar { impl Serialize for Scalar { fn serialize(&self, serializer: S) -> Result { - let bytes = self.to_bytes_be(); - bytes.serialize(serializer) + let mut bytes = Vec::with_capacity(n_bytes_field::()); + self.0.serialize_uncompressed(&mut bytes).map_err(serde::ser::Error::custom)?; + serializer.serialize_bytes(&bytes) } } impl<'de, C: CurveGroup> Deserialize<'de> for Scalar { fn deserialize>(deserializer: D) -> Result { let bytes = >::deserialize(deserializer)?; - let scalar = Scalar::from_be_bytes_mod_order(&bytes); - Ok(scalar) + let inner = C::ScalarField::deserialize_uncompressed(bytes.as_slice()) + .map_err(serde::de::Error::custom)?; + Ok(Scalar(inner)) } } @@ -720,6 +723,17 @@ mod test { use itertools::Itertools; use rand::{thread_rng, Rng, RngCore}; + /// Tests serialization and deserialization of scalars + #[test] + fn test_scalar_serialization() { + let mut rng = thread_rng(); + let scalar = Scalar::::random(&mut rng); + + let bytes = serde_json::to_vec(&scalar).unwrap(); + let deserialized: Scalar = serde_json::from_slice(&bytes).unwrap(); + assert_eq!(scalar, deserialized); + } + /// Tests addition of raw scalars in a circuit #[tokio::test] async fn test_scalar_add() {