Skip to content

Commit

Permalink
Query status API
Browse files Browse the repository at this point in the history
  • Loading branch information
andyleiserson committed Jun 28, 2023
1 parent ef97d17 commit efbf594
Show file tree
Hide file tree
Showing 15 changed files with 453 additions and 125 deletions.
36 changes: 29 additions & 7 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use crate::{
},
hpke::{KeyPair, KeyRegistry},
protocol::QueryId,
query::{NewQueryError, QueryCompletionError, QueryInputError, QueryProcessor},
query::{
NewQueryError, QueryCompletionError, QueryInputError, QueryProcessor, QueryStatus,
QueryStatusError,
},
sync::Arc,
};

Expand Down Expand Up @@ -49,6 +52,7 @@ impl Setup {
let rqp = Arc::clone(query_processor);
let pqp = Arc::clone(query_processor);
let iqp = Arc::clone(query_processor);
let sqp = Arc::clone(query_processor);
let cqp = Arc::clone(query_processor);

TransportCallbacks {
Expand All @@ -68,6 +72,10 @@ impl Setup {
let processor = Arc::clone(&iqp);
Box::pin(async move { processor.receive_inputs(transport, query_input) })
}),
query_status: Box::new(move |_transport: TransportImpl, query_id| {
let processor = Arc::clone(&sqp);
Box::pin(async move { processor.query_status(query_id) })
}),
complete_query: Box::new(move |_transport: TransportImpl, query_id| {
let processor = Arc::clone(&cqp);
Box::pin(async move { processor.complete(query_id).await })
Expand Down Expand Up @@ -97,17 +105,29 @@ impl HelperApp {
.query_id)
}

/// Drives the given query to completion by providing the inputs to it and awaiting the results
/// of the computation.
/// Sends query input to a helper.
///
/// ## Errors
/// If a query with the given id is not running on this helper or if an error occurred while
/// processing this query.
pub async fn execute_query(&self, input: QueryInput) -> Result<Vec<u8>, Error> {
let query_id = input.query_id;
/// Propagates errors from the helper.
pub fn execute_query(&self, input: QueryInput) -> Result<(), Error> {
let transport = <TransportImpl as Clone>::clone(&self.transport);
self.query_processor.receive_inputs(transport, input)?;
Ok(())
}

/// Retrieves the status of a query.
///
/// ## Errors
/// Propagates errors from the helper.
pub fn query_status(&self, query_id: QueryId) -> Result<QueryStatus, Error> {
Ok(self.query_processor.query_status(query_id)?)
}

/// Waits for a query to complete and returns the result.
///
/// ## Errors
/// Propagates errors from the helper.
pub async fn complete_query(&self, query_id: QueryId) -> Result<Vec<u8>, Error> {
Ok(self.query_processor.complete(query_id).await?.into_bytes())
}
}
Expand All @@ -121,4 +141,6 @@ pub enum Error {
QueryInput(#[from] QueryInputError),
#[error(transparent)]
QueryCompletion(#[from] QueryCompletionError),
#[error(transparent)]
QueryStatus(#[from] QueryStatusError),
}
25 changes: 24 additions & 1 deletion src/cli/playbook/ipa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::{
attribution::input::MCAggregateCreditOutputRow, ipa::IPAInputRow, BreakdownKey, MatchKey,
QueryId,
},
query::QueryStatus,
report::{KeyIdentifier, Report},
secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares},
test_fixture::{input::GenericReportTestInput, ipa::TestRawDataRecord, Reconstruct},
Expand All @@ -17,7 +18,12 @@ use futures_util::future::try_join_all;
use generic_array::GenericArray;
use rand::{distributions::Standard, prelude::Distribution, rngs::StdRng};
use rand_core::SeedableRng;
use std::{iter::zip, time::Instant};
use std::{
cmp::min,
iter::zip,
time::{Duration, Instant},
};
use tokio::time::sleep;
use typenum::Unsigned;

/// Semi-honest IPA protocol.
Expand Down Expand Up @@ -103,6 +109,23 @@ where
.await
.unwrap();

let mut delay = Duration::from_millis(125);
loop {
if try_join_all(clients.iter().map(|client| client.query_status(query_id)))
.await
.unwrap()
.into_iter()
.all(|status| status == QueryStatus::Completed)
{
break;
}

sleep(delay).await;
delay = min(Duration::from_secs(60), delay * 2);
// TODO: Add a timeout of some sort. Possibly, add some sort of progress indicator to
// the status API so we can check whether the query is making progress.
}

// wait until helpers have processed the query and get the results from them
let results: [_; 3] = try_join_all(clients.iter().map(|client| client.query_results(query_id)))
.await
Expand Down
9 changes: 9 additions & 0 deletions src/helpers/transport/callbacks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::{
protocol::QueryId,
query::{
NewQueryError, PrepareQueryError, ProtocolResult, QueryCompletionError, QueryInputError,
QueryStatus, QueryStatusError,
},
};
use std::{future::Future, pin::Pin};
Expand Down Expand Up @@ -54,6 +55,10 @@ callbacks! {
(QueryInputCallback, QueryInputResult):
async fn(T, QueryInput) -> Result<(), QueryInputError>;

/// Called by clients to retrieve query status.
(QueryStatusCallback, QueryStatusResult):
async fn(T, QueryId) -> Result<QueryStatus, QueryStatusError>;

/// Called by clients to drive query to completion and retrieve results.
(CompleteQueryCallback, CompleteQueryResult):
async fn(T, QueryId) -> Result<Box<dyn ProtocolResult>, QueryCompletionError>;
Expand All @@ -63,6 +68,7 @@ pub struct TransportCallbacks<T> {
pub receive_query: Box<dyn ReceiveQueryCallback<T>>,
pub prepare_query: Box<dyn PrepareQueryCallback<T>>,
pub query_input: Box<dyn QueryInputCallback<T>>,
pub query_status: Box<dyn QueryStatusCallback<T>>,
pub complete_query: Box<dyn CompleteQueryCallback<T>>,
}

Expand All @@ -83,6 +89,9 @@ impl<T> Default for TransportCallbacks<T> {
query_input: Box::new(move |_, _| {
Box::pin(async { panic!("unexpected call to query_input") })
}),
query_status: Box::new(move |_, _| {
Box::pin(async { panic!("unexpected call to query_status") })
}),
complete_query: Box::new(move |_, _| {
Box::pin(async { panic!("unexpected call to complete_query") })
}),
Expand Down
51 changes: 0 additions & 51 deletions src/helpers/transport/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@ use crate::{
GatewayConfig, RoleAssignment, RouteId, RouteParams,
},
protocol::{step::Step, QueryId},
query::ProtocolResult,
};
use serde::{Deserialize, Deserializer, Serialize};
use std::{
fmt::{Debug, Display, Formatter},
num::NonZeroU32,
};
use tokio::sync::oneshot;

#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq)]
#[cfg_attr(feature = "enable-serde", derive(Serialize))]
Expand Down Expand Up @@ -197,55 +195,6 @@ impl Debug for QueryInput {
}
}

pub enum QueryCommand {
Create(QueryConfig, oneshot::Sender<QueryId>),
Prepare(PrepareQuery, oneshot::Sender<()>),
Input(QueryInput, oneshot::Sender<()>),
Results(QueryId, oneshot::Sender<Box<dyn ProtocolResult>>),
}

impl Debug for QueryCommand {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "QueryCommand: {:?}", self.query_id())?;
match self {
QueryCommand::Create(config, _) => {
write!(f, "[{config:?}]")
}
QueryCommand::Prepare(prepare, _) => {
write!(f, "[{prepare:?}]")
}
QueryCommand::Input(input, _) => {
write!(f, "[{input:?}]")
}
QueryCommand::Results(query_id, _) => {
write!(f, "{query_id:?} [Results]")
}
}
}
}

impl QueryCommand {
#[must_use]
pub fn name(&self) -> &'static str {
match self {
Self::Create(_, _) => "Query Create",
Self::Prepare(_, _) => "Query Prepare",
Self::Input(_, _) => "Query Input",
Self::Results(_, _) => "Query Results",
}
}

#[must_use]
pub fn query_id(&self) -> Option<QueryId> {
match self {
Self::Create(_, _) => None,
Self::Prepare(data, _) => Some(data.query_id),
Self::Input(data, _) => Some(data.query_id),
Self::Results(query_id, _) => Some(*query_id),
}
}
}

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "enable-serde", derive(Serialize, Deserialize))]
pub enum QueryType {
Expand Down
29 changes: 26 additions & 3 deletions src/net/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,21 +311,42 @@ impl MpcHelperClient {
Ok(self.request(req))
}

/// Retrieve the status of a query.
///
/// ## Errors
/// If the request has illegal arguments, or fails to deliver to helper
#[cfg(any(all(test, not(feature = "shuttle")), feature = "cli"))]
pub async fn query_status(
&self,
query_id: QueryId,
) -> Result<crate::query::QueryStatus, Error> {
let req = http_serde::query::status::Request::new(query_id);
let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?;

let resp = self.request(req).await?;
if resp.status().is_success() {
let body_bytes = body::to_bytes(resp.into_body()).await?;
let http_serde::query::status::ResponseBody { status } =
serde_json::from_slice(&body_bytes)?;
Ok(status)
} else {
Err(Error::from_failed_resp(resp).await)
}
}

/// Wait for completion of the query and pull the results of this query. This is a blocking
/// API so it is not supposed to be used outside of CLI context.
///
/// ## Errors
/// If the request has illegal arguments, or fails to deliver to helper
/// # Panics
/// if there is a problem reading the response body
#[cfg(any(all(test, not(feature = "shuttle")), feature = "cli"))]
pub async fn query_results(&self, query_id: QueryId) -> Result<body::Bytes, Error> {
let req = http_serde::query::results::Request::new(query_id);
let req = req.try_into_http_request(self.scheme.clone(), self.authority.clone())?;

let resp = self.request(req).await?;
if resp.status().is_success() {
Ok(body::to_bytes(resp.into_body()).await.unwrap())
Ok(body::to_bytes(resp.into_body()).await?)
} else {
Err(Error::from_failed_resp(resp).await)
}
Expand Down Expand Up @@ -377,11 +398,13 @@ pub(crate) mod tests {
let ri = Arc::clone(inner);
let pi = Arc::clone(inner);
let qi = Arc::clone(inner);
let si = Arc::clone(inner);
let ci = Arc::clone(inner);
TransportCallbacks {
receive_query: Box::new(move |t, req| (ri.receive_query)(t, req)),
prepare_query: Box::new(move |t, req| (pi.prepare_query)(t, req)),
query_input: Box::new(move |t, req| (qi.query_input)(t, req)),
query_status: Box::new(move |t, req| (si.query_status)(t, req)),
complete_query: Box::new(move |t, req| (ci.complete_query)(t, req)),
}
}
Expand Down
54 changes: 54 additions & 0 deletions src/net/http_serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,60 @@ pub mod query {
pub const AXUM_PATH: &str = "/:query_id/step/*step";
}

pub mod status {
use crate::{net::Error, protocol::QueryId, query::QueryStatus};
use async_trait::async_trait;
use axum::extract::{FromRequest, Path, RequestParts};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone)]
pub struct Request {
pub query_id: QueryId,
}

impl Request {
#[cfg(any(all(test, not(feature = "shuttle")), feature = "cli"))] // needed because client is blocking; remove when non-blocking
pub fn new(query_id: QueryId) -> Self {
Self { query_id }
}

#[cfg(any(all(test, not(feature = "shuttle")), feature = "cli"))] // needed because client is blocking; remove when non-blocking
pub fn try_into_http_request(
self,
scheme: axum::http::uri::Scheme,
authority: axum::http::uri::Authority,
) -> Result<hyper::Request<hyper::Body>, Error> {
let uri = axum::http::uri::Uri::builder()
.scheme(scheme)
.authority(authority)
.path_and_query(format!(
"{}/{}",
crate::net::http_serde::query::BASE_AXUM_PATH,
self.query_id.as_ref()
))
.build()?;
Ok(hyper::Request::get(uri).body(hyper::Body::empty())?)
}
}

#[async_trait]
impl<B: Send> FromRequest<B> for Request {
type Rejection = Error;

async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let Path(query_id) = req.extract().await?;
Ok(Request { query_id })
}
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ResponseBody {
pub status: QueryStatus,
}

pub const AXUM_PATH: &str = "/:query_id";
}

pub mod results {
use crate::{net::Error, protocol::QueryId};
use async_trait::async_trait;
Expand Down
2 changes: 2 additions & 0 deletions src/net/server/handlers/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod create;
mod input;
mod prepare;
mod results;
mod status;
mod step;

use crate::{
Expand Down Expand Up @@ -29,6 +30,7 @@ pub fn query_router(transport: Arc<HttpTransport>) -> Router {
Router::new()
.merge(create::router(Arc::clone(&transport)))
.merge(input::router(Arc::clone(&transport)))
.merge(status::router(Arc::clone(&transport)))
.merge(results::router(transport))
}

Expand Down
Loading

0 comments on commit efbf594

Please sign in to comment.