Skip to content

Commit

Permalink
DKG: support zero weights, and handle "impossible" edge cases more gr…
Browse files Browse the repository at this point in the history
…acefully (#708)
  • Loading branch information
benr-ml authored Dec 13, 2023
1 parent 14d62bb commit 3ae45d6
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 44 deletions.
65 changes: 51 additions & 14 deletions fastcrypto-tbls/src/dkg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use fastcrypto::traits::AllowedRng;
use itertools::Itertools;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use tracing::{debug, info, warn};
use tracing::{debug, error, info, warn};

use tap::prelude::*;

Expand Down Expand Up @@ -100,7 +100,7 @@ impl<G: GroupElement, EG: GroupElement> From<&[ProcessedMessage<G, EG>]>

/// Processed messages that were not excluded.
pub struct VerifiedProcessedMessages<G: GroupElement, EG: GroupElement>(
pub Vec<ProcessedMessage<G, EG>>,
Vec<ProcessedMessage<G, EG>>,
);

impl<G: GroupElement, EG: GroupElement> VerifiedProcessedMessages<G, EG> {
Expand All @@ -113,6 +113,18 @@ impl<G: GroupElement, EG: GroupElement> VerifiedProcessedMessages<G, EG> {
.collect::<Vec<_>>();
Self(filtered)
}

pub fn len(&self) -> usize {
self.0.len()
}

pub fn is_empty(&self) -> bool {
self.0.is_empty()
}

pub fn data(&self) -> &[ProcessedMessage<G, EG>] {
&self.0
}
}

/// [Output] is the final output of the DKG protocol in case it runs
Expand All @@ -123,11 +135,9 @@ impl<G: GroupElement, EG: GroupElement> VerifiedProcessedMessages<G, EG> {
pub struct Output<G: GroupElement, EG: GroupElement> {
pub nodes: Nodes<EG>,
pub vss_pk: Poly<G>,
pub shares: Option<Vec<Share<G::ScalarType>>>, // None if some shares are missing.
pub shares: Option<Vec<Share<G::ScalarType>>>, // None if some shares are missing or weight is zero.
}

// TODO: Handle parties with zero weights (currently rejected by Nodes::new()).

/// A dealer in the DKG ceremony.
///
/// Can be instantiated with G1Curve or G2Curve.
Expand All @@ -149,11 +159,10 @@ where
) -> FastCryptoResult<Self> {
// Check that my ecies pk is in the nodes.
let enc_pk = ecies::PublicKey::<EG>::from_private_key(&enc_sk);
let my_id = nodes
let my_node = nodes
.iter()
.find(|n| n.pk == enc_pk)
.ok_or(FastCryptoError::InvalidInput)?
.id;
.ok_or(FastCryptoError::InvalidInput)?;
// Check that the threshold makes sense.
if t >= nodes.total_weight() || t == 0 {
return Err(FastCryptoError::InvalidInput);
Expand All @@ -164,9 +173,11 @@ where

// TODO: remove once the protocol is stable since it's a non negligible computation.
let vss_pk = vss_sk.commit::<G>();

info!(
"DKG: Creating party {}, nodes hash {:?}, t {}, n {}, ro {:?}, enc pk {:?}, vss pk c0 {:?}",
my_id,
"DKG: Creating party {} with weight {}, nodes hash {:?}, t {}, n {}, ro {:?}, enc pk {:?}, vss pk c0 {:?}",
my_node.id,
my_node.weight,
nodes.hash(),
t,
nodes.total_weight(),
Expand All @@ -176,7 +187,7 @@ where
);

Ok(Self {
id: my_id,
id: my_node.id,
nodes,
t,
random_oracle,
Expand Down Expand Up @@ -210,6 +221,7 @@ where
.iter()
.map(|share_id| self.vss_sk.eval(*share_id).value)
.collect::<Vec<_>>();
// Works even with empty shares_ids (will result in [0]).
let buff = bcs::to_bytes(&shares).expect("serialize of shares should never fail");
(node.pk.clone(), buff)
})
Expand Down Expand Up @@ -429,6 +441,14 @@ where
conf.complaints.push(complaint.clone());
}
}

if filtered_messages.0.iter().all(|m| m.complaint.is_some()) {
error!("DKG: All processed messages resulted in complaints, this should never happen");
return Err(FastCryptoError::GeneralError(
"All processed messages resulted in complaints".to_string(),
));
}

Ok((conf, filtered_messages))
}

Expand Down Expand Up @@ -542,6 +562,15 @@ where
&to_exclude.into_iter().collect::<Vec<_>>(),
);

if verified_messages.is_empty() {
error!(
"DKG: No verified messages after processing complaints, this should never happen"
);
return Err(FastCryptoError::GeneralError(
"No verified messages after processing complaints".to_string(),
));
}

// Log verified messages parties.
let used_parties = verified_messages
.0
Expand Down Expand Up @@ -601,11 +630,19 @@ where

