-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make active_work match records_per_batch #1316
Changes from 3 commits
b628bf2
a4c6f03
1a6e388
706dcbe
c5c0ccf
966160f
d2512f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -827,33 +827,150 @@ mod tests { | |
use bitvec::{order::Lsb0, prelude::BitArray, vec::BitVec}; | ||
use futures::{StreamExt, TryStreamExt}; | ||
use futures_util::stream::iter; | ||
use proptest::{prop_compose, proptest, sample::select}; | ||
use rand::{thread_rng, Rng}; | ||
use proptest::{prelude::Strategy, prop_oneof, proptest}; | ||
use rand::{distributions::Standard, prelude::Distribution}; | ||
|
||
use crate::{ | ||
error::Error, | ||
ff::{boolean::Boolean, Fp61BitPrime}, | ||
ff::{ | ||
boolean::Boolean, | ||
boolean_array::{BooleanArray, BA16, BA20, BA256, BA32, BA64, BA8}, | ||
Fp61BitPrime, | ||
}, | ||
protocol::{ | ||
basics::SecureMul, | ||
basics::{select, BooleanArrayMul, SecureMul}, | ||
context::{ | ||
dzkp_field::{DZKPCompatibleField, BLOCK_SIZE}, | ||
dzkp_validator::{ | ||
Batch, DZKPValidator, Segment, SegmentEntry, BIT_ARRAY_LEN, TARGET_PROOF_SIZE, | ||
}, | ||
Context, UpgradableContext, TEST_DZKP_STEPS, | ||
Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext, | ||
UpgradableContext, TEST_DZKP_STEPS, | ||
}, | ||
Gate, RecordId, | ||
}, | ||
rand::{thread_rng, Rng}, | ||
secret_sharing::{ | ||
replicated::semi_honest::AdditiveShare as Replicated, IntoShares, SharedValue, | ||
Vectorizable, | ||
}, | ||
seq_join::{seq_join, SeqJoin}, | ||
seq_join::seq_join, | ||
sharding::NotSharded, | ||
test_fixture::{join3v, Reconstruct, Runner, TestWorld}, | ||
}; | ||
|
||
async fn test_select_semi_honest<V>() | ||
where | ||
V: BooleanArray, | ||
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedSemiHonestContext<'a, NotSharded>>, | ||
Standard: Distribution<V>, | ||
{ | ||
let world = TestWorld::default(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just realized that we probably want to use the same seed for rng and test world (it is supported via There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I made an issue for this. It's definitely worth doing in general, but it doesn't seem all that important for this particular test case, where the input values should be unrelated to the behavior of the test. |
||
let context = world.contexts(); | ||
let mut rng = thread_rng(); | ||
|
||
let bit = rng.gen::<Boolean>(); | ||
let a = rng.gen::<V>(); | ||
let b = rng.gen::<V>(); | ||
|
||
let bit_shares = bit.share_with(&mut rng); | ||
let a_shares = a.share_with(&mut rng); | ||
let b_shares = b.share_with(&mut rng); | ||
|
||
let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map( | ||
|(ctx, (bit_share, (a_share, b_share)))| async move { | ||
let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); | ||
let sh_ctx = v.context(); | ||
|
||
let result = select( | ||
sh_ctx.set_total_records(1), | ||
RecordId::from(0), | ||
&bit_share, | ||
&a_share, | ||
&b_share, | ||
) | ||
.await?; | ||
|
||
v.validate().await?; | ||
|
||
Ok::<_, Error>(result) | ||
}, | ||
); | ||
|
||
let [ab0, ab1, ab2] = join3v(futures).await; | ||
|
||
let ab = [ab0, ab1, ab2].reconstruct(); | ||
|
||
assert_eq!(ab, if bit.into() { a } else { b }); | ||
} | ||
|
||
#[tokio::test] | ||
async fn select_semi_honest() { | ||
test_select_semi_honest::<BA8>().await; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it worth testing it for for weird types like BA3 and BA7 as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think so. I had BA4 but took it out because it's not |
||
test_select_semi_honest::<BA16>().await; | ||
test_select_semi_honest::<BA20>().await; | ||
test_select_semi_honest::<BA32>().await; | ||
test_select_semi_honest::<BA64>().await; | ||
test_select_semi_honest::<BA256>().await; | ||
} | ||
|
||
async fn test_select_malicious<V>() | ||
where | ||
V: BooleanArray, | ||
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedMaliciousContext<'a>>, | ||
Standard: Distribution<V>, | ||
{ | ||
let world = TestWorld::default(); | ||
let context = world.malicious_contexts(); | ||
let mut rng = thread_rng(); | ||
|
||
let bit = rng.gen::<Boolean>(); | ||
let a = rng.gen::<V>(); | ||
let b = rng.gen::<V>(); | ||
|
||
let bit_shares = bit.share_with(&mut rng); | ||
let a_shares = a.share_with(&mut rng); | ||
let b_shares = b.share_with(&mut rng); | ||
|
||
let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map( | ||
|(ctx, (bit_share, (a_share, b_share)))| async move { | ||
let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1); | ||
let m_ctx = v.context(); | ||
|
||
let result = select( | ||
m_ctx.set_total_records(1), | ||
RecordId::from(0), | ||
&bit_share, | ||
&a_share, | ||
&b_share, | ||
) | ||
.await?; | ||
|
||
v.validate().await?; | ||
|
||
Ok::<_, Error>(result) | ||
}, | ||
); | ||
|
||
let [ab0, ab1, ab2] = join3v(futures).await; | ||
|
||
let ab = [ab0, ab1, ab2].reconstruct(); | ||
|
||
assert_eq!(ab, if bit.into() { a } else { b }); | ||
} | ||
|
||
#[tokio::test] | ||
async fn select_malicious() { | ||
test_select_malicious::<BA8>().await; | ||
test_select_malicious::<BA16>().await; | ||
test_select_malicious::<BA20>().await; | ||
test_select_malicious::<BA32>().await; | ||
test_select_malicious::<BA64>().await; | ||
test_select_malicious::<BA256>().await; | ||
} | ||
|
||
#[tokio::test] | ||
async fn dzkp_malicious() { | ||
async fn two_multiplies_malicious() { | ||
const COUNT: usize = 32; | ||
let mut rng = thread_rng(); | ||
|
||
|
@@ -914,8 +1031,8 @@ mod tests { | |
} | ||
|
||
/// test for testing `validated_seq_join` | ||
/// similar to `complex_circuit` in `validator.rs` | ||
async fn complex_circuit_dzkp( | ||
/// similar to `complex_circuit` in `validator.rs` (which has a more detailed comment) | ||
async fn chained_multiplies_dzkp( | ||
count: usize, | ||
max_multiplications_per_gate: usize, | ||
) -> Result<(), Error> { | ||
|
@@ -945,7 +1062,7 @@ mod tests { | |
.map(|(ctx, input_shares)| async move { | ||
let v = ctx | ||
.set_total_records(count - 1) | ||
.dzkp_validator(TEST_DZKP_STEPS, ctx.active_work().get()); | ||
.dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate); | ||
let m_ctx = v.context(); | ||
|
||
let m_results = v | ||
|
@@ -1021,19 +1138,24 @@ mod tests { | |
Ok(()) | ||
} | ||
|
||
prop_compose! { | ||
fn arb_count_and_chunk()((log_count, log_multiplication_amount) in select(&[(5,5),(7,5),(5,8)])) -> (usize, usize) { | ||
(1usize<<log_count, 1usize<<log_multiplication_amount) | ||
} | ||
fn record_count_strategy() -> impl Strategy<Value = usize> { | ||
prop_oneof![1usize..=512, (1usize..=9).prop_map(|i| 1usize << i)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am being stupid and it is 12pm now - is it really testing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What it does is:
Two reasons to focus on powers of two, although I didn't put a huge amount of thought into this (e.g. it just occurred to me to add
I think it is clearer to have the two |
||
} | ||
|
||
fn max_multiplications_per_gate_strategy() -> impl Strategy<Value = usize> { | ||
prop_oneof![1usize..=128, (1usize..=7).prop_map(|i| 1usize << i)] | ||
} | ||
|
||
proptest! { | ||
#[test] | ||
fn test_complex_circuit_dzkp((count, multiplication_amount) in arb_count_and_chunk()){ | ||
let future = async { | ||
let _ = complex_circuit_dzkp(count, multiplication_amount).await; | ||
}; | ||
tokio::runtime::Runtime::new().unwrap().block_on(future); | ||
fn test_chained_multiplies_dzkp( | ||
record_count in record_count_strategy(), | ||
max_multiplications_per_gate in max_multiplications_per_gate_strategy(), | ||
) { | ||
println!("record_count {record_count} batch {max_multiplications_per_gate}"); | ||
tokio::runtime::Runtime::new().unwrap().block_on(async { | ||
chained_multiplies_dzkp(record_count, max_multiplications_per_gate).await.unwrap(); | ||
}); | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
use std::{ | ||
any::type_name, | ||
cmp::max, | ||
fmt::{Debug, Formatter}, | ||
num::NonZeroUsize, | ||
}; | ||
|
@@ -174,13 +175,21 @@ pub(super) type MacBatcher<'a, F> = Mutex<Batcher<'a, validator::Malicious<'a, F | |
pub struct Upgraded<'a, F: ExtendableField> { | ||
batch: Weak<MacBatcher<'a, F>>, | ||
base_ctx: Context<'a>, | ||
active_work: NonZeroUsize, | ||
} | ||
|
||
impl<'a, F: ExtendableField> Upgraded<'a, F> { | ||
pub(super) fn new(batch: &Arc<MacBatcher<'a, F>>, ctx: Context<'a>) -> Self { | ||
pub(super) fn new(batch: &Arc<MacBatcher<'a, F>>, base_ctx: Context<'a>) -> Self { | ||
// Adjust active_work to be at least records_per_batch. The MAC validator | ||
// currently configures the batcher with records_per_batch = active_work, which | ||
// makes this adjustment a no-op, but we do it this way to match the DZKP validator. | ||
let records_per_batch = batch.lock().unwrap().records_per_batch(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would it be better to assert this to make sure we don't miss the misalignment between MAC and ZKP? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, I'm fine with that. |
||
let active_work = | ||
NonZeroUsize::new(max(base_ctx.active_work().get(), records_per_batch)).unwrap(); | ||
Self { | ||
batch: Arc::downgrade(batch), | ||
base_ctx: ctx, | ||
base_ctx, | ||
active_work, | ||
} | ||
} | ||
|
||
|
@@ -297,7 +306,7 @@ impl<'a, F: ExtendableField> super::Context for Upgraded<'a, F> { | |
|
||
impl<'a, F: ExtendableField> SeqJoin for Upgraded<'a, F> { | ||
fn active_work(&self) -> NonZeroUsize { | ||
self.base_ctx.active_work() | ||
self.active_work | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -217,7 +217,7 @@ impl<'a, F: ExtendableField> BatchValidator<'a, F> { | |
|
||
// TODO: Right now we set the batch work to be equal to active_work, | ||
// but it does not need to be. We can make this configurable if needed. | ||
let records_per_batch = ctx.active_work().get().min(total_records.get()); | ||
let records_per_batch = ctx.active_work().get(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reducing if larger than |
||
|
||
Self { | ||
protocol_ctx: ctx.narrow(&Step::MaliciousProtocol), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should probably log the fact that we've adjusted it - maybe at the debug level