From 3ae45d646d01ef10990a5a7536dc5d5d879087b4 Mon Sep 17 00:00:00 2001 From: benr-ml <112846738+benr-ml@users.noreply.github.com> Date: Wed, 13 Dec 2023 14:40:56 +0200 Subject: [PATCH] DKG: support zero weights, and handle "impossible" edge cases more gracefully (#708) --- fastcrypto-tbls/src/dkg.rs | 65 ++++++++++++++++----- fastcrypto-tbls/src/nodes.rs | 48 ++++++++++------ fastcrypto-tbls/src/tests/dkg_tests.rs | 42 +++++++++++--- fastcrypto-tbls/src/tests/nodes_tests.rs | 72 ++++++++++++++++++++++-- 4 files changed, 183 insertions(+), 44 deletions(-) diff --git a/fastcrypto-tbls/src/dkg.rs b/fastcrypto-tbls/src/dkg.rs index e237425707..a532465d04 100644 --- a/fastcrypto-tbls/src/dkg.rs +++ b/fastcrypto-tbls/src/dkg.rs @@ -19,7 +19,7 @@ use fastcrypto::traits::AllowedRng; use itertools::Itertools; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; -use tracing::{debug, info, warn}; +use tracing::{debug, error, info, warn}; use tap::prelude::*; @@ -100,7 +100,7 @@ impl From<&[ProcessedMessage]> /// Processed messages that were not excluded. pub struct VerifiedProcessedMessages( - pub Vec>, + Vec>, ); impl VerifiedProcessedMessages { @@ -113,6 +113,18 @@ impl VerifiedProcessedMessages { .collect::>(); Self(filtered) } + + pub fn len(&self) -> usize { + self.0.len() + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn data(&self) -> &[ProcessedMessage] { + &self.0 + } } /// [Output] is the final output of the DKG protocol in case it runs @@ -123,11 +135,9 @@ impl VerifiedProcessedMessages { pub struct Output { pub nodes: Nodes, pub vss_pk: Poly, - pub shares: Option>>, // None if some shares are missing. + pub shares: Option>>, // None if some shares are missing or weight is zero. } -// TODO: Handle parties with zero weights (currently rejected by Nodes::new()). - /// A dealer in the DKG ceremony. /// /// Can be instantiated with G1Curve or G2Curve. @@ -149,11 +159,10 @@ where ) -> FastCryptoResult { // Check that my ecies pk is in the nodes. let enc_pk = ecies::PublicKey::::from_private_key(&enc_sk); - let my_id = nodes + let my_node = nodes .iter() .find(|n| n.pk == enc_pk) - .ok_or(FastCryptoError::InvalidInput)? - .id; + .ok_or(FastCryptoError::InvalidInput)?; // Check that the threshold makes sense. if t >= nodes.total_weight() || t == 0 { return Err(FastCryptoError::InvalidInput); @@ -164,9 +173,11 @@ where // TODO: remove once the protocol is stable since it's a non negligible computation. let vss_pk = vss_sk.commit::(); + info!( - "DKG: Creating party {}, nodes hash {:?}, t {}, n {}, ro {:?}, enc pk {:?}, vss pk c0 {:?}", - my_id, + "DKG: Creating party {} with weight {}, nodes hash {:?}, t {}, n {}, ro {:?}, enc pk {:?}, vss pk c0 {:?}", + my_node.id, + my_node.weight, nodes.hash(), t, nodes.total_weight(), @@ -176,7 +187,7 @@ where ); Ok(Self { - id: my_id, + id: my_node.id, nodes, t, random_oracle, @@ -210,6 +221,7 @@ where .iter() .map(|share_id| self.vss_sk.eval(*share_id).value) .collect::>(); + // Works even with empty shares_ids (will result in [0]). let buff = bcs::to_bytes(&shares).expect("serialize of shares should never fail"); (node.pk.clone(), buff) }) @@ -429,6 +441,14 @@ where conf.complaints.push(complaint.clone()); } } + + if filtered_messages.0.iter().all(|m| m.complaint.is_some()) { + error!("DKG: All processed messages resulted in complaints, this should never happen"); + return Err(FastCryptoError::GeneralError( + "All processed messages resulted in complaints".to_string(), + )); + } + Ok((conf, filtered_messages)) } @@ -542,6 +562,15 @@ where &to_exclude.into_iter().collect::>(), ); + if verified_messages.is_empty() { + error!( + "DKG: No verified messages after processing complaints, this should never happen" + ); + return Err(FastCryptoError::GeneralError( + "No verified messages after processing complaints".to_string(), + )); + } + // Log verified messages parties. let used_parties = verified_messages .0 @@ -601,11 +630,19 @@ where // If I didn't receive a valid share for one of the verified messages (i.e., my complaint // was not processed), then I don't have a valid share for the final key. - let shares = if messages.0.iter().all(|m| m.complaint.is_none()) { - info!("DKG: Aggregating my shares succeeded"); + let has_invalid_share = messages.0.iter().any(|m| m.complaint.is_some()); + let has_zero_shares = final_shares.is_empty(); + info!( + "DKG: Aggregating my shares completed with has_invalid_share={}, has_zero_shares={}", + has_invalid_share, has_zero_shares + ); + if has_invalid_share { + warn!("DKG: Aggregating my shares failed"); + } + + let shares = if !has_invalid_share && !has_zero_shares { Some(final_shares.values().cloned().collect()) } else { - warn!("DKG: Aggregating my shares failed"); None }; diff --git a/fastcrypto-tbls/src/nodes.rs b/fastcrypto-tbls/src/nodes.rs index 213536bd7a..b2e0f1c909 100644 --- a/fastcrypto-tbls/src/nodes.rs +++ b/fastcrypto-tbls/src/nodes.rs @@ -16,15 +16,17 @@ pub type PartyId = u16; pub struct Node { pub id: PartyId, pub pk: ecies::PublicKey, - pub weight: u16, + pub weight: u16, // May be zero } /// Wrapper for a set of nodes. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct Nodes { - nodes: Vec>, // Party ids are 0..len(nodes)-1 - total_weight: u32, // Share ids are 1..total_weight + nodes: Vec>, // Party ids are 0..len(nodes)-1 (inclusive) + total_weight: u32, // Share ids are 1..total_weight (inclusive) + // Next two fields are used to map share ids to party ids. accumulated_weights: Vec, // Accumulated sum of all nodes' weights. Used to map share ids to party ids. + nodes_with_nonzero_weight: Vec, // Indexes of nodes with non-zero weight } impl Nodes { @@ -40,13 +42,11 @@ impl Nodes { if nodes.is_empty() || nodes.len() > 1000 { return Err(FastCryptoError::InvalidInput); } - // Check that all weights are non-zero - if nodes.iter().any(|n| n.weight == 0) { - return Err(FastCryptoError::InvalidInput); - } - // We use accumulated weights to map share ids to party ids. + // We use the next two to map share ids to party ids. let accumulated_weights = Self::get_accumulated_weights(&nodes); + let nodes_with_nonzero_weight = Self::filter_nonzero_weights(&nodes); + let total_weight = *accumulated_weights .last() .expect("Number of nodes is non-zero"); @@ -55,13 +55,28 @@ impl Nodes { nodes, total_weight, accumulated_weights, + nodes_with_nonzero_weight, }) } + fn filter_nonzero_weights(nodes: &[Node]) -> Vec { + nodes + .iter() + .enumerate() + .filter_map(|(i, n)| if n.weight > 0 { Some(i as u16) } else { None }) + .collect::>() + } + fn get_accumulated_weights(nodes: &[Node]) -> Vec { nodes .iter() - .map(|n| n.weight as u32) + .filter_map(|n| { + if n.weight > 0 { + Some(n.weight as u32) + } else { + None + } + }) .scan(0, |accumulated_weight, weight| { *accumulated_weight += weight; Some(*accumulated_weight) @@ -86,13 +101,14 @@ impl Nodes { /// Get the node corresponding to a share id. pub fn share_id_to_node(&self, share_id: &ShareIndex) -> FastCryptoResult<&Node> { - let node_id: PartyId = match self.accumulated_weights.binary_search(&share_id.get()) { + let nonzero_node_id = match self.accumulated_weights.binary_search(&share_id.get()) { Ok(i) => i, Err(i) => i, + }; + match self.nodes_with_nonzero_weight.get(nonzero_node_id) { + Some(node_id) => self.node_id_to_node(*node_id), + None => Err(InvalidInput), } - .try_into() - .map_err(|_| InvalidInput)?; - self.node_id_to_node(node_id) } pub fn node_id_to_node(&self, party_id: PartyId) -> FastCryptoResult<&Node> { @@ -132,10 +148,6 @@ impl Nodes { pub fn reduce(&self, t: u16, allowed_delta: u16) -> (Self, u16) { let mut max_d = 1; for d in 2..=40 { - // TODO: [perf] Remove once the DKG & Nodes can work with zero weights. - if self.nodes.iter().any(|n| n.weight < d) { - break; - } let sum = self.nodes.iter().map(|n| n.weight % d).sum::(); if sum <= allowed_delta { max_d = d; @@ -151,6 +163,7 @@ impl Nodes { }) .collect::>(); let accumulated_weights = Self::get_accumulated_weights(&nodes); + let nodes_with_nonzero_weight = Self::filter_nonzero_weights(&nodes); let total_weight = nodes.iter().map(|n| n.weight as u32).sum::(); let new_t = t / max_d + (t % max_d != 0) as u16; ( @@ -158,6 +171,7 @@ impl Nodes { nodes, total_weight, accumulated_weights, + nodes_with_nonzero_weight, }, new_t, ) diff --git a/fastcrypto-tbls/src/tests/dkg_tests.rs b/fastcrypto-tbls/src/tests/dkg_tests.rs index 8bdb875e6f..92f4efc7f4 100644 --- a/fastcrypto-tbls/src/tests/dkg_tests.rs +++ b/fastcrypto-tbls/src/tests/dkg_tests.rs @@ -34,7 +34,7 @@ fn gen_keys_and_nodes(n: usize) -> (Vec>, Nodes) { .map(|(id, _sk, pk)| Node:: { id: *id, pk: pk.clone(), - weight: 2 + id, + weight: if *id == 2 { 0 } else { 2 + id }, }) .collect(); let nodes = Nodes::new(nodes).unwrap(); @@ -67,7 +67,15 @@ fn test_dkg_e2e_5_parties_min_weight_2_threshold_4() { &mut thread_rng(), ) .unwrap(); - // The third party (d2) is ignored (emulating a byzantine party). + // Party with weight 0 + let d2 = Party::::new( + keys.get(2_usize).unwrap().1.clone(), + nodes.clone(), + t, + ro.clone(), + &mut thread_rng(), + ) + .unwrap(); let d3 = Party::::new( keys.get(3_usize).unwrap().1.clone(), nodes.clone(), @@ -141,6 +149,13 @@ fn test_dkg_e2e_5_parties_min_weight_2_threshold_4() { .collect::>(); let (conf1, used_msgs1) = d1.merge(proc_msg1).unwrap(); + let proc_msg2 = &all_messages + .iter() + .map(|m| d2.process_message(m.clone(), &mut thread_rng()).unwrap()) + .collect::>(); + let (conf2, used_msgs2) = d2.merge(proc_msg2).unwrap(); + assert!(conf2.complaints.is_empty()); + // Note that d3's first round message is not included but it should still be able to receive // shares and post complaints. let proc_msg3 = &all_messages @@ -203,24 +218,33 @@ fn test_dkg_e2e_5_parties_min_weight_2_threshold_4() { let ver_msg1 = d1 .process_confirmations(&used_msgs1, &all_confirmations, 3, &mut thread_rng()) .unwrap(); + let ver_msg2 = d2 + .process_confirmations(&used_msgs2, &all_confirmations, 3, &mut thread_rng()) + .unwrap(); let ver_msg3 = d3 .process_confirmations(&used_msgs3, &all_confirmations, 3, &mut thread_rng()) .unwrap(); let ver_msg5 = d5 .process_confirmations(&used_msgs5, &all_confirmations, 3, &mut thread_rng()) .unwrap(); - assert_eq!(ver_msg0.0.len(), 2); // only msg0, msg5 were valid and didn't send invalid complaints - assert_eq!(ver_msg1.0.len(), 2); - assert_eq!(ver_msg3.0.len(), 2); - assert_eq!(ver_msg5.0.len(), 2); + assert_eq!(ver_msg0.len(), 2); // only msg0, msg5 were valid and didn't send invalid complaints + assert_eq!(ver_msg1.len(), 2); + assert_eq!(ver_msg2.len(), 2); + assert_eq!(ver_msg3.len(), 2); + assert_eq!(ver_msg5.len(), 2); let o0 = d0.aggregate(&ver_msg0); let _o1 = d1.aggregate(&ver_msg1); + let o2 = d2.aggregate(&ver_msg2); let o3 = d3.aggregate(&ver_msg3); let o5 = d5.aggregate(&ver_msg5); assert!(o0.shares.is_some()); + assert!(o2.shares.is_none()); assert!(o3.shares.is_some()); assert!(o5.shares.is_none()); // recall that it didn't receive valid share from msg0 + assert_eq!(o0.vss_pk, o2.vss_pk); + assert_eq!(o0.vss_pk, o3.vss_pk); + assert_eq!(o0.vss_pk, o5.vss_pk); // check the resulting vss pk let mut poly = msg0.vss_pk.clone(); @@ -478,7 +502,7 @@ fn test_test_process_confirmations() { // d3 is ignored because it sent an invalid message assert_eq!( ver_msg - .0 + .data() .iter() .map(|m| m.message.sender) .collect::>(), @@ -499,7 +523,7 @@ fn test_test_process_confirmations() { // d3 is not ignored since conf7 is ignored assert_eq!( ver_msg - .0 + .data() .iter() .map(|m| m.message.sender) .collect::>(), @@ -521,7 +545,7 @@ fn test_test_process_confirmations() { // now also d2 is ignored because it sent an invalid complaint assert_eq!( ver_msg - .0 + .data() .iter() .map(|m| m.message.sender) .collect::>(), diff --git a/fastcrypto-tbls/src/tests/nodes_tests.rs b/fastcrypto-tbls/src/tests/nodes_tests.rs index 546d621693..11b2c97d0b 100644 --- a/fastcrypto-tbls/src/tests/nodes_tests.rs +++ b/fastcrypto-tbls/src/tests/nodes_tests.rs @@ -48,10 +48,6 @@ fn test_new_failures() { // too little let nodes_vec: Vec> = Vec::new(); assert!(Nodes::new(nodes_vec).is_err()); - // with zero weight - let mut nodes_vec = get_nodes::(20); - nodes_vec[19].weight = 0; - assert!(Nodes::new(nodes_vec).is_err()); } #[test] @@ -66,6 +62,74 @@ fn test_new_order() { assert_eq!(nodes1.hash(), nodes2.hash()); } +#[test] +fn test_zero_weight() { + // The basic case + let nodes_vec = get_nodes::(10); + let nodes1 = Nodes::new(nodes_vec.clone()).unwrap(); + assert_eq!( + nodes1 + .share_id_to_node(&NonZeroU32::new(1).unwrap()) + .unwrap() + .id, + 0 + ); + assert_eq!( + nodes1 + .share_id_to_node(&NonZeroU32::new(2).unwrap()) + .unwrap() + .id, + 1 + ); + assert_eq!(nodes1.share_ids_of(0), vec![NonZeroU32::new(1).unwrap()]); + + // first node's weight is 0 + let mut nodes_vec = get_nodes::(10); + nodes_vec[0].weight = 0; + let nodes1 = Nodes::new(nodes_vec.clone()).unwrap(); + assert_eq!( + nodes1 + .share_id_to_node(&NonZeroU32::new(1).unwrap()) + .unwrap() + .id, + 1 + ); + assert_eq!( + nodes1 + .share_id_to_node(&NonZeroU32::new(2).unwrap()) + .unwrap() + .id, + 1 + ); + assert_eq!(nodes1.share_ids_of(0), vec![]); + + // last node's weight is 0 + let mut nodes_vec = get_nodes::(10); + nodes_vec[9].weight = 0; + let nodes1 = Nodes::new(nodes_vec.clone()).unwrap(); + assert_eq!( + nodes1 + .share_id_to_node(&NonZeroU32::new(nodes1.total_weight()).unwrap()) + .unwrap() + .id, + 8 + ); + assert_eq!(nodes1.share_ids_of(9), vec![]); + + // third node's weight is 0 + let mut nodes_vec = get_nodes::(10); + nodes_vec[2].weight = 0; + let nodes1 = Nodes::new(nodes_vec.clone()).unwrap(); + assert_eq!( + nodes1 + .share_id_to_node(&NonZeroU32::new(4).unwrap()) + .unwrap() + .id, + 3 + ); + assert_eq!(nodes1.share_ids_of(2), vec![]); +} + #[test] fn test_interfaces() { let nodes_vec = get_nodes::(100);