Skip to content

Commit

Permalink
Merge pull request #835 from private-attribution/debugging_stall
Browse files Browse the repository at this point in the history
Moving collect to a better place in OPRF IPA
  • Loading branch information
benjaminsavage authored Nov 8, 2023
2 parents 650cb4b + 50ef10e commit 4dd9554
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions src/protocol/ipa_prf/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ where
/// Takes an input stream of `PrfShardedIpaInputRecordRow` which is assumed to have all records with a given PRF adjacent
/// and converts it into a stream of vectors of `PrfShardedIpaInputRecordRow` having the same PRF.
///
/// Filters out any users that only have a single row, since they will produce no attributed conversions.
///
fn chunk_rows_by_user<IS, BK, TV, TS>(
input_stream: IS,
first_row: PrfShardedIpaInputRow<BK, TV, TS>,
Expand All @@ -361,13 +363,16 @@ where
{
unfold(Some((input_stream, first_row)), |state| async move {
let (mut s, last_row) = state?;
let last_row_prf = last_row.prf_of_match_key;
let mut last_row_prf = last_row.prf_of_match_key;
let mut current_chunk = vec![last_row];
while let Some(row) = s.next().await {
if row.prf_of_match_key == last_row_prf {
current_chunk.push(row);
} else {
} else if current_chunk.len() > 1 {
return Some((current_chunk, Some((s, row))));
} else {
last_row_prf = row.prf_of_match_key;
current_chunk = vec![row];
}
}
Some((current_chunk, None))
Expand Down Expand Up @@ -435,8 +440,11 @@ where
let first_row = first_row.unwrap();
let rows_chunked_by_user = chunk_rows_by_user(input_stream, first_row);

let mut collected = rows_chunked_by_user.collect::<Vec<_>>().await;
collected.sort_by(|a, b| std::cmp::Ord::cmp(&b.len(), &a.len()));

// Convert to a stream of async futures that represent the result of executing the per-user circuit
let stream_of_per_user_circuits = pin!(rows_chunked_by_user.then(|rows_for_user| {
let stream_of_per_user_circuits = pin!(stream_iter(collected).then(|rows_for_user| {
let num_user_rows = rows_for_user.len();
let contexts = ctx_for_row_number[..num_user_rows - 1].to_owned();
let record_ids = record_id_for_row_depth[..num_user_rows].to_owned();
Expand All @@ -458,16 +466,15 @@ where
}));

// Execute all of the async futures (sequentially), and flatten the result
let collected_per_user_results = stream_of_per_user_circuits.collect::<Vec<_>>().await;
let per_user_attribution_outputs = sh_ctx.parallel_join(collected_per_user_results).await?;
let flattenned_stream = per_user_attribution_outputs.into_iter().flatten();
let flattenned_stream = seq_join(sh_ctx.active_work(), stream_of_per_user_circuits)
.flat_map(|x| stream_iter(x.unwrap()));

// modulus convert breakdown keys and trigger values
let converted_bks_and_tvs = convert_bits(
prime_field_ctx
.narrow(&Step::ModulusConvertBreakdownKeyBitsAndTriggerValues)
.set_total_records(num_outputs),
stream_iter(flattenned_stream),
flattenned_stream,
0..BK::BITS + TV::BITS,
);

Expand Down

0 comments on commit 4dd9554

Please sign in to comment.