Skip to content

Commit

Permalink
Merge pull request #1056 from benjaminsavage/fewer_steps
Browse files Browse the repository at this point in the history
  • Loading branch information
benjaminsavage authored May 13, 2024
2 parents bf75126 + 3e73128 commit 10521e6
Show file tree
Hide file tree
Showing 3 changed files with 1,191 additions and 1,275 deletions.
26 changes: 16 additions & 10 deletions ipa-core/src/protocol/ipa_prf/aggregation/bucket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@ use crate::{
error::Error,
ff::boolean::Boolean,
helpers::repeat_n,
protocol::{
basics::SecureMul, boolean::and::bool_and_9_bit, context::Context,
ipa_prf::prf_sharding::BinaryTreeDepthStep, RecordId,
},
protocol::{basics::SecureMul, boolean::and::bool_and_9_bit, context::Context, RecordId},
secret_sharing::{replicated::semi_honest::AdditiveShare, BitDecomposed, FieldSimd},
};

const MAX_BREAKDOWNS: usize = 512; // constrained by the compact step ability to generate dynamic steps

#[derive(Step)]
pub enum BucketStep {
#[dynamic(256)]
#[dynamic(512)] // should be equal to MAX_BREAKDOWNS
Bit(usize),
}

Expand Down Expand Up @@ -88,7 +87,6 @@ where
Boolean: FieldSimd<N>,
AdditiveShare<Boolean, N>: SecureMul<C>,
{
const MAX_BREAKDOWNS: usize = 512; // constrained by the compact step ability to generate dynamic steps
let mut step: usize = 1 << bd_key.len();

if breakdown_count > step {
Expand All @@ -107,19 +105,26 @@ where

let mut row_contribution = vec![value; breakdown_count];

for (tree_depth, bit_of_bdkey) in bd_key.iter().enumerate().rev() {
// To move a value to one of 2^bd_key_bits buckets requires 2^bd_key_bits - 1 multiplications
// They happen in a tree like fashion:
// 1 multiplication for the first bit
// 2 for the second bit
// 4 for the 3rd bit
// And so on. Simply ordering them sequentially is a functional way
// of enumerating them without creating more step transitions than necessary
let mut multiplication_channel = 0;

for bit_of_bdkey in bd_key.iter().rev() {
let span = step >> 1;
if !robust && span > breakdown_count {
step = span;
continue;
}

let depth_c = ctx.narrow(&BinaryTreeDepthStep::from(tree_depth));

let contributions = ctx
.parallel_join((0..breakdown_count).step_by(step).enumerate().filter_map(
|(i, tree_index)| {
let bucket_c = depth_c.narrow(&BucketStep::from(i));
let bucket_c = ctx.narrow(&BucketStep::from(multiplication_channel + i));

let index_contribution = &row_contribution[tree_index];

Expand All @@ -134,6 +139,7 @@ where
},
))
.await?;
multiplication_channel += contributions.len();

for (index, bdbit_contribution) in contributions.into_iter().enumerate() {
let left_index = index * step;
Expand Down
12 changes: 0 additions & 12 deletions ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,18 +273,6 @@ impl From<usize> for UserNthRowStep {
}
}

#[derive(Step)]
pub enum BinaryTreeDepthStep {
#[dynamic(64)]
Depth(usize),
}

impl From<usize> for BinaryTreeDepthStep {
fn from(v: usize) -> Self {
Self::Depth(v)
}
}

#[derive(Step)]
pub(crate) enum Step {
BinaryValidator,
Expand Down
Loading

0 comments on commit 10521e6

Please sign in to comment.