diff --git a/crates/dapf/src/acceptance/mod.rs b/crates/dapf/src/acceptance/mod.rs index a2ecd0ad..48031281 100644 --- a/crates/dapf/src/acceptance/mod.rs +++ b/crates/dapf/src/acceptance/mod.rs @@ -510,7 +510,7 @@ impl Test { // Send AggregationJobInitReq. let headers = construct_request_headers( - DapMediaType::AggregationJobInitReq.as_str_for_version(task_config.version), + DapMediaType::AggregationJobInitReq.as_str(), taskprov_advertisement.as_deref(), &self.bearer_token, ) @@ -604,7 +604,7 @@ impl Test { info!("Starting AggregationJobInitReq"); let start = Instant::now(); let headers = construct_request_headers( - DapMediaType::AggregateShareReq.as_str_for_version(version), + DapMediaType::AggregateShareReq.as_str(), taskprov_advertisement, &self.bearer_token, )?; diff --git a/crates/dapf/src/main.rs b/crates/dapf/src/main.rs index b41bff56..831bbc50 100644 --- a/crates/dapf/src/main.rs +++ b/crates/dapf/src/main.rs @@ -477,12 +477,7 @@ async fn handle_leader_actions( let mut headers = reqwest::header::HeaderMap::new(); headers.insert( reqwest::header::CONTENT_TYPE, - reqwest::header::HeaderValue::from_str( - DapMediaType::Report - .as_str_for_version(version) - .ok_or_else(|| anyhow!("invalid content-type for dap version"))?, - ) - .expect("failecd to construct content-type header"), + reqwest::header::HeaderValue::from_static(DapMediaType::Report.as_str()), ); let resp = http_client .post(leader_url.join("upload")?) @@ -523,12 +518,7 @@ async fn handle_leader_actions( let mut headers = reqwest::header::HeaderMap::new(); headers.insert( reqwest::header::CONTENT_TYPE, - reqwest::header::HeaderValue::from_str( - DapMediaType::CollectReq - .as_str_for_version(version) - .ok_or_else(|| anyhow!("invalid content-type for dap version"))?, - ) - .expect("failed to construct content-type hader"), + reqwest::header::HeaderValue::from_static(DapMediaType::CollectReq.as_str()), ); if let Ok(token) = std::env::var("LEADER_BEARER_TOKEN") { headers.insert( diff --git a/crates/daphne-server/src/roles/leader.rs b/crates/daphne-server/src/roles/leader.rs index 521c6f2d..726e3231 100644 --- a/crates/daphne-server/src/roles/leader.rs +++ b/crates/daphne-server/src/roles/leader.rs @@ -126,16 +126,14 @@ impl crate::App { M: Send + ParameterizedEncode, { use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue}; - let content_type = req - .media_type - .and_then(|mt| mt.as_str_for_version(req.version)) - .ok_or_else(|| { - fatal_error!( - err = "failed to construct content-type", - ?req.media_type, - ?req.version, - ) - })?; + + let content_type = req.media_type.map(|mt| mt.as_str()).ok_or_else(|| { + fatal_error!( + err = "failed to construct content-type", + ?req.media_type, + ?req.version, + ) + })?; let mut headers = HeaderMap::new(); headers.insert( @@ -217,7 +215,7 @@ impl crate::App { .get_all(reqwest::header::CONTENT_TYPE) .into_iter() .filter_map(|h| h.to_str().ok()) - .find_map(|h| DapMediaType::from_str_for_version(req.version, h)) + .find_map(DapMediaType::from_http_content_type) .ok_or_else(|| fatal_error!(err = "peer response is missing media type"))?; let payload = reqwest_resp diff --git a/crates/daphne-server/src/router/extractor.rs b/crates/daphne-server/src/router/extractor.rs index 3bee7586..51511f9d 100644 --- a/crates/daphne-server/src/router/extractor.rs +++ b/crates/daphne-server/src/router/extractor.rs @@ -138,13 +138,12 @@ where let msg = "header value contains non ascii or invisible characters".into(); AxumDapResponse::new_error(DapAbort::BadRequest(msg), state.server_metrics()) })?; - let mt = - DapMediaType::from_str_for_version(version, content_type).ok_or_else(|| { - AxumDapResponse::new_error( - DapAbort::BadRequest("invalid media type".into()), - state.server_metrics(), - ) - })?; + let mt = DapMediaType::from_http_content_type(content_type).ok_or_else(|| { + AxumDapResponse::new_error( + DapAbort::BadRequest("invalid media type".into()), + state.server_metrics(), + ) + })?; Some(mt) } else { None @@ -454,12 +453,7 @@ mod test { "/{version}/{}/parse-mandatory-fields", task_id.to_base64url() )) - .header( - CONTENT_TYPE, - DapMediaType::AggregateShareReq - .as_str_for_version(version) - .unwrap(), - ) + .header(CONTENT_TYPE, DapMediaType::AggregateShareReq.as_str()) .header(http_headers::DAP_AUTH_TOKEN, BEARER_TOKEN) .body(Body::empty()) .unwrap(), @@ -486,12 +480,7 @@ mod test { task_id.to_base64url(), agg_job_id.to_base64url(), )) - .header( - CONTENT_TYPE, - DapMediaType::AggregationJobInitReq - .as_str_for_version(version) - .unwrap(), - ) + .header(CONTENT_TYPE, DapMediaType::AggregationJobInitReq.as_str()) .header(http_headers::DAP_AUTH_TOKEN, BEARER_TOKEN) .body(Body::empty()) .unwrap(), @@ -518,12 +507,7 @@ mod test { task_id.to_base64url(), collect_job_id.to_base64url(), )) - .header( - CONTENT_TYPE, - DapMediaType::CollectReq - .as_str_for_version(version) - .unwrap(), - ) + .header(CONTENT_TYPE, DapMediaType::CollectReq.as_str()) .header(http_headers::DAP_AUTH_TOKEN, BEARER_TOKEN) .body(Body::empty()) .unwrap(), @@ -543,10 +527,7 @@ mod test { let status_code = test( Request::builder() .uri(format!("/{version}/{}/auth", mk_task_id().to_base64url())) - .header( - CONTENT_TYPE, - DapMediaType::Report.as_str_for_version(version).unwrap(), - ) + .header(CONTENT_TYPE, DapMediaType::Report.as_str()) .header(http_headers::DAP_AUTH_TOKEN, "something incorrect") .body(Body::empty()) .unwrap(), @@ -565,10 +546,7 @@ mod test { let status_code = test( Request::builder() .uri(format!("/{version}/{}/auth", mk_task_id().to_base64url())) - .header( - CONTENT_TYPE, - DapMediaType::Report.as_str_for_version(version).unwrap(), - ) + .header(CONTENT_TYPE, DapMediaType::Report.as_str()) .body(Body::empty()) .unwrap(), ) @@ -586,10 +564,7 @@ mod test { let req = test( Request::builder() .uri(format!("/{version}/{}/auth", mk_task_id().to_base64url())) - .header( - CONTENT_TYPE, - DapMediaType::Report.as_str_for_version(version).unwrap(), - ) + .header(CONTENT_TYPE, DapMediaType::Report.as_str()) .header("X-Client-Cert-Verified", "SUCCESS") .header(http_headers::DAP_TASKPROV, "some-taskprov-string") .body(Body::empty()) @@ -608,10 +583,7 @@ mod test { let code = test( Request::builder() .uri(format!("/{version}/{}/auth", mk_task_id().to_base64url())) - .header( - CONTENT_TYPE, - DapMediaType::Report.as_str_for_version(version).unwrap(), - ) + .header(CONTENT_TYPE, DapMediaType::Report.as_str()) .header(http_headers::DAP_AUTH_TOKEN, "something incorrect") .header("X-Client-Cert-Verified", "SUCCESS") .body(Body::empty()) @@ -631,10 +603,7 @@ mod test { let code = test( Request::builder() .uri(format!("/{version}/{}/auth", mk_task_id().to_base64url())) - .header( - CONTENT_TYPE, - DapMediaType::Report.as_str_for_version(version).unwrap(), - ) + .header(CONTENT_TYPE, DapMediaType::Report.as_str()) .header(http_headers::DAP_AUTH_TOKEN, BEARER_TOKEN) .header("X-Client-Cert-Verified", "FAILED") .body(Body::empty()) diff --git a/crates/daphne-server/src/router/mod.rs b/crates/daphne-server/src/router/mod.rs index ad7b716f..2f5d45e3 100644 --- a/crates/daphne-server/src/router/mod.rs +++ b/crates/daphne-server/src/router/mod.rs @@ -18,10 +18,7 @@ use axum::{ response::IntoResponse, Json, }; -use daphne::{ - error::DapAbort, fatal_error, messages::TaskId, DapError, DapRequestMeta, DapResponse, - DapSender, -}; +use daphne::{error::DapAbort, messages::TaskId, DapError, DapRequestMeta, DapResponse, DapSender}; use daphne_service_utils::{bearer_token::BearerToken, metrics::DaphneServiceMetrics, DapRole}; use either::Either; use http::Request; @@ -148,24 +145,10 @@ impl AxumDapResponse { pub fn new_success_with_code( response: DapResponse, - metrics: &dyn DaphneServiceMetrics, + _metrics: &dyn DaphneServiceMetrics, status_code: StatusCode, ) -> Self { - let Some(media_type) = response.media_type.as_str_for_version(response.version) else { - return AxumDapResponse::new_error( - fatal_error!(err = "invalid content-type for DAP version"), - metrics, - ); - }; - let media_type = match HeaderValue::from_str(media_type) { - Ok(media_type) => media_type, - Err(e) => { - return AxumDapResponse::new_error( - fatal_error!(err = ?e, "content-type contained invalid bytes {media_type:?}"), - metrics, - ) - } - }; + let media_type = HeaderValue::from_static(response.media_type.as_str()); let headers = [(CONTENT_TYPE, media_type)]; diff --git a/crates/daphne-server/tests/e2e/e2e.rs b/crates/daphne-server/tests/e2e/e2e.rs index 6198cc10..ec796af7 100644 --- a/crates/daphne-server/tests/e2e/e2e.rs +++ b/crates/daphne-server/tests/e2e/e2e.rs @@ -275,11 +275,7 @@ async fn leader_upload(version: DapVersion) { let mut headers = reqwest::header::HeaderMap::new(); headers.insert( reqwest::header::CONTENT_TYPE, - DapMediaType::Report - .as_str_for_version(version) - .unwrap() - .parse() - .unwrap(), + reqwest::header::HeaderValue::from_static(DapMediaType::Report.as_str()), ); let builder = client.put(url.as_str()); let resp = builder diff --git a/crates/daphne-server/tests/e2e/test_runner.rs b/crates/daphne-server/tests/e2e/test_runner.rs index a1e6ef4d..0a668525 100644 --- a/crates/daphne-server/tests/e2e/test_runner.rs +++ b/crates/daphne-server/tests/e2e/test_runner.rs @@ -372,10 +372,7 @@ impl TestRunner { let mut headers = reqwest::header::HeaderMap::new(); headers.insert( reqwest::header::CONTENT_TYPE, - media_type - .as_str_for_version(self.version) - .context("no string for version")? - .parse()?, + reqwest::header::HeaderValue::from_static(media_type.as_str()), ); if let Some(taskprov_advertisement) = taskprov { headers.insert( @@ -418,10 +415,7 @@ impl TestRunner { let mut headers = reqwest::header::HeaderMap::new(); headers.insert( reqwest::header::CONTENT_TYPE, - media_type - .as_str_for_version(self.version) - .context("no string for version")? - .parse()?, + reqwest::header::HeaderValue::from_static(media_type.as_str()), ); if let Some(token) = dap_auth_token { headers.insert( @@ -481,10 +475,7 @@ impl TestRunner { let mut headers = reqwest::header::HeaderMap::new(); headers.insert( reqwest::header::CONTENT_TYPE, - media_type - .as_str_for_version(self.version) - .context("no string for version")? - .parse()?, + reqwest::header::HeaderValue::from_static(media_type.as_str()), ); headers.insert( reqwest::header::HeaderName::from_static(http_headers::DAP_AUTH_TOKEN), @@ -531,10 +522,7 @@ impl TestRunner { let mut headers = reqwest::header::HeaderMap::new(); headers.insert( reqwest::header::CONTENT_TYPE, - media_type - .as_str_for_version(self.version) - .context("no string for version")? - .parse()?, + reqwest::header::HeaderValue::from_static(media_type.as_str()), ); if let Some(token) = dap_auth_token { headers.insert( @@ -589,11 +577,7 @@ impl TestRunner { let mut headers = reqwest::header::HeaderMap::new(); headers.insert( reqwest::header::CONTENT_TYPE, - reqwest::header::HeaderValue::from_str( - DapMediaType::CollectReq - .as_str_for_version(self.version) - .context("no string for version")?, - )?, + reqwest::header::HeaderValue::from_static(DapMediaType::CollectReq.as_str()), ); headers.insert( reqwest::header::HeaderName::from_static(http_headers::DAP_AUTH_TOKEN), @@ -778,11 +762,7 @@ impl TestRunner { let mut headers = reqwest::header::HeaderMap::new(); headers.insert( reqwest::header::CONTENT_TYPE, - reqwest::header::HeaderValue::from_str( - DapMediaType::CollectReq - .as_str_for_version(self.version) - .context("no string for version")?, - )?, + reqwest::header::HeaderValue::from_static(DapMediaType::CollectReq.as_str()), ); headers.insert( reqwest::header::HeaderName::from_static(http_headers::DAP_AUTH_TOKEN), diff --git a/crates/daphne/src/constants.rs b/crates/daphne/src/constants.rs index e9d2f420..e3c7e312 100644 --- a/crates/daphne/src/constants.rs +++ b/crates/daphne/src/constants.rs @@ -3,7 +3,10 @@ //! Constants used in the DAP protocol. -use crate::{DapSender, DapVersion}; +use core::fmt; +use std::str::FromStr; + +use crate::DapSender; // Media types for HTTP requests. const MEDIA_TYPE_AGG_JOB_INIT_REQ: &str = "application/dap-aggregation-job-init-req"; @@ -45,9 +48,31 @@ impl DapMediaType { } /// Parse the media type from the content-type HTTP header. - pub fn from_str_for_version(_version: DapVersion, content_type: &str) -> Option { + pub fn from_http_content_type(content_type: &str) -> Option { let (content_type, _) = content_type.split_once(';').unwrap_or((content_type, "")); - let media_type = match content_type { + content_type.parse().ok() + } + + /// If the media type is used with the current DAP version, then return its representation as + /// an HTTP content type. + pub const fn as_str(&self) -> &'static str { + match self { + Self::AggregationJobInitReq => MEDIA_TYPE_AGG_JOB_INIT_REQ, + Self::AggregationJobResp => MEDIA_TYPE_AGG_JOB_RESP, + Self::AggregateShareReq => MEDIA_TYPE_AGG_SHARE_REQ, + Self::AggregateShare => MEDIA_TYPE_AGG_SHARE, + Self::CollectReq => MEDIA_TYPE_COLLECT_REQ, + Self::Collection => MEDIA_TYPE_COLLECTION, + Self::HpkeConfigList => MEDIA_TYPE_HPKE_CONFIG_LIST, + Self::Report => MEDIA_TYPE_REPORT, + } + } +} + +impl FromStr for DapMediaType { + type Err = String; + fn from_str(s: &str) -> Result { + let media_type = match s { MEDIA_TYPE_AGG_JOB_INIT_REQ => Self::AggregationJobInitReq, MEDIA_TYPE_AGG_JOB_RESP => Self::AggregationJobResp, MEDIA_TYPE_AGG_SHARE => Self::AggregateShare, @@ -56,24 +81,15 @@ impl DapMediaType { MEDIA_TYPE_AGG_SHARE_REQ => Self::AggregateShareReq, MEDIA_TYPE_COLLECT_REQ => Self::CollectReq, MEDIA_TYPE_REPORT => Self::Report, - _ => return None, + _ => return Err(format!("invalid media type: {s}")), }; - Some(media_type) + Ok(media_type) } +} - /// If the media type is used with the current DAP version, then return its representation as - /// an HTTP content type. - pub fn as_str_for_version(&self, _version: DapVersion) -> Option<&'static str> { - match self { - Self::AggregationJobInitReq => Some(MEDIA_TYPE_AGG_JOB_INIT_REQ), - Self::AggregationJobResp => Some(MEDIA_TYPE_AGG_JOB_RESP), - Self::AggregateShareReq => Some(MEDIA_TYPE_AGG_SHARE_REQ), - Self::AggregateShare => Some(MEDIA_TYPE_AGG_SHARE), - Self::CollectReq => Some(MEDIA_TYPE_COLLECT_REQ), - Self::Collection => Some(MEDIA_TYPE_COLLECTION), - Self::HpkeConfigList => Some(MEDIA_TYPE_HPKE_CONFIG_LIST), - Self::Report => Some(MEDIA_TYPE_REPORT), - } +impl fmt::Display for DapMediaType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.as_str()) } } @@ -86,82 +102,59 @@ mod test { #[test] fn from_str_for_version() { assert_eq!( - DapMediaType::from_str_for_version( - DapVersion::Draft09, - "application/dap-hpke-config-list", - ), + DapMediaType::from_http_content_type("application/dap-hpke-config-list",), Some(DapMediaType::HpkeConfigList) ); assert_eq!( - DapMediaType::from_str_for_version( - DapVersion::Draft09, - "application/dap-aggregation-job-init-req" - ), + DapMediaType::from_http_content_type("application/dap-aggregation-job-init-req"), Some(DapMediaType::AggregationJobInitReq), ); assert_eq!( - DapMediaType::from_str_for_version( - DapVersion::Draft09, - "application/dap-aggregation-job-resp" - ), + DapMediaType::from_http_content_type("application/dap-aggregation-job-resp"), Some(DapMediaType::AggregationJobResp), ); assert_eq!( - DapMediaType::from_str_for_version( - DapVersion::Draft09, - "application/dap-aggregate-share-req" - ), + DapMediaType::from_http_content_type("application/dap-aggregate-share-req"), Some(DapMediaType::AggregateShareReq), ); assert_eq!( - DapMediaType::from_str_for_version( - DapVersion::Draft09, - "application/dap-aggregate-share" - ), + DapMediaType::from_http_content_type("application/dap-aggregate-share"), Some(DapMediaType::AggregateShare), ); assert_eq!( - DapMediaType::from_str_for_version(DapVersion::Draft09, "application/dap-collect-req"), + DapMediaType::from_http_content_type("application/dap-collect-req"), Some(DapMediaType::CollectReq), ); assert_eq!( - DapMediaType::from_str_for_version(DapVersion::Draft09, "application/dap-collection"), + DapMediaType::from_http_content_type("application/dap-collection"), Some(DapMediaType::Collection), ); // Invalid media type - assert_eq!( - DapMediaType::from_str_for_version(DapVersion::Draft09, "blah-blah-blah"), - None, - ); + assert_eq!(DapMediaType::from_http_content_type("blah-blah-blah"), None,); } // Test conversion of DAP media types to and from the content-type HTTP header. fn round_trip(version: DapVersion) { for media_type in DapMediaType::iter() { - if let Some(content_type) = media_type.as_str_for_version(version) { - // If the DAP media type is used for this version of DAP, then expect decoding the - // content-type should result in the same DAP media type. - assert_eq!( - DapMediaType::from_str_for_version(version, content_type).unwrap(), - media_type, - "round trip test failed for {version:?} and {media_type:?}" - ); - } + let content_type = media_type.as_str(); + assert_eq!( + DapMediaType::from_http_content_type(content_type).unwrap(), + media_type, + "round trip test failed for {version:?} and {media_type:?}" + ); } } test_versions! { round_trip } - fn media_type_parsing_ignores_content_type_paramters(version: DapVersion) { + #[test] + fn media_type_parsing_ignores_content_type_paramters() { assert_eq!( - DapMediaType::from_str_for_version( - version, + DapMediaType::from_http_content_type( "application/dap-aggregation-job-init-req;version=09", ), Some(DapMediaType::AggregationJobInitReq), ); } - - test_versions! { media_type_parsing_ignores_content_type_paramters } } diff --git a/crates/daphne/src/error/aborts.rs b/crates/daphne/src/error/aborts.rs index a110f704..91bfb764 100644 --- a/crates/daphne/src/error/aborts.rs +++ b/crates/daphne/src/error/aborts.rs @@ -183,25 +183,13 @@ impl DapAbort { /// Abort due to unexpected value for HTTP content-type header. pub fn content_type(req: &DapRequestMeta, expected: [DapMediaType; N]) -> Self { - let want_content_type = expected.map(|m| { - m.as_str_for_version(req.version).unwrap_or_else(|| { - unreachable!("unexpected content-type for DAP version {:?}", req.version) - }) - }); + let want_content_type = expected.map(|m| m.as_str()); let Some(media_type) = req.media_type else { return Self::BadRequest("missing content-type".into()); }; - let got_content_type = media_type - .as_str_for_version(req.version) - .unwrap_or_else(|| { - unreachable!( - "missing or unexpected content type for DAP version {:?}", - req.version - ) - }); - + let got_content_type = media_type.as_str(); Self::BadRequest(format!( "unexpected content-type: got {got_content_type}; want any of {want_content_type:?}" )) diff --git a/crates/daphne/src/roles/leader/mod.rs b/crates/daphne/src/roles/leader/mod.rs index 0362cffe..ae38074c 100644 --- a/crates/daphne/src/roles/leader/mod.rs +++ b/crates/daphne/src/roles/leader/mod.rs @@ -631,8 +631,8 @@ fn check_response_content_type(resp: &DapResponse, expected: DapMediaType) -> Re if resp.media_type != expected { Err(fatal_error!( err = "response from peer has unexpected content-type", - got = resp.media_type.as_str_for_version(resp.version), - want = expected.as_str_for_version(resp.version), + got = resp.media_type.as_str(), + want = expected.as_str(), )) } else { Ok(())