Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Breakdown key simplification cleanups #725

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/bin/ipa_bench/cmd.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use crate::sample::Sample;

use super::gen_events::generate_events;

use crate::{gen_events::generate_events, sample::Sample};
use clap::Parser;
use ipa::cli::Verbosity;
use rand::{rngs::StdRng, SeedableRng};
Expand Down
7 changes: 4 additions & 3 deletions src/bin/ipa_bench/gen_events.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::models::{Epoch, Event, EventTimestamp, GenericReport, MatchKey, Number};

use super::sample::Sample;
use crate::{
models::{Epoch, Event, EventTimestamp, GenericReport, MatchKey, Number},
sample::Sample,
};
use bitvec::view::BitViewSized;
use rand::{
distributions::{Bernoulli, Distribution},
Expand Down
3 changes: 1 addition & 2 deletions src/bin/ipa_bench/sample.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use crate::config::Config;
use rand::{
distributions::{Distribution, WeightedIndex},
CryptoRng, Rng, RngCore,
};
use std::time::Duration;

use crate::config::Config;

pub struct Sample<'a> {
config: &'a Config,

Expand Down
2 changes: 2 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ pub enum Error {
InvalidQueryParameter(String),
#[error("invalid report: {0}")]
InvalidReport(#[from] InvalidReportError),
#[error("unsupported: {0}")]
Unsupported(String),
}

impl Default for Error {
Expand Down
11 changes: 7 additions & 4 deletions src/protocol/attribution/aggregate_credit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
protocol::{
attribution::input::{MCAggregateCreditInputRow, MCAggregateCreditOutputRow},
basics::ZeroPositions,
context::{Context, UpgradableContext, UpgradedContext, Validator},
context::{UpgradableContext, UpgradedContext, Validator},
modulus_conversion::{convert_bit, convert_bit_local},
sort::{check_everything, generate_permutation::ShuffledPermutationWrapper},
step::BitOpStep,
Expand Down Expand Up @@ -52,9 +52,12 @@ where

if max_breakdown_key <= SIMPLE_AGGREGATION_BREAK_EVEN_POINT {
let res = simple_aggregate_credit(m_ctx, capped_credits, max_breakdown_key).await?;
return Ok((validator, res));
Ok((validator, res))
} else {
Err(Error::Unsupported(
format!("query uses {max_breakdown_key} breakdown keys; only {SIMPLE_AGGREGATION_BREAK_EVEN_POINT} are supported")
))
}
panic!()
}

async fn simple_aggregate_credit<F, C, I, T, BK>(
Expand All @@ -65,7 +68,7 @@ async fn simple_aggregate_credit<F, C, I, T, BK>(
where
F: PrimeField,
I: Iterator<Item = MCAggregateCreditInputRow<F, T>> + ExactSizeIterator + Send,
C: Context + UpgradedContext<F, Share = T>,
C: UpgradedContext<F, Share = T>,
T: LinearSecretSharing<F> + BasicProtocols<C, F> + Serializable + 'static,
BK: GaloisField,
{
Expand Down
42 changes: 21 additions & 21 deletions src/protocol/attribution/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,22 @@ pub struct ApplyAttributionWindowInputRow<F: Field, BK: GaloisField> {
}

#[derive(Debug)]
pub struct MCApplyAttributionWindowInputRow<F: Field, T: LinearSecretSharing<F>> {
pub timestamp: T,
pub is_trigger_report: T,
pub helper_bit: T,
pub struct MCApplyAttributionWindowInputRow<F: Field, S: LinearSecretSharing<F>> {
pub timestamp: S,
pub is_trigger_report: S,
pub helper_bit: S,
pub breakdown_key: Vec<Replicated<Gf2>>,
pub trigger_value: T,
pub trigger_value: S,
_marker: PhantomData<F>,
}

impl<F: Field, T: LinearSecretSharing<F>> MCApplyAttributionWindowInputRow<F, T> {
impl<F: Field, S: LinearSecretSharing<F>> MCApplyAttributionWindowInputRow<F, S> {
pub fn new(
timestamp: T,
is_trigger_report: T,
helper_bit: T,
timestamp: S,
is_trigger_report: S,
helper_bit: S,
breakdown_key: Vec<Replicated<Gf2>>,
trigger_value: T,
trigger_value: S,
) -> Self {
Self {
timestamp,
Expand All @@ -61,7 +61,7 @@ impl<F: Field, T: LinearSecretSharing<F>> MCApplyAttributionWindowInputRow<F, T>
}
}

pub type MCApplyAttributionWindowOutputRow<F, T> = MCAccumulateCreditInputRow<F, T>;
pub type MCApplyAttributionWindowOutputRow<F, S> = MCAccumulateCreditInputRow<F, S>;

//
// `accumulate_credit` protocol
Expand All @@ -76,22 +76,22 @@ pub struct AccumulateCreditInputRow<F: Field, BK: GaloisField> {
}

#[derive(Debug)]
pub struct MCAccumulateCreditInputRow<F: Field, T: LinearSecretSharing<F>> {
pub is_trigger_report: T,
pub helper_bit: T,
pub active_bit: T,
pub struct MCAccumulateCreditInputRow<F: Field, S: LinearSecretSharing<F>> {
pub is_trigger_report: S,
pub helper_bit: S,
pub active_bit: S,
pub breakdown_key: Vec<Replicated<Gf2>>,
pub trigger_value: T,
pub trigger_value: S,
_marker: PhantomData<F>,
}

impl<F: Field, T: LinearSecretSharing<F>> MCAccumulateCreditInputRow<F, T> {
impl<F: Field, S: LinearSecretSharing<F>> MCAccumulateCreditInputRow<F, S> {
pub fn new(
is_trigger_report: T,
helper_bit: T,
active_bit: T,
is_trigger_report: S,
helper_bit: S,
active_bit: S,
breakdown_key: Vec<Replicated<Gf2>>,
trigger_value: T,
trigger_value: S,
) -> Self {
Self {
is_trigger_report,
Expand Down
3 changes: 1 addition & 2 deletions src/protocol/attribution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use crate::{
basics::SecureMul,
boolean::{bitwise_equal::bitwise_equal_gf2, or::or, RandomBits},
context::{Context, UpgradableContext, UpgradedContext, Validator},
ipa::{ArithmeticallySharedIPAInputs, BinarySharedIPAInputs},
modulus_conversion::{convert_bit, convert_bit_local, BitConversionTriple},
sort::generate_permutation::ShuffledPermutationWrapper,
step, BasicProtocols, RecordId,
Expand All @@ -36,8 +37,6 @@ use crate::{
use futures::future::try_join;
use std::iter::{empty, once, zip};

use super::ipa::{ArithmeticallySharedIPAInputs, BinarySharedIPAInputs};

/// Performs a set of attribution protocols on the sorted IPA input.
///
/// # Errors
Expand Down
6 changes: 4 additions & 2 deletions src/protocol/boolean/bitwise_less_than_prime.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use super::{any_ones, or::or};
use crate::{
error::Error,
ff::PrimeField,
protocol::{
boolean::multiply_all_shares, context::Context, step::BitOpStep, BasicProtocols, RecordId,
boolean::{any_ones, multiply_all_shares, or::or},
context::Context,
step::BitOpStep,
BasicProtocols, RecordId,
},
secret_sharing::Linear as LinearSecretSharing,
};
Expand Down
9 changes: 6 additions & 3 deletions src/protocol/boolean/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
use crate::{
error::Error,
ff::{Field, PrimeField},
protocol::{basics::SecureMul, BasicProtocols},
protocol::{
basics::{SecureMul, ShareKnownValue},
context::Context,
step::BitOpStep,
BasicProtocols, RecordId,
},
secret_sharing::{Linear as LinearSecretSharing, SecretSharing},
};
use std::iter::repeat;

use super::{basics::ShareKnownValue, context::Context, step::BitOpStep, RecordId};

mod add_constant;
mod bit_decomposition;
pub mod bitwise_equal;
Expand Down
7 changes: 5 additions & 2 deletions src/protocol/boolean/solved_bits.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use super::{bitwise_less_than_prime::BitwiseLessThanPrime, RandomBits};
use crate::{
error::Error,
ff::{Field, PrimeField},
protocol::{context::Context, BasicProtocols, RecordId},
protocol::{
boolean::{bitwise_less_than_prime::BitwiseLessThanPrime, RandomBits},
context::Context,
BasicProtocols, RecordId,
},
secret_sharing::{
replicated::malicious::{
AdditiveShare as MaliciousReplicated, DowngradeMalicious, ExtendableField,
Expand Down
10 changes: 7 additions & 3 deletions src/protocol/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ pub mod validator;
use crate::{
error::Error,
helpers::{ChannelId, Gateway, Message, ReceivingEnd, Role, SendingEnd, TotalRecords},
protocol::{basics::ZeroPositions, prss::Endpoint as PrssEndpoint, step, NoRecord, RecordId},
protocol::{
basics::ZeroPositions,
prss::Endpoint as PrssEndpoint,
step,
step::{Gate, StepNarrow},
NoRecord, RecordId,
},
secret_sharing::{
replicated::{malicious::ExtendableField, semi_honest::AdditiveShare as Replicated},
SecretSharing,
Expand All @@ -23,8 +29,6 @@ pub use semi_honest::{Context as SemiHonestContext, Upgraded as UpgradedSemiHone
pub use upgrade::{UpgradeContext, UpgradeToMalicious};
pub use validator::Validator;

use super::step::{Gate, StepNarrow};

/// Context used by each helper to perform secure computation. Provides access to shared randomness
/// generator and communication channel.
pub trait Context: Clone + Send + Sync + SeqJoin {
Expand Down
71 changes: 23 additions & 48 deletions src/protocol/context/upgrade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ use crate::{
};
use async_trait::async_trait;
use futures::future::{try_join, try_join3};
use std::{
iter::{repeat, zip},
marker::PhantomData,
};
use std::marker::PhantomData;

/// Special context type used for malicious upgrades.
///
Expand All @@ -44,17 +41,6 @@ use std::{
/// let _ = <UpgradeContext<C<'_, F>, F, NoRecord> as UpgradeToMalicious<Vec<Replicated<F>>, _>>::upgrade;
/// let _ = <UpgradeContext<C<'_, F>, F, NoRecord> as UpgradeToMalicious<(Vec<Replicated<F>>, Vec<Replicated<F>>), _>>::upgrade;
/// ```
///
/// ```compile_fail
/// use ipa::protocol::{context::{UpgradeContext, UpgradeToMalicious, UpgradedMaliciousContext as C}, NoRecord, RecordId};
/// use ipa::ff::Fp32BitPrime as F;
/// use ipa::secret_sharing::replicated::{
/// malicious::AdditiveShare as MaliciousReplicated, semi_honest::AdditiveShare as Replicated,
/// };
/// // This can't be upgraded with a record-bound context because the record ID
/// // is used internally for vector indexing.
/// let _ = <UpgradeContext<C<'_, F>, F, RecordId> as UpgradeToMalicious<Vec<Replicated<F>>, _>>::upgrade;
/// ```
pub struct UpgradeContext<
'a,
C: UpgradedContext<F>,
Expand Down Expand Up @@ -195,18 +181,21 @@ impl AsRef<str> for Upgrade2DVectors {
}

#[async_trait]
impl<'a, C, F, T, M> UpgradeToMalicious<'a, Vec<T>, Vec<M>> for UpgradeContext<'a, C, F, NoRecord>
impl<'a, C, F, T, M> UpgradeToMalicious<'a, T, Vec<M>> for UpgradeContext<'a, C, F, NoRecord>
where
C: UpgradedContext<F>,
F: ExtendableField,
T: Send + 'static,
T: IntoIterator + Send + 'static,
T::IntoIter: ExactSizeIterator + Send,
T::Item: Send + 'static,
M: Send + 'static,
for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, T, M>,
for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, T::Item, M>,
{
async fn upgrade(self, input: Vec<T>) -> Result<Vec<M>, Error> {
let ctx = self.ctx.set_total_records(input.len());
async fn upgrade(self, input: T) -> Result<Vec<M>, Error> {
let iter = input.into_iter();
let ctx = self.ctx.set_total_records(iter.len());
let ctx_ref = &ctx;
ctx.try_join(input.into_iter().enumerate().map(|(i, share)| async move {
ctx.try_join(iter.enumerate().map(|(i, share)| async move {
// TODO: make it a bit more ergonomic to call with record id bound
UpgradeContext::new(ctx_ref.clone(), RecordId::from(i))
.upgrade(share)
Expand All @@ -220,8 +209,7 @@ where
/// It assumes the inner vector is much smaller (e.g. multiple bits per record) than the outer vector (e.g. records)
/// Each inner vector element uses a different context and outer vector shares a context for the same inner vector index
#[async_trait]
impl<'a, C, F, T, M> UpgradeToMalicious<'a, Vec<Vec<T>>, Vec<Vec<M>>>
for UpgradeContext<'a, C, F, NoRecord>
impl<'a, C, F, T, M> UpgradeToMalicious<'a, Vec<T>, Vec<M>> for UpgradeContext<'a, C, F, RecordId>
where
C: UpgradedContext<F>,
F: ExtendableField,
Expand All @@ -230,28 +218,16 @@ where
for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, T, M>,
{
/// # Panics
/// Panics if input is empty
async fn upgrade(self, input: Vec<Vec<T>>) -> Result<Vec<Vec<M>>, Error> {
let num_records = input.len();
let num_columns = input.first().map_or(1, Vec::len);
assert_ne!(num_columns, 0);
let ctx = self.ctx.set_total_records(num_records);
/// Only vectors with 64 or less items are supported; larger vectors cause a panic.
async fn upgrade(self, input: Vec<T>) -> Result<Vec<M>, Error> {
let ctx_ref = &self.ctx;
let all_ctx = (0..num_columns).map(|idx| ctx.narrow(&Upgrade2DVectors::V(idx)));

ctx_ref
.try_join(zip(repeat(all_ctx), input.into_iter()).enumerate().map(
|(record_idx, (all_ctx, one_input))| async move {
// This inner join is truly concurrent.
ctx_ref
.parallel_join(zip(all_ctx, one_input).map(|(ctx, share)| async move {
UpgradeContext::new(ctx, RecordId::from(record_idx))
.upgrade(share)
.await
}))
.await
},
))
let record_id = self.record_binding;
self.ctx
.parallel_join(input.into_iter().enumerate().map(|(i, share)| async move {
UpgradeContext::new(ctx_ref.narrow(&Upgrade2DVectors::V(i)), record_id)
.upgrade(share)
.await
}))
.await
}
}
Expand Down Expand Up @@ -395,15 +371,14 @@ where
// context. This is only used for tests where the protocol takes a single `Replicated<F>` input.
#[cfg(test)]
#[async_trait]
impl<'a, C, F, T, M> UpgradeToMalicious<'a, T, M> for UpgradeContext<'a, C, F, NoRecord>
impl<'a, C, F, M> UpgradeToMalicious<'a, Replicated<F>, M> for UpgradeContext<'a, C, F, NoRecord>
where
C: UpgradedContext<F>,
F: ExtendableField,
T: Send + 'static,
M: 'static,
for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, T, M>,
for<'u> UpgradeContext<'u, C, F, RecordId>: UpgradeToMalicious<'u, Replicated<F>, M>,
{
async fn upgrade(self, input: T) -> Result<M, Error> {
async fn upgrade(self, input: Replicated<F>) -> Result<M, Error> {
let ctx = if self.ctx.total_records().is_unspecified() {
self.ctx.set_total_records(1)
} else {
Expand Down
Loading