Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: address Beaver triple generation comments #348

Merged
merged 2 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions mpc-recovery/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ opentelemetry-otlp = { version = "0.13.0", features = [
] }
opentelemetry-semantic-conventions = "0.12.0"
prometheus = { version = "0.13.3", features = ["process"] }
rand = "0.7"
rand8 = { package = "rand", version = "0.8" }
rand = "0.8"
reqwest = { version = "0.11.16", features = ["blocking"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
Expand Down
2 changes: 1 addition & 1 deletion mpc-recovery/src/oauth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ mod tests {
use super::*;
use chrono::{Duration, Utc};
use jsonwebtoken::{encode, EncodingKey, Header};
use rand8::rngs::OsRng;
use rand::rngs::OsRng;
use rsa::{
pkcs1::{EncodeRsaPrivateKey, EncodeRsaPublicKey},
RsaPrivateKey, RsaPublicKey,
Expand Down
4 changes: 2 additions & 2 deletions mpc-recovery/src/sign_node/aggregate_signer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use curv::BigInt;
use ed25519_dalek::{Sha512, Signature, Verifier};
use multi_party_eddsa::protocols;
use multi_party_eddsa::protocols::aggsig::{self, KeyAgg, SignSecondMsg};
use rand8::rngs::OsRng;
use rand8::Rng;
use rand::rngs::OsRng;
use rand::Rng;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;

Expand Down
12 changes: 8 additions & 4 deletions node/src/protocol/cryptography.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::protocol::message::{GeneratingMessage, ResharingMessage};
use crate::protocol::state::WaitingForConsensusState;
use crate::protocol::MpcMessage;
use async_trait::async_trait;
use cait_sith::protocol::{Action, Participant};
use cait_sith::protocol::{Action, InitializationError, Participant, ProtocolError};
use k256::elliptic_curve::group::GroupEncoding;

pub trait CryptographicCtx {
Expand All @@ -18,6 +18,10 @@ pub enum CryptographicError {
SendError(#[from] SendError),
#[error("unknown participant: {0:?}")]
UnknownParticipant(Participant),
#[error("cait-sith initialization error: {0}")]
CaitSithInitializationError(#[from] InitializationError),
#[error("cait-sith protocol error: {0}")]
CaitSithProtocolError(#[from] ProtocolError),
}

#[async_trait]
Expand All @@ -36,7 +40,7 @@ impl CryptographicProtocol for GeneratingState {
) -> Result<NodeState, CryptographicError> {
tracing::info!("progressing key generation");
loop {
let action = self.protocol.poke().unwrap();
let action = self.protocol.poke()?;
match action {
Action::Wait => {
tracing::debug!("waiting");
Expand Down Expand Up @@ -170,9 +174,9 @@ impl CryptographicProtocol for RunningState {
ctx: C,
) -> Result<NodeState, CryptographicError> {
if self.triple_manager.potential_len() < 2 {
self.triple_manager.generate();
self.triple_manager.generate()?;
}
for (p, msg) in self.triple_manager.poke() {
for (p, msg) in self.triple_manager.poke()? {
let url = self.participants.get(&p).unwrap();
http_client::message(ctx.http_client(), url.clone(), MpcMessage::Triple(msg)).await?;
}
Expand Down
48 changes: 40 additions & 8 deletions node/src/protocol/message.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::{HashMap, VecDeque};

use super::state::{GeneratingState, NodeState, ResharingState, RunningState};
use cait_sith::protocol::{MessageData, Participant};
use cait_sith::protocol::{InitializationError, MessageData, Participant, ProtocolError};
use serde::{Deserialize, Serialize};

pub trait MessageCtx {
Expand Down Expand Up @@ -63,49 +63,81 @@ impl MpcMessageQueue {
}
}

#[derive(thiserror::Error, Debug)]
pub enum MessageHandleError {
#[error("cait-sith initialization error: {0}")]
CaitSithInitializationError(#[from] InitializationError),
#[error("cait-sith protocol error: {0}")]
CaitSithProtocolError(#[from] ProtocolError),
}

pub trait MessageHandler {
fn handle<C: MessageCtx + Send + Sync>(&mut self, ctx: C, queue: &mut MpcMessageQueue);
fn handle<C: MessageCtx + Send + Sync>(
&mut self,
ctx: C,
queue: &mut MpcMessageQueue,
) -> Result<(), MessageHandleError>;
}

impl MessageHandler for GeneratingState {
fn handle<C: MessageCtx + Send + Sync>(&mut self, _ctx: C, queue: &mut MpcMessageQueue) {
fn handle<C: MessageCtx + Send + Sync>(
&mut self,
_ctx: C,
queue: &mut MpcMessageQueue,
) -> Result<(), MessageHandleError> {
while let Some(msg) = queue.generating.pop_front() {
tracing::debug!("handling new generating message");
self.protocol.message(msg.from, msg.data);
}
Ok(())
}
}

impl MessageHandler for ResharingState {
fn handle<C: MessageCtx + Send + Sync>(&mut self, _ctx: C, queue: &mut MpcMessageQueue) {
fn handle<C: MessageCtx + Send + Sync>(
&mut self,
_ctx: C,
queue: &mut MpcMessageQueue,
) -> Result<(), MessageHandleError> {
let q = queue.resharing_bins.entry(self.old_epoch).or_default();
while let Some(msg) = q.pop_front() {
tracing::debug!("handling new resharing message");
self.protocol.message(msg.from, msg.data);
}
Ok(())
}
}

impl MessageHandler for RunningState {
fn handle<C: MessageCtx + Send + Sync>(&mut self, _ctx: C, queue: &mut MpcMessageQueue) {
fn handle<C: MessageCtx + Send + Sync>(
&mut self,
_ctx: C,
queue: &mut MpcMessageQueue,
) -> Result<(), MessageHandleError> {
for (id, queue) in queue.triple_bins.entry(self.epoch).or_default() {
if let Some(protocol) = self.triple_manager.get_or_generate(*id) {
if let Some(protocol) = self.triple_manager.get_or_generate(*id)? {
while let Some(message) = queue.pop_front() {
protocol.message(message.from, message.data);
}
}
}
Ok(())
}
}

impl MessageHandler for NodeState {
fn handle<C: MessageCtx + Send + Sync>(&mut self, ctx: C, queue: &mut MpcMessageQueue) {
fn handle<C: MessageCtx + Send + Sync>(
&mut self,
ctx: C,
queue: &mut MpcMessageQueue,
) -> Result<(), MessageHandleError> {
match self {
NodeState::Generating(state) => state.handle(ctx, queue),
NodeState::Resharing(state) => state.handle(ctx, queue),
NodeState::Running(state) => state.handle(ctx, queue),
_ => {
tracing::debug!("skipping message processing")
tracing::debug!("skipping message processing");
Ok(())
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion node/src/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ impl MpcSignProtocol {
let mut state = std::mem::take(&mut *state_guard);
state = state.progress(&self.ctx).await?;
state = state.advance(&self.ctx, contract_state).await?;
state.handle(&self.ctx, &mut queue);
state.handle(&self.ctx, &mut queue)?;
*state_guard = state;
drop(state_guard);
tokio::time::sleep(Duration::from_millis(1000)).await;
Expand Down
136 changes: 74 additions & 62 deletions node/src/protocol/triple.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::message::TripleMessage;
use crate::types::TripleProtocol;
use crate::util::AffinePointExt;
use cait_sith::protocol::{Action, Participant};
use cait_sith::protocol::{Action, InitializationError, Participant, ProtocolError};
use cait_sith::triples::TripleGenerationOutput;
use k256::Secp256k1;
use std::collections::btree_map::Entry;
Expand Down Expand Up @@ -57,99 +57,111 @@ impl TripleManager {
}

/// Starts a new Beaver triple generation protocol.
pub fn generate(&mut self) {
pub fn generate(&mut self) -> Result<(), InitializationError> {
let id = rand::random();
tracing::info!(id, "starting protocol to generate a new triple");
let protocol: TripleProtocol = Box::new(
cait_sith::triples::generate_triple(&self.participants, self.me, self.threshold)
.unwrap(),
);
let protocol: TripleProtocol = Box::new(cait_sith::triples::generate_triple(
&self.participants,
self.me,
self.threshold,
)?);
self.generators.insert(id, protocol);
Ok(())
}

/// Take an unspent triple by its id with no way to return it.
/// It is very important to NOT reuse the same triple twice for two different
/// protocols.
pub fn take(&mut self, id: TripleId) -> Option<TripleGenerationOutput<Secp256k1>> {
match self.triples.entry(id) {
Entry::Vacant(_) => None,
Entry::Occupied(entry) => Some(entry.remove()),
}
self.triples.remove(&id)
}

/// Ensures that the triple with the given id is either:
/// 1) Already generated in which case returns `None`, or
/// 2) Is currently being generated by `protocol` in which case returns `Some(protocol)`, or
/// 3) Has never been seen by the manager in which case start a new protocol and returns `Some(protocol)`
// TODO: What if the triple completed generation and is already spent?
pub fn get_or_generate(&mut self, id: TripleId) -> Option<&mut TripleProtocol> {
pub fn get_or_generate(
&mut self,
id: TripleId,
) -> Result<Option<&mut TripleProtocol>, InitializationError> {
if self.triples.contains_key(&id) {
None
Ok(None)
} else {
Some(self.generators.entry(id).or_insert_with(|| {
tracing::info!(id, "joining protocol to generate a new triple");
Box::new(
cait_sith::triples::generate_triple(
match self.generators.entry(id) {
Entry::Vacant(e) => {
tracing::info!(id, "joining protocol to generate a new triple");
let protocol = cait_sith::triples::generate_triple(
&self.participants,
self.me,
self.threshold,
)
.unwrap(),
)
}))
)?;
Ok(Some(e.insert(Box::new(protocol))))
}
Entry::Occupied(e) => Ok(Some(e.into_mut())),
}
}
}

/// Pokes all of the ongoing generation protocols and returns a vector of
/// messages to be sent to the respective participant.
///
/// An empty vector means we cannot progress until we receive a new message.
pub fn poke(&mut self) -> Vec<(Participant, TripleMessage)> {
pub fn poke(&mut self) -> Result<Vec<(Participant, TripleMessage)>, ProtocolError> {
let mut messages = Vec::new();
self.generators.retain(|id, protocol| loop {
let action = protocol.poke().unwrap();
match action {
Action::Wait => {
tracing::debug!("waiting");
// Retain protocol until we are finished
return true;
}
Action::SendMany(data) => {
for p in &self.participants {
messages.push((
*p,
TripleMessage {
id: *id,
epoch: self.epoch,
from: self.me,
data: data.clone(),
},
))
let mut result = Ok(());
self.generators.retain(|id, protocol| {
loop {
let action = match protocol.poke() {
Ok(action) => action,
Err(e) => {
result = Err(e);
break false;
}
};
match action {
Action::Wait => {
tracing::debug!("waiting");
// Retain protocol until we are finished
break true;
}
Action::SendMany(data) => {
for p in &self.participants {
messages.push((
*p,
TripleMessage {
id: *id,
epoch: self.epoch,
from: self.me,
data: data.clone(),
},
))
}
}
Action::SendPrivate(p, data) => messages.push((
p,
TripleMessage {
id: *id,
epoch: self.epoch,
from: self.me,
data: data.clone(),
},
)),
Action::Return(output) => {
tracing::info!(
id,
big_a = ?output.1.big_a.to_base58(),
big_b = ?output.1.big_b.to_base58(),
big_c = ?output.1.big_c.to_base58(),
"completed triple generation"
);
self.triples.insert(*id, output);
// Do not retain the protocol
break false;
}
}
Action::SendPrivate(p, data) => messages.push((
p,
TripleMessage {
id: *id,
epoch: self.epoch,
from: self.me,
data: data.clone(),
},
)),
Action::Return(output) => {
tracing::info!(
id,
big_a = ?output.1.big_a.into_base58(),
big_b = ?output.1.big_b.into_base58(),
big_c = ?output.1.big_c.into_base58(),
"completed triple generation"
);
self.triples.insert(*id, output);
// Do not retain the protocol
return false;
}
}
});
messages
result.map(|_| messages)
}
}
4 changes: 2 additions & 2 deletions node/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl NearPublicKeyExt for near_crypto::PublicKey {

pub trait AffinePointExt {
fn into_near_public_key(self) -> near_crypto::PublicKey;
fn into_base58(self) -> String;
fn to_base58(&self) -> String;
}

impl AffinePointExt for AffinePoint {
Expand All @@ -48,7 +48,7 @@ impl AffinePointExt for AffinePoint {
)
}

fn into_base58(self) -> String {
fn to_base58(&self) -> String {
let key = near_crypto::Secp256K1PublicKey::try_from(
&self.to_encoded_point(false).as_bytes()[1..65],
)
Expand Down
Loading