Skip to content

Commit

Permalink
Move to functions::helper the code needed to make a request to the he…
Browse files Browse the repository at this point in the history
…lper
  • Loading branch information
mendess committed Oct 15, 2024
1 parent 77daa7e commit 4ef30f1
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 147 deletions.
181 changes: 34 additions & 147 deletions crates/dapf/src/acceptance/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,14 @@

pub mod load_testing;

use crate::{
deduce_dap_version_from_url, response_to_anyhow, test_durations::TestDurations, HttpClient,
};
use crate::{deduce_dap_version_from_url, functions, test_durations::TestDurations, HttpClient};
use anyhow::{anyhow, bail, Context, Result};
use async_trait::async_trait;
use daphne::{
constants::DapMediaType,
error::aborts::ProblemDetails,
hpke::{HpkeConfig, HpkeKemId, HpkeReceiverConfig},
messages::{
self, AggregateShareReq, AggregationJobId, AggregationJobResp, Base64Encode, BatchId,
BatchSelector, PartialBatchSelector, TaskId,
self, AggregateShareReq, AggregationJobId, Base64Encode, BatchId, BatchSelector,
PartialBatchSelector, TaskId,
},
metrics::DaphneMetrics,
roles::DapReportInitializer,
Expand All @@ -39,9 +35,8 @@ use daphne::{
DapQueryConfig, DapTaskConfig, DapTaskParameters, DapVersion, EarlyReportStateConsumed,
EarlyReportStateInitialized, ReplayProtection,
};
use daphne_service_utils::{bearer_token::BearerToken, http_headers};
use daphne_service_utils::bearer_token::BearerToken;
use futures::{future::OptionFuture, StreamExt, TryStreamExt};
use prio::codec::{Decode, ParameterizedEncode};
use prometheus::{Encoder, HistogramVec, IntCounterVec, IntGaugeVec, TextEncoder};
use rand::{rngs, Rng};
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
Expand Down Expand Up @@ -509,62 +504,30 @@ impl Test {
.context("producing agg job init request")?;

// Send AggregationJobInitReq.
let headers = construct_request_headers(
DapMediaType::AggregationJobInitReq.as_str_for_version(task_config.version),
taskprov_advertisement.as_deref(),
&self.bearer_token,
)
.context("constructing request headers for AggregationJobInitReq")?;
let url = self.helper_url.join(&format!(
"tasks/{}/aggregation_jobs/{}",
task_id.to_base64url(),
agg_job_id.to_base64url()
))?;

// wait for all agg jobs to be ready to fire.
info!("Reports generated, waiting for other tasks...");
let _guard = load_control.wait().await;
info!("Starting AggregationJobInitReq");
let start = Instant::now();
let resp = send(
self.http_client
.put(url)
.body(
agg_job_init_req
.get_encoded_with_param(&task_config.version)
.unwrap(),
)
.headers(headers),
)
.await?;
let agg_job_resp = self
.http_client
.submit_aggregate_init_req(
self.helper_url.join(&format!(
"tasks/{}/aggregation_jobs/{}",
task_id.to_base64url(),
agg_job_id.to_base64url()
))?,
agg_job_init_req,
task_config.version,
functions::helper::Options {
taskprov_advertisement: taskprov_advertisement.as_deref(),
bearer_token: self.bearer_token.as_ref(),
},
)
.await?;
let duration = start.elapsed();
info!("Finished AggregationJobInitReq in {duration:#?}");

if resp.status() == 400 {
let text = resp.text().await?;
let problem_details: ProblemDetails =
serde_json::from_str(&text).with_context(|| {
format!("400 Bad Request: failed to parse problem details document: {text:?}")
})?;
return Err(anyhow!("400 Bad Request: {problem_details:?}"));
} else if resp.status() == 500 {
return Err(anyhow::anyhow!(
"500 Internal Server Error: {}",
resp.text().await?
));
} else if !resp.status().is_success() {
return Err(response_to_anyhow(resp).await)
.context("while running an AggregateInitReq");
}

// Handle AggregationJobResp..
let agg_job_resp = AggregationJobResp::get_decoded(
&resp
.bytes()
.await
.context("transfering bytes from the AggregateInitReq")?,
)
.with_context(|| "failed to parse response to AggregateInitReq from Helper")?;
let agg_share_span = task_config.consume_agg_job_resp(
task_id,
agg_job_state,
Expand Down Expand Up @@ -603,43 +566,21 @@ impl Test {
// Send AggregateShareReq.
info!("Starting AggregationJobInitReq");
let start = Instant::now();
let headers = construct_request_headers(
DapMediaType::AggregateShareReq.as_str_for_version(version),
taskprov_advertisement,
&self.bearer_token,
)?;
let url = self.helper_url.join(&format!(
"tasks/{}/aggregate_shares",
task_id.to_base64url()
))?;
let resp = send(
self.http_client
.post(url)
.body(agg_share_req.get_encoded_with_param(&version).unwrap())
.headers(headers),
)
.await?;
let duration = start.elapsed();
info!("Finished AggregateShareReq in {duration:#?}");
if resp.status() == 400 {
let problem_details: ProblemDetails = serde_json::from_slice(
&resp
.bytes()
.await
.context("transfering bytes for AggregateShareReq")?,
self.http_client
.get_aggregate_share(
self.helper_url.join(&format!(
"tasks/{}/aggregate_shares",
task_id.to_base64url()
))?,
agg_share_req,
version,
functions::helper::Options {
taskprov_advertisement,
bearer_token: self.bearer_token.as_ref(),
},
)
.with_context(|| "400 Bad Request: failed to parse problem details document")?;
return Err(anyhow!("400 Bad Request: {problem_details:?}"));
} else if resp.status() == 500 {
return Err(anyhow::anyhow!(
"500 Internal Server Error: {}",
resp.text().await?
));
} else if !resp.status().is_success() {
return Err(response_to_anyhow(resp).await)
.context("while running an AggregateInitReq");
}
Ok(duration)
.await?;
Ok(start.elapsed())
}

pub async fn test_helper(&self, opt: &TestOptions) -> Result<TestDurations> {
Expand Down Expand Up @@ -794,60 +735,6 @@ impl DapReportInitializer for Test {
}
}

fn construct_request_headers<'a, M, T, B>(
media_type: M,
taskprov: T,
bearer_token: B,
) -> Result<reqwest::header::HeaderMap>
where
M: Into<Option<&'a str>>,
T: Into<Option<&'a str>>,
B: Into<Option<&'a BearerToken>>,
{
let mut headers = reqwest::header::HeaderMap::new();
if let Some(media_type) = media_type.into() {
headers.insert(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_str(media_type)?,
);
}
if let Some(taskprov) = taskprov.into() {
headers.insert(
reqwest::header::HeaderName::from_static(http_headers::DAP_TASKPROV),
reqwest::header::HeaderValue::from_str(taskprov)?,
);
}
if let Some(token) = bearer_token.into() {
headers.insert(
reqwest::header::HeaderName::from_static(http_headers::DAP_AUTH_TOKEN),
reqwest::header::HeaderValue::from_str(token.as_ref())?,
);
}
Ok(headers)
}

async fn send(req: reqwest::RequestBuilder) -> reqwest::Result<reqwest::Response> {
for i in 0..4 {
let resp = req.try_clone().unwrap().send().await;
match &resp {
Ok(r) if r.status() != reqwest::StatusCode::BAD_GATEWAY => {
return resp;
}
Ok(r) if r.status().is_client_error() => {
return resp;
}
Ok(_) => {}
Err(e) => {
tracing::error!("request failed: {e:?}");
}
}
if i == 3 {
return resp;
}
}
unreachable!()
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Now(u64);
pub fn now() -> Now {
Expand Down
137 changes: 137 additions & 0 deletions crates/dapf/src/functions/helper.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

use anyhow::{anyhow, Context as _};
use daphne::{
constants::DapMediaType,
error::aborts::ProblemDetails,
messages::{AggregateShareReq, AggregationJobInitReq, AggregationJobResp},
DapVersion,
};
use daphne_service_utils::{bearer_token::BearerToken, http_headers};
use prio::codec::{Decode as _, ParameterizedEncode as _};
use reqwest::header;
use url::Url;

use crate::{response_to_anyhow, HttpClient};

impl HttpClient {
pub async fn submit_aggregate_init_req(
&self,
url: Url,
agg_job_init_req: AggregationJobInitReq,
version: DapVersion,
opts: Options<'_>,
) -> anyhow::Result<AggregationJobResp> {
let resp = self
.put(url)
.body(agg_job_init_req.get_encoded_with_param(&version).unwrap())
.headers(construct_request_headers(
DapMediaType::AggregationJobInitReq
.as_str_for_version(version)
.with_context(|| {
format!("AggregationJobInitReq media type is not defined for {version}")
})?,
opts,
)?)
.send()
.await
.context("sending AggregationJobInitReq")?;
if resp.status() == 400 {
let text = resp.text().await?;
let problem_details: ProblemDetails =
serde_json::from_str(&text).with_context(|| {
format!("400 Bad Request: failed to parse problem details document: {text:?}")
})?;
Err(anyhow!("400 Bad Request: {problem_details:?}"))
} else if resp.status() == 500 {
Err(anyhow::anyhow!(
"500 Internal Server Error: {}",
resp.text().await?
))
} else if !resp.status().is_success() {
Err(response_to_anyhow(resp).await).context("while running an AggregationJobInitReq")
} else {
AggregationJobResp::get_decoded(
&resp
.bytes()
.await
.context("transfering bytes from the AggregateInitReq")?,
)
.with_context(|| "failed to parse response to AggregateInitReq from Helper")
}
}

pub async fn get_aggregate_share(
&self,
url: Url,
agg_share_req: AggregateShareReq,
version: DapVersion,
opts: Options<'_>,
) -> anyhow::Result<()> {
let resp = self
.post(url)
.body(agg_share_req.get_encoded_with_param(&version).unwrap())
.headers(construct_request_headers(
DapMediaType::AggregateShareReq
.as_str_for_version(version)
.with_context(|| {
format!("AggregateShareReq media type is not defined for {version}")
})?,
opts,
)?)
.send()
.await
.context("sending AggregateShareReq")?;
if resp.status() == 400 {
let problem_details: ProblemDetails = serde_json::from_slice(
&resp
.bytes()
.await
.context("transfering bytes for AggregateShareReq")?,
)
.with_context(|| "400 Bad Request: failed to parse problem details document")?;
Err(anyhow!("400 Bad Request: {problem_details:?}"))
} else if resp.status() == 500 {
Err(anyhow!("500 Internal Server Error: {}", resp.text().await?))
} else if !resp.status().is_success() {
Err(response_to_anyhow(resp).await).context("while running an AggregateShareReq")
} else {
Ok(())
}
}
}

#[derive(Default, Debug)]
pub struct Options<'s> {
pub taskprov_advertisement: Option<&'s str>,
pub bearer_token: Option<&'s BearerToken>,
}

fn construct_request_headers(
media_type: &str,
options: Options<'_>,
) -> Result<header::HeaderMap, header::InvalidHeaderValue> {
let mut headers = header::HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_str(media_type)?,
);
let Options {
taskprov_advertisement,
bearer_token,
} = options;
if let Some(taskprov) = taskprov_advertisement {
headers.insert(
const { header::HeaderName::from_static(http_headers::DAP_TASKPROV) },
header::HeaderValue::from_str(taskprov)?,
);
}
if let Some(token) = bearer_token {
headers.insert(
const { header::HeaderName::from_static(http_headers::DAP_AUTH_TOKEN) },
header::HeaderValue::from_str(token.as_str())?,
);
}
Ok(headers)
}
1 change: 1 addition & 0 deletions crates/dapf/src/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

pub mod helper;
pub mod hpke;
pub mod test_routes;

0 comments on commit 4ef30f1

Please sign in to comment.