Skip to content

Commit

Permalink
fix: address Beaver triple generation comments (#348)
Browse files Browse the repository at this point in the history
  • Loading branch information
itegulov authored Nov 10, 2023
1 parent 920853c commit 785456e
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 83 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions mpc-recovery/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ opentelemetry-otlp = { version = "0.13.0", features = [
] }
opentelemetry-semantic-conventions = "0.12.0"
prometheus = { version = "0.13.3", features = ["process"] }
rand = "0.7"
rand8 = { package = "rand", version = "0.8" }
rand = "0.8"
reqwest = { version = "0.11.16", features = ["blocking"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
Expand Down
2 changes: 1 addition & 1 deletion mpc-recovery/src/oauth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ mod tests {
use super::*;
use chrono::{Duration, Utc};
use jsonwebtoken::{encode, EncodingKey, Header};
use rand8::rngs::OsRng;
use rand::rngs::OsRng;
use rsa::{
pkcs1::{EncodeRsaPrivateKey, EncodeRsaPublicKey},
RsaPrivateKey, RsaPublicKey,
Expand Down
4 changes: 2 additions & 2 deletions mpc-recovery/src/sign_node/aggregate_signer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use curv::BigInt;
use ed25519_dalek::{Sha512, Signature, Verifier};
use multi_party_eddsa::protocols;
use multi_party_eddsa::protocols::aggsig::{self, KeyAgg, SignSecondMsg};
use rand8::rngs::OsRng;
use rand8::Rng;
use rand::rngs::OsRng;
use rand::Rng;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;

Expand Down
12 changes: 8 additions & 4 deletions node/src/protocol/cryptography.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::protocol::message::{GeneratingMessage, ResharingMessage};
use crate::protocol::state::WaitingForConsensusState;
use crate::protocol::MpcMessage;
use async_trait::async_trait;
use cait_sith::protocol::{Action, Participant};
use cait_sith::protocol::{Action, InitializationError, Participant, ProtocolError};
use k256::elliptic_curve::group::GroupEncoding;

pub trait CryptographicCtx {
Expand All @@ -18,6 +18,10 @@ pub enum CryptographicError {
SendError(#[from] SendError),
#[error("unknown participant: {0:?}")]
UnknownParticipant(Participant),
#[error("cait-sith initialization error: {0}")]
CaitSithInitializationError(#[from] InitializationError),
#[error("cait-sith protocol error: {0}")]
CaitSithProtocolError(#[from] ProtocolError),
}

#[async_trait]
Expand All @@ -36,7 +40,7 @@ impl CryptographicProtocol for GeneratingState {
) -> Result<NodeState, CryptographicError> {
tracing::info!("progressing key generation");
loop {
let action = self.protocol.poke().unwrap();
let action = self.protocol.poke()?;
match action {
Action::Wait => {
tracing::debug!("waiting");
Expand Down Expand Up @@ -170,9 +174,9 @@ impl CryptographicProtocol for RunningState {
ctx: C,
) -> Result<NodeState, CryptographicError> {
if self.triple_manager.potential_len() < 2 {
self.triple_manager.generate();
self.triple_manager.generate()?;
}
for (p, msg) in self.triple_manager.poke() {
for (p, msg) in self.triple_manager.poke()? {
let url = self.participants.get(&p).unwrap();
http_client::message(ctx.http_client(), url.clone(), MpcMessage::Triple(msg)).await?;
}
Expand Down
48 changes: 40 additions & 8 deletions node/src/protocol/message.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::{HashMap, VecDeque};

use super::state::{GeneratingState, NodeState, ResharingState, RunningState};
use cait_sith::protocol::{MessageData, Participant};
use cait_sith::protocol::{InitializationError, MessageData, Participant, ProtocolError};
use serde::{Deserialize, Serialize};

pub trait MessageCtx {
Expand Down Expand Up @@ -63,49 +63,81 @@ impl MpcMessageQueue {
}
}

#[derive(thiserror::Error, Debug)]
pub enum MessageHandleError {
#[error("cait-sith initialization error: {0}")]
CaitSithInitializationError(#[from] InitializationError),
#[error("cait-sith protocol error: {0}")]
CaitSithProtocolError(#[from] ProtocolError),
}

pub trait MessageHandler {
fn handle<C: MessageCtx + Send + Sync>(&mut self, ctx: C, queue: &mut MpcMessageQueue);
fn handle<C: MessageCtx + Send + Sync>(
&mut self,
ctx: C,
queue: &mut MpcMessageQueue,
) -> Result<(), MessageHandleError>;
}

impl MessageHandler for GeneratingState {
fn handle<C: MessageCtx + Send + Sync>(&mut self, _ctx: C, queue: &mut MpcMessageQueue) {
fn handle<C: MessageCtx + Send + Sync>(
&mut self,
_ctx: C,
queue: &mut MpcMessageQueue,
) -> Result<(), MessageHandleError> {
while let Some(msg) = queue.generating.pop_front() {
tracing::debug!("handling new generating message");
self.protocol.message(msg.from, msg.data);
}
Ok(())
}
}

impl MessageHandler for ResharingState {
fn handle<C: MessageCtx + Send + Sync>(&mut self, _ctx: C, queue: &mut MpcMessageQueue) {
fn handle<C: MessageCtx + Send + Sync>(
&mut self,
_ctx: C,
queue: &mut MpcMessageQueue,
) -> Result<(), MessageHandleError> {
let q = queue.resharing_bins.entry(self.old_epoch).or_default();
while let Some(msg) = q.pop_front() {
tracing::debug!("handling new resharing message");
self.protocol.message(msg.from, msg.data);
}
Ok(())
}
}

impl MessageHandler for RunningState {
fn handle<C: MessageCtx + Send + Sync>(&mut self, _ctx: C, queue: &mut MpcMessageQueue) {
fn handle<C: MessageCtx + Send + Sync>(
&mut self,
_ctx: C,
queue: &mut MpcMessageQueue,
) -> Result<(), MessageHandleError> {
for (id, queue) in queue.triple_bins.entry(self.epoch).or_default() {
if let Some(protocol) = self.triple_manager.get_or_generate(*id) {
if let Some(protocol) = self.triple_manager.get_or_generate(*id)? {
while let Some(message) = queue.pop_front() {
protocol.message(message.from, message.data);
}
}
}
Ok(())
}
}

impl MessageHandler for NodeState {
fn handle<C: MessageCtx + Send + Sync>(&mut self, ctx: C, queue: &mut MpcMessageQueue) {
fn handle<C: MessageCtx + Send + Sync>(
&mut self,
ctx: C,
queue: &mut MpcMessageQueue,
) -> Result<(), MessageHandleError> {
match self {
NodeState::Generating(state) => state.handle(ctx, queue),
NodeState::Resharing(state) => state.handle(ctx, queue),
NodeState::Running(state) => state.handle(ctx, queue),
_ => {
tracing::debug!("skipping message processing")
tracing::debug!("skipping message processing");
Ok(())
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion node/src/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ impl MpcSignProtocol {
let mut state = std::mem::take(&mut *state_guard);
state = state.progress(&self.ctx).await?;
state = state.advance(&self.ctx, contract_state).await?;
state.handle(&self.ctx, &mut queue);
state.handle(&self.ctx, &mut queue)?;
*state_guard = state;
drop(state_guard);
tokio::time::sleep(Duration::from_millis(1000)).await;
Expand Down
136 changes: 74 additions & 62 deletions node/src/protocol/triple.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::message::TripleMessage;
use crate::types::TripleProtocol;
use crate::util::AffinePointExt;
use cait_sith::protocol::{Action, Participant};
use cait_sith::protocol::{Action, InitializationError, Participant, ProtocolError};
use cait_sith::triples::TripleGenerationOutput;
use k256::Secp256k1;
use std::collections::btree_map::Entry;
Expand Down Expand Up @@ -57,99 +57,111 @@ impl TripleManager {
}

/// Starts a new Beaver triple generation protocol.
pub fn generate(&mut self) {
pub fn generate(&mut self) -> Result<(), InitializationError> {
let id = rand::random();
tracing::info!(id, "starting protocol to generate a new triple");
let protocol: TripleProtocol = Box::new(
cait_sith::triples::generate_triple(&self.participants, self.me, self.threshold)
.unwrap(),
);
let protocol: TripleProtocol = Box::new(cait_sith::triples::generate_triple(
&self.participants,
self.me,
self.threshold,
)?);
self.generators.insert(id, protocol);
Ok(())
}

/// Take an unspent triple by its id with no way to return it.
/// It is very important to NOT reuse the same triple twice for two different
/// protocols.
pub fn take(&mut self, id: TripleId) -> Option<TripleGenerationOutput<Secp256k1>> {
match self.triples.entry(id) {
Entry::Vacant(_) => None,
Entry::Occupied(entry) => Some(entry.remove()),
}
self.triples.remove(&id)
}

/// Ensures that the triple with the given id is either:
/// 1) Already generated in which case returns `None`, or
/// 2) Is currently being generated by `protocol` in which case returns `Some(protocol)`, or
/// 3) Has never been seen by the manager in which case start a new protocol and returns `Some(protocol)`
// TODO: What if the triple completed generation and is already spent?
pub fn get_or_generate(&mut self, id: TripleId) -> Option<&mut TripleProtocol> {
pub fn get_or_generate(
&mut self,
id: TripleId,
) -> Result<Option<&mut TripleProtocol>, InitializationError> {
if self.triples.contains_key(&id) {
None
Ok(None)
} else {
Some(self.generators.entry(id).or_insert_with(|| {
tracing::info!(id, "joining protocol to generate a new triple");
Box::new(
cait_sith::triples::generate_triple(
match self.generators.entry(id) {
Entry::Vacant(e) => {
tracing::info!(id, "joining protocol to generate a new triple");
let protocol = cait_sith::triples::generate_triple(
&self.participants,
self.me,
self.threshold,
)
.unwrap(),
)
}))
)?;
Ok(Some(e.insert(Box::new(protocol))))
}
Entry::Occupied(e) => Ok(Some(e.into_mut())),
}
}
}

/// Pokes all of the ongoing generation protocols and returns a vector of
/// messages to be sent to the respective participant.
///
/// An empty vector means we cannot progress until we receive a new message.
pub fn poke(&mut self) -> Vec<(Participant, TripleMessage)> {
pub fn poke(&mut self) -> Result<Vec<(Participant, TripleMessage)>, ProtocolError> {
let mut messages = Vec::new();
self.generators.retain(|id, protocol| loop {
let action = protocol.poke().unwrap();
match action {
Action::Wait => {
tracing::debug!("waiting");
// Retain protocol until we are finished
return true;
}
Action::SendMany(data) => {
for p in &self.participants {
messages.push((
*p,
TripleMessage {
id: *id,
epoch: self.epoch,
from: self.me,
data: data.clone(),
},
))
let mut result = Ok(());
self.generators.retain(|id, protocol| {
loop {
let action = match protocol.poke() {
Ok(action) => action,
Err(e) => {
result = Err(e);
break false;
}
};
match action {
Action::Wait => {
tracing::debug!("waiting");
// Retain protocol until we are finished
break true;
}
Action::SendMany(data) => {
for p in &self.participants {
messages.push((
*p,
TripleMessage {
id: *id,
epoch: self.epoch,
from: self.me,
data: data.clone(),
},
))
}
}
Action::SendPrivate(p, data) => messages.push((
p,
TripleMessage {
id: *id,
epoch: self.epoch,
from: self.me,
data: data.clone(),
},
)),
Action::Return(output) => {
tracing::info!(
id,
big_a = ?output.1.big_a.to_base58(),
big_b = ?output.1.big_b.to_base58(),
big_c = ?output.1.big_c.to_base58(),
"completed triple generation"
);
self.triples.insert(*id, output);
// Do not retain the protocol
break false;
}
}
Action::SendPrivate(p, data) => messages.push((
p,
TripleMessage {
id: *id,
epoch: self.epoch,
from: self.me,
data: data.clone(),
},
)),
Action::Return(output) => {
tracing::info!(
id,
big_a = ?output.1.big_a.into_base58(),
big_b = ?output.1.big_b.into_base58(),
big_c = ?output.1.big_c.into_base58(),
"completed triple generation"
);
self.triples.insert(*id, output);
// Do not retain the protocol
return false;
}
}
});
messages
result.map(|_| messages)
}
}
4 changes: 2 additions & 2 deletions node/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl NearPublicKeyExt for near_crypto::PublicKey {

pub trait AffinePointExt {
fn into_near_public_key(self) -> near_crypto::PublicKey;
fn into_base58(self) -> String;
fn to_base58(&self) -> String;
}

impl AffinePointExt for AffinePoint {
Expand All @@ -48,7 +48,7 @@ impl AffinePointExt for AffinePoint {
)
}

fn into_base58(self) -> String {
fn to_base58(&self) -> String {
let key = near_crypto::Secp256K1PublicKey::try_from(
&self.to_encoded_point(false).as_bytes()[1..65],
)
Expand Down

0 comments on commit 785456e

Please sign in to comment.