Skip to content

Commit

Permalink
Clean up bearer token resolution logic
Browse files Browse the repository at this point in the history
Decouple taskprov from the `BearerTokenProvider` by getting rid of the
`is_taskprov_*token()` calls and forcing the `get_*token()` calls to
resolve the taskprov tokens.

Accordingly, pass the `DapTaskConfig` so that the implementer knows if
the task was configured via taskprov. This requires resolving the
`DapTaskConfig` before authorizing the request, so re-order the logic in
`DapLeader` and `DapHelper` accordingly.

While at it, pick up an efficiency improvement for DaphneWorker: if
taskprov is enabled for the task, then we use the token from the global
taskprov config and avoid a KV lookup.
  • Loading branch information
cjpatton committed Sep 6, 2023
1 parent 9e544f3 commit 2cbcadd
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 167 deletions.
39 changes: 15 additions & 24 deletions daphne/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
constants::DapMediaType,
fatal_error,
messages::{constant_time_eq, TaskId},
DapError, DapRequest, DapSender,
DapError, DapRequest, DapSender, DapTaskConfig,
};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -62,30 +62,27 @@ pub trait BearerTokenProvider {
async fn get_leader_bearer_token_for<'s>(
&'s self,
task_id: &'s TaskId,
task_config: &DapTaskConfig,
) -> Result<Option<Self::WrappedBearerToken<'s>>, DapError>;

/// Fetch the Collector's bearer token for the given task, if the task is recognized.
async fn get_collector_bearer_token_for<'s>(
&'s self,
task_id: &'s TaskId,
task_config: &DapTaskConfig,
) -> Result<Option<Self::WrappedBearerToken<'s>>, DapError>;

/// Returns true if the given bearer token matches the leader token configured for the "taskprov" extension.
fn is_taskprov_leader_bearer_token(&self, token: &BearerToken) -> bool;

/// Returns true if the given bearer token matches the collector token configured for the "taskprov" extension.
fn is_taskprov_collector_bearer_token(&self, token: &BearerToken) -> bool;

