From fe2e4211c689fa5b46cdd990f9ccb78696fc38b1 Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Wed, 6 Sep 2023 15:36:42 -0700 Subject: [PATCH 1/7] Remove non-constant-time comparisons of secret values. This affects the types representing: * Input shares * Preparation states * Output shares * Aggregate shares Mostly, the comparisons were either dropped entirely or updated to be test-only. Input shares were instead given a constant-time equality implementation, as it is believed this is required for DAP implementations. PartialEq & Eq implementations are still derived when the "test-util" feature is enabled. This is to ease testing for users of this library. --- Cargo.toml | 3 ++- src/vdaf.rs | 24 +++++++++++++++--- src/vdaf/poplar1.rs | 38 ++++++++++++++++++++++++---- src/vdaf/prio2.rs | 4 ++- src/vdaf/prio3.rs | 61 ++++++++++++++++++++++++++++++++++++++++----- src/vdaf/xof.rs | 10 +++----- 6 files changed, 117 insertions(+), 23 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0020ce2ef..720912b15 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,7 @@ itertools = "0.11.0" modinverse = "0.1.0" num-bigint = "0.4.4" once_cell = "1.18.0" -prio = { path = ".", features = ["crypto-dependencies"] } +prio = { path = ".", features = ["crypto-dependencies", "test-util"] } rand = "0.8" serde_json = "1.0" statrs = "0.16.0" @@ -58,6 +58,7 @@ experimental = ["bitvec", "fiat-crypto", "fixed", "num-bigint", "num-rational", multithreaded = ["rayon"] prio2 = ["crypto-dependencies", "hmac", "sha2"] crypto-dependencies = ["aes", "ctr"] +test-util = [] [workspace] members = [".", "binaries"] diff --git a/src/vdaf.rs b/src/vdaf.rs index cc2a306fe..b4a8100d5 100644 --- a/src/vdaf.rs +++ b/src/vdaf.rs @@ -18,6 +18,7 @@ use crate::{ }; use serde::{Deserialize, Serialize}; use std::{fmt::Debug, io::Cursor}; +use subtle::{Choice, ConstantTimeEq}; /// A component of the domain-separation tag, used to bind the VDAF operations to the document /// version. This will be revised with each draft with breaking changes. @@ -57,7 +58,9 @@ pub enum VdafError { } /// An additive share of a vector of field elements. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug)] +// Only derive equality checks in test code, as the content of this type is a secret. +#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] pub enum Share { /// An uncompressed share, typically sent to the leader. Leader(Vec), @@ -78,6 +81,18 @@ impl Share { } } +impl ConstantTimeEq for Share { + fn ct_eq(&self, other: &Self) -> subtle::Choice { + // We allow short-circuiting on the type (Leader vs Helper) of the value, but not the types' + // contents. + match (self, other) { + (Share::Leader(self_val), Share::Leader(other_val)) => self_val.ct_eq(other_val), + (Share::Helper(self_val), Share::Helper(other_val)) => self_val.ct_eq(other_val), + _ => Choice::from(0), + } + } +} + /// Parameters needed to decode a [`Share`] #[derive(Clone, Debug, PartialEq, Eq)] pub(crate) enum ShareDecodingParameter { @@ -310,7 +325,7 @@ pub trait Aggregatable: Clone + Debug + From { } /// An output share comprised of a vector of field elements. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] pub struct OutputShare(Vec); impl AsRef<[F]> for OutputShare { @@ -339,7 +354,10 @@ impl Encode for OutputShare { /// /// This is suitable for VDAFs where both output shares and aggregate shares are vectors of field /// elements, and output shares need no special transformation to be merged into an aggregate share. -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] +// Only derive equality checks in test code, as the content of this type is a secret. +#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] + pub struct AggregateShare(Vec); impl AsRef<[F]> for AggregateShare { diff --git a/src/vdaf/poplar1.rs b/src/vdaf/poplar1.rs index 7f9a277d6..34549f81b 100644 --- a/src/vdaf/poplar1.rs +++ b/src/vdaf/poplar1.rs @@ -110,7 +110,9 @@ impl ParameterizedDecode> for P /// /// This is comprised of an IDPF key share and the correlated randomness used to compute the sketch /// during preparation. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone)] +// Only derive equality checks in test code, as the content of this type is a secret. +#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] pub struct Poplar1InputShare { /// IDPF key share. idpf_key: Seed<16>, @@ -128,6 +130,24 @@ pub struct Poplar1InputShare { corr_leaf: [Field255; 2], } +impl ConstantTimeEq for Poplar1InputShare { + fn ct_eq(&self, other: &Self) -> Choice { + // We short-circuit on the length of corr_inner being different. Only the content is + // protected. + if self.corr_inner.len() != other.corr_inner.len() { + return Choice::from(0); + } + + let mut rslt = self.idpf_key.ct_eq(&other.idpf_key) + & self.corr_seed.ct_eq(&other.corr_seed) + & self.corr_leaf.ct_eq(&other.corr_leaf); + for (x, y) in self.corr_inner.iter().zip(other.corr_inner.iter()) { + rslt &= x.ct_eq(y); + } + rslt + } +} + impl Encode for Poplar1InputShare { fn encode(&self, bytes: &mut Vec) { self.idpf_key.encode(bytes); @@ -174,7 +194,9 @@ impl<'a, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Poplar1 ParameterizedDecode<(&'a Poplar1), Leaf(PrepareState), @@ -252,7 +276,9 @@ impl<'a, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Poplar1 { sketch: SketchState, output_share: Vec, @@ -450,7 +476,9 @@ impl ParameterizedDecode for Poplar1PrepareMessage { } /// A vector of field elements transmitted while evaluating Poplar1. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] +// Only derive equality checks in test code, as the content of this type is a secret. +#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] pub enum Poplar1FieldVec { /// Field type for inner nodes of the IDPF tree. Inner(Vec), diff --git a/src/vdaf/prio2.rs b/src/vdaf/prio2.rs index ff61cf52d..62b7234e3 100644 --- a/src/vdaf/prio2.rs +++ b/src/vdaf/prio2.rs @@ -165,7 +165,9 @@ impl Client<16> for Prio2 { } /// State of each [`Aggregator`] during the Preparation phase. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug)] +// Only derive equality checks in test code, as the content of this type is a secret. +#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] pub struct Prio2PrepareState(Share); impl Encode for Prio2PrepareState { diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 851689e36..4487a9ba0 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -61,6 +61,7 @@ use std::fmt::Debug; use std::io::Cursor; use std::iter::{self, IntoIterator}; use std::marker::PhantomData; +use subtle::{Choice, ConstantTimeEq}; const DST_MEASUREMENT_SHARE: u16 = 1; const DST_PROOF_SHARE: u16 = 2; @@ -595,7 +596,7 @@ where } /// Message broadcast by the [`Client`] to every [`Aggregator`] during the Sharding phase. -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] pub struct Prio3PublicShare { /// Contributions to the joint randomness from every aggregator's share. joint_rand_parts: Option>>, @@ -620,6 +621,22 @@ impl Encode for Prio3PublicShare { } } +impl PartialEq for Prio3PublicShare { + fn eq(&self, other: &Self) -> bool { + // Handle case that both join_rand_parts are populated. + if let Some(self_joint_rand_parts) = &self.joint_rand_parts { + if let Some(other_joint_rand_parts) = &other.joint_rand_parts { + return self_joint_rand_parts.ct_eq(other_joint_rand_parts).into(); + } + } + + // Handle case that at least one joint_rand_parts is not populated. + self.joint_rand_parts.is_none() && other.joint_rand_parts.is_none() + } +} + +impl Eq for Prio3PublicShare {} + impl ParameterizedDecode> for Prio3PublicShare where @@ -646,7 +663,9 @@ where } /// Message sent by the [`Client`] to each [`Aggregator`] during the Sharding phase. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug)] +// Only derive equality checks in test code, as the content of this type is a secret. +#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] pub struct Prio3InputShare { /// The measurement share. measurement_share: Share, @@ -659,6 +678,24 @@ pub struct Prio3InputShare { joint_rand_blind: Option>, } +impl ConstantTimeEq for Prio3InputShare { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the existence (but not contents) of the joint_rand_blind, + // as its existence is a property of the type in use. + let joint_rand_eq = match (&self.joint_rand_blind, &other.joint_rand_blind) { + (Some(self_joint_rand), Some(other_joint_rand)) => { + self_joint_rand.ct_eq(other_joint_rand) + } + (None, None) => Choice::from(1), + _ => Choice::from(0), + }; + + joint_rand_eq + & self.measurement_share.ct_eq(&other.measurement_share) + & self.proof_share.ct_eq(&other.proof_share) + } +} + impl Encode for Prio3InputShare { fn encode(&self, bytes: &mut Vec) { if matches!( @@ -726,7 +763,9 @@ where } } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug)] +// Only derive equality checks in test code, as the content of this type is a secret. +#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] /// Message broadcast by each [`Aggregator`] in each round of the Preparation phase. pub struct Prio3PrepareShare { /// A share of the FLP verifier message. (See [`Type`].) @@ -783,7 +822,9 @@ impl } } -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug)] +// Only derive equality checks in test code, as the content of this type is a secret. +#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] /// Result of combining a round of [`Prio3PrepareShare`] messages. pub struct Prio3PrepareMessage { /// The joint randomness seed computed by the Aggregators. @@ -841,7 +882,9 @@ where } /// State of each [`Aggregator`] during the Preparation phase. -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Clone, Debug)] +// Only derive equality checks in test code, as the content of this type is a secret. +#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] pub struct Prio3PrepareState { measurement_share: Share, joint_rand_seed: Option>, @@ -1111,7 +1154,13 @@ where ) -> Result, VdafError> { if self.typ.joint_rand_len() > 0 { // Check that the joint randomness was correct. - if step.joint_rand_seed.as_ref().unwrap() != msg.joint_rand_seed.as_ref().unwrap() { + if (!step + .joint_rand_seed + .as_ref() + .unwrap() + .ct_eq(msg.joint_rand_seed.as_ref().unwrap())) + .into() + { return Err(VdafError::Uncategorized( "joint randomness mismatch".to_string(), )); diff --git a/src/vdaf/xof.rs b/src/vdaf/xof.rs index 64af49a01..ed8f2511a 100644 --- a/src/vdaf/xof.rs +++ b/src/vdaf/xof.rs @@ -38,7 +38,9 @@ use std::{ use subtle::{Choice, ConstantTimeEq}; /// Input of [`Xof`]. -#[derive(Clone, Debug, Eq)] +#[derive(Clone, Debug)] +// Only derive equality checks in test code, as the content of this type is sometimes a secret. +#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] pub struct Seed(pub(crate) [u8; SEED_SIZE]); impl Seed { @@ -67,12 +69,6 @@ impl ConstantTimeEq for Seed { } } -impl PartialEq for Seed { - fn eq(&self, other: &Self) -> bool { - self.ct_eq(other).into() - } -} - impl Encode for Seed { fn encode(&self, bytes: &mut Vec) { bytes.extend_from_slice(&self.0[..]); From 0b624de69fa3b3cca4e0847bd07bb07551ddeaf5 Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Thu, 7 Sep 2023 14:39:32 -0700 Subject: [PATCH 2/7] Code review. --- Cargo.toml | 3 +- src/vdaf.rs | 26 +++++++-- src/vdaf/poplar1.rs | 134 +++++++++++++++++++++++++++++++++++++++----- src/vdaf/prio2.rs | 17 +++++- src/vdaf/prio3.rs | 104 ++++++++++++++++++++++++---------- src/vdaf/xof.rs | 10 +++- 6 files changed, 240 insertions(+), 54 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 720912b15..0020ce2ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,7 @@ itertools = "0.11.0" modinverse = "0.1.0" num-bigint = "0.4.4" once_cell = "1.18.0" -prio = { path = ".", features = ["crypto-dependencies", "test-util"] } +prio = { path = ".", features = ["crypto-dependencies"] } rand = "0.8" serde_json = "1.0" statrs = "0.16.0" @@ -58,7 +58,6 @@ experimental = ["bitvec", "fiat-crypto", "fixed", "num-bigint", "num-rational", multithreaded = ["rayon"] prio2 = ["crypto-dependencies", "hmac", "sha2"] crypto-dependencies = ["aes", "ctr"] -test-util = [] [workspace] members = [".", "binaries"] diff --git a/src/vdaf.rs b/src/vdaf.rs index b4a8100d5..a47a15a7d 100644 --- a/src/vdaf.rs +++ b/src/vdaf.rs @@ -59,8 +59,6 @@ pub enum VdafError { /// An additive share of a vector of field elements. #[derive(Clone, Debug)] -// Only derive equality checks in test code, as the content of this type is a secret. -#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] pub enum Share { /// An uncompressed share, typically sent to the leader. Leader(Vec), @@ -81,6 +79,14 @@ impl Share { } } +impl PartialEq for Share { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Share {} + impl ConstantTimeEq for Share { fn ct_eq(&self, other: &Self) -> subtle::Choice { // We allow short-circuiting on the type (Leader vs Helper) of the value, but not the types' @@ -355,11 +361,23 @@ impl Encode for OutputShare { /// This is suitable for VDAFs where both output shares and aggregate shares are vectors of field /// elements, and output shares need no special transformation to be merged into an aggregate share. #[derive(Clone, Debug, Serialize, Deserialize)] -// Only derive equality checks in test code, as the content of this type is a secret. -#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] pub struct AggregateShare(Vec); +impl PartialEq for AggregateShare { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for AggregateShare {} + +impl ConstantTimeEq for AggregateShare { + fn ct_eq(&self, other: &Self) -> subtle::Choice { + self.0.ct_eq(&other.0) + } +} + impl AsRef<[F]> for AggregateShare { fn as_ref(&self) -> &[F] { &self.0 diff --git a/src/vdaf/poplar1.rs b/src/vdaf/poplar1.rs index 34549f81b..07b01d881 100644 --- a/src/vdaf/poplar1.rs +++ b/src/vdaf/poplar1.rs @@ -111,8 +111,6 @@ impl ParameterizedDecode> for P /// This is comprised of an IDPF key share and the correlated randomness used to compute the sketch /// during preparation. #[derive(Debug, Clone)] -// Only derive equality checks in test code, as the content of this type is a secret. -#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] pub struct Poplar1InputShare { /// IDPF key share. idpf_key: Seed<16>, @@ -130,6 +128,14 @@ pub struct Poplar1InputShare { corr_leaf: [Field255; 2], } +impl PartialEq for Poplar1InputShare { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Poplar1InputShare {} + impl ConstantTimeEq for Poplar1InputShare { fn ct_eq(&self, other: &Self) -> Choice { // We short-circuit on the length of corr_inner being different. Only the content is @@ -138,13 +144,13 @@ impl ConstantTimeEq for Poplar1InputShare { return Choice::from(0); } - let mut rslt = self.idpf_key.ct_eq(&other.idpf_key) + let mut res = self.idpf_key.ct_eq(&other.idpf_key) & self.corr_seed.ct_eq(&other.corr_seed) & self.corr_leaf.ct_eq(&other.corr_leaf); for (x, y) in self.corr_inner.iter().zip(other.corr_inner.iter()) { - rslt &= x.ct_eq(y); + res &= x.ct_eq(y); } - rslt + res } } @@ -195,10 +201,22 @@ impl<'a, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Poplar1 bool { + self.ct_eq(other).into() + } +} + +impl Eq for Poplar1PrepareState {} + +impl ConstantTimeEq for Poplar1PrepareState { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + impl Encode for Poplar1PrepareState { fn encode(&self, bytes: &mut Vec) { self.0.encode(bytes) @@ -224,13 +242,30 @@ impl<'a, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Poplar1), Leaf(PrepareState), } +impl PartialEq for PrepareStateVariant { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for PrepareStateVariant {} + +impl ConstantTimeEq for PrepareStateVariant { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the type (Inner vs Leaf). + match (self, other) { + (Self::Inner(self_val), Self::Inner(other_val)) => self_val.ct_eq(other_val), + (Self::Leaf(self_val), Self::Leaf(other_val)) => self_val.ct_eq(other_val), + _ => Choice::from(0), + } + } +} + impl Encode for PrepareStateVariant { fn encode(&self, bytes: &mut Vec) { match self { @@ -277,13 +312,25 @@ impl<'a, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Poplar1 { sketch: SketchState, output_share: Vec, } +impl PartialEq for PrepareState { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for PrepareState {} + +impl ConstantTimeEq for PrepareState { + fn ct_eq(&self, other: &Self) -> Choice { + self.sketch.ct_eq(&other.sketch) & self.output_share.ct_eq(&other.output_share) + } +} + impl Encode for PrepareState { fn encode(&self, bytes: &mut Vec) { self.sketch.encode(bytes); @@ -323,7 +370,7 @@ impl<'a, P, F: FieldElement, const SEED_SIZE: usize> } } -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug)] enum SketchState { #[allow(non_snake_case)] RoundOne { @@ -334,6 +381,44 @@ enum SketchState { RoundTwo, } +impl PartialEq for SketchState { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for SketchState {} + +impl ConstantTimeEq for SketchState { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the round (RoundOne vs RoundTwo). + match (self, other) { + ( + SketchState::RoundOne { + A_share: self_a_share, + B_share: self_b_share, + is_leader: self_is_leader, + }, + SketchState::RoundOne { + A_share: other_a_share, + B_share: other_b_share, + is_leader: other_is_leader, + }, + ) => { + let self_is_leader = Choice::from(*self_is_leader as u8); + let other_is_leader = Choice::from(*other_is_leader as u8); + + self_a_share.ct_eq(other_a_share) + & self_b_share.ct_eq(other_b_share) + & self_is_leader.ct_eq(&other_is_leader) + } + + (SketchState::RoundTwo, SketchState::RoundTwo) => Choice::from(1), + _ => Choice::from(0), + } + } +} + impl Encode for SketchState { fn encode(&self, bytes: &mut Vec) { match self { @@ -477,8 +562,6 @@ impl ParameterizedDecode for Poplar1PrepareMessage { /// A vector of field elements transmitted while evaluating Poplar1. #[derive(Clone, Debug)] -// Only derive equality checks in test code, as the content of this type is a secret. -#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] pub enum Poplar1FieldVec { /// Field type for inner nodes of the IDPF tree. Inner(Vec), @@ -497,6 +580,29 @@ impl Poplar1FieldVec { } } +impl PartialEq for Poplar1FieldVec { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Poplar1FieldVec {} + +impl ConstantTimeEq for Poplar1FieldVec { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the type (Inner vs Leaf). + match (self, other) { + (Poplar1FieldVec::Inner(self_val), Poplar1FieldVec::Inner(other_val)) => { + self_val.ct_eq(&other_val) + } + (Poplar1FieldVec::Leaf(self_val), Poplar1FieldVec::Leaf(other_val)) => { + self_val.ct_eq(&other_val) + } + _ => Choice::from(0), + } + } +} + impl Encode for Poplar1FieldVec { fn encode(&self, bytes: &mut Vec) { match self { diff --git a/src/vdaf/prio2.rs b/src/vdaf/prio2.rs index 62b7234e3..cce0bde47 100644 --- a/src/vdaf/prio2.rs +++ b/src/vdaf/prio2.rs @@ -22,6 +22,7 @@ use crate::{ use hmac::{Hmac, Mac}; use sha2::Sha256; use std::{convert::TryFrom, io::Cursor}; +use subtle::{Choice, ConstantTimeEq}; mod client; mod server; @@ -166,10 +167,22 @@ impl Client<16> for Prio2 { /// State of each [`Aggregator`] during the Preparation phase. #[derive(Clone, Debug)] -// Only derive equality checks in test code, as the content of this type is a secret. -#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] pub struct Prio2PrepareState(Share); +impl PartialEq for Prio2PrepareState { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Prio2PrepareState {} + +impl ConstantTimeEq for Prio2PrepareState { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + impl Encode for Prio2PrepareState { fn encode(&self, bytes: &mut Vec) { self.0.encode(bytes); diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 4487a9ba0..3aab245ac 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -623,15 +623,11 @@ impl Encode for Prio3PublicShare { impl PartialEq for Prio3PublicShare { fn eq(&self, other: &Self) -> bool { - // Handle case that both join_rand_parts are populated. - if let Some(self_joint_rand_parts) = &self.joint_rand_parts { - if let Some(other_joint_rand_parts) = &other.joint_rand_parts { - return self_joint_rand_parts.ct_eq(other_joint_rand_parts).into(); - } + match (&self.joint_rand_parts, &other.joint_rand_parts) { + (Some(self_jrps), Some(other_jrps)) => self_jrps.ct_eq(other_jrps).into(), + (None, None) => true, + _ => false, } - - // Handle case that at least one joint_rand_parts is not populated. - self.joint_rand_parts.is_none() && other.joint_rand_parts.is_none() } } @@ -664,8 +660,6 @@ where /// Message sent by the [`Client`] to each [`Aggregator`] during the Sharding phase. #[derive(Clone, Debug)] -// Only derive equality checks in test code, as the content of this type is a secret. -#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] pub struct Prio3InputShare { /// The measurement share. measurement_share: Share, @@ -678,19 +672,17 @@ pub struct Prio3InputShare { joint_rand_blind: Option>, } +impl PartialEq for Prio3InputShare { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Prio3InputShare {} + impl ConstantTimeEq for Prio3InputShare { fn ct_eq(&self, other: &Self) -> Choice { - // We allow short-circuiting on the existence (but not contents) of the joint_rand_blind, - // as its existence is a property of the type in use. - let joint_rand_eq = match (&self.joint_rand_blind, &other.joint_rand_blind) { - (Some(self_joint_rand), Some(other_joint_rand)) => { - self_joint_rand.ct_eq(other_joint_rand) - } - (None, None) => Choice::from(1), - _ => Choice::from(0), - }; - - joint_rand_eq + option_ct_eq(&self.joint_rand_blind, &other.joint_rand_blind) & self.measurement_share.ct_eq(&other.measurement_share) & self.proof_share.ct_eq(&other.proof_share) } @@ -764,9 +756,8 @@ where } #[derive(Clone, Debug)] -// Only derive equality checks in test code, as the content of this type is a secret. -#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] /// Message broadcast by each [`Aggregator`] in each round of the Preparation phase. + pub struct Prio3PrepareShare { /// A share of the FLP verifier message. (See [`Type`].) verifier: Vec, @@ -775,6 +766,21 @@ pub struct Prio3PrepareShare { joint_rand_part: Option>, } +impl PartialEq for Prio3PrepareShare { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Prio3PrepareShare {} + +impl ConstantTimeEq for Prio3PrepareShare { + fn ct_eq(&self, other: &Self) -> Choice { + option_ct_eq(&self.joint_rand_part, &other.joint_rand_part) + & self.verifier.ct_eq(&other.verifier) + } +} + impl Encode for Prio3PrepareShare { @@ -823,14 +829,26 @@ impl } #[derive(Clone, Debug)] -// Only derive equality checks in test code, as the content of this type is a secret. -#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] /// Result of combining a round of [`Prio3PrepareShare`] messages. pub struct Prio3PrepareMessage { /// The joint randomness seed computed by the Aggregators. joint_rand_seed: Option>, } +impl PartialEq for Prio3PrepareMessage { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Prio3PrepareMessage {} + +impl ConstantTimeEq for Prio3PrepareMessage { + fn ct_eq(&self, other: &Self) -> Choice { + option_ct_eq(&self.joint_rand_seed, &other.joint_rand_seed) + } +} + impl Encode for Prio3PrepareMessage { fn encode(&self, bytes: &mut Vec) { if let Some(ref seed) = self.joint_rand_seed { @@ -883,8 +901,6 @@ where /// State of each [`Aggregator`] during the Preparation phase. #[derive(Clone, Debug)] -// Only derive equality checks in test code, as the content of this type is a secret. -#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] pub struct Prio3PrepareState { measurement_share: Share, joint_rand_seed: Option>, @@ -892,6 +908,23 @@ pub struct Prio3PrepareState { verifier_len: usize, } +impl PartialEq for Prio3PrepareState { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Prio3PrepareState {} + +impl ConstantTimeEq for Prio3PrepareState { + fn ct_eq(&self, other: &Self) -> Choice { + option_ct_eq(&self.joint_rand_seed, &other.joint_rand_seed) + & self.measurement_share.ct_eq(&other.measurement_share) + & self.agg_id.ct_eq(&other.agg_id) + & self.verifier_len.ct_eq(&other.verifier_len) + } +} + impl Encode for Prio3PrepareState { @@ -1154,12 +1187,12 @@ where ) -> Result, VdafError> { if self.typ.joint_rand_len() > 0 { // Check that the joint randomness was correct. - if (!step + if step .joint_rand_seed .as_ref() .unwrap() - .ct_eq(msg.joint_rand_seed.as_ref().unwrap())) - .into() + .ct_ne(msg.joint_rand_seed.as_ref().unwrap()) + .into() { return Err(VdafError::Uncategorized( "joint randomness mismatch".to_string(), @@ -1308,6 +1341,17 @@ where } } +// This function determines equality between two optional, constant-time comparable values. It +// short-circuits on the existence (but not contents) of the values -- timing information may reveal +// whether the values match on Some or None. +fn option_ct_eq(left: &Option, right: &Option) -> Choice { + match (left, right) { + (Some(left), Some(right)) => left.ct_eq(right), + (None, None) => Choice::from(1), + _ => Choice::from(0), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/vdaf/xof.rs b/src/vdaf/xof.rs index ed8f2511a..03c944368 100644 --- a/src/vdaf/xof.rs +++ b/src/vdaf/xof.rs @@ -39,8 +39,6 @@ use subtle::{Choice, ConstantTimeEq}; /// Input of [`Xof`]. #[derive(Clone, Debug)] -// Only derive equality checks in test code, as the content of this type is sometimes a secret. -#[cfg_attr(feature = "test-util", derive(PartialEq, Eq))] pub struct Seed(pub(crate) [u8; SEED_SIZE]); impl Seed { @@ -63,6 +61,14 @@ impl AsRef<[u8; SEED_SIZE]> for Seed { } } +impl PartialEq for Seed { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Seed {} + impl ConstantTimeEq for Seed { fn ct_eq(&self, other: &Self) -> Choice { self.0.ct_eq(&other.0) From 1d54c6d275298eaf29961b6e957e8d6a447418cf Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Thu, 7 Sep 2023 16:17:07 -0700 Subject: [PATCH 3/7] Tests. --- src/vdaf.rs | 73 ++++++++++++++++ src/vdaf/poplar1.rs | 198 +++++++++++++++++++++++++++++++++++++++++++- src/vdaf/prio2.rs | 25 +++++- src/vdaf/prio3.rs | 147 +++++++++++++++++++++++++++++++- src/vdaf/xof.rs | 7 +- 5 files changed, 444 insertions(+), 6 deletions(-) diff --git a/src/vdaf.rs b/src/vdaf.rs index a47a15a7d..5fb568514 100644 --- a/src/vdaf.rs +++ b/src/vdaf.rs @@ -334,6 +334,20 @@ pub trait Aggregatable: Clone + Debug + From { #[derive(Clone, Debug)] pub struct OutputShare(Vec); +impl PartialEq for OutputShare { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for OutputShare {} + +impl ConstantTimeEq for OutputShare { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + impl AsRef<[F]> for OutputShare { fn as_ref(&self) -> &[F] { &self.0 @@ -588,6 +602,65 @@ where assert_eq!(encoded, bytes); } +#[cfg(test)] +fn equality_comparison_test(values: &[T]) +where + T: Debug + PartialEq, +{ + use std::ptr; + + // This function expects that every value passed in `values` is distinct, i.e. should not + // compare as equal to any other element. We test both (i, j) and (j, i) to gain confidence that + // equality implementations are symmetric. + for (i, i_val) in values.iter().enumerate() { + for (j, j_val) in values.iter().enumerate() { + if i == j { + assert!(ptr::eq(i_val, j_val)); // sanity + assert_eq!( + i_val, j_val, + "Expected element at index {i} to be equal to itself, but it was not" + ); + } else { + assert_ne!( + i_val, j_val, + "Expected elements at indices {i} & {j} to not be equal, but they were" + ) + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::vdaf::{equality_comparison_test, xof::Seed, AggregateShare, OutputShare, Share}; + + #[test] + fn share_equality_test() { + equality_comparison_test(&[ + Share::Leader(Vec::from([1, 2, 3])), + Share::Leader(Vec::from([3, 2, 1])), + Share::Helper(Seed([1, 2, 3])), + Share::Helper(Seed([3, 2, 1])), + ]) + } + + #[test] + fn output_share_equality_test() { + equality_comparison_test(&[ + OutputShare(Vec::from([1, 2, 3])), + OutputShare(Vec::from([3, 2, 1])), + ]) + } + + #[test] + fn aggregate_share_equality_test() { + equality_comparison_test(&[ + AggregateShare(Vec::from([1, 2, 3])), + AggregateShare(Vec::from([3, 2, 1])), + ]) + } +} + #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] #[cfg_attr( docsrs, diff --git a/src/vdaf/poplar1.rs b/src/vdaf/poplar1.rs index 07b01d881..12f6c57df 100644 --- a/src/vdaf/poplar1.rs +++ b/src/vdaf/poplar1.rs @@ -593,10 +593,10 @@ impl ConstantTimeEq for Poplar1FieldVec { // We allow short-circuiting on the type (Inner vs Leaf). match (self, other) { (Poplar1FieldVec::Inner(self_val), Poplar1FieldVec::Inner(other_val)) => { - self_val.ct_eq(&other_val) + self_val.ct_eq(other_val) } (Poplar1FieldVec::Leaf(self_val), Poplar1FieldVec::Leaf(other_val)) => { - self_val.ct_eq(&other_val) + self_val.ct_eq(other_val) } _ => Choice::from(0), } @@ -1502,7 +1502,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::vdaf::run_vdaf_prepare; + use crate::vdaf::{equality_comparison_test, run_vdaf_prepare}; use assert_matches::assert_matches; use rand::prelude::*; use serde::Deserialize; @@ -2263,4 +2263,196 @@ mod tests { fn test_vec_poplar1_3() { check_test_vec(include_str!("test_vec/07/Poplar1_3.json")); } + + #[test] + fn input_share_equality_test() { + equality_comparison_test(&[ + // Default. + Poplar1InputShare { + idpf_key: Seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + corr_seed: Seed([16, 17, 18]), + corr_inner: Vec::from([ + [Field64::from(19), Field64::from(20)], + [Field64::from(21), Field64::from(22)], + [Field64::from(23), Field64::from(24)], + ]), + corr_leaf: [Field255::from(25), Field255::from(26)], + }, + // Modified idpf_key. + Poplar1InputShare { + idpf_key: Seed([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]), + corr_seed: Seed([16, 17, 18]), + corr_inner: Vec::from([ + [Field64::from(19), Field64::from(20)], + [Field64::from(21), Field64::from(22)], + [Field64::from(23), Field64::from(24)], + ]), + corr_leaf: [Field255::from(25), Field255::from(26)], + }, + // Modified corr_seed. + Poplar1InputShare { + idpf_key: Seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + corr_seed: Seed([18, 17, 16]), + corr_inner: Vec::from([ + [Field64::from(19), Field64::from(20)], + [Field64::from(21), Field64::from(22)], + [Field64::from(23), Field64::from(24)], + ]), + corr_leaf: [Field255::from(25), Field255::from(26)], + }, + // Modified corr_inner. + Poplar1InputShare { + idpf_key: Seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + corr_seed: Seed([16, 17, 18]), + corr_inner: Vec::from([ + [Field64::from(24), Field64::from(23)], + [Field64::from(22), Field64::from(21)], + [Field64::from(20), Field64::from(19)], + ]), + corr_leaf: [Field255::from(25), Field255::from(26)], + }, + // Modified corr_leaf. + Poplar1InputShare { + idpf_key: Seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + corr_seed: Seed([16, 17, 18]), + corr_inner: Vec::from([ + [Field64::from(19), Field64::from(20)], + [Field64::from(21), Field64::from(22)], + [Field64::from(23), Field64::from(24)], + ]), + corr_leaf: [Field255::from(26), Field255::from(25)], + }, + ]) + } + + #[test] + fn prepare_state_equality_test() { + // This test effectively covers PrepareStateVariant, PrepareState, SketchState as well. + equality_comparison_test(&[ + // Inner, round one. (default) + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field64::from(0), + B_share: Field64::from(1), + is_leader: false, + }, + output_share: Vec::from([Field64::from(2), Field64::from(3)]), + })), + // Inner, round one, modified A_share. + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field64::from(100), + B_share: Field64::from(1), + is_leader: false, + }, + output_share: Vec::from([Field64::from(2), Field64::from(3)]), + })), + // Inner, round one, modified B_share. + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field64::from(0), + B_share: Field64::from(101), + is_leader: false, + }, + output_share: Vec::from([Field64::from(2), Field64::from(3)]), + })), + // Inner, round one, modified is_leader. + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field64::from(0), + B_share: Field64::from(1), + is_leader: true, + }, + output_share: Vec::from([Field64::from(2), Field64::from(3)]), + })), + // Inner, round one, modified output_share. + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field64::from(0), + B_share: Field64::from(1), + is_leader: false, + }, + output_share: Vec::from([Field64::from(3), Field64::from(2)]), + })), + // Inner, round two. (default) + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundTwo, + output_share: Vec::from([Field64::from(2), Field64::from(3)]), + })), + // Inner, round two, modified output_share. + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundTwo, + output_share: Vec::from([Field64::from(3), Field64::from(2)]), + })), + // Leaf, round one. (default) + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field255::from(0), + B_share: Field255::from(1), + is_leader: false, + }, + output_share: Vec::from([Field255::from(2), Field255::from(3)]), + })), + // Leaf, round one, modified A_share. + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field255::from(100), + B_share: Field255::from(1), + is_leader: false, + }, + output_share: Vec::from([Field255::from(2), Field255::from(3)]), + })), + // Leaf, round one, modified B_share. + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field255::from(0), + B_share: Field255::from(101), + is_leader: false, + }, + output_share: Vec::from([Field255::from(2), Field255::from(3)]), + })), + // Leaf, round one, modified is_leader. + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field255::from(0), + B_share: Field255::from(1), + is_leader: true, + }, + output_share: Vec::from([Field255::from(2), Field255::from(3)]), + })), + // Leaf, round one, modified output_share. + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field255::from(0), + B_share: Field255::from(1), + is_leader: false, + }, + output_share: Vec::from([Field255::from(3), Field255::from(2)]), + })), + // Leaf, round two. (default) + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundTwo, + output_share: Vec::from([Field255::from(2), Field255::from(3)]), + })), + // Leaf, round two, modified output_share. + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundTwo, + output_share: Vec::from([Field255::from(3), Field255::from(2)]), + })), + ]) + } + + #[test] + fn field_vec_equality_test() { + equality_comparison_test(&[ + // Inner. (default) + Poplar1FieldVec::Inner(Vec::from([Field64::from(0), Field64::from(1)])), + // Inner, modified value. + Poplar1FieldVec::Inner(Vec::from([Field64::from(1), Field64::from(0)])), + // Leaf. (deafult) + Poplar1FieldVec::Leaf(Vec::from([Field255::from(0), Field255::from(1)])), + // Leaf, modified value. + Poplar1FieldVec::Leaf(Vec::from([Field255::from(1), Field255::from(0)])), + ]) + } } diff --git a/src/vdaf/prio2.rs b/src/vdaf/prio2.rs index cce0bde47..778d87b36 100644 --- a/src/vdaf/prio2.rs +++ b/src/vdaf/prio2.rs @@ -385,7 +385,10 @@ fn role_try_from(agg_id: usize) -> Result { #[cfg(test)] mod tests { use super::*; - use crate::vdaf::{fieldvec_roundtrip_test, prio2::test_vector::Priov2TestVector, run_vdaf}; + use crate::vdaf::{ + equality_comparison_test, fieldvec_roundtrip_test, prio2::test_vector::Priov2TestVector, + run_vdaf, + }; use assert_matches::assert_matches; use rand::prelude::*; @@ -516,4 +519,24 @@ mod tests { assert_eq!(reconstructed, test_vector.reference_sum); } + + #[test] + fn prepare_state_equality_test() { + equality_comparison_test(&[ + Prio2PrepareState(Share::Leader(Vec::from([ + FieldPrio2::from(0), + FieldPrio2::from(1), + ]))), + Prio2PrepareState(Share::Leader(Vec::from([ + FieldPrio2::from(1), + FieldPrio2::from(0), + ]))), + Prio2PrepareState(Share::Helper(Seed( + (0..32).collect::>().try_into().unwrap(), + ))), + Prio2PrepareState(Share::Helper(Seed( + (1..33).collect::>().try_into().unwrap(), + ))), + ]) + } } diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 3aab245ac..3821029f0 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -1357,7 +1357,9 @@ mod tests { use super::*; #[cfg(feature = "experimental")] use crate::flp::gadgets::ParallelSumGadget; - use crate::vdaf::{fieldvec_roundtrip_test, run_vdaf, run_vdaf_prepare}; + use crate::vdaf::{ + equality_comparison_test, fieldvec_roundtrip_test, run_vdaf, run_vdaf_prepare, + }; use assert_matches::assert_matches; #[cfg(feature = "experimental")] use fixed::{ @@ -1825,4 +1827,147 @@ mod tests { 12, ); } + + #[test] + fn public_share_equality_test() { + equality_comparison_test(&[ + Prio3PublicShare { + joint_rand_parts: Some(Vec::from([Seed([0])])), + }, + Prio3PublicShare { + joint_rand_parts: Some(Vec::from([Seed([1])])), + }, + Prio3PublicShare { + joint_rand_parts: None, + }, + ]) + } + + #[test] + fn input_share_equality_test() { + equality_comparison_test(&[ + // Default. + Prio3InputShare { + measurement_share: Share::Leader(Vec::from([0])), + proof_share: Share::Leader(Vec::from([1])), + joint_rand_blind: Some(Seed([2])), + }, + // Modified measurement share. + Prio3InputShare { + measurement_share: Share::Leader(Vec::from([100])), + proof_share: Share::Leader(Vec::from([1])), + joint_rand_blind: Some(Seed([2])), + }, + // Modified proof share. + Prio3InputShare { + measurement_share: Share::Leader(Vec::from([0])), + proof_share: Share::Leader(Vec::from([101])), + joint_rand_blind: Some(Seed([2])), + }, + // Modified joint_rand_blind. + Prio3InputShare { + measurement_share: Share::Leader(Vec::from([0])), + proof_share: Share::Leader(Vec::from([1])), + joint_rand_blind: Some(Seed([102])), + }, + // Missing joint_rand_blind. + Prio3InputShare { + measurement_share: Share::Leader(Vec::from([0])), + proof_share: Share::Leader(Vec::from([1])), + joint_rand_blind: None, + }, + ]) + } + + #[test] + fn prepare_share_equality_test() { + equality_comparison_test(&[ + // Default. + Prio3PrepareShare { + verifier: Vec::from([0]), + joint_rand_part: Some(Seed([1])), + }, + // Modified verifier. + Prio3PrepareShare { + verifier: Vec::from([100]), + joint_rand_part: Some(Seed([1])), + }, + // Modified joint_rand_part. + Prio3PrepareShare { + verifier: Vec::from([0]), + joint_rand_part: Some(Seed([101])), + }, + // Missing joint_rand_part. + Prio3PrepareShare { + verifier: Vec::from([0]), + joint_rand_part: None, + }, + ]) + } + + #[test] + fn prepare_message_equality_test() { + equality_comparison_test(&[ + // Default. + Prio3PrepareMessage { + joint_rand_seed: Some(Seed([0])), + }, + // Modified joint_rand_seed. + Prio3PrepareMessage { + joint_rand_seed: Some(Seed([100])), + }, + // Missing joint_rand_seed. + Prio3PrepareMessage { + joint_rand_seed: None, + }, + ]) + } + + #[test] + fn prepare_state_equality_test() { + equality_comparison_test(&[ + // Default. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([0])), + joint_rand_seed: Some(Seed([1])), + agg_id: 2, + verifier_len: 3, + }, + // Modified measurement share. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([100])), + joint_rand_seed: Some(Seed([1])), + agg_id: 2, + verifier_len: 3, + }, + // Modified joint_rand_seed. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([0])), + joint_rand_seed: Some(Seed([101])), + agg_id: 2, + verifier_len: 3, + }, + // Missing joint_rand_seed. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([0])), + joint_rand_seed: None, + agg_id: 2, + verifier_len: 3, + }, + // Modified agg_id. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([0])), + joint_rand_seed: Some(Seed([1])), + agg_id: 102, + verifier_len: 3, + }, + // Modified verifier_len. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([0])), + joint_rand_seed: Some(Seed([1])), + agg_id: 2, + verifier_len: 103, + }, + ]) + } } diff --git a/src/vdaf/xof.rs b/src/vdaf/xof.rs index 03c944368..1456c5588 100644 --- a/src/vdaf/xof.rs +++ b/src/vdaf/xof.rs @@ -407,7 +407,7 @@ impl SeedStreamFixedKeyAes128 { #[cfg(test)] mod tests { use super::*; - use crate::field::Field128; + use crate::{field::Field128, vdaf::equality_comparison_test}; use serde::{Deserialize, Serialize}; use std::{convert::TryInto, io::Cursor}; @@ -537,4 +537,9 @@ mod tests { assert_eq!(output_1_trait_api, output_1_alternate_api); assert_eq!(output_2_trait_api, output_2_alternate_api); } + + #[test] + fn seed_equality_test() { + equality_comparison_test(&[Seed([1, 2, 3]), Seed([3, 2, 1])]) + } } From 676609b8f456ff35d92900a2daf863637bdfb167 Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Fri, 8 Sep 2023 15:00:16 -0700 Subject: [PATCH 4/7] Code review. --- src/vdaf/poplar1.rs | 11 ++++++----- src/vdaf/prio3.rs | 32 +++++++++++++++++++++++--------- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/src/vdaf/poplar1.rs b/src/vdaf/poplar1.rs index 12f6c57df..46a014920 100644 --- a/src/vdaf/poplar1.rs +++ b/src/vdaf/poplar1.rs @@ -391,7 +391,8 @@ impl Eq for SketchState {} impl ConstantTimeEq for SketchState { fn ct_eq(&self, other: &Self) -> Choice { - // We allow short-circuiting on the round (RoundOne vs RoundTwo). + // We allow short-circuiting on the round (RoundOne vs RoundTwo), as well as is_leader for + // RoundOne comparisons. match (self, other) { ( SketchState::RoundOne { @@ -405,12 +406,12 @@ impl ConstantTimeEq for SketchState { is_leader: other_is_leader, }, ) => { - let self_is_leader = Choice::from(*self_is_leader as u8); - let other_is_leader = Choice::from(*other_is_leader as u8); - + if self_is_leader != other_is_leader { + return Choice::from(0); + } + self_a_share.ct_eq(other_a_share) & self_b_share.ct_eq(other_b_share) - & self_is_leader.ct_eq(&other_is_leader) } (SketchState::RoundTwo, SketchState::RoundTwo) => Choice::from(1), diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 3821029f0..90726e173 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -623,16 +623,23 @@ impl Encode for Prio3PublicShare { impl PartialEq for Prio3PublicShare { fn eq(&self, other: &Self) -> bool { - match (&self.joint_rand_parts, &other.joint_rand_parts) { - (Some(self_jrps), Some(other_jrps)) => self_jrps.ct_eq(other_jrps).into(), - (None, None) => true, - _ => false, - } + self.ct_eq(other).into() } } impl Eq for Prio3PublicShare {} +impl ConstantTimeEq for Prio3PublicShare { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the presence or absence of the joint_rand_parts. + match (&self.joint_rand_parts, &other.joint_rand_parts) { + (Some(self_joint_rand_parts), Some(other_joint_rand_parts)) => self_joint_rand_parts.ct_eq(other_joint_rand_parts).into(), + (None, None) => Choice::from(1), + _ => Choice::from(0), + } + } +} + impl ParameterizedDecode> for Prio3PublicShare where @@ -682,6 +689,7 @@ impl Eq for Prio3InputShare ConstantTimeEq for Prio3InputShare { fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the presence or absence of the joint_rand_blind. option_ct_eq(&self.joint_rand_blind, &other.joint_rand_blind) & self.measurement_share.ct_eq(&other.measurement_share) & self.proof_share.ct_eq(&other.proof_share) @@ -776,6 +784,7 @@ impl Eq for Prio3PrepareShare ConstantTimeEq for Prio3PrepareShare { fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the presence or absence of the joint_rand_part. option_ct_eq(&self.joint_rand_part, &other.joint_rand_part) & self.verifier.ct_eq(&other.verifier) } @@ -845,6 +854,7 @@ impl Eq for Prio3PrepareMessage {} impl ConstantTimeEq for Prio3PrepareMessage { fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the presnce or absence of the joint_rand_seed. option_ct_eq(&self.joint_rand_seed, &other.joint_rand_seed) } } @@ -918,10 +928,14 @@ impl Eq for Prio3PrepareState ConstantTimeEq for Prio3PrepareState { fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the presence or absence of the joint_rand_seed, as well as + // the aggregator ID & verifier length parameters. + if self.agg_id != other.agg_id || self.verifier_len != other.verifier_len { + return Choice::from(0); + } + option_ct_eq(&self.joint_rand_seed, &other.joint_rand_seed) & self.measurement_share.ct_eq(&other.measurement_share) - & self.agg_id.ct_eq(&other.agg_id) - & self.verifier_len.ct_eq(&other.verifier_len) } } @@ -1342,8 +1356,8 @@ where } // This function determines equality between two optional, constant-time comparable values. It -// short-circuits on the existence (but not contents) of the values -- timing information may reveal -// whether the values match on Some or None. +// short-circuits on the existence (but not contents) of the values -- a timing side-channel may +// reveal whether the values match on Some or None. fn option_ct_eq(left: &Option, right: &Option) -> Choice { match (left, right) { (Some(left), Some(right)) => left.ct_eq(right), From 7618ecf684d71941bb557b40f597a2cdb0131f45 Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Fri, 8 Sep 2023 15:10:50 -0700 Subject: [PATCH 5/7] One last review comment (oops). --- src/vdaf/prio3.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 90726e173..34f72e6aa 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -765,7 +765,6 @@ where #[derive(Clone, Debug)] /// Message broadcast by each [`Aggregator`] in each round of the Preparation phase. - pub struct Prio3PrepareShare { /// A share of the FLP verifier message. (See [`Type`].) verifier: Vec, From 00fef143d0f2458a441c6a3d376b0ff3be92cb7a Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Fri, 8 Sep 2023 15:33:30 -0700 Subject: [PATCH 6/7] Last, last review comment. Getting `option_ct_eq` to work in the last place it needed to required a few changes. --- src/vdaf/prio3.rs | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index 34f72e6aa..a26841150 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -632,11 +632,7 @@ impl Eq for Prio3PublicShare {} impl ConstantTimeEq for Prio3PublicShare { fn ct_eq(&self, other: &Self) -> Choice { // We allow short-circuiting on the presence or absence of the joint_rand_parts. - match (&self.joint_rand_parts, &other.joint_rand_parts) { - (Some(self_joint_rand_parts), Some(other_joint_rand_parts)) => self_joint_rand_parts.ct_eq(other_joint_rand_parts).into(), - (None, None) => Choice::from(1), - _ => Choice::from(0), - } + option_ct_eq(self.joint_rand_parts.as_deref(), other.joint_rand_parts.as_deref()) } } @@ -690,7 +686,7 @@ impl Eq for Prio3InputShare ConstantTimeEq for Prio3InputShare { fn ct_eq(&self, other: &Self) -> Choice { // We allow short-circuiting on the presence or absence of the joint_rand_blind. - option_ct_eq(&self.joint_rand_blind, &other.joint_rand_blind) + option_ct_eq(self.joint_rand_blind.as_ref(), other.joint_rand_blind.as_ref()) & self.measurement_share.ct_eq(&other.measurement_share) & self.proof_share.ct_eq(&other.proof_share) } @@ -784,7 +780,7 @@ impl Eq for Prio3PrepareShare ConstantTimeEq for Prio3PrepareShare { fn ct_eq(&self, other: &Self) -> Choice { // We allow short-circuiting on the presence or absence of the joint_rand_part. - option_ct_eq(&self.joint_rand_part, &other.joint_rand_part) + option_ct_eq(self.joint_rand_part.as_ref(), other.joint_rand_part.as_ref()) & self.verifier.ct_eq(&other.verifier) } } @@ -854,7 +850,7 @@ impl Eq for Prio3PrepareMessage {} impl ConstantTimeEq for Prio3PrepareMessage { fn ct_eq(&self, other: &Self) -> Choice { // We allow short-circuiting on the presnce or absence of the joint_rand_seed. - option_ct_eq(&self.joint_rand_seed, &other.joint_rand_seed) + option_ct_eq(self.joint_rand_seed.as_ref(), other.joint_rand_seed.as_ref()) } } @@ -933,7 +929,7 @@ impl ConstantTimeEq for Prio3PrepareS return Choice::from(0); } - option_ct_eq(&self.joint_rand_seed, &other.joint_rand_seed) + option_ct_eq(self.joint_rand_seed.as_ref(), other.joint_rand_seed.as_ref()) & self.measurement_share.ct_eq(&other.measurement_share) } } @@ -1357,7 +1353,8 @@ where // This function determines equality between two optional, constant-time comparable values. It // short-circuits on the existence (but not contents) of the values -- a timing side-channel may // reveal whether the values match on Some or None. -fn option_ct_eq(left: &Option, right: &Option) -> Choice { +#[inline] +fn option_ct_eq(left: Option<&T>, right: Option<&T>) -> Choice where T: ConstantTimeEq + ?Sized { match (left, right) { (Some(left), Some(right)) => left.ct_eq(right), (None, None) => Choice::from(1), From c1d5494a841d0a85276d00d09d21d639f652a047 Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Fri, 8 Sep 2023 15:42:18 -0700 Subject: [PATCH 7/7] cargo fmt --- src/vdaf/poplar1.rs | 5 ++--- src/vdaf/prio3.rs | 33 ++++++++++++++++++++++++--------- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/vdaf/poplar1.rs b/src/vdaf/poplar1.rs index 46a014920..7b09e5f53 100644 --- a/src/vdaf/poplar1.rs +++ b/src/vdaf/poplar1.rs @@ -409,9 +409,8 @@ impl ConstantTimeEq for SketchState { if self_is_leader != other_is_leader { return Choice::from(0); } - - self_a_share.ct_eq(other_a_share) - & self_b_share.ct_eq(other_b_share) + + self_a_share.ct_eq(other_a_share) & self_b_share.ct_eq(other_b_share) } (SketchState::RoundTwo, SketchState::RoundTwo) => Choice::from(1), diff --git a/src/vdaf/prio3.rs b/src/vdaf/prio3.rs index a26841150..7d0b107e2 100644 --- a/src/vdaf/prio3.rs +++ b/src/vdaf/prio3.rs @@ -632,7 +632,10 @@ impl Eq for Prio3PublicShare {} impl ConstantTimeEq for Prio3PublicShare { fn ct_eq(&self, other: &Self) -> Choice { // We allow short-circuiting on the presence or absence of the joint_rand_parts. - option_ct_eq(self.joint_rand_parts.as_deref(), other.joint_rand_parts.as_deref()) + option_ct_eq( + self.joint_rand_parts.as_deref(), + other.joint_rand_parts.as_deref(), + ) } } @@ -686,8 +689,10 @@ impl Eq for Prio3InputShare ConstantTimeEq for Prio3InputShare { fn ct_eq(&self, other: &Self) -> Choice { // We allow short-circuiting on the presence or absence of the joint_rand_blind. - option_ct_eq(self.joint_rand_blind.as_ref(), other.joint_rand_blind.as_ref()) - & self.measurement_share.ct_eq(&other.measurement_share) + option_ct_eq( + self.joint_rand_blind.as_ref(), + other.joint_rand_blind.as_ref(), + ) & self.measurement_share.ct_eq(&other.measurement_share) & self.proof_share.ct_eq(&other.proof_share) } } @@ -780,8 +785,10 @@ impl Eq for Prio3PrepareShare ConstantTimeEq for Prio3PrepareShare { fn ct_eq(&self, other: &Self) -> Choice { // We allow short-circuiting on the presence or absence of the joint_rand_part. - option_ct_eq(self.joint_rand_part.as_ref(), other.joint_rand_part.as_ref()) - & self.verifier.ct_eq(&other.verifier) + option_ct_eq( + self.joint_rand_part.as_ref(), + other.joint_rand_part.as_ref(), + ) & self.verifier.ct_eq(&other.verifier) } } @@ -850,7 +857,10 @@ impl Eq for Prio3PrepareMessage {} impl ConstantTimeEq for Prio3PrepareMessage { fn ct_eq(&self, other: &Self) -> Choice { // We allow short-circuiting on the presnce or absence of the joint_rand_seed. - option_ct_eq(self.joint_rand_seed.as_ref(), other.joint_rand_seed.as_ref()) + option_ct_eq( + self.joint_rand_seed.as_ref(), + other.joint_rand_seed.as_ref(), + ) } } @@ -929,8 +939,10 @@ impl ConstantTimeEq for Prio3PrepareS return Choice::from(0); } - option_ct_eq(self.joint_rand_seed.as_ref(), other.joint_rand_seed.as_ref()) - & self.measurement_share.ct_eq(&other.measurement_share) + option_ct_eq( + self.joint_rand_seed.as_ref(), + other.joint_rand_seed.as_ref(), + ) & self.measurement_share.ct_eq(&other.measurement_share) } } @@ -1354,7 +1366,10 @@ where // short-circuits on the existence (but not contents) of the values -- a timing side-channel may // reveal whether the values match on Some or None. #[inline] -fn option_ct_eq(left: Option<&T>, right: Option<&T>) -> Choice where T: ConstantTimeEq + ?Sized { +fn option_ct_eq(left: Option<&T>, right: Option<&T>) -> Choice +where + T: ConstantTimeEq + ?Sized, +{ match (left, right) { (Some(left), Some(right)) => left.ct_eq(right), (None, None) => Choice::from(1),