Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make active_work match records_per_batch #1316

Merged
merged 7 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ipa-core/src/protocol/context/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ impl<'a, B> Batcher<'a, B> {
self.total_records = self.total_records.overwrite(total_records.into());
}

pub fn records_per_batch(&self) -> usize {
self.records_per_batch
}

fn batch_offset(&self, record_id: RecordId) -> usize {
let batch_index = usize::from(record_id) / self.records_per_batch;
batch_index
Expand Down
11 changes: 10 additions & 1 deletion ipa-core/src/protocol/context/dzkp_malicious.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{
cmp::max,
fmt::{Debug, Formatter},
num::NonZeroUsize,
};
Expand Down Expand Up @@ -29,16 +30,24 @@ use crate::{
pub struct DZKPUpgraded<'a> {
validator_inner: Weak<MaliciousDZKPValidatorInner<'a>>,
base_ctx: MaliciousContext<'a>,
active_work: NonZeroUsize,
}

impl<'a> DZKPUpgraded<'a> {
pub(super) fn new(
validator_inner: &Arc<MaliciousDZKPValidatorInner<'a>>,
base_ctx: MaliciousContext<'a>,
) -> Self {
// Adjust active_work to be at least records_per_batch. If it is less, we will
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably log the fact that we've adjusted it - maybe at the debug level

// stall, since every record in the batch remains incomplete until the batch is
// validated.
let records_per_batch = validator_inner.batcher.lock().unwrap().records_per_batch();
let active_work =
NonZeroUsize::new(max(base_ctx.active_work().get(), records_per_batch)).unwrap();
Self {
validator_inner: Arc::downgrade(validator_inner),
base_ctx,
active_work,
}
}

Expand Down Expand Up @@ -130,7 +139,7 @@ impl<'a> super::Context for DZKPUpgraded<'a> {

impl<'a> SeqJoin for DZKPUpgraded<'a> {
fn active_work(&self) -> NonZeroUsize {
self.base_ctx.active_work()
self.active_work
}
}

Expand Down
160 changes: 141 additions & 19 deletions ipa-core/src/protocol/context/dzkp_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -827,33 +827,150 @@ mod tests {
use bitvec::{order::Lsb0, prelude::BitArray, vec::BitVec};
use futures::{StreamExt, TryStreamExt};
use futures_util::stream::iter;
use proptest::{prop_compose, proptest, sample::select};
use rand::{thread_rng, Rng};
use proptest::{prelude::Strategy, prop_oneof, proptest};
use rand::{distributions::Standard, prelude::Distribution};

use crate::{
error::Error,
ff::{boolean::Boolean, Fp61BitPrime},
ff::{
boolean::Boolean,
boolean_array::{BooleanArray, BA16, BA20, BA256, BA32, BA64, BA8},
Fp61BitPrime,
},
protocol::{
basics::SecureMul,
basics::{select, BooleanArrayMul, SecureMul},
context::{
dzkp_field::{DZKPCompatibleField, BLOCK_SIZE},
dzkp_validator::{
Batch, DZKPValidator, Segment, SegmentEntry, BIT_ARRAY_LEN, TARGET_PROOF_SIZE,
},
Context, UpgradableContext, TEST_DZKP_STEPS,
Context, DZKPUpgradedMaliciousContext, DZKPUpgradedSemiHonestContext,
UpgradableContext, TEST_DZKP_STEPS,
},
Gate, RecordId,
},
rand::{thread_rng, Rng},
secret_sharing::{
replicated::semi_honest::AdditiveShare as Replicated, IntoShares, SharedValue,
Vectorizable,
},
seq_join::{seq_join, SeqJoin},
seq_join::seq_join,
sharding::NotSharded,
test_fixture::{join3v, Reconstruct, Runner, TestWorld},
};

async fn test_select_semi_honest<V>()
where
V: BooleanArray,
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedSemiHonestContext<'a, NotSharded>>,
Standard: Distribution<V>,
{
let world = TestWorld::default();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized that we probably want to use the same seed for rng and test world (it is supported via TestWorldConfig struct). That way we can make it reproducible if it ever fails

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made an issue for this. It's definitely worth doing in general, but it doesn't seem all that important for this particular test case, where the input values should be unrelated to the behavior of the test.

let context = world.contexts();
let mut rng = thread_rng();

let bit = rng.gen::<Boolean>();
let a = rng.gen::<V>();
let b = rng.gen::<V>();

let bit_shares = bit.share_with(&mut rng);
let a_shares = a.share_with(&mut rng);
let b_shares = b.share_with(&mut rng);

let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map(
|(ctx, (bit_share, (a_share, b_share)))| async move {
let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1);
let sh_ctx = v.context();

let result = select(
sh_ctx.set_total_records(1),
RecordId::from(0),
&bit_share,
&a_share,
&b_share,
)
.await?;

v.validate().await?;

Ok::<_, Error>(result)
},
);

let [ab0, ab1, ab2] = join3v(futures).await;

let ab = [ab0, ab1, ab2].reconstruct();

assert_eq!(ab, if bit.into() { a } else { b });
}

#[tokio::test]
async fn select_semi_honest() {
test_select_semi_honest::<BA8>().await;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it worth testing it for for weird types like BA3 and BA7 as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think so. I had BA4 but took it out because it's not boolean_vector! (not for any good reason other than we haven't needed it). But I think it is worth having a less-than-one-byte case, and maybe even adding a new BA type so we can cover the between-one-and-two-bytes case.

test_select_semi_honest::<BA16>().await;
test_select_semi_honest::<BA20>().await;
test_select_semi_honest::<BA32>().await;
test_select_semi_honest::<BA64>().await;
test_select_semi_honest::<BA256>().await;
}

async fn test_select_malicious<V>()
where
V: BooleanArray,
for<'a> Replicated<V>: BooleanArrayMul<DZKPUpgradedMaliciousContext<'a>>,
Standard: Distribution<V>,
{
let world = TestWorld::default();
let context = world.malicious_contexts();
let mut rng = thread_rng();

let bit = rng.gen::<Boolean>();
let a = rng.gen::<V>();
let b = rng.gen::<V>();

let bit_shares = bit.share_with(&mut rng);
let a_shares = a.share_with(&mut rng);
let b_shares = b.share_with(&mut rng);

let futures = zip(context.iter(), zip(bit_shares, zip(a_shares, b_shares))).map(
|(ctx, (bit_share, (a_share, b_share)))| async move {
let v = ctx.clone().dzkp_validator(TEST_DZKP_STEPS, 1);
let m_ctx = v.context();

let result = select(
m_ctx.set_total_records(1),
RecordId::from(0),
&bit_share,
&a_share,
&b_share,
)
.await?;

v.validate().await?;

Ok::<_, Error>(result)
},
);

let [ab0, ab1, ab2] = join3v(futures).await;

let ab = [ab0, ab1, ab2].reconstruct();

assert_eq!(ab, if bit.into() { a } else { b });
}

#[tokio::test]
async fn select_malicious() {
test_select_malicious::<BA8>().await;
test_select_malicious::<BA16>().await;
test_select_malicious::<BA20>().await;
test_select_malicious::<BA32>().await;
test_select_malicious::<BA64>().await;
test_select_malicious::<BA256>().await;
}

#[tokio::test]
async fn dzkp_malicious() {
async fn two_multiplies_malicious() {
const COUNT: usize = 32;
let mut rng = thread_rng();

Expand Down Expand Up @@ -914,8 +1031,8 @@ mod tests {
}

/// test for testing `validated_seq_join`
/// similar to `complex_circuit` in `validator.rs`
async fn complex_circuit_dzkp(
/// similar to `complex_circuit` in `validator.rs` (which has a more detailed comment)
async fn chained_multiplies_dzkp(
count: usize,
max_multiplications_per_gate: usize,
) -> Result<(), Error> {
Expand Down Expand Up @@ -945,7 +1062,7 @@ mod tests {
.map(|(ctx, input_shares)| async move {
let v = ctx
.set_total_records(count - 1)
.dzkp_validator(TEST_DZKP_STEPS, ctx.active_work().get());
.dzkp_validator(TEST_DZKP_STEPS, max_multiplications_per_gate);
let m_ctx = v.context();

let m_results = v
Expand Down Expand Up @@ -1021,19 +1138,24 @@ mod tests {
Ok(())
}

prop_compose! {
fn arb_count_and_chunk()((log_count, log_multiplication_amount) in select(&[(5,5),(7,5),(5,8)])) -> (usize, usize) {
(1usize<<log_count, 1usize<<log_multiplication_amount)
}
fn record_count_strategy() -> impl Strategy<Value = usize> {
prop_oneof![1usize..=512, (1usize..=9).prop_map(|i| 1usize << i)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am being stupid and it is 12pm now - is it really testing $$2^{512}$$ records?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What it does is:

  • 50% of the time it picks a random power of two between 2 and 512 (9 possible values)
  • 50% of the time it picks a random integer between 1 and 512 (512 possible values)

Two reasons to focus on powers of two, although I didn't put a huge amount of thought into this (e.g. it just occurred to me to add $2^0$):

  • Because we plan to restrict batch sizes to powers of two.
  • As the count increases, it makes sense to me to sample the space more sparsely, there's relatively less interesting about testing batch sizes 399, 400, and 401 than about testing batch sizes 7, 8, and 9.

I think it is clearer to have the two prop_oneof cases on different lines, but rustfmt insisted on a single line.

}

fn max_multiplications_per_gate_strategy() -> impl Strategy<Value = usize> {
prop_oneof![1usize..=128, (1usize..=7).prop_map(|i| 1usize << i)]
}

proptest! {
#[test]
fn test_complex_circuit_dzkp((count, multiplication_amount) in arb_count_and_chunk()){
let future = async {
let _ = complex_circuit_dzkp(count, multiplication_amount).await;
};
tokio::runtime::Runtime::new().unwrap().block_on(future);
fn test_chained_multiplies_dzkp(
record_count in record_count_strategy(),
max_multiplications_per_gate in max_multiplications_per_gate_strategy(),
) {
println!("record_count {record_count} batch {max_multiplications_per_gate}");
tokio::runtime::Runtime::new().unwrap().block_on(async {
chained_multiplies_dzkp(record_count, max_multiplications_per_gate).await.unwrap();
});
}
}

Expand Down
15 changes: 12 additions & 3 deletions ipa-core/src/protocol/context/malicious.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
any::type_name,
cmp::max,
fmt::{Debug, Formatter},
num::NonZeroUsize,
};
Expand Down Expand Up @@ -174,13 +175,21 @@ pub(super) type MacBatcher<'a, F> = Mutex<Batcher<'a, validator::Malicious<'a, F
pub struct Upgraded<'a, F: ExtendableField> {
batch: Weak<MacBatcher<'a, F>>,
base_ctx: Context<'a>,
active_work: NonZeroUsize,
}

impl<'a, F: ExtendableField> Upgraded<'a, F> {
pub(super) fn new(batch: &Arc<MacBatcher<'a, F>>, ctx: Context<'a>) -> Self {
pub(super) fn new(batch: &Arc<MacBatcher<'a, F>>, base_ctx: Context<'a>) -> Self {
// Adjust active_work to be at least records_per_batch. The MAC validator
// currently configures the batcher with records_per_batch = active_work, which
// makes this adjustment a no-op, but we do it this way to match the DZKP validator.
let records_per_batch = batch.lock().unwrap().records_per_batch();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be better to assert this to make sure we don't miss the misalignment between MAC and ZKP?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'm fine with that.

let active_work =
NonZeroUsize::new(max(base_ctx.active_work().get(), records_per_batch)).unwrap();
Self {
batch: Arc::downgrade(batch),
base_ctx: ctx,
base_ctx,
active_work,
}
}

Expand Down Expand Up @@ -297,7 +306,7 @@ impl<'a, F: ExtendableField> super::Context for Upgraded<'a, F> {

impl<'a, F: ExtendableField> SeqJoin for Upgraded<'a, F> {
fn active_work(&self) -> NonZeroUsize {
self.base_ctx.active_work()
self.active_work
}
}

Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/protocol/context/validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ impl<'a, F: ExtendableField> BatchValidator<'a, F> {

// TODO: Right now we set the batch work to be equal to active_work,
// but it does not need to be. We can make this configurable if needed.
let records_per_batch = ctx.active_work().get().min(total_records.get());
let records_per_batch = ctx.active_work().get();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reducing if larger than total_records may have been necessary with an earlier version of the batcher, but the current version should take care of that internally, so I removed it here. No relation to the rest of these changes though.


Self {
protocol_ctx: ctx.narrow(&Step::MaliciousProtocol),
Expand Down
4 changes: 1 addition & 3 deletions ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,9 +510,7 @@ where
protocol: &Step::Attribute,
validate: &Step::AttributeValidate,
},
// The size of a single batch should not exceed the active work limit,
// otherwise it will stall
std::cmp::min(sh_ctx.active_work().get(), chunk_size),
chunk_size,
);
dzkp_validator.set_total_records(TotalRecords::specified(histogram[1]).unwrap());
let ctx_for_row_number = set_up_contexts(&dzkp_validator.context(), histogram)?;
Expand Down
Loading
Loading