diff --git a/fastcrypto-tbls/src/dkg.rs b/fastcrypto-tbls/src/dkg.rs index 694d146253..087c63f682 100644 --- a/fastcrypto-tbls/src/dkg.rs +++ b/fastcrypto-tbls/src/dkg.rs @@ -16,8 +16,9 @@ use crate::types::ShareIndex; use fastcrypto::error::{FastCryptoError, FastCryptoResult}; use fastcrypto::groups::{FiatShamirChallenge, GroupElement, MultiScalarMul}; use fastcrypto::traits::AllowedRng; +use itertools::Itertools; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; /// Generics below use `G: GroupElement' for the group of the VSS public key, and `EG: GroupElement' /// for the group of the ECIES public key. @@ -64,6 +65,7 @@ pub struct Confirmation { pub complaints: Vec>, } +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct ProcessedMessage { message: Message, shares: Vec>, //possibly empty @@ -73,13 +75,49 @@ pub struct ProcessedMessage { /// Mapping from node id to the shares received from that sender. pub type SharesMap = HashMap>>; +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct UsedProcessedMessages( + pub Vec>, +); + +impl From<&[ProcessedMessage]> + for UsedProcessedMessages +{ + fn from(msgs: &[ProcessedMessage]) -> Self { + let filtered = msgs + .iter() + .unique_by(|&m| m.message.sender) + .cloned() + .collect::>(); + Self(filtered) + } +} + +pub struct VerifiedProcessedMessages( + pub Vec>, +); + +impl VerifiedProcessedMessages { + fn filter_from(msgs: &UsedProcessedMessages, to_exclude: &[PartyId]) -> Self { + let filtered = msgs + .0 + .iter() + .filter(|m| !to_exclude.contains(&m.message.sender)) + .cloned() + .collect::>(); + Self(filtered) + } +} + /// [Output] is the final output of the DKG protocol in case it runs /// successfully. It can be used later with [ThresholdBls], see examples in tests. +/// +/// If shares is None, the object can only be used for verifying (partial and full) signatures. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Output { pub nodes: Nodes, pub vss_pk: Poly, - pub shares: Vec>, + pub shares: Option>>, // None if some shares are missing. } /// A dealer in the DKG ceremony. @@ -157,21 +195,28 @@ where } fn sanity_check_message(&self, msg: &Message) -> FastCryptoResult<()> { - self.nodes.node_id_to_node(msg.sender)?; + self.nodes + .node_id_to_node(msg.sender) + .map_err(|_| FastCryptoError::InvalidMessage)?; if self.t != msg.vss_pk.degree() + 1 { - return Err(FastCryptoError::InvalidInput); + return Err(FastCryptoError::InvalidMessage); } if self.nodes.num_nodes() != msg.encrypted_shares.len() { - return Err(FastCryptoError::InvalidInput); + return Err(FastCryptoError::InvalidMessage); } msg.encrypted_shares - .verify_knowledge(&self.random_oracle.extend(&format!("encs {}", msg.sender)))?; + .verify_knowledge(&self.random_oracle.extend(&format!("encs {}", msg.sender))) + .map_err(|_| FastCryptoError::InvalidMessage)?; Ok(()) } /// 5. Process a message and create the second message to be broadcasted. /// The second message contains the list of complaints on invalid shares. In addition, it /// returns a set of valid shares (so far). + /// + /// Returns error InvalidMessage if the message is invalid and should be ignored (note that we + /// could count it as part of the f+1 messages we wait for, but it's also safe to ignore it + /// and just wait for f+1 valid messages). pub fn process_message( &self, message: Message, @@ -242,64 +287,56 @@ where }) } - /// 6. Merge results from multiple process_message calls so only one message needs to be sent. - /// Returns InputTooShort if the threshold t is not met. + /// 6. Merge results from multiple ProcessedMessages so only one message needs to be sent. + /// Returns NotEnoughInputs if the threshold t is not met. pub fn merge( &self, processed_messages: &[ProcessedMessage], - ) -> FastCryptoResult<(SharesMap, Confirmation)> { - // Enforce unique senders - let processed_messages = processed_messages - .iter() - .map(|m| (m.message.sender, m)) - .collect::>(); + ) -> FastCryptoResult<(Confirmation, UsedProcessedMessages)> { + let filtered_messages = UsedProcessedMessages::from(processed_messages); // Verify we have enough messages - let total_weight = processed_messages - .keys() - .map(|sender| { + let total_weight = filtered_messages + .0 + .iter() + .map(|m| { self.nodes - .node_id_to_node(*sender) + .node_id_to_node(m.message.sender) .expect("checked in process_message") .weight as u32 }) .sum::(); if total_weight < self.t { - return Err(FastCryptoError::InputTooShort(self.t as usize)); + return Err(FastCryptoError::NotEnoughInputs); } - let mut shares = HashMap::new(); let mut conf = Confirmation { sender: self.id, complaints: Vec::new(), }; - for m in processed_messages.values() { - shares.insert(m.message.sender, m.shares.clone()); + for m in &filtered_messages.0 { if m.complaint.is_some() { - let complaint = m.complaint.clone().unwrap(); + let complaint = m.complaint.clone().expect("checked above"); conf.complaints.push(complaint); } } - Ok((shares, conf)) + Ok((conf, filtered_messages)) } - // TODO: Handle the case of not having enough valid shares gracefully (e.g., - // process_confirmations without my complaint). - /// 7. Process all confirmations, check all complaints, and update the local set of /// valid shares accordingly. /// /// minimal_threshold is the minimal number of second round messages we expect. Its value is /// application dependent but in most cases it should be at least t+f to guarantee that at /// least t honest nodes have valid shares. - /// Returns InputTooShort if the threshold minimal_threshold is not met. - pub fn process_confirmations( + /// + /// Returns NotEnoughInputs if the threshold minimal_threshold is not met. + pub(crate) fn process_confirmations( &self, - messages: &[Message], + messages: &UsedProcessedMessages, confirmations: &[Confirmation], - shares: SharesMap, minimal_threshold: u32, rng: &mut R, - ) -> Result, FastCryptoError> { + ) -> FastCryptoResult> { if minimal_threshold < self.t { return Err(FastCryptoError::InvalidInput); } @@ -307,6 +344,7 @@ where let confirmations = confirmations .iter() .filter(|c| self.nodes.node_id_to_node(c.sender).is_ok()) + .unique_by(|m| m.sender) .collect::>(); // Verify we have enough confirmations let total_weight = confirmations @@ -319,23 +357,25 @@ where }) .sum::(); if total_weight < minimal_threshold { - return Err(FastCryptoError::InputTooShort(minimal_threshold as usize)); + return Err(FastCryptoError::NotEnoughInputs); } // Two hash maps for faster access in the main loop below. - let id_to_pk: HashMap> = - self.nodes.iter().map(|n| (n.id, &n.pk)).collect(); - let id_to_m1: HashMap> = - messages.iter().map(|m| (m.sender, m)).collect(); + let id_to_pk = self + .nodes + .iter() + .map(|n| (n.id, &n.pk)) + .collect::>(); + let id_to_m1 = messages + .0 + .iter() + .map(|m| (m.message.sender, &m.message)) + .collect::>(); - let mut shares = shares; + let mut to_exclude = HashSet::new(); 'outer: for m2 in confirmations { - 'inner: for complaint in &m2.complaints[..] { + 'inner: for complaint in &m2.complaints { let accused = complaint.accused_sender; - // Ignore senders that are already not relevant, or invalid complaints. - if !shares.contains_key(&accused) { - continue 'inner; - } let accuser = m2.sender; let accuser_pk = id_to_pk .get(&accuser) @@ -347,7 +387,7 @@ where .expect("checked above that is not None") .encrypted_shares .get_encryption(accuser as usize) - .expect("checked above that there are enough encryptions"); + .expect("checked earlier that there are enough encryptions"); Self::check_delegated_key_and_share( &complaint.proof, accuser_pk, @@ -365,29 +405,34 @@ where // Ignore accused from now on, and continue processing complaints from the // current accuser. true => { - shares.remove(&accused); + to_exclude.insert(accused); continue 'inner; } // Ignore the accuser from now on, including its other complaints (not critical // for security, just saves some work). false => { - shares.remove(&accuser); + to_exclude.insert(accuser); continue 'outer; } } } } - Ok(shares) + let verified_messages = VerifiedProcessedMessages::filter_from( + messages, + &to_exclude.into_iter().collect::>(), + ); + + Ok(verified_messages) } /// 8. Aggregate the valid shares (as returned from the previous step) and the public key. - pub fn aggregate( - &self, - first_messages: &[Message], - shares: SharesMap, - ) -> Output { - let id_to_m1: HashMap<_, _> = first_messages.iter().map(|m| (m.sender, m)).collect(); + pub(crate) fn aggregate(&self, messages: &VerifiedProcessedMessages) -> Output { + let id_to_m1 = messages + .0 + .iter() + .map(|m| (m.message.sender, &m.message)) + .collect::>(); let mut vss_pk = PublicPoly::::zero(); let my_share_ids = self.nodes.share_ids_of(self.id); @@ -404,14 +449,14 @@ where }) .collect::>(); - for (from_sender, shares_from_sender) in shares { + for m in &messages.0 { vss_pk.add( &id_to_m1 - .get(&from_sender) + .get(&m.message.sender) .expect("shares only includes shares from valid first messages") .vss_pk, ); - for share in shares_from_sender { + for share in &m.shares { final_shares .get_mut(&share.index) .expect("created above") @@ -419,13 +464,34 @@ where } } + // If I didn't receive a valid share for one of the verified messages (i.e., my complaint + // was not processed), then I don't have a valid share for the final key. + let shares = if messages.0.iter().all(|m| m.complaint.is_none()) { + Some(final_shares.values().cloned().collect()) + } else { + None + }; + Output { nodes: self.nodes.clone(), vss_pk, - shares: final_shares.values().cloned().collect(), + shares, } } + /// Execute the previous two steps together. + pub fn complete( + &self, + messages: &UsedProcessedMessages, + confirmations: &[Confirmation], + minimal_threshold: u32, + rng: &mut R, + ) -> FastCryptoResult> { + let verified_messages = + self.process_confirmations(messages, confirmations, minimal_threshold, rng)?; + Ok(self.aggregate(&verified_messages)) + } + fn decrypt_and_get_share( sk: &ecies::PrivateKey, encrypted_shares: &ecies::Encryption, diff --git a/fastcrypto-tbls/src/dl_verification.rs b/fastcrypto-tbls/src/dl_verification.rs index fd4550a3c5..7bac24e54c 100644 --- a/fastcrypto-tbls/src/dl_verification.rs +++ b/fastcrypto-tbls/src/dl_verification.rs @@ -176,6 +176,5 @@ pub fn verify_equal_exponents( } pub fn get_random_scalars(n: u32, rng: &mut R) -> Vec { - // TODO: can use 40 bits instead of 64 ("& 0x000F_FFFF_FFFF_FFFF" below) (0..n).map(|_| S::from(rng.next_u64())).collect::>() } diff --git a/fastcrypto-tbls/src/ecies.rs b/fastcrypto-tbls/src/ecies.rs index cc679bedbd..bed75accd5 100644 --- a/fastcrypto-tbls/src/ecies.rs +++ b/fastcrypto-tbls/src/ecies.rs @@ -22,6 +22,7 @@ use typenum::consts::{U16, U32}; /// APIs that use a random oracle must receive one as an argument. That RO must be unique and thus /// the caller should initialize/derive it using a unique prefix. +// TODO: Use ZeroizeOnDrop. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct PrivateKey(G::ScalarType); diff --git a/fastcrypto-tbls/src/nodes.rs b/fastcrypto-tbls/src/nodes.rs index c536f5405e..f504723557 100644 --- a/fastcrypto-tbls/src/nodes.rs +++ b/fastcrypto-tbls/src/nodes.rs @@ -7,6 +7,7 @@ use fastcrypto::error::FastCryptoError::InvalidInput; use fastcrypto::error::{FastCryptoError, FastCryptoResult}; use fastcrypto::groups::GroupElement; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::iter::Map; use std::ops::RangeInclusive; @@ -23,6 +24,7 @@ pub struct Node { pub struct Nodes { nodes: Vec>, n: u32, // share ids are 1..n + share_id_to_party_id: HashMap, } impl Nodes { @@ -36,7 +38,27 @@ impl Nodes { } // Get the total weight of the nodes let n = nodes.iter().map(|n| n.weight as u32).sum::(); - Ok(Self { nodes, n }) + + let share_id_to_party_id = Self::get_share_id_to_party_id(&nodes); + + Ok(Self { + nodes, + n, + share_id_to_party_id, + }) + } + + fn get_share_id_to_party_id(nodes: &Vec>) -> HashMap { + let mut curr_share_id = 1; + let mut share_id_to_party_id = HashMap::new(); + for n in nodes { + for _ in 1..=n.weight { + let share_id = ShareIndex::new(curr_share_id).expect("nonzero"); + share_id_to_party_id.insert(share_id, n.id); + curr_share_id += 1; + } + } + share_id_to_party_id } /// Total weight of the nodes. @@ -56,16 +78,10 @@ impl Nodes { /// Get the node corresponding to a share id. pub fn share_id_to_node(&self, share_id: &ShareIndex) -> FastCryptoResult<&Node> { - // TODO: [perf opt] Cache this - let mut curr_share_id = 1; - for n in &self.nodes { - if curr_share_id <= share_id.get() && share_id.get() < curr_share_id + (n.weight as u32) - { - return Ok(n); - } - curr_share_id += n.weight as u32; - } - Err(FastCryptoError::InvalidInput) + self.share_id_to_party_id + .get(share_id) + .map(|id| self.node_id_to_node(*id)) + .ok_or(FastCryptoError::InvalidInput)? } pub fn node_id_to_node(&self, party_id: PartyId) -> FastCryptoResult<&Node> { @@ -113,8 +129,16 @@ impl Nodes { weight: n.weight / max_d, }) .collect::>(); + let share_id_to_party_id = Self::get_share_id_to_party_id(&nodes); let n = nodes.iter().map(|n| n.weight as u32).sum::(); let new_t = t / max_d + (t % max_d != 0) as u16; - (Self { nodes, n }, new_t) + ( + Self { + nodes, + n, + share_id_to_party_id, + }, + new_t, + ) } } diff --git a/fastcrypto-tbls/src/polynomial.rs b/fastcrypto-tbls/src/polynomial.rs index 710bbc815c..49dbe49126 100644 --- a/fastcrypto-tbls/src/polynomial.rs +++ b/fastcrypto-tbls/src/polynomial.rs @@ -86,8 +86,8 @@ impl Poly { t: u32, shares: &[Eval], ) -> FastCryptoResult> { - if shares.len() < t.try_into().unwrap() { - return Err(FastCryptoError::InvalidInput); + if shares.len() < t as usize { + return Err(FastCryptoError::NotEnoughInputs); } // Check for duplicates. let mut ids_set = HashSet::new(); @@ -95,7 +95,7 @@ impl Poly { ids_set.insert(id); }); if ids_set.len() != shares.len() { - return Err(FastCryptoError::InvalidInput); + return Err(FastCryptoError::InvalidInput); // expected unique ids } let indices = shares diff --git a/fastcrypto-tbls/src/tbls.rs b/fastcrypto-tbls/src/tbls.rs index 307e0ab89a..3602617d92 100644 --- a/fastcrypto-tbls/src/tbls.rs +++ b/fastcrypto-tbls/src/tbls.rs @@ -7,9 +7,10 @@ use crate::dl_verification::{batch_coefficients, get_random_scalars}; use crate::polynomial::Poly; use crate::types::IndexedValue; -use fastcrypto::error::FastCryptoError; +use fastcrypto::error::{FastCryptoError, FastCryptoResult}; use fastcrypto::groups::{GroupElement, HashToGroupElement, MultiScalarMul, Scalar}; use fastcrypto::traits::AllowedRng; +use itertools::Itertools; pub type Share = IndexedValue; pub type PartialSignature = IndexedValue; @@ -99,9 +100,15 @@ pub trait ThresholdBls { fn aggregate( threshold: u32, partials: &[PartialSignature], - ) -> Result { + ) -> FastCryptoResult { + let unique_partials = partials + .iter() + .unique_by(|p| p.index) + .take(threshold as usize) + .cloned() + .collect::>(); // No conversion is required since PartialSignature and Eval are different aliases to // IndexedValue. - Poly::::recover_c0_msm(threshold, partials) + Poly::::recover_c0_msm(threshold, &unique_partials) } } diff --git a/fastcrypto-tbls/src/tests/dkg_tests.rs b/fastcrypto-tbls/src/tests/dkg_tests.rs index e570b6381f..972da07c4d 100644 --- a/fastcrypto-tbls/src/tests/dkg_tests.rs +++ b/fastcrypto-tbls/src/tests/dkg_tests.rs @@ -7,6 +7,7 @@ use crate::nodes::{Node, Nodes, PartyId}; use crate::random_oracle::RandomOracle; use crate::tbls::ThresholdBls; use crate::types::ThresholdBls12381MinSig; +use fastcrypto::error::FastCryptoError; use fastcrypto::groups::bls12381::G2Element; use fastcrypto::groups::ristretto255::RistrettoPoint; use rand::thread_rng; @@ -42,7 +43,7 @@ fn setup_party( Party::::new( keys.get(id as usize).unwrap().1.clone(), Nodes::new(nodes).unwrap(), - (keys.len() / 2) as u32, + 3, RandomOracle::new("dkg"), &mut thread_rng(), ) @@ -71,36 +72,34 @@ fn test_dkg_e2e_4_parties_threshold_2() { msg1.encrypted_shares.swap_for_testing(0, 1); // Don't send the message of d3 to d0 (emulating a slow party). let _msg3 = d3.create_message(&mut thread_rng()); + + assert_eq!( + d0.merge(&[d0.process_message(msg0.clone(), &mut thread_rng()).unwrap()]) + .err(), + Some(FastCryptoError::NotEnoughInputs) + ); + let r1_all = vec![msg0, msg1]; - let (shares0, conf0) = d0 - .merge( - &r1_all - .iter() - .map(|m| d0.process_message(m.clone(), &mut thread_rng()).unwrap()) - .collect::>(), - ) - .unwrap(); + let proc_msg0 = &r1_all + .iter() + .map(|m| d0.process_message(m.clone(), &mut thread_rng()).unwrap()) + .collect::>(); + let (conf0, used_msgs0) = d0.merge(proc_msg0).unwrap(); - let (shares1, conf1) = d1 - .merge( - &r1_all - .iter() - .map(|m| d1.process_message(m.clone(), &mut thread_rng()).unwrap()) - .collect::>(), - ) - .unwrap(); + let proc_msg1 = &r1_all + .iter() + .map(|m| d1.process_message(m.clone(), &mut thread_rng()).unwrap()) + .collect::>(); + let (conf1, used_msgs1) = d1.merge(proc_msg1).unwrap(); // Note that d3's first round message is not included but it should still be able to receive // shares and post complaints. - let (shares3, conf3) = d3 - .merge( - &r1_all - .iter() - .map(|m| d3.process_message(m.clone(), &mut thread_rng()).unwrap()) - .collect::>(), - ) - .unwrap(); + let proc_msg3 = &r1_all + .iter() + .map(|m| d3.process_message(m.clone(), &mut thread_rng()).unwrap()) + .collect::>(); + let (conf3, used_msgs3) = d3.merge(proc_msg3).unwrap(); // There should be some complaints on the first messages of d1. assert!( @@ -114,33 +113,30 @@ fn test_dkg_e2e_4_parties_threshold_2() { ); let r2_all = vec![conf0, conf1, conf3]; - let shares0 = d1 - .process_confirmations(&r1_all, &r2_all, shares0, 3, &mut thread_rng()) + let ver_msg0 = d1 + .process_confirmations(&used_msgs0, &r2_all, 3, &mut thread_rng()) .unwrap(); - let shares1 = d1 - .process_confirmations(&r1_all, &r2_all, shares1, 3, &mut thread_rng()) + let ver_msg1 = d1 + .process_confirmations(&used_msgs1, &r2_all, 3, &mut thread_rng()) .unwrap(); - let shares3 = d3 - .process_confirmations(&r1_all, &r2_all, shares3, 3, &mut thread_rng()) + let ver_msg3 = d3 + .process_confirmations(&used_msgs3, &r2_all, 3, &mut thread_rng()) .unwrap(); - // Only the first message of d0 passed all tests -> only one vss is used. - assert_eq!(shares0.len(), 1); - assert_eq!(shares1.len(), 1); - assert_eq!(shares3.len(), 1); - - let o0 = d0.aggregate(&r1_all, shares0); - let _o1 = d1.aggregate(&r1_all, shares1); - let o3 = d3.aggregate(&r1_all, shares3); + let o0 = d0.aggregate(&ver_msg0); + let _o1 = d1.aggregate(&ver_msg1); + let o3 = d3.aggregate(&ver_msg3); // Use the shares from 01 and o4 to sign a message. - let sig0 = S::partial_sign(&o0.shares[0], &MSG); - let sig3 = S::partial_sign(&o3.shares[0], &MSG); + let sig00 = S::partial_sign(&o0.shares.as_ref().unwrap()[0], &MSG); + let sig30 = S::partial_sign(&o3.shares.as_ref().unwrap()[0], &MSG); + let sig31 = S::partial_sign(&o3.shares.as_ref().unwrap()[1], &MSG); - S::partial_verify(&o0.vss_pk, &MSG, &sig0).unwrap(); - S::partial_verify(&o3.vss_pk, &MSG, &sig3).unwrap(); + S::partial_verify(&o0.vss_pk, &MSG, &sig00).unwrap(); + S::partial_verify(&o3.vss_pk, &MSG, &sig30).unwrap(); + S::partial_verify(&o3.vss_pk, &MSG, &sig31).unwrap(); - let sigs = vec![sig0, sig3]; + let sigs = vec![sig00, sig30, sig31]; let sig = S::aggregate(d0.t(), &sigs).unwrap(); S::verify(o0.vss_pk.c0(), &MSG, &sig).unwrap(); } diff --git a/fastcrypto/src/error.rs b/fastcrypto/src/error.rs index f1bbf49cc6..18b9b42ef7 100644 --- a/fastcrypto/src/error.rs +++ b/fastcrypto/src/error.rs @@ -40,6 +40,14 @@ pub enum FastCryptoError { #[error("Invalid proof was given to the function")] InvalidProof, + /// Not enough inputs were given to the function, retry with more + #[error("Not enough inputs were given to the function, retry with more")] + NotEnoughInputs, + + /// Invalid message was given to the function + #[error("Invalid message was given to the function")] + InvalidMessage, + /// General cryptographic error. #[error("General cryptographic error: {0}")] GeneralError(String),