Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not compare preparation states in non-test code. #1879

Merged
merged 1 commit into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 22 additions & 58 deletions aggregator/src/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1502,10 +1502,8 @@ impl VdafOps {
/// the sense that no new rows would need to be written to service the job.
async fn check_aggregation_job_idempotence<'b, const SEED_SIZE: usize, Q, A, C>(
tx: &Transaction<'b, C>,
vdaf: &A,
task: &Task,
incoming_aggregation_job: &AggregationJob<SEED_SIZE, Q, A>,
incoming_report_share_data: &[ReportShareData<SEED_SIZE, A>],
) -> Result<bool, Error>
where
Q: AccumulableQueryType,
Expand All @@ -1519,6 +1517,8 @@ impl VdafOps {
Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)> + PartialEq,
A::OutputShare: Send + Sync + PartialEq,
{
// unwrap safety: this function should only be called if there is an existing aggregation
// job.
let existing_aggregation_job = tx
.get_aggregation_job(task.id(), incoming_aggregation_job.id())
.await?
Expand All @@ -1530,46 +1530,7 @@ impl VdafOps {
)
});

if !existing_aggregation_job.eq(incoming_aggregation_job) {
return Ok(false);
}

// Check the existing report aggregations for this job against the ones in the incoming
// message.
let existing_report_aggregations = tx
.get_report_aggregations_for_aggregation_job(
vdaf,
&Role::Helper,
task.id(),
incoming_aggregation_job.id(),
)
.await?;

if existing_report_aggregations.len() != incoming_report_share_data.len() {
return Ok(false);
}

// Check each report share in the incoming aggregation job against the already recorded
// report aggregations. `existing_report_aggregations` preserves the order in which the
// report shares were seen in the previous `AggregationJobInitReq`, and that order should be
// preserved in the repeated message, so it's OK to just zip the iterators together.
if incoming_report_share_data
.iter()
.zip(existing_report_aggregations)
.any(
|(incoming_report_share_data, existing_report_aggregation)| {
!existing_report_aggregation
.report_metadata()
.eq(incoming_report_share_data.report_share.metadata())
|| !existing_report_aggregation
.eq(&incoming_report_share_data.report_aggregation)
},
)
{
return Ok(false);
}

Ok(true)
Ok(existing_aggregation_job.eq(incoming_aggregation_job))
}

