Skip to content

Commit

Permalink
vidpf: Remove eval_with_cache(), limit eval() to testing
Browse files Browse the repository at this point in the history
Mastic uses `eval_prefix_tree_with_siblings()`. This method caches the
prefix tree just like `eval_with_cache()` does, but it doesn't try to
compute the onehot proof. It also concatenates the weight shares into
the output shares for us.

The only other use case for `eval_with_cache()` is for computing the
shares of beta during sharding. Replace this code with a simpler
implementation and remove `eval_with_cache()`.

Finally, fence `eval()` by the `cfg(test)` attribute. It's useful for
testing, but its not part of the actual specification.
  • Loading branch information
cjpatton committed Dec 29, 2024
1 parent ce10c4e commit 013db1c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 196 deletions.
20 changes: 6 additions & 14 deletions src/vdaf/mastic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,20 +301,12 @@ where
// keys for the measurement and evaluating each of them.
let public_share = self.vidpf.gen_with_keys(&vidpf_keys, alpha, &beta, nonce)?;

let leader_beta_share = self.vidpf.eval_root(
VidpfServerId::S0,
&vidpf_keys[0],
&public_share,
&mut BinaryTree::default(),
nonce,
)?;
let helper_beta_share = self.vidpf.eval_root(
VidpfServerId::S1,
&vidpf_keys[1],
&public_share,
&mut BinaryTree::default(),
nonce,
)?;
let leader_beta_share =
self.vidpf
.get_beta_sahre(VidpfServerId::S0, &public_share, &vidpf_keys[0], nonce)?;
let helper_beta_share =
self.vidpf
.get_beta_sahre(VidpfServerId::S1, &public_share, &vidpf_keys[1], nonce)?;

