Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Event generation using test_mpc and IPA integration test + migrate to HTTP2 #647

Merged
merged 13 commits into from
May 22, 2023
37 changes: 15 additions & 22 deletions benches/oneshot/ipa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,8 @@ use ipa::{
ff::Fp32BitPrime,
helpers::{query::IpaQueryConfig, GatewayConfig},
test_fixture::{
ipa::{
generate_random_user_records_in_reverse_chronological_order, ipa_in_the_clear,
test_ipa, IpaSecurityModel,
},
TestWorld, TestWorldConfig,
ipa::{ipa_in_the_clear, test_ipa, IpaSecurityModel},
EventGenerator, EventGeneratorConfig, TestWorld, TestWorldConfig,
},
};
use rand::{rngs::StdRng, thread_rng, Rng, SeedableRng};
Expand Down Expand Up @@ -38,7 +35,7 @@ struct Args {
query_size: usize,
/// The maximum number of records for each person.
#[arg(short = 'u', long, default_value = "50")]
records_per_user: usize,
records_per_user: u32,
/// The contribution cap for each person.
#[arg(short = 'c', long, default_value = "3")]
per_user_cap: u32,
Expand Down Expand Up @@ -109,23 +106,19 @@ async fn run(args: Args) -> Result<(), Error> {
"Using random seed: {seed} for {q} records",
q = args.query_size
);
let mut rng = StdRng::seed_from_u64(seed);

let mut raw_data = Vec::with_capacity(args.query_size + args.records_per_user);
while raw_data.len() < args.query_size {
let mut records_for_user = generate_random_user_records_in_reverse_chronological_order(
&mut rng,
args.records_per_user,
args.breakdown_keys,
args.max_trigger_value,
);
records_for_user.truncate(args.query_size - raw_data.len());
raw_data.append(&mut records_for_user);
}
let rng = StdRng::seed_from_u64(seed);
let raw_data = EventGenerator::with_config(
rng,
EventGeneratorConfig {
max_trigger_value: NonZeroU32::try_from(args.max_trigger_value).unwrap(),
max_breakdown_key: NonZeroU32::try_from(args.breakdown_keys).unwrap(),
max_events_per_user: NonZeroU32::try_from(args.records_per_user).unwrap(),
..Default::default()
},
)
.take(args.query_size)
.collect::<Vec<_>>();

// Sort the records in chronological order
// This is part of the IPA spec. Callers should do this before sending a batch of records in for processing.
raw_data.sort_unstable_by(|a, b| a.timestamp.cmp(&b.timestamp));
let expected_results =
ipa_in_the_clear(&raw_data, args.per_user_cap, args.attribution_window());

Expand Down
139 changes: 105 additions & 34 deletions src/bin/test_mpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use hyper::http::uri::Scheme;
use ipa::{
cli::{
playbook::{secure_mul, semi_honest, InputSource},
Verbosity,
CsvSerializer, Verbosity,
},
config::NetworkConfig,
ff::{Field, FieldType, Fp31, Fp32BitPrime, Serializable},
Expand All @@ -16,9 +16,21 @@ use ipa::{
test_fixture::{
config::TestConfigBuilder,
ipa::{ipa_in_the_clear, TestRawDataRecord},
EventGenerator, EventGeneratorConfig,
},
};
use std::{error::Error, fmt::Debug, fs, ops::Add, path::PathBuf, time::Duration};
use rand::thread_rng;
use std::{
error::Error,
fmt::Debug,
fs,
fs::OpenOptions,
io,
io::{stdout, Write},
ops::Add,
path::PathBuf,
time::Duration,
};
use tokio::time::sleep;

#[derive(Debug, Parser)]
Expand Down Expand Up @@ -77,7 +89,29 @@ enum TestAction {
/// Execute end-to-end multiplication.
Multiply,
/// Execute IPA in semi-honest majority setting
SemiHonestIPA,
SemiHonestIpa(IpaQueryConfig),
/// Generate inputs for IPA
GenIpaInputs {
/// Number of records to generate
#[clap(long, short = 'n')]
count: u32,

/// The destination file for generated records
#[arg(long)]
output_file: Option<PathBuf>,

#[clap(flatten)]
gen_args: EventGeneratorConfig,
},
}

#[derive(Debug, clap::Args)]
struct GenInputArgs {
/// Maximum records per user
#[clap(long)]
max_per_user: u32,
/// number of breakdowns
breakdowns: u32,
}

async fn clients_ready(clients: &[MpcHelperClient; 3]) -> bool {
Expand Down Expand Up @@ -124,7 +158,7 @@ where
i += 1;
}

tracing::info!("{table}");
tracing::info!("\n{table}\n");

assert!(
mismatch.is_empty(),
Expand All @@ -135,61 +169,98 @@ where

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
fn make_clients(disable_https: bool, config_path: Option<&PathBuf>) -> [MpcHelperClient; 3] {
let scheme = if disable_https {
let args = Args::parse();
let _handle = args.logging.setup_logging();

let make_clients = || async {
let scheme = if args.disable_https {
Scheme::HTTP
} else {
Scheme::HTTPS
};

let config_path = args.network.as_deref();
let mut wait = args.wait;

let config = if let Some(path) = config_path {
NetworkConfig::from_toml_str(&fs::read_to_string(path).unwrap()).unwrap()
} else {
TestConfigBuilder::with_default_test_ports().build().network
}
.override_scheme(&scheme);
MpcHelperClient::from_conf(&config)
}
let clients = MpcHelperClient::from_conf(&config);
while wait > 0 && !clients_ready(&clients).await {
tracing::debug!("waiting for servers to come up");
sleep(Duration::from_secs(1)).await;
wait -= 1;
}

let args = Args::parse();
let _handle = args.logging.setup_logging();
clients
};

let clients = make_clients(args.disable_https, args.network.as_ref());
match args.action {
TestAction::Multiply => multiply(&args, &make_clients().await).await,
TestAction::SemiHonestIpa(config) => {
semi_honest_ipa(&args, &config, &make_clients().await).await
}
TestAction::GenIpaInputs {
count,
output_file,
gen_args,
} => gen_inputs(count, output_file, gen_args).unwrap(),
};

let mut wait = args.wait;
while wait > 0 && !clients_ready(&clients).await {
println!("waiting for servers to come up");
sleep(Duration::from_secs(1)).await;
wait -= 1;
}
Ok(())
}

match args.action {
TestAction::Multiply => multiply(args, &clients).await,
TestAction::SemiHonestIPA => semi_honest_ipa(args, &clients).await,
fn gen_inputs(
count: u32,
output_file: Option<PathBuf>,
args: EventGeneratorConfig,
) -> io::Result<()> {
let event_gen = EventGenerator::with_config(thread_rng(), args).take(count as usize);
let mut writer: Box<dyn Write> = if let Some(path) = output_file {
Box::new(OpenOptions::new().write(true).create_new(true).open(path)?)
} else {
Box::new(stdout().lock())
};

for event in event_gen {
event.to_csv(&mut writer)?;
writer.write(&[b'\n'])?;
}

Ok(())
}

async fn semi_honest_ipa(args: Args, helper_clients: &[MpcHelperClient; 3]) {
async fn semi_honest_ipa(
args: &Args,
ipa_query_config: &IpaQueryConfig,
helper_clients: &[MpcHelperClient; 3],
) {
let input = InputSource::from(&args.input);
let ipa_query_config = IpaQueryConfig {
per_user_credit_cap: 3,
max_breakdown_key: 3,
num_multi_bits: 3,
attribution_window_seconds: None,
};
let query_type = QueryType::Ipa(ipa_query_config.clone());
let query_config = QueryConfig {
field_type: args.input.field,
query_type,
};
let query_id = helper_clients[0].create_query(query_config).await.unwrap();
let input_rows = input.iter::<TestRawDataRecord>().collect::<Vec<_>>();
let expected = ipa_in_the_clear(
&input_rows,
ipa_query_config.per_user_credit_cap,
ipa_query_config.attribution_window_seconds,
);
let expected = {
let mut r = ipa_in_the_clear(
&input_rows,
ipa_query_config.per_user_credit_cap,
ipa_query_config.attribution_window_seconds,
);

// pad the output vector to the max breakdown key, to make sure it is aligned with the MPC results
// truncate shouldn't happen unless in_the_clear is badly broken
r.resize(
usize::try_from(ipa_query_config.max_breakdown_key).unwrap(),
0,
);
r
};

let actual = match args.input.field {
FieldType::Fp31 => {
Expand All @@ -210,7 +281,7 @@ async fn semi_honest_ipa(args: Args, helper_clients: &[MpcHelperClient; 3]) {
}

async fn multiply_in_field<F: Field>(
args: Args,
args: &Args,
helper_clients: &[MpcHelperClient; 3],
query_id: QueryId,
) where
Expand All @@ -226,7 +297,7 @@ async fn multiply_in_field<F: Field>(
validate(expected, actual);
}

async fn multiply(args: Args, helper_clients: &[MpcHelperClient; 3]) {
async fn multiply(args: &Args, helper_clients: &[MpcHelperClient; 3]) {
let query_config = QueryConfig {
field_type: args.input.field,
query_type: QueryType::TestMultiply,
Expand Down
22 changes: 22 additions & 0 deletions src/cli/csv.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use std::{io, io::Write};

pub trait Serializer {
/// Converts self into a CSV-encoded byte string
/// ## Errors
/// If this conversion fails due to insufficient capacity in `buf` or other reasons.
fn to_csv<W: Write>(&self, buf: &mut W) -> io::Result<()>;
}

#[cfg(any(test, feature = "test-fixture"))]
impl Serializer for crate::test_fixture::ipa::TestRawDataRecord {
fn to_csv<W: Write>(&self, buf: &mut W) -> io::Result<()> {
// fmt::write is cool because it does not allocate when serializing integers
write!(buf, "{},", self.timestamp)?;
write!(buf, "{},", self.user_id)?;
write!(buf, "{},", u8::from(self.is_trigger_report))?;
write!(buf, "{},", self.breakdown_key)?;
write!(buf, "{}", self.trigger_value)?;

Ok(())
}
}
2 changes: 2 additions & 0 deletions src/cli/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod csv;
#[cfg(feature = "web-app")]
mod keygen;
mod metric_collector;
Expand All @@ -7,6 +8,7 @@ pub mod playbook;
mod test_setup;
mod verbosity;

pub use csv::Serializer as CsvSerializer;
#[cfg(feature = "web-app")]
pub use keygen::{keygen, KeygenArgs};
pub use metric_collector::{install_collector, CollectorHandle};
Expand Down
16 changes: 14 additions & 2 deletions src/cli/test_setup.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{
cli::{keygen, KeygenArgs},
config::{NetworkConfig, PeerConfig},
config::{ClientConfig, NetworkConfig, PeerConfig},
};
use clap::Args;
use std::{
Expand All @@ -24,6 +24,10 @@ pub struct TestSetupArgs {
#[arg(long)]
disable_https: bool,

/// Configure helper clients to use HTTP1 instead of default HTTP version (HTTP2 at the moment).
#[arg(long, default_value_t = false)]
use_http1: bool,

#[arg(short, long, num_args = 3, value_name = "PORT", default_values = vec!["3000", "3001", "3002"])]
ports: Vec<u16>,
}
Expand Down Expand Up @@ -66,7 +70,15 @@ pub fn test_setup(args: TestSetupArgs) -> Result<(), Box<dyn Error>> {
.try_into()
.unwrap();

let network_config = toml::to_string_pretty(&NetworkConfig { peers })?;
let client_config = if args.use_http1 {
ClientConfig::use_http1()
} else {
ClientConfig::default()
};
let network_config = toml::to_string_pretty(&NetworkConfig {
peers,
client: client_config,
})?;

fs::write(args.output_dir.join("network.toml"), network_config)?;

Expand Down
Loading