Skip to content

Commit

Permalink
Use batcher for DZKPs. Enable tests of DZKP batching.
Browse files Browse the repository at this point in the history
  • Loading branch information
andyleiserson committed Sep 5, 2024
1 parent 53ef056 commit e42d0b1
Show file tree
Hide file tree
Showing 12 changed files with 290 additions and 507 deletions.
11 changes: 2 additions & 9 deletions ipa-core/src/protocol/basics/mul/dzkp_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
protocol::{
basics::{mul::semi_honest::multiplication_protocol, SecureMul},
context::{
dzkp_field::DZKPCompatibleField, dzkp_validator::Segment, Context, DZKPContext,
dzkp_field::DZKPCompatibleField, dzkp_validator::Segment, Context,
DZKPUpgradedMaliciousContext,
},
prss::SharedRandomness,
Expand Down Expand Up @@ -81,11 +81,10 @@ impl<'a, F: Field + DZKPCompatibleField<N>, const N: usize>
#[cfg(all(test, unit_test))]
mod test {
use crate::{
error::Error,
ff::boolean::Boolean,
protocol::{
basics::SecureMul,
context::{dzkp_validator::DZKPValidator, Context, DZKPContext, UpgradableContext},
context::{dzkp_validator::DZKPValidator, Context, UpgradableContext},
RecordId,
},
rand::{thread_rng, Rng},
Expand All @@ -109,15 +108,9 @@ mod test {
.await
.unwrap();

// batch contains elements
assert!(matches!(mctx.is_verified(), Err(Error::ContextUnsafe(_))));

// validate all elements in the batch
validator.validate().await.unwrap();

// batch is empty now
assert!(mctx.is_verified().is_ok());

result
})
.await;
Expand Down
23 changes: 9 additions & 14 deletions ipa-core/src/protocol/basics/reveal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::{

use embed_doc_image::embed_doc_image;
use futures::{FutureExt, TryFutureExt};
use ipa_step::{Step, StepNarrow};

use crate::{
error::Error,
Expand All @@ -14,10 +13,10 @@ use crate::{
protocol::{
boolean::step::TwoHundredFiftySixBitOpStep,
context::{
dzkp_validator::DZKPValidator, Context, DZKPUpgradedMaliciousContext,
DZKPUpgradedSemiHonestContext, UpgradedMaliciousContext, UpgradedSemiHonestContext,
Context, DZKPContext, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext,
UpgradedMaliciousContext, UpgradedSemiHonestContext,
},
Gate, RecordId,
RecordId,
},
secret_sharing::{
replicated::{
Expand Down Expand Up @@ -382,21 +381,17 @@ where
S::generic_reveal(v, ctx, record_id, excluded)
}

pub async fn validated_partial_reveal<'fut, V, S, STEP>(
validator: V,
step: &'fut STEP,
pub async fn validated_partial_reveal<'fut, C, S>(
ctx: C,
record_id: RecordId,
excluded: Role,
v: &'fut S,
) -> Result<Option<<S as Reveal<V::Context>>::Output>, Error>
) -> Result<Option<<S as Reveal<C>>::Output>, Error>
where
V: DZKPValidator + 'fut,
S: Reveal<V::Context> + Send + Sync + ?Sized,
STEP: Step + Send + Sync + 'static,
Gate: StepNarrow<STEP>,
C: DZKPContext + 'fut,
S: Reveal<C> + Send + Sync + ?Sized,
{
let ctx = validator.context().narrow(step);
validator.validate_record(record_id).await?;
ctx.validate_record(record_id).await?;
partial_reveal(ctx, record_id, excluded, v).await
}

Expand Down
2 changes: 0 additions & 2 deletions ipa-core/src/protocol/context/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ impl<'a, B> Batcher<'a, B> {
///
/// # Panics
/// If the batcher contains more than one batch.
#[allow(dead_code)]
pub fn into_single_batch(mut self) -> B {
assert!(self.first_batch == 0);
assert!(self.batches.len() <= 1);
Expand All @@ -206,7 +205,6 @@ impl<'a, B> Batcher<'a, B> {
}
}

#[allow(dead_code)]
pub fn iter(&self) -> impl Iterator<Item = &BatchState<B>> {
self.batches.iter()
}
Expand Down
126 changes: 56 additions & 70 deletions ipa-core/src/protocol/context/dzkp_malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,99 +8,108 @@ use ipa_step::{Step, StepNarrow};

use crate::{
error::Error,
helpers::{ChannelId, Gateway, MpcMessage, MpcReceivingEnd, Role, SendingEnd, TotalRecords},
helpers::{MpcMessage, MpcReceivingEnd, Role, SendingEnd, TotalRecords},
protocol::{
context::{
dzkp_validator::{DZKPBatch, Segment},
batcher::Batcher,
dzkp_validator::{Batch, Segment},
prss::InstrumentedIndexedSharedRandomness,
Base, Context as ContextTrait, DZKPContext, InstrumentedSequentialSharedRandomness,
step::ZeroKnowledgeProofValidateStep,
Context as ContextTrait, DZKPContext, InstrumentedSequentialSharedRandomness,
MaliciousContext,
},
prss::Endpoint as PrssEndpoint,
Gate, RecordId,
},
seq_join::SeqJoin,
sync::Arc,
sync::{Arc, Mutex, Weak},
};

pub(super) type DzkpBatcher<'a> = Mutex<Batcher<'a, Batch>>;

/// Represents protocol context in malicious setting when using zero-knowledge proofs,
/// i.e. secure against one active adversary in 3 party MPC ring.
#[derive(Clone)]
pub struct DZKPUpgraded<'a> {
/// TODO (alex): Arc is required here because of the `TestWorld` structure. Real world
/// may operate with raw references and be more efficient
inner: Arc<DZKPUpgradedInner<'a>>,
gate: Gate,
total_records: TotalRecords,
batcher: Weak<DzkpBatcher<'a>>,
base_ctx: MaliciousContext<'a>,
}

impl<'a> DZKPUpgraded<'a> {
pub(super) fn new<S: Step + ?Sized>(
source: &Base<'a>,
malicious_step: &S,
batch: DZKPBatch,
) -> Self
where
Gate: StepNarrow<S>,
{
pub(super) fn new(batch: &Arc<DzkpBatcher<'a>>, base_ctx: MaliciousContext<'a>) -> Self {
Self {
inner: DZKPUpgradedInner::new(source, batch),
gate: source.gate().narrow(malicious_step),
total_records: source.total_records,
batcher: Arc::downgrade(batch),
base_ctx,
}
}

pub fn push(&self, record_id: RecordId, segment: Segment) {
self.with_batch(record_id, |batch| {
batch.push(self.base_ctx.gate().clone(), record_id, segment);
});
}

fn with_batch<C: FnOnce(&mut Batch) -> T, T>(&self, record_id: RecordId, action: C) -> T {
let batcher = self.batcher.upgrade().expect("Validator is active");

let mut batch = batcher.lock().unwrap();
let state = batch.get_batch(record_id);
(action)(&mut state.batch)
}
}

#[async_trait]
impl<'a> DZKPContext for DZKPUpgraded<'a> {
fn is_verified(&self) -> Result<(), Error> {
if self.inner.batch.is_empty() {
Ok(())
} else {
Err(Error::ContextUnsafe(format!("{self:?}")))
}
}

fn push(&self, record_id: RecordId, segment: Segment) {
self.inner.batch.push(self.gate.clone(), record_id, segment);
async fn validate_record(&self, record_id: RecordId) -> Result<(), Error> {
let validation_future = self
.batcher
.upgrade()
.expect("Validation batch is active")
.lock()
.unwrap()
.validate_record(record_id, |batch_idx, batch| {
batch.validate(
self.base_ctx
.narrow(&ZeroKnowledgeProofValidateStep::DZKPValidate(batch_idx))
.validator_context(),
)
});

validation_future.await
}
}

impl<'a> super::Context for DZKPUpgraded<'a> {
fn role(&self) -> Role {
self.inner.gateway.role()
self.base_ctx.role()
}

fn gate(&self) -> &Gate {
&self.gate
self.base_ctx.gate()

Check warning on line 87 in ipa-core/src/protocol/context/dzkp_malicious.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/context/dzkp_malicious.rs#L87

Added line #L87 was not covered by tests
}

fn narrow<S: Step + ?Sized>(&self, step: &S) -> Self
where
Gate: StepNarrow<S>,
{
Self {
inner: Arc::clone(&self.inner),
gate: self.gate.narrow(step),
total_records: self.total_records,
base_ctx: self.base_ctx.narrow(step),
..self.clone()
}
}

fn set_total_records<T: Into<TotalRecords>>(&self, total_records: T) -> Self {
Self {
inner: Arc::clone(&self.inner),
gate: self.gate.clone(),
total_records: self.total_records.overwrite(total_records),
base_ctx: self.base_ctx.set_total_records(total_records),
..self.clone()
}
}

fn total_records(&self) -> TotalRecords {
self.total_records
self.base_ctx.total_records()

Check warning on line 108 in ipa-core/src/protocol/context/dzkp_malicious.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/context/dzkp_malicious.rs#L108

Added line #L108 was not covered by tests
}

fn prss(&self) -> InstrumentedIndexedSharedRandomness<'_> {
let prss = self.inner.prss.indexed(self.gate());

InstrumentedIndexedSharedRandomness::new(prss, &self.gate, self.role())
self.base_ctx.prss()
}

fn prss_rng(
Expand All @@ -109,29 +118,21 @@ impl<'a> super::Context for DZKPUpgraded<'a> {
InstrumentedSequentialSharedRandomness<'_>,
InstrumentedSequentialSharedRandomness<'_>,
) {
let (left, right) = self.inner.prss.sequential(self.gate());
(
InstrumentedSequentialSharedRandomness::new(left, self.gate(), self.role()),
InstrumentedSequentialSharedRandomness::new(right, self.gate(), self.role()),
)
self.base_ctx.prss_rng()

Check warning on line 121 in ipa-core/src/protocol/context/dzkp_malicious.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/protocol/context/dzkp_malicious.rs#L121

Added line #L121 was not covered by tests
}

fn send_channel<M: MpcMessage>(&self, role: Role) -> SendingEnd<Role, M> {
self.inner
.gateway
.get_mpc_sender(&ChannelId::new(role, self.gate.clone()), self.total_records)
self.base_ctx.send_channel(role)
}

fn recv_channel<M: MpcMessage>(&self, role: Role) -> MpcReceivingEnd<M> {
self.inner
.gateway
.get_mpc_receiver(&ChannelId::new(role, self.gate.clone()))
self.base_ctx.recv_channel(role)
}
}

impl<'a> SeqJoin for DZKPUpgraded<'a> {
fn active_work(&self) -> NonZeroUsize {
self.inner.gateway.config().active_work()
self.base_ctx.active_work()
}
}

Expand All @@ -140,18 +141,3 @@ impl Debug for DZKPUpgraded<'_> {
write!(f, "DZKPMaliciousContext")
}
}
struct DZKPUpgradedInner<'a> {
prss: &'a PrssEndpoint,
gateway: &'a Gateway,
batch: DZKPBatch,
}

impl<'a> DZKPUpgradedInner<'a> {
fn new(base_context: &Base<'a>, batch: DZKPBatch) -> Arc<Self> {
Arc::new(DZKPUpgradedInner {
prss: base_context.inner.prss,
gateway: base_context.inner.gateway,
batch,
})
}
}
9 changes: 2 additions & 7 deletions ipa-core/src/protocol/context/dzkp_semi_honest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
},
protocol::{
context::{
dzkp_validator::Segment, Base, DZKPContext, InstrumentedIndexedSharedRandomness,
Base, DZKPContext, InstrumentedIndexedSharedRandomness,
InstrumentedSequentialSharedRandomness,
},
Gate, RecordId,
Expand Down Expand Up @@ -108,14 +108,9 @@ impl<'a, B: ShardBinding> SeqJoin for DZKPUpgraded<'a, B> {

#[async_trait]
impl<'a, B: ShardBinding> DZKPContext for DZKPUpgraded<'a, B> {
fn is_verified(&self) -> Result<(), Error> {
async fn validate_record(&self, _record_id: RecordId) -> Result<(), Error> {
Ok(())
}

fn push(&self, _record_id: RecordId, _segment: Segment) {
// in the semi-honest setting, the segment is not added
// therefore this function does nothing
}
}

impl<B: ShardBinding> Debug for DZKPUpgraded<'_, B> {
Expand Down
Loading

0 comments on commit e42d0b1

Please sign in to comment.