Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove DapVersion from DapMediaType parsing #689

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/dapf/src/acceptance/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ impl Test {

// Send AggregationJobInitReq.
let headers = construct_request_headers(
DapMediaType::AggregationJobInitReq.as_str_for_version(task_config.version),
DapMediaType::AggregationJobInitReq.as_str(),
taskprov_advertisement.as_deref(),
&self.bearer_token,
)
Expand Down Expand Up @@ -604,7 +604,7 @@ impl Test {
info!("Starting AggregationJobInitReq");
let start = Instant::now();
let headers = construct_request_headers(
DapMediaType::AggregateShareReq.as_str_for_version(version),
DapMediaType::AggregateShareReq.as_str(),
taskprov_advertisement,
&self.bearer_token,
)?;
Expand Down
14 changes: 2 additions & 12 deletions crates/dapf/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,12 +477,7 @@ async fn handle_leader_actions(
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_str(
DapMediaType::Report
.as_str_for_version(version)
.ok_or_else(|| anyhow!("invalid content-type for dap version"))?,
)
.expect("failecd to construct content-type header"),
reqwest::header::HeaderValue::from_static(DapMediaType::Report.as_str()),
);
let resp = http_client
.post(leader_url.join("upload")?)
Expand Down Expand Up @@ -523,12 +518,7 @@ async fn handle_leader_actions(
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_str(
DapMediaType::CollectReq
.as_str_for_version(version)
.ok_or_else(|| anyhow!("invalid content-type for dap version"))?,
)
.expect("failed to construct content-type hader"),
reqwest::header::HeaderValue::from_static(DapMediaType::CollectReq.as_str()),
);
if let Ok(token) = std::env::var("LEADER_BEARER_TOKEN") {
headers.insert(
Expand Down
20 changes: 9 additions & 11 deletions crates/daphne-server/src/roles/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,14 @@ impl crate::App {
M: Send + ParameterizedEncode<DapVersion>,
{
use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue};
let content_type = req
.media_type
.and_then(|mt| mt.as_str_for_version(req.version))
.ok_or_else(|| {
fatal_error!(
err = "failed to construct content-type",
?req.media_type,
?req.version,
)
})?;

let content_type = req.media_type.map(|mt| mt.as_str()).ok_or_else(|| {
fatal_error!(
err = "failed to construct content-type",
?req.media_type,
?req.version,
)
})?;

let mut headers = HeaderMap::new();
headers.insert(
Expand Down Expand Up @@ -217,7 +215,7 @@ impl crate::App {
.get_all(reqwest::header::CONTENT_TYPE)
.into_iter()
.filter_map(|h| h.to_str().ok())
.find_map(|h| DapMediaType::from_str_for_version(req.version, h))
.find_map(DapMediaType::from_http_content_type)
.ok_or_else(|| fatal_error!(err = "peer response is missing media type"))?;

let payload = reqwest_resp
Expand Down
59 changes: 14 additions & 45 deletions crates/daphne-server/src/router/extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,12 @@ where
let msg = "header value contains non ascii or invisible characters".into();
AxumDapResponse::new_error(DapAbort::BadRequest(msg), state.server_metrics())
})?;
let mt =
DapMediaType::from_str_for_version(version, content_type).ok_or_else(|| {
AxumDapResponse::new_error(
DapAbort::BadRequest("invalid media type".into()),
state.server_metrics(),
)
})?;
let mt = DapMediaType::from_http_content_type(content_type).ok_or_else(|| {
AxumDapResponse::new_error(
DapAbort::BadRequest("invalid media type".into()),
state.server_metrics(),
)
})?;
Some(mt)
} else {
None
Expand Down Expand Up @@ -454,12 +453,7 @@ mod test {
"/{version}/{}/parse-mandatory-fields",
task_id.to_base64url()
))
.header(
CONTENT_TYPE,
DapMediaType::AggregateShareReq
.as_str_for_version(version)
.unwrap(),
)
.header(CONTENT_TYPE, DapMediaType::AggregateShareReq.as_str())
.header(http_headers::DAP_AUTH_TOKEN, BEARER_TOKEN)
.body(Body::empty())
.unwrap(),
Expand All @@ -486,12 +480,7 @@ mod test {
task_id.to_base64url(),
agg_job_id.to_base64url(),
))
.header(
CONTENT_TYPE,
DapMediaType::AggregationJobInitReq
.as_str_for_version(version)
.unwrap(),
)
.header(CONTENT_TYPE, DapMediaType::AggregationJobInitReq.as_str())
.header(http_headers::DAP_AUTH_TOKEN, BEARER_TOKEN)
.body(Body::empty())
.unwrap(),
Expand All @@ -518,12 +507,7 @@ mod test {
task_id.to_base64url(),
collect_job_id.to_base64url(),
))
.header(
CONTENT_TYPE,
DapMediaType::CollectReq
.as_str_for_version(version)
.unwrap(),
)
.header(CONTENT_TYPE, DapMediaType::CollectReq.as_str())
.header(http_headers::DAP_AUTH_TOKEN, BEARER_TOKEN)
.body(Body::empty())
.unwrap(),
Expand All @@ -543,10 +527,7 @@ mod test {
let status_code = test(
Request::builder()
.uri(format!("/{version}/{}/auth", mk_task_id().to_base64url()))
.header(
CONTENT_TYPE,
DapMediaType::Report.as_str_for_version(version).unwrap(),
)
.header(CONTENT_TYPE, DapMediaType::Report.as_str())
.header(http_headers::DAP_AUTH_TOKEN, "something incorrect")
.body(Body::empty())
.unwrap(),
Expand All @@ -565,10 +546,7 @@ mod test {
let status_code = test(
Request::builder()
.uri(format!("/{version}/{}/auth", mk_task_id().to_base64url()))
.header(
CONTENT_TYPE,
DapMediaType::Report.as_str_for_version(version).unwrap(),
)
.header(CONTENT_TYPE, DapMediaType::Report.as_str())
.body(Body::empty())
.unwrap(),
)
Expand All @@ -586,10 +564,7 @@ mod test {
let req = test(
Request::builder()
.uri(format!("/{version}/{}/auth", mk_task_id().to_base64url()))
.header(
CONTENT_TYPE,
DapMediaType::Report.as_str_for_version(version).unwrap(),
)
.header(CONTENT_TYPE, DapMediaType::Report.as_str())
.header("X-Client-Cert-Verified", "SUCCESS")
.header(http_headers::DAP_TASKPROV, "some-taskprov-string")
.body(Body::empty())
Expand All @@ -608,10 +583,7 @@ mod test {
let code = test(
Request::builder()
.uri(format!("/{version}/{}/auth", mk_task_id().to_base64url()))
.header(
CONTENT_TYPE,
DapMediaType::Report.as_str_for_version(version).unwrap(),
)
.header(CONTENT_TYPE, DapMediaType::Report.as_str())
.header(http_headers::DAP_AUTH_TOKEN, "something incorrect")
.header("X-Client-Cert-Verified", "SUCCESS")
.body(Body::empty())
Expand All @@ -631,10 +603,7 @@ mod test {
let code = test(
Request::builder()
.uri(format!("/{version}/{}/auth", mk_task_id().to_base64url()))
.header(
CONTENT_TYPE,
DapMediaType::Report.as_str_for_version(version).unwrap(),
)
.header(CONTENT_TYPE, DapMediaType::Report.as_str())
.header(http_headers::DAP_AUTH_TOKEN, BEARER_TOKEN)
.header("X-Client-Cert-Verified", "FAILED")
.body(Body::empty())
Expand Down
23 changes: 3 additions & 20 deletions crates/daphne-server/src/router/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@ use axum::{
response::IntoResponse,
Json,
};
use daphne::{
error::DapAbort, fatal_error, messages::TaskId, DapError, DapRequestMeta, DapResponse,
DapSender,
};
use daphne::{error::DapAbort, messages::TaskId, DapError, DapRequestMeta, DapResponse, DapSender};
use daphne_service_utils::{bearer_token::BearerToken, metrics::DaphneServiceMetrics, DapRole};
use either::Either;
use http::Request;
Expand Down Expand Up @@ -148,24 +145,10 @@ impl AxumDapResponse {

pub fn new_success_with_code(
response: DapResponse,
metrics: &dyn DaphneServiceMetrics,
_metrics: &dyn DaphneServiceMetrics,
status_code: StatusCode,
) -> Self {
let Some(media_type) = response.media_type.as_str_for_version(response.version) else {
return AxumDapResponse::new_error(
fatal_error!(err = "invalid content-type for DAP version"),
metrics,
);
};
let media_type = match HeaderValue::from_str(media_type) {
Ok(media_type) => media_type,
Err(e) => {
return AxumDapResponse::new_error(
fatal_error!(err = ?e, "content-type contained invalid bytes {media_type:?}"),
metrics,
)
}
};
let media_type = HeaderValue::from_static(response.media_type.as_str());

let headers = [(CONTENT_TYPE, media_type)];

Expand Down
6 changes: 1 addition & 5 deletions crates/daphne-server/tests/e2e/e2e.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,7 @@ async fn leader_upload(version: DapVersion) {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
DapMediaType::Report
.as_str_for_version(version)
.unwrap()
.parse()
.unwrap(),
reqwest::header::HeaderValue::from_static(DapMediaType::Report.as_str()),
);
let builder = client.put(url.as_str());
let resp = builder
Expand Down
32 changes: 6 additions & 26 deletions crates/daphne-server/tests/e2e/test_runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,7 @@ impl TestRunner {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
media_type
.as_str_for_version(self.version)
.context("no string for version")?
.parse()?,
reqwest::header::HeaderValue::from_static(media_type.as_str()),
);
if let Some(taskprov_advertisement) = taskprov {
headers.insert(
Expand Down Expand Up @@ -418,10 +415,7 @@ impl TestRunner {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
media_type
.as_str_for_version(self.version)
.context("no string for version")?
.parse()?,
reqwest::header::HeaderValue::from_static(media_type.as_str()),
);
if let Some(token) = dap_auth_token {
headers.insert(
Expand Down Expand Up @@ -481,10 +475,7 @@ impl TestRunner {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
media_type
.as_str_for_version(self.version)
.context("no string for version")?
.parse()?,
reqwest::header::HeaderValue::from_static(media_type.as_str()),
);
headers.insert(
reqwest::header::HeaderName::from_static(http_headers::DAP_AUTH_TOKEN),
Expand Down Expand Up @@ -531,10 +522,7 @@ impl TestRunner {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
media_type
.as_str_for_version(self.version)
.context("no string for version")?
.parse()?,
reqwest::header::HeaderValue::from_static(media_type.as_str()),
);
if let Some(token) = dap_auth_token {
headers.insert(
Expand Down Expand Up @@ -589,11 +577,7 @@ impl TestRunner {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_str(
DapMediaType::CollectReq
.as_str_for_version(self.version)
.context("no string for version")?,
)?,
reqwest::header::HeaderValue::from_static(DapMediaType::CollectReq.as_str()),
);
headers.insert(
reqwest::header::HeaderName::from_static(http_headers::DAP_AUTH_TOKEN),
Expand Down Expand Up @@ -778,11 +762,7 @@ impl TestRunner {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
reqwest::header::HeaderValue::from_str(
DapMediaType::CollectReq
.as_str_for_version(self.version)
.context("no string for version")?,
)?,
reqwest::header::HeaderValue::from_static(DapMediaType::CollectReq.as_str()),
);
headers.insert(
reqwest::header::HeaderName::from_static(http_headers::DAP_AUTH_TOKEN),
Expand Down
Loading
Loading