/// Return a bearer token that can be used to authorize a request with the given task ID and
/// media type.
async fn authorize_with_bearer_token<'s>(
&'s self,
task_id: &'s TaskId,
task_config: &DapTaskConfig,
media_type: &DapMediaType,
) -> Result<Self::WrappedBearerToken<'s>, DapError> {
if matches!(media_type.sender(), Some(DapSender::Leader)) {
let token = self
.get_leader_bearer_token_for(task_id)
.get_leader_bearer_token_for(task_id, task_config)
.await?
.ok_or_else(|| {
fatal_error!(err = "attempted to authorize request with unknown task ID")
Expand All @@ -105,6 +102,7 @@ pub trait BearerTokenProvider {
/// is the reason for the failure.
async fn bearer_token_authorized<T: AsRef<BearerToken>>(
&self,
task_config: &DapTaskConfig,
req: &DapRequest<T>,
) -> Result<Option<String>, DapError> {
if req.task_id.is_none() {
Expand All @@ -121,38 +119,31 @@ pub trait BearerTokenProvider {
// token is not formatted properly.
if matches!(req.media_type.sender(), Some(DapSender::Leader)) {
if let Some(ref got) = req.sender_auth {
if let Some(expected) = self.get_leader_bearer_token_for(task_id).await? {
if let Some(expected) = self
.get_leader_bearer_token_for(task_id, task_config)
.await?
{
return Ok(if got.as_ref() == expected.as_ref() {
None
} else {
Some("The indicated beareer token is incorrect for the Leader.".into())
Some("The indicated bearer token is incorrect for the Leader.".into())
});
}
return Ok(if self.is_taskprov_leader_bearer_token(got.as_ref()) {
None
} else {
Some("The indicated beaer token is incorrect for Taskprov Leader.".into())
});
}
}

if matches!(req.media_type.sender(), Some(DapSender::Collector)) {
if let Some(ref got) = req.sender_auth {
if let Some(expected) = self.get_collector_bearer_token_for(task_id).await? {
if let Some(expected) = self
.get_collector_bearer_token_for(task_id, task_config)
.await?
{
return Ok(if got.as_ref() == expected.as_ref() {
None
} else {
Some("The indicated bearer token is incorrect for the Collector.".into())
});
}
return Ok(if self.is_taskprov_collector_bearer_token(got.as_ref()) {
None
} else {
Some(
"The indicated bearer token is incorrect for the Taskprov Collector."
.into(),
)
});
}
}

Expand Down
6 changes: 5 additions & 1 deletion daphne/src/roles/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ pub trait DapAggregator<S>: HpkeDecrypter + DapReportInitializer + Sized {
/// If the return value is `None`, then the request is authorized. If the return value is
/// `Some(reason)`, then the request is denied and `reason` conveys details about how the
/// decision was reached.
async fn unauthorized_reason(&self, req: &DapRequest<S>) -> Result<Option<String>, DapError>;
async fn unauthorized_reason(
&self,
task_config: &DapTaskConfig,
req: &DapRequest<S>,
) -> Result<Option<String>, DapError>;

/// Look up the DAP global configuration.
fn get_global_config(&self) -> &DapGlobalConfig;
Expand Down
38 changes: 23 additions & 15 deletions daphne/src/roles/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,6 @@ pub trait DapHelper<S>: DapAggregator<S> {
return Err(DapAbort::version_unknown());
}

if let Some(reason) = self.unauthorized_reason(req).await? {
error!("aborted unauthorized collect request: {reason}");
return Err(DapAbort::UnauthorizedRequest {
detail: reason,
task_id: task_id.clone(),
});
}

match req.media_type {
DapMediaType::AggregationJobInitReq => {
let agg_job_init_req =
Expand Down Expand Up @@ -110,6 +102,14 @@ pub trait DapHelper<S>: DapAggregator<S> {
.ok_or(DapAbort::UnrecognizedTask)?;
let task_config = wrapped_task_config.as_ref();

if let Some(reason) = self.unauthorized_reason(task_config, req).await? {
error!("aborted unauthorized collect request: {reason}");
return Err(DapAbort::UnauthorizedRequest {
detail: reason,
task_id: task_id.clone(),
});
}

// draft02 compatibility: In draft02, the aggregation job ID is parsed from the
// HTTP request payload; in the latest draft, the aggregation job ID is parsed from
// the request path.
Expand Down Expand Up @@ -198,6 +198,14 @@ pub trait DapHelper<S>: DapAggregator<S> {
.ok_or(DapAbort::UnrecognizedTask)?;
let task_config = wrapped_task_config.as_ref();

if let Some(reason) = self.unauthorized_reason(task_config, req).await? {
error!("aborted unauthorized collect request: {reason}");
return Err(DapAbort::UnauthorizedRequest {
detail: reason,
task_id: task_id.clone(),
});
}

// Check whether the DAP version in the request matches the task config.
if task_config.version != req.version {
return Err(DapAbort::version_mismatch(req.version, task_config.version));
Expand Down Expand Up @@ -306,20 +314,20 @@ pub trait DapHelper<S>: DapAggregator<S> {

resolve_taskprov(self, task_id, req, None).await?;

if let Some(reason) = self.unauthorized_reason(req).await? {
let wrapped_task_config = self
.get_task_config_for(Cow::Borrowed(req.task_id()?))
.await?
.ok_or(DapAbort::UnrecognizedTask)?;
let task_config = wrapped_task_config.as_ref();

if let Some(reason) = self.unauthorized_reason(task_config, req).await? {
error!("aborted unauthorized collect request: {reason}");
return Err(DapAbort::UnauthorizedRequest {
detail: reason,
task_id: task_id.clone(),
});
}

let wrapped_task_config = self
.get_task_config_for(Cow::Borrowed(req.task_id()?))
.await?
.ok_or(DapAbort::UnrecognizedTask)?;
let task_config = wrapped_task_config.as_ref();

// Check whether the DAP version in the request matches the task config.
if task_config.version != req.version {
return Err(DapAbort::version_mismatch(req.version, task_config.version));
Expand Down
19 changes: 12 additions & 7 deletions daphne/src/roles/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ async fn leader_send_http_request<S>(
task_id: Some(task_id.clone()),
resource,
url,
sender_auth: Some(role.authorize(task_id, &req_media_type, &req_data).await?),
sender_auth: Some(
role.authorize(task_id, task_config, &req_media_type, &req_data)
.await?,
),
payload: req_data,
taskprov: None,
};
Expand All @@ -83,6 +86,7 @@ pub trait DapAuthorizedSender<S> {
async fn authorize(
&self,
task_id: &TaskId,
task_config: &DapTaskConfig,
media_type: &DapMediaType,
payload: &[u8],
) -> Result<S, DapError>;
Expand Down Expand Up @@ -226,7 +230,13 @@ pub trait DapLeader<S>: DapAuthorizedSender<S> + DapAggregator<S> {

resolve_taskprov(self, task_id, req, None).await?;

if let Some(reason) = self.unauthorized_reason(req).await? {
let wrapped_task_config = self
.get_task_config_for(Cow::Borrowed(req.task_id()?))
.await?
.ok_or(DapAbort::UnrecognizedTask)?;
let task_config = wrapped_task_config.as_ref();

if let Some(reason) = self.unauthorized_reason(task_config, req).await? {
error!("aborted unauthorized collect request: {reason}");
return Err(DapAbort::UnauthorizedRequest {
detail: reason,
Expand All @@ -237,11 +247,6 @@ pub trait DapLeader<S>: DapAuthorizedSender<S> + DapAggregator<S> {
let mut collect_req =
CollectionReq::get_decoded_with_param(&req.version, req.payload.as_ref())
.map_err(|e| DapAbort::from_codec_error(e, task_id.clone()))?;
let wrapped_task_config = self
.get_task_config_for(Cow::Borrowed(req.task_id()?))
.await?
.ok_or(DapAbort::UnrecognizedTask)?;
let task_config = wrapped_task_config.as_ref();

// Check whether the DAP version in the request matches the task config.
if task_config.version != req.version {
Expand Down
Loading

0 comments on commit 2cbcadd

Please sign in to comment.