Skip to content

Commit

Permalink
Remove version from DapMediaType parsing and stringifying methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess committed Oct 8, 2024
1 parent e56f4f2 commit a0acfef
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 195 deletions.
4 changes: 2 additions & 2 deletions crates/dapf/src/acceptance/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ impl Test {

// Send AggregationJobInitReq.
let headers = construct_request_headers(
DapMediaType::AggregationJobInitReq.as_str_for_version(task_config.version),
DapMediaType::AggregationJobInitReq.as_str(),
taskprov_advertisement.as_deref(),
&self.bearer_token,
)
Expand Down Expand Up @@ -604,7 +604,7 @@ impl Test {
info!("Starting AggregationJobInitReq");
let start = Instant::now();
let headers = construct_request_headers(
DapMediaType::AggregateShareReq.as_str_for_version(version),
DapMediaType::AggregateShareReq.as_str(),
taskprov_advertisement,
&self.bearer_token,
)?;
Expand Down
14 changes: 2 additions & 12 deletions crates/dapf/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,12 +477,7 @@ async fn handle_leader_actions(
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_str(
DapMediaType::Report
.as_str_for_version(version)
.ok_or_else(|| anyhow!("invalid content-type for dap version"))?,
)
.expect("failecd to construct content-type header"),
reqwest::header::HeaderValue::from_static(DapMediaType::Report.as_str()),
);
let resp = http_client
.post(leader_url.join("upload")?)
Expand Down Expand Up @@ -523,12 +518,7 @@ async fn handle_leader_actions(
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_str(
DapMediaType::CollectReq
.as_str_for_version(version)
.ok_or_else(|| anyhow!("invalid content-type for dap version"))?,
)
.expect("failed to construct content-type hader"),
reqwest::header::HeaderValue::from_static(DapMediaType::CollectReq.as_str()),
);
if let Ok(token) = std::env::var("LEADER_BEARER_TOKEN") {
headers.insert(
Expand Down
20 changes: 9 additions & 11 deletions crates/daphne-server/src/roles/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,14 @@ impl crate::App {
M: Send + ParameterizedEncode<DapVersion>,
{
use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue};
let content_type = req
.media_type
.and_then(|mt| mt.as_str_for_version(req.version))
.ok_or_else(|| {
fatal_error!(
err = "failed to construct content-type",
?req.media_type,
?req.version,
)
})?;

let content_type = req.media_type.map(|mt| mt.as_str()).ok_or_else(|| {
fatal_error!(
err = "failed to construct content-type",
?req.media_type,
?req.version,
)
})?;

let mut headers = HeaderMap::new();
headers.insert(
Expand Down Expand Up @@ -217,7 +215,7 @@ impl crate::App {
.get_all(reqwest::header::CONTENT_TYPE)
.into_iter()
.filter_map(|h| h.to_str().ok())
.find_map(|h| DapMediaType::from_str_for_version(req.version, h))
.find_map(DapMediaType::from_http_content_type)
.ok_or_else(|| fatal_error!(err = "peer response is missing media type"))?;

let payload = reqwest_resp
Expand Down
59 changes: 14 additions & 45 deletions crates/daphne-server/src/router/extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,12 @@ where
let msg = "header value contains non ascii or invisible characters".into();
AxumDapResponse::new_error(DapAbort::BadRequest(msg), state.server_metrics())
})?;
let mt =
DapMediaType::from_str_for_version(version, content_type).ok_or_else(|| {
AxumDapResponse::new_error(
DapAbort::BadRequest("invalid media type".into()),
state.server_metrics(),
)
})?;
let mt = DapMediaType::from_http_content_type(content_type).ok_or_else(|| {
AxumDapResponse::new_error(
DapAbort::BadRequest("invalid media type".into()),
state.server_metrics(),
)
})?;
Some(mt)
} else {
None
Expand Down Expand Up @@ -454,12 +453,7 @@ mod test {
"/{version}/{}/parse-mandatory-fields",
task_id.to_base64url()
))
.header(
CONTENT_TYPE,
DapMediaType::AggregateShareReq
.as_str_for_version(version)
.unwrap(),
)
.header(CONTENT_TYPE, DapMediaType::AggregateShareReq.as_str())
.header(http_headers::DAP_AUTH_TOKEN, BEARER_TOKEN)
.body(Body::empty())
.unwrap(),
Expand All @@ -486,12 +480,7 @@ mod test {
task_id.to_base64url(),
agg_job_id.to_base64url(),
))
.header(
CONTENT_TYPE,
DapMediaType::AggregationJobInitReq
.as_str_for_version(version)
.unwrap(),
)
.header(CONTENT_TYPE, DapMediaType::AggregationJobInitReq.as_str())
.header(http_headers::DAP_AUTH_TOKEN, BEARER_TOKEN)
.body(Body::empty())
.unwrap(),
Expand All @@ -518,12 +507,7 @@ mod test {
task_id.to_base64url(),
collect_job_id.to_base64url(),
))
.header(
CONTENT_TYPE,
DapMediaType::CollectReq
.as_str_for_version(version)
.unwrap(),
)
.header(CONTENT_TYPE, DapMediaType::CollectReq.as_str())
.header(http_headers::DAP_AUTH_TOKEN, BEARER_TOKEN)
.body(Body::empty())
.unwrap(),
Expand All @@ -543,10 +527,7 @@ mod test {
let status_code = test(
Request::builder()
.uri(format!("/{version}/{}/auth", mk_task_id().to_base64url()))
.header(
CONTENT_TYPE,
DapMediaType::Report.as_str_for_version(version).unwrap(),
)
.header(CONTENT_TYPE, DapMediaType::Report.as_str())
.header(http_headers::DAP_AUTH_TOKEN, "something incorrect")
.body(Body::empty())
.unwrap(),
Expand All @@ -565,10 +546,7 @@ mod test {
let status_code = test(
Request::builder()
.uri(format!("/{version}/{}/auth", mk_task_id().to_base64url()))
.header(
CONTENT_TYPE,
DapMediaType::Report.as_str_for_version(version).unwrap(),
)
.header(CONTENT_TYPE, DapMediaType::Report.as_str())
.body(Body::empty())
.unwrap(),
)
Expand All @@ -586,10 +564,7 @@ mod test {
let req = test(
Request::builder()
.uri(format!("/{version}/{}/auth", mk_task_id().to_base64url()))
.header(
CONTENT_TYPE,
DapMediaType::Report.as_str_for_version(version).unwrap(),
)
.header(CONTENT_TYPE, DapMediaType::Report.as_str())
.header("X-Client-Cert-Verified", "SUCCESS")
.header(http_headers::DAP_TASKPROV, "some-taskprov-string")
.body(Body::empty())
Expand All @@ -608,10 +583,7 @@ mod test {
let code = test(
Request::builder()
.uri(format!("/{version}/{}/auth", mk_task_id().to_base64url()))
.header(
CONTENT_TYPE,
DapMediaType::Report.as_str_for_version(version).unwrap(),
)
.header(CONTENT_TYPE, DapMediaType::Report.as_str())
.header(http_headers::DAP_AUTH_TOKEN, "something incorrect")
.header("X-Client-Cert-Verified", "SUCCESS")
.body(Body::empty())
Expand All @@ -631,10 +603,7 @@ mod test {
let code = test(
Request::builder()
.uri(format!("/{version}/{}/auth", mk_task_id().to_base64url()))
.header(
CONTENT_TYPE,
DapMediaType::Report.as_str_for_version(version).unwrap(),
)
.header(CONTENT_TYPE, DapMediaType::Report.as_str())
.header(http_headers::DAP_AUTH_TOKEN, BEARER_TOKEN)
.header("X-Client-Cert-Verified", "FAILED")
.body(Body::empty())
Expand Down
23 changes: 3 additions & 20 deletions crates/daphne-server/src/router/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@ use axum::{
response::IntoResponse,
Json,
};
use daphne::{
error::DapAbort, fatal_error, messages::TaskId, DapError, DapRequestMeta, DapResponse,
DapSender,
};
use daphne::{error::DapAbort, messages::TaskId, DapError, DapRequestMeta, DapResponse, DapSender};
use daphne_service_utils::{bearer_token::BearerToken, metrics::DaphneServiceMetrics, DapRole};
use either::Either;
use http::Request;
Expand Down Expand Up @@ -148,24 +145,10 @@ impl AxumDapResponse {

pub fn new_success_with_code(
response: DapResponse,
metrics: &dyn DaphneServiceMetrics,
_metrics: &dyn DaphneServiceMetrics,
status_code: StatusCode,
) -> Self {
let Some(media_type) = response.media_type.as_str_for_version(response.version) else {
return AxumDapResponse::new_error(
fatal_error!(err = "invalid content-type for DAP version"),
metrics,
);
};
let media_type = match HeaderValue::from_str(media_type) {
Ok(media_type) => media_type,
Err(e) => {
return AxumDapResponse::new_error(
fatal_error!(err = ?e, "content-type contained invalid bytes {media_type:?}"),
metrics,
)
}
};
let media_type = HeaderValue::from_static(response.media_type.as_str());

let headers = [(CONTENT_TYPE, media_type)];

Expand Down
6 changes: 1 addition & 5 deletions crates/daphne-server/tests/e2e/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,7 @@ async fn leader_upload(version: DapVersion) {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
DapMediaType::Report
.as_str_for_version(version)
.unwrap()
.parse()
.unwrap(),
reqwest::header::HeaderValue::from_static(DapMediaType::Report.as_str()),
);
let builder = client.put(url.as_str());
let resp = builder
Expand Down
32 changes: 6 additions & 26 deletions crates/daphne-server/tests/e2e/test_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,7 @@ impl TestRunner {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
media_type
.as_str_for_version(self.version)
.context("no string for version")?
.parse()?,
reqwest::header::HeaderValue::from_static(media_type.as_str()),
);
if let Some(taskprov_advertisement) = taskprov {
headers.insert(
Expand Down Expand Up @@ -418,10 +415,7 @@ impl TestRunner {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
media_type
.as_str_for_version(self.version)
.context("no string for version")?
.parse()?,
reqwest::header::HeaderValue::from_static(media_type.as_str()),
);
if let Some(token) = dap_auth_token {
headers.insert(
Expand Down Expand Up @@ -481,10 +475,7 @@ impl TestRunner {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
media_type
.as_str_for_version(self.version)
.context("no string for version")?
.parse()?,
reqwest::header::HeaderValue::from_static(media_type.as_str()),
);
headers.insert(
reqwest::header::HeaderName::from_static(http_headers::DAP_AUTH_TOKEN),
Expand Down Expand Up @@ -531,10 +522,7 @@ impl TestRunner {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
media_type
.as_str_for_version(self.version)
.context("no string for version")?
.parse()?,
reqwest::header::HeaderValue::from_static(media_type.as_str()),
);
if let Some(token) = dap_auth_token {
headers.insert(
Expand Down Expand Up @@ -589,11 +577,7 @@ impl TestRunner {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_str(
DapMediaType::CollectReq
.as_str_for_version(self.version)
.context("no string for version")?,
)?,
reqwest::header::HeaderValue::from_static(DapMediaType::CollectReq.as_str()),
);
headers.insert(
reqwest::header::HeaderName::from_static(http_headers::DAP_AUTH_TOKEN),
Expand Down Expand Up @@ -778,11 +762,7 @@ impl TestRunner {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_str(
DapMediaType::CollectReq
.as_str_for_version(self.version)
.context("no string for version")?,
)?,
reqwest::header::HeaderValue::from_static(DapMediaType::CollectReq.as_str()),
);
headers.insert(
reqwest::header::HeaderName::from_static(http_headers::DAP_AUTH_TOKEN),
Expand Down
Loading

0 comments on commit a0acfef

Please sign in to comment.