/// Implements the aggregate initialization request portion of the `/aggregate` endpoint for the
Expand All @@ -1596,6 +1557,8 @@ impl VdafOps {
Send + Sync + Encode + ParameterizedDecode<(&'a A, usize)> + PartialEq,
A::OutputShare: Send + Sync + PartialEq,
{
// unwrap safety: SHA-256 computed by ring should always be 32 bytes
let request_hash = digest(&SHA256, req_bytes).as_ref().try_into().unwrap();
let req = AggregationJobInitializeReq::<Q>::get_decoded(req_bytes)?;

// If two ReportShare messages have the same report ID, then the helper MUST abort with
Expand Down Expand Up @@ -1818,19 +1781,22 @@ impl VdafOps {
.difference(min_client_timestamp)?
.add(&Duration::from_seconds(1))?,
)?;
let aggregation_job = Arc::new(AggregationJob::<SEED_SIZE, Q, A>::new(
*task.id(),
*aggregation_job_id,
agg_param,
req.batch_selector().batch_identifier().clone(),
client_timestamp_interval,
if saw_continue {
AggregationJobState::InProgress
} else {
AggregationJobState::Finished
},
AggregationJobRound::from(0),
));
let aggregation_job = Arc::new(
AggregationJob::<SEED_SIZE, Q, A>::new(
*task.id(),
*aggregation_job_id,
agg_param,
req.batch_selector().batch_identifier().clone(),
client_timestamp_interval,
if saw_continue {
AggregationJobState::InProgress
} else {
AggregationJobState::Finished
},
AggregationJobRound::from(0),
)
.with_last_request_hash(request_hash),
);
let interval_per_batch_identifier = Arc::new(interval_per_batch_identifier);

Ok(datastore
Expand Down Expand Up @@ -1901,10 +1867,8 @@ impl VdafOps {
// the datastore, which we must now check.
if !Self::check_aggregation_job_idempotence(
tx,
&vdaf,
task.borrow(),
aggregation_job.borrow(),
&report_share_data,
)
.await
.map_err(|e| datastore::Error::User(e.into()))?
Expand Down Expand Up @@ -2060,7 +2024,7 @@ impl VdafOps {
// but the leader never got our response and so retried stepping the job.
// TODO(issue #1087): measure how often this happens with a Prometheus metric
if helper_aggregation_job.round() == leader_aggregation_job.round() {
match helper_aggregation_job.last_continue_request_hash() {
match helper_aggregation_job.last_request_hash() {
None => {
return Err(datastore::Error::User(
Error::Internal(format!(
Expand Down
2 changes: 1 addition & 1 deletion aggregator/src/aggregator/aggregation_job_continue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ impl VdafOps {
))
}
})
.with_last_continue_request_hash(request_hash);
.with_last_request_hash(request_hash);

try_join!(
tx.update_aggregation_job(&helper_aggregation_job),
Expand Down
5 changes: 3 additions & 2 deletions aggregator/src/aggregator/aggregation_job_driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ impl AggregationJobDriver {
// operation to avoid needing to join many futures?
let client_reports: HashMap<_, _> =
try_join_all(report_aggregations.iter().filter_map(|report_aggregation| {
if report_aggregation.state() == &ReportAggregationState::Start {
if matches!(report_aggregation.state(), &ReportAggregationState::Start)
{
Some(
tx.get_client_report(
vdaf.as_ref(),
Expand Down Expand Up @@ -308,7 +309,7 @@ impl AggregationJobDriver {
let report_aggregations: Vec<_> = report_aggregations
.into_iter()
.filter(|report_aggregation| {
report_aggregation.state() == &ReportAggregationState::Start
matches!(report_aggregation.state(), &ReportAggregationState::Start)
})
.collect();

Expand Down
4 changes: 2 additions & 2 deletions aggregator/src/aggregator/http_handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2710,7 +2710,7 @@ mod tests {
AggregationJobState::Finished,
AggregationJobRound::from(1),
)
.with_last_continue_request_hash(aggregation_job.last_continue_request_hash().unwrap())
.with_last_request_hash(aggregation_job.last_request_hash().unwrap())
);
assert_eq!(
report_aggregations,
Expand Down Expand Up @@ -3633,7 +3633,7 @@ mod tests {
AggregationJobState::Finished,
AggregationJobRound::from(1),
)
.with_last_continue_request_hash(aggregation_job.last_continue_request_hash().unwrap())
.with_last_request_hash(aggregation_job.last_request_hash().unwrap())
);
assert_eq!(
report_aggregation,
Expand Down
22 changes: 11 additions & 11 deletions aggregator_core/src/datastore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1734,7 +1734,7 @@ impl<C: Clock> Transaction<'_, C> {
.prepare_cached(
"SELECT
aggregation_param, batch_id, client_timestamp_interval, state, round,
last_continue_request_hash
last_request_hash
FROM aggregation_jobs
JOIN tasks ON tasks.id = aggregation_jobs.task_id
WHERE tasks.task_id = $1
Expand Down Expand Up @@ -1771,7 +1771,7 @@ impl<C: Clock> Transaction<'_, C> {
.prepare_cached(
"SELECT
aggregation_job_id, aggregation_param, batch_id, client_timestamp_interval,
state, round, last_continue_request_hash
state, round, last_request_hash
FROM aggregation_jobs
JOIN tasks ON tasks.id = aggregation_jobs.task_id
WHERE tasks.task_id = $1
Expand Down Expand Up @@ -1817,10 +1817,10 @@ impl<C: Clock> Transaction<'_, C> {
row.get_postgres_integer_and_convert::<i32, _, _>("round")?,
);

if let Some(hash) = row.try_get::<_, Option<Vec<u8>>>("last_continue_request_hash")? {
job = job.with_last_continue_request_hash(hash.try_into().map_err(|h| {
if let Some(hash) = row.try_get::<_, Option<Vec<u8>>>("last_request_hash")? {
job = job.with_last_request_hash(hash.try_into().map_err(|h| {
Error::DbState(format!(
"last_continue_request_hash value {h:?} cannot be converted to 32 byte array"
"last_request_hash value {h:?} cannot be converted to 32 byte array"
))
})?);
}
Expand Down Expand Up @@ -1951,7 +1951,7 @@ impl<C: Clock> Transaction<'_, C> {
.prepare_cached(
"INSERT INTO aggregation_jobs
(task_id, aggregation_job_id, aggregation_param, batch_id,
client_timestamp_interval, state, round, last_continue_request_hash)
client_timestamp_interval, state, round, last_request_hash)
VALUES ((SELECT id FROM tasks WHERE task_id = $1), $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT DO NOTHING
RETURNING COALESCE(UPPER(client_timestamp_interval) < COALESCE($9::TIMESTAMP - (SELECT report_expiry_age FROM tasks WHERE task_id = $1) * '1 second'::INTERVAL, '-infinity'::TIMESTAMP), FALSE) AS is_expired",
Expand All @@ -1971,8 +1971,8 @@ impl<C: Clock> Transaction<'_, C> {
&SqlInterval::from(aggregation_job.client_timestamp_interval()),
/* state */ &aggregation_job.state(),
/* round */ &(u16::from(aggregation_job.round()) as i32),
/* last_continue_request_hash */
&aggregation_job.last_continue_request_hash(),
/* last_request_hash */
&aggregation_job.last_request_hash(),
/* now */ &self.clock.now().as_naive_date_time()?,
],
)
Expand Down Expand Up @@ -2022,7 +2022,7 @@ impl<C: Clock> Transaction<'_, C> {
"UPDATE aggregation_jobs SET
state = $1,
round = $2,
last_continue_request_hash = $3
last_request_hash = $3
FROM tasks
WHERE tasks.task_id = $4
AND aggregation_jobs.aggregation_job_id = $5
Expand All @@ -2035,8 +2035,8 @@ impl<C: Clock> Transaction<'_, C> {
&[
/* state */ &aggregation_job.state(),
/* round */ &(u16::from(aggregation_job.round()) as i32),
/* last_continue_request_hash */
&aggregation_job.last_continue_request_hash(),
/* last_request_hash */
&aggregation_job.last_request_hash(),
/* task_id */ &aggregation_job.task_id().as_ref(),
/* aggregation_job_id */ &aggregation_job.id().as_ref(),
/* now */ &self.clock.now().as_naive_date_time()?,
Expand Down
35 changes: 26 additions & 9 deletions aggregator_core/src/datastore/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ pub struct AggregationJob<const SEED_SIZE: usize, Q: QueryType, A: vdaf::Aggrega
/// The SHA-256 hash of the most recent [`janus_messages::AggregationJobContinueReq`]
/// received for this aggregation job. Will only be set for helpers, and only after the
/// first round of the job.
last_continue_request_hash: Option<[u8; 32]>,
last_request_hash: Option<[u8; 32]>,
}

impl<const SEED_SIZE: usize, Q: QueryType, A: vdaf::Aggregator<SEED_SIZE, 16>>
Expand All @@ -261,7 +261,7 @@ impl<const SEED_SIZE: usize, Q: QueryType, A: vdaf::Aggregator<SEED_SIZE, 16>>
client_timestamp_interval,
state,
round,
last_continue_request_hash: None,
last_request_hash: None,
}
}

Expand Down Expand Up @@ -317,16 +317,17 @@ impl<const SEED_SIZE: usize, Q: QueryType, A: vdaf::Aggregator<SEED_SIZE, 16>>
}

/// Returns the SHA-256 digest of the most recent
/// [`janus_messages::AggregationJobInitializeReq`] or
/// [`janus_messages::AggregationJobContinueReq`] for the job, if any.
pub fn last_continue_request_hash(&self) -> Option<[u8; 32]> {
self.last_continue_request_hash
pub fn last_request_hash(&self) -> Option<[u8; 32]> {
self.last_request_hash
}

/// Returns a new [`AggregationJob`] corresponding to this aggregation job updated to have
/// the given last continue request hash.
pub fn with_last_continue_request_hash(self, hash: [u8; 32]) -> Self {
/// Returns a new [`AggregationJob`] corresponding to this aggregation job updated to have the
/// given last request hash.
pub fn with_last_request_hash(self, hash: [u8; 32]) -> Self {
Self {
last_continue_request_hash: Some(hash),
last_request_hash: Some(hash),
..self
}
}
Expand Down Expand Up @@ -354,7 +355,7 @@ where
&& self.client_timestamp_interval == other.client_timestamp_interval
&& self.state == other.state
&& self.round == other.round
&& self.last_continue_request_hash == other.last_continue_request_hash
&& self.last_request_hash == other.last_request_hash
}
}

Expand Down Expand Up @@ -669,6 +670,10 @@ impl<const SEED_SIZE: usize, A: vdaf::Aggregator<SEED_SIZE, 16>> ReportAggregati
}
}

// This trait implementation is gated on the `test-util` feature as we do not wish to compare
// preparation states in non-test code, since doing so would require a constant-time comparison to
// avoid risking leaking information about the preparation state.
#[cfg(feature = "test-util")]
impl<const SEED_SIZE: usize, A: vdaf::Aggregator<SEED_SIZE, 16>> PartialEq
for ReportAggregation<SEED_SIZE, A>
where
Expand All @@ -688,6 +693,10 @@ where
}
}

// This trait implementation is gated on the `test-util` feature as we do not wish to compare
// preparation states in non-test code, since doing so would require a constant-time comparison to
// avoid risking leaking information about the preparation state.
#[cfg(feature = "test-util")]
impl<const SEED_SIZE: usize, A: vdaf::Aggregator<SEED_SIZE, 16>> Eq
for ReportAggregation<SEED_SIZE, A>
where
Expand Down Expand Up @@ -775,6 +784,10 @@ pub enum ReportAggregationStateCode {
Failed,
}

// This trait implementation is gated on the `test-util` feature as we do not wish to compare
// preparation states in non-test code, since doing so would require a constant-time comparison to
// avoid risking leaking information about the preparation state.
#[cfg(feature = "test-util")]
impl<const SEED_SIZE: usize, A: vdaf::Aggregator<SEED_SIZE, 16>> PartialEq
for ReportAggregationState<SEED_SIZE, A>
where
Expand All @@ -797,6 +810,10 @@ where
}
}

// This trait implementation is gated on the `test-util` feature as we do not wish to compare
// preparation states in non-test code, since doing so would require a constant-time comparison to
// avoid risking leaking information about the preparation state.
#[cfg(feature = "test-util")]
impl<const SEED_SIZE: usize, A: vdaf::Aggregator<SEED_SIZE, 16>> Eq
for ReportAggregationState<SEED_SIZE, A>
where
Expand Down
5 changes: 2 additions & 3 deletions aggregator_core/src/datastore/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1270,8 +1270,7 @@ async fn roundtrip_aggregation_job(ephemeral_datastore: EphemeralDatastore) {
let new_leader_aggregation_job = leader_aggregation_job
.clone()
.with_state(AggregationJobState::Finished);
let new_helper_aggregation_job =
helper_aggregation_job.with_last_continue_request_hash([3; 32]);
let new_helper_aggregation_job = helper_aggregation_job.with_last_request_hash([3; 32]);
ds.run_tx(|tx| {
let (new_leader_aggregation_job, new_helper_aggregation_job) = (
new_leader_aggregation_job.clone(),
Expand Down Expand Up @@ -1812,7 +1811,7 @@ async fn get_aggregation_jobs_for_task(ephemeral_datastore: EphemeralDatastore)
AggregationJobState::InProgress,
AggregationJobRound::from(0),
)
.with_last_continue_request_hash([3; 32]);
.with_last_request_hash([3; 32]);

let mut want_agg_jobs = Vec::from([
first_aggregation_job,
Expand Down
3 changes: 1 addition & 2 deletions db/00000000000001_initial_schema.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ CREATE TABLE aggregation_jobs(
client_timestamp_interval TSRANGE NOT NULL, -- the minimal interval containing all of client timestamps included in this aggregation job
state AGGREGATION_JOB_STATE NOT NULL, -- current state of the aggregation job
round INTEGER NOT NULL, -- current round of the VDAF preparation protocol
last_continue_request_hash BYTEA, -- SHA-256 hash of the most recently received AggregationJobContinueReq
-- (helper only and only after the first round of the job)
last_request_hash BYTEA, -- SHA-256 hash of the most recently received AggregationJobContinueReq (helper only)
trace_context JSONB, -- distributed tracing metadata

lease_expiry TIMESTAMP NOT NULL DEFAULT TIMESTAMP '-infinity', -- when lease on this aggregation job expires; -infinity implies no current lease
Expand Down