// If I didn't receive a valid share for one of the verified messages (i.e., my complaint
// was not processed), then I don't have a valid share for the final key.
let shares = if messages.0.iter().all(|m| m.complaint.is_none()) {
info!("DKG: Aggregating my shares succeeded");
let has_invalid_share = messages.0.iter().any(|m| m.complaint.is_some());
let has_zero_shares = final_shares.is_empty();
info!(
"DKG: Aggregating my shares completed with has_invalid_share={}, has_zero_shares={}",
has_invalid_share, has_zero_shares
);
if has_invalid_share {
warn!("DKG: Aggregating my shares failed");
}

let shares = if !has_invalid_share && !has_zero_shares {
Some(final_shares.values().cloned().collect())
} else {
warn!("DKG: Aggregating my shares failed");
None
};

Expand Down
48 changes: 31 additions & 17 deletions fastcrypto-tbls/src/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@ pub type PartyId = u16;
pub struct Node<G: GroupElement> {
pub id: PartyId,
pub pk: ecies::PublicKey<G>,
pub weight: u16,
pub weight: u16, // May be zero
}

/// Wrapper for a set of nodes.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Nodes<G: GroupElement> {
nodes: Vec<Node<G>>, // Party ids are 0..len(nodes)-1
total_weight: u32, // Share ids are 1..total_weight
nodes: Vec<Node<G>>, // Party ids are 0..len(nodes)-1 (inclusive)
total_weight: u32, // Share ids are 1..total_weight (inclusive)
// Next two fields are used to map share ids to party ids.
accumulated_weights: Vec<u32>, // Accumulated sum of all nodes' weights. Used to map share ids to party ids.
nodes_with_nonzero_weight: Vec<u16>, // Indexes of nodes with non-zero weight
}

impl<G: GroupElement + Serialize> Nodes<G> {
Expand All @@ -40,13 +42,11 @@ impl<G: GroupElement + Serialize> Nodes<G> {
if nodes.is_empty() || nodes.len() > 1000 {
return Err(FastCryptoError::InvalidInput);
}
// Check that all weights are non-zero
if nodes.iter().any(|n| n.weight == 0) {
return Err(FastCryptoError::InvalidInput);
}

// We use accumulated weights to map share ids to party ids.
// We use the next two to map share ids to party ids.
let accumulated_weights = Self::get_accumulated_weights(&nodes);
let nodes_with_nonzero_weight = Self::filter_nonzero_weights(&nodes);

let total_weight = *accumulated_weights
.last()
.expect("Number of nodes is non-zero");
Expand All @@ -55,13 +55,28 @@ impl<G: GroupElement + Serialize> Nodes<G> {
nodes,
total_weight,
accumulated_weights,
nodes_with_nonzero_weight,
})
}

fn filter_nonzero_weights(nodes: &[Node<G>]) -> Vec<u16> {
nodes
.iter()
.enumerate()
.filter_map(|(i, n)| if n.weight > 0 { Some(i as u16) } else { None })
.collect::<Vec<_>>()
}