let [leader_szk_proof_share, helper_szk_proof_share] = self.szk.prove(
&leader_beta_share.as_ref()[1..],
Expand Down
215 changes: 33 additions & 182 deletions src/vidpf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,20 @@ impl<W: VidpfValue> Vidpf<W> {
Ok(VidpfPublicShare { cw })
}

/// Evaluate a given VIDPF (comprised of the key and public share) at a given input.
pub fn eval(
/// Evaluate a given VIDPF (comprised of the key and public share) at a given prefix. Return
/// the weight for that prefix along with a hash of the node proofs along the path from the
/// root to the prefix.
#[cfg(test)]
pub(crate) fn eval(
&self,
id: VidpfServerId,
key: &VidpfKey,
public: &VidpfPublicShare<W>,
input: &VidpfInput,
nonce: &[u8],
) -> Result<(W, VidpfProof), VidpfError> {
use sha3::{Digest, Sha3_256};

let mut r = VidpfEvalResult {
state: VidpfEvalState::init_from_key(id, key),
share: W::zero(&self.weight_parameter), // not used
Expand All @@ -194,73 +199,15 @@ impl<W: VidpfValue> Vidpf<W> {
return Err(VidpfError::InvalidAttributeLength);
}

let mut onehot_proof = ONEHOT_PROOF_INIT;
let mut hash = Sha3_256::new();
for (idx, cw) in input.index_iter()?.zip(public.cw.iter()) {
r = self.eval_next(cw, idx, &r.state, nonce);
onehot_proof = xor_proof(
onehot_proof,
&Self::hash_proof(xor_proof(onehot_proof, &r.state.node_proof)),
);
hash.update(&r.state.node_proof);
}

let mut weight = r.share;
weight.conditional_negate(Choice::from(id));
Ok((weight, onehot_proof))
}

/// Evaluates the entire `input` and produces a share of the
/// input's weight. It reuses computation from previous levels available in the
/// cache.
pub(crate) fn eval_with_cache(
&self,
id: VidpfServerId,
key: &VidpfKey,
public: &VidpfPublicShare<W>,
input: &VidpfInput,
cache_tree: &mut BinaryTree<VidpfEvalResult<W>>,
nonce: &[u8],
) -> Result<(W, VidpfProof), VidpfError> {
if input.len() > public.cw.len() {
return Err(VidpfError::InvalidAttributeLength);
}

let mut sub_tree = cache_tree.root.get_or_insert_with(|| {
Box::new(Node::new(VidpfEvalResult {
state: VidpfEvalState::init_from_key(id, key),
share: W::zero(&self.weight_parameter), // not used
}))
});

let mut onehot_proof = ONEHOT_PROOF_INIT;
for (idx, cw) in input.index_iter()?.zip(public.cw.iter()) {
sub_tree = if idx.bit.unwrap_u8() == 0 {
sub_tree.left.get_or_insert_with(|| {
Box::new(Node::new(self.eval_next(
cw,
idx,
&sub_tree.value.state,
nonce,
)))
})
} else {
sub_tree.right.get_or_insert_with(|| {
Box::new(Node::new(self.eval_next(
cw,
idx,
&sub_tree.value.state,
nonce,
)))
})
};
onehot_proof = xor_proof(
onehot_proof,
&Self::hash_proof(xor_proof(onehot_proof, &sub_tree.value.state.node_proof)),
);
}

let mut weight = sub_tree.value.to_share();
weight.conditional_negate(Choice::from(id));
Ok((weight, onehot_proof))
Ok((weight, hash.finalize().into()))
}

/// Evaluates the `input` at the given level using the provided initial
Expand Down Expand Up @@ -311,32 +258,34 @@ impl<W: VidpfValue> Vidpf<W> {
}
}

pub(crate) fn eval_root(
pub(crate) fn get_beta_sahre(
&self,
id: VidpfServerId,
public: &VidpfPublicShare<W>,
key: &VidpfKey,
public_share: &VidpfPublicShare<W>,
cache_tree: &mut BinaryTree<VidpfEvalResult<W>>,
nonce: &[u8],
) -> Result<W, VidpfError> {
let (weight_share_left, _onehot_proof_left) = self.eval_with_cache(
id,
key,
public_share,
&VidpfInput::from_bools(&[false]),
cache_tree,
nonce,
)?;

let (weight_share_right, _onehot_proof_right) = self.eval_with_cache(
id,
key,
public_share,
&VidpfInput::from_bools(&[true]),
cache_tree,
nonce,
)?;
let cw = public.cw.first().ok_or_else(|| VidpfError::InputTooLong)?;

let state = VidpfEvalState::init_from_key(id, key);
let idx_left = VidpfEvalIndex {
bit: Choice::from(0),
input: &VidpfInput::from_bools(&[false]),
level: 0,
};

let VidpfEvalResult {
state: _,
share: mut weight_share_left,
} = self.eval_next(cw, idx_left, &state, nonce);

let VidpfEvalResult {
state: _,
share: mut weight_share_right,
} = self.eval_next(cw, idx_left.right_sibling(), &state, nonce);

weight_share_left.conditional_negate(Choice::from(id));
weight_share_right.conditional_negate(Choice::from(id));
Ok(weight_share_left + weight_share_right)
}

Expand Down Expand Up @@ -624,12 +573,6 @@ pub(crate) struct VidpfEvalResult<W: VidpfValue> {
pub(crate) share: W,
}

impl<W: VidpfValue> VidpfEvalResult<W> {
fn to_share(&self) -> W {
self.share.clone()
}
}

const VIDPF_PROOF_SIZE: usize = 32;
const VIDPF_SEED_SIZE: usize = 16;

Expand Down Expand Up @@ -873,13 +816,9 @@ mod tests {

mod vidpf {
use crate::{
bt::BinaryTree,
codec::{Encode, ParameterizedDecode},
idpf::IdpfValue,
vidpf::{
Vidpf, VidpfEvalResult, VidpfEvalState, VidpfInput, VidpfKey, VidpfPublicShare,
VidpfServerId,
},
vidpf::{Vidpf, VidpfEvalState, VidpfInput, VidpfKey, VidpfPublicShare, VidpfServerId},
};

use super::{TestWeight, TEST_NONCE, TEST_NONCE_SIZE, TEST_WEIGHT_LEN};
Expand Down Expand Up @@ -1000,94 +939,6 @@ mod tests {
state_1 = r1.state;
}
}

#[test]
fn caching_at_each_level() {
let input = VidpfInput::from_bytes(&[0xFF]);
let weight = TestWeight::from(vec![21.into(), 22.into(), 23.into()]);
let (vidpf, public, keys, nonce) = vidpf_gen_setup(&input, &weight);

test_equivalence_of_eval_with_caching(&vidpf, &keys, &public, &input, &nonce);
}

/// Ensures that VIDPF outputs match regardless of whether the path to
/// each node is recomputed or cached during evaluation.
fn test_equivalence_of_eval_with_caching(
vidpf: &Vidpf<TestWeight>,
[key_0, key_1]: &[VidpfKey; 2],
public: &VidpfPublicShare<TestWeight>,
input: &VidpfInput,
nonce: &[u8],
) {
let mut cache_tree_0 = BinaryTree::<VidpfEvalResult<TestWeight>>::default();
let mut cache_tree_1 = BinaryTree::<VidpfEvalResult<TestWeight>>::default();

let n = input.len();
for level in 0..n {
let val_share_0 = vidpf
.eval(
VidpfServerId::S0,
key_0,
public,
&input.prefix(level),
nonce,
)
.unwrap();
let val_share_1 = vidpf
.eval(
VidpfServerId::S1,
key_1,
public,
&input.prefix(level),
nonce,
)
.unwrap();
let val_share_0_cached = vidpf
.eval_with_cache(
VidpfServerId::S0,
key_0,
public,
&input.prefix(level),
&mut cache_tree_0,
nonce,
)
.unwrap();
let val_share_1_cached = vidpf
.eval_with_cache(
VidpfServerId::S1,
key_1,
public,
&input.prefix(level),
&mut cache_tree_1,
nonce,
)
.unwrap();

assert_eq!(
val_share_0, val_share_0_cached,
"shares must be computed equally with or without caching: {:?}",
level
);

assert_eq!(
val_share_1, val_share_1_cached,
"shares must be computed equally with or without caching: {:?}",
level
);

assert_eq!(
val_share_0, val_share_0_cached,
"proofs must be equal with or without caching: {:?}",
level
);

assert_eq!(
val_share_1, val_share_1_cached,
"proofs must be equal with or without caching: {:?}",
level
);
}
}
}

mod weight {
Expand Down

0 comments on commit 013db1c

Please sign in to comment.