fn get_accumulated_weights(nodes: &[Node<G>]) -> Vec<u32> {
nodes
.iter()
.map(|n| n.weight as u32)
.filter_map(|n| {
if n.weight > 0 {
Some(n.weight as u32)
} else {
None
}
})
.scan(0, |accumulated_weight, weight| {
*accumulated_weight += weight;
Some(*accumulated_weight)
Expand All @@ -86,13 +101,14 @@ impl<G: GroupElement + Serialize> Nodes<G> {

/// Get the node corresponding to a share id.
pub fn share_id_to_node(&self, share_id: &ShareIndex) -> FastCryptoResult<&Node<G>> {
let node_id: PartyId = match self.accumulated_weights.binary_search(&share_id.get()) {
let nonzero_node_id = match self.accumulated_weights.binary_search(&share_id.get()) {
Ok(i) => i,
Err(i) => i,
};
match self.nodes_with_nonzero_weight.get(nonzero_node_id) {
Some(node_id) => self.node_id_to_node(*node_id),
None => Err(InvalidInput),
}
.try_into()
.map_err(|_| InvalidInput)?;
self.node_id_to_node(node_id)
}

pub fn node_id_to_node(&self, party_id: PartyId) -> FastCryptoResult<&Node<G>> {
Expand Down Expand Up @@ -132,10 +148,6 @@ impl<G: GroupElement + Serialize> Nodes<G> {
pub fn reduce(&self, t: u16, allowed_delta: u16) -> (Self, u16) {
let mut max_d = 1;
for d in 2..=40 {
// TODO: [perf] Remove once the DKG & Nodes can work with zero weights.
if self.nodes.iter().any(|n| n.weight < d) {
break;
}
let sum = self.nodes.iter().map(|n| n.weight % d).sum::<u16>();
if sum <= allowed_delta {
max_d = d;
Expand All @@ -151,13 +163,15 @@ impl<G: GroupElement + Serialize> Nodes<G> {
})
.collect::<Vec<_>>();
let accumulated_weights = Self::get_accumulated_weights(&nodes);
let nodes_with_nonzero_weight = Self::filter_nonzero_weights(&nodes);
let total_weight = nodes.iter().map(|n| n.weight as u32).sum::<u32>();
let new_t = t / max_d + (t % max_d != 0) as u16;
(
Self {
nodes,
total_weight,
accumulated_weights,
nodes_with_nonzero_weight,
},
new_t,
)
Expand Down
42 changes: 33 additions & 9 deletions fastcrypto-tbls/src/tests/dkg_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ fn gen_keys_and_nodes(n: usize) -> (Vec<KeyNodePair<EG>>, Nodes<EG>) {
.map(|(id, _sk, pk)| Node::<EG> {
id: *id,
pk: pk.clone(),
weight: 2 + id,
weight: if *id == 2 { 0 } else { 2 + id },
})
.collect();
let nodes = Nodes::new(nodes).unwrap();
Expand Down Expand Up @@ -67,7 +67,15 @@ fn test_dkg_e2e_5_parties_min_weight_2_threshold_4() {
&mut thread_rng(),
)
.unwrap();
// The third party (d2) is ignored (emulating a byzantine party).
// Party with weight 0
let d2 = Party::<G, EG>::new(
keys.get(2_usize).unwrap().1.clone(),
nodes.clone(),
t,
ro.clone(),
&mut thread_rng(),
)
.unwrap();
let d3 = Party::<G, EG>::new(
keys.get(3_usize).unwrap().1.clone(),
nodes.clone(),
Expand Down Expand Up @@ -141,6 +149,13 @@ fn test_dkg_e2e_5_parties_min_weight_2_threshold_4() {
.collect::<Vec<_>>();
let (conf1, used_msgs1) = d1.merge(proc_msg1).unwrap();

let proc_msg2 = &all_messages
.iter()
.map(|m| d2.process_message(m.clone(), &mut thread_rng()).unwrap())
.collect::<Vec<_>>();
let (conf2, used_msgs2) = d2.merge(proc_msg2).unwrap();
assert!(conf2.complaints.is_empty());

// Note that d3's first round message is not included but it should still be able to receive
// shares and post complaints.
let proc_msg3 = &all_messages
Expand Down Expand Up @@ -203,24 +218,33 @@ fn test_dkg_e2e_5_parties_min_weight_2_threshold_4() {
let ver_msg1 = d1
.process_confirmations(&used_msgs1, &all_confirmations, 3, &mut thread_rng())
.unwrap();
let ver_msg2 = d2
.process_confirmations(&used_msgs2, &all_confirmations, 3, &mut thread_rng())
.unwrap();
let ver_msg3 = d3
.process_confirmations(&used_msgs3, &all_confirmations, 3, &mut thread_rng())
.unwrap();
let ver_msg5 = d5
.process_confirmations(&used_msgs5, &all_confirmations, 3, &mut thread_rng())
.unwrap();
assert_eq!(ver_msg0.0.len(), 2); // only msg0, msg5 were valid and didn't send invalid complaints
assert_eq!(ver_msg1.0.len(), 2);
assert_eq!(ver_msg3.0.len(), 2);
assert_eq!(ver_msg5.0.len(), 2);
assert_eq!(ver_msg0.len(), 2); // only msg0, msg5 were valid and didn't send invalid complaints
assert_eq!(ver_msg1.len(), 2);
assert_eq!(ver_msg2.len(), 2);
assert_eq!(ver_msg3.len(), 2);
assert_eq!(ver_msg5.len(), 2);

let o0 = d0.aggregate(&ver_msg0);
let _o1 = d1.aggregate(&ver_msg1);
let o2 = d2.aggregate(&ver_msg2);
let o3 = d3.aggregate(&ver_msg3);
let o5 = d5.aggregate(&ver_msg5);
assert!(o0.shares.is_some());
assert!(o2.shares.is_none());
assert!(o3.shares.is_some());
assert!(o5.shares.is_none()); // recall that it didn't receive valid share from msg0
assert_eq!(o0.vss_pk, o2.vss_pk);
assert_eq!(o0.vss_pk, o3.vss_pk);
assert_eq!(o0.vss_pk, o5.vss_pk);

// check the resulting vss pk
let mut poly = msg0.vss_pk.clone();
Expand Down Expand Up @@ -478,7 +502,7 @@ fn test_test_process_confirmations() {
// d3 is ignored because it sent an invalid message
assert_eq!(
ver_msg
.0
.data()
.iter()
.map(|m| m.message.sender)
.collect::<Vec<_>>(),
Expand All @@ -499,7 +523,7 @@ fn test_test_process_confirmations() {
// d3 is not ignored since conf7 is ignored
assert_eq!(
ver_msg
.0
.data()
.iter()
.map(|m| m.message.sender)
.collect::<Vec<_>>(),
Expand All @@ -521,7 +545,7 @@ fn test_test_process_confirmations() {
// now also d2 is ignored because it sent an invalid complaint
assert_eq!(
ver_msg
.0
.data()
.iter()
.map(|m| m.message.sender)
.collect::<Vec<_>>(),
Expand Down
Loading

0 comments on commit 3ae45d6

Please sign in to comment.