Skip to content

Commit

Permalink
Merge pull request #1414 from cberkhoff/si-prep2
Browse files Browse the repository at this point in the history
Sharded Prepare Query API
  • Loading branch information
cberkhoff authored Nov 12, 2024
2 parents 0a34110 + d21afe7 commit 7c4b548
Show file tree
Hide file tree
Showing 6 changed files with 497 additions and 159 deletions.
54 changes: 39 additions & 15 deletions ipa-core/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
query::{PrepareQuery, QueryConfig, QueryInput},
routing::{Addr, RouteId},
ApiError, BodyStream, HandlerBox, HandlerRef, HelperIdentity, HelperResponse,
MpcTransportImpl, RequestHandler, ShardTransportImpl, Transport,
MpcTransportImpl, RequestHandler, ShardTransportImpl, Transport, TransportIdentity,
},
hpke::{KeyRegistry, PrivateKeyOnly},
protocol::QueryId,
Expand Down Expand Up @@ -124,7 +124,8 @@ impl HelperApp {
.inner
.query_processor
.new_query(
Transport::clone_ref(&self.inner.mpc_transport),
self.inner.mpc_transport.clone_ref(),
self.inner.shard_transport.clone_ref(),
query_config,
)
.await?
Expand All @@ -136,8 +137,8 @@ impl HelperApp {
/// ## Errors
/// Propagates errors from the helper.
pub fn execute_query(&self, input: QueryInput) -> Result<(), ApiError> {
let mpc_transport = Transport::clone_ref(&self.inner.mpc_transport);
let shard_transport = Transport::clone_ref(&self.inner.shard_transport);
let mpc_transport = self.inner.mpc_transport.clone_ref();
let shard_transport = self.inner.shard_transport.clone_ref();
self.inner
.query_processor
.receive_inputs(mpc_transport, shard_transport, input)?;
Expand Down Expand Up @@ -166,14 +167,32 @@ impl HelperApp {
}
}

fn ext_query_id<I: TransportIdentity>(req: &Addr<I>) -> Result<QueryId, ApiError> {
req.query_id
.ok_or_else(|| ApiError::BadRequest("Query input is missing query_id argument".into()))
}

#[async_trait]
impl RequestHandler<ShardIndex> for Inner {
async fn handle(
&self,
_req: Addr<ShardIndex>,
req: Addr<ShardIndex>,
_data: BodyStream,
) -> Result<HelperResponse, ApiError> {
Ok(HelperResponse::ok())
let qp = &self.query_processor;

Ok(match req.route {
RouteId::PrepareQuery => {
let req = req.into::<PrepareQuery>()?;
HelperResponse::from(qp.prepare_shard(&self.shard_transport, req)?)
}
r => {
return Err(ApiError::BadRequest(
format!("{r:?} request must not be handled by shard query processing flow")
.into(),
))
}
})
}
}

Expand All @@ -184,12 +203,6 @@ impl RequestHandler<HelperIdentity> for Inner {
req: Addr<HelperIdentity>,
data: BodyStream,
) -> Result<HelperResponse, ApiError> {
fn ext_query_id(req: &Addr<HelperIdentity>) -> Result<QueryId, ApiError> {
req.query_id.ok_or_else(|| {
ApiError::BadRequest("Query input is missing query_id argument".into())
})
}

let qp = &self.query_processor;

Ok(match req.route {
Expand All @@ -202,13 +215,24 @@ impl RequestHandler<HelperIdentity> for Inner {
RouteId::ReceiveQuery => {
let req = req.into::<QueryConfig>()?;
HelperResponse::from(
qp.new_query(Transport::clone_ref(&self.mpc_transport), req)
.await?,
qp.new_query(
self.mpc_transport.clone_ref(),
self.shard_transport.clone_ref(),
req,
)
.await?,
)
}
RouteId::PrepareQuery => {
let req = req.into::<PrepareQuery>()?;
HelperResponse::from(qp.prepare(&self.mpc_transport, req)?)
HelperResponse::from(
qp.prepare_helper(
self.mpc_transport.clone_ref(),
self.shard_transport.clone_ref(),
req,
)
.await?,
)
}
RouteId::QueryInput => {
let query_id = ext_query_id(&req)?;
Expand Down
10 changes: 5 additions & 5 deletions ipa-core/src/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ pub use transport::{
config as in_memory_config, InMemoryMpcNetwork, InMemoryShardNetwork, InMemoryTransport,
};
pub use transport::{
make_owned_handler, query, routing, ApiError, BodyStream, BytesStream, HandlerBox, HandlerRef,
HelperResponse, Identity as TransportIdentity, LengthDelimitedStream, LogErrors, NoQueryId,
NoResourceIdentifier, NoStep, QueryIdBinding, ReceiveRecords, RecordsStream, RequestHandler,
RouteParams, SingleRecordStream, StepBinding, StreamCollection, StreamKey, Transport,
WrappedBoxBodyStream,
make_owned_handler, query, routing, ApiError, BodyStream, BroadcastError, BytesStream,
HandlerBox, HandlerRef, HelperResponse, Identity as TransportIdentity, LengthDelimitedStream,
LogErrors, NoQueryId, NoResourceIdentifier, NoStep, QueryIdBinding, ReceiveRecords,
RecordsStream, RequestHandler, RouteParams, SingleRecordStream, StepBinding, StreamCollection,
StreamKey, Transport, WrappedBoxBodyStream,
};
use typenum::{Const, ToUInt, Unsigned, U8};
use x25519_dalek::PublicKey;
Expand Down
47 changes: 42 additions & 5 deletions ipa-core/src/helpers/transport/in_memory/sharding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{
helpers::{
in_memory_config::{passthrough, DynStreamInterceptor},
transport::in_memory::transport::{InMemoryTransport, Setup, TransportConfigBuilder},
HelperIdentity,
HandlerBox, HelperIdentity, RequestHandler,
},
sharding::ShardIndex,
sync::{Arc, Weak},
Expand Down Expand Up @@ -30,8 +30,24 @@ impl InMemoryShardNetwork {
shard_count: I,
interceptor: &DynStreamInterceptor,
) -> Self {
let shard_network = Self::create_shard_connections(shard_count, interceptor).map(
|(shard_connections, h)| {
shard_connections
.into_iter()
.map(|s| tracing::info_span!("", ?h).in_scope(|| s.start(None)))
.collect::<Vec<_>>()
.into()
},
);
Self { shard_network }
}

pub fn create_shard_connections<I: Into<ShardIndex>>(
shard_count: I,
interceptor: &DynStreamInterceptor,
) -> [(Vec<Setup<ShardIndex>>, HelperIdentity); 3] {
let shard_count = shard_count.into();
let shard_network: [_; 3] = HelperIdentity::make_three().map(|h| {
HelperIdentity::make_three().map(|h| {
let mut config_builder = TransportConfigBuilder::for_helper(h);
config_builder.with_interceptor(interceptor);

Expand All @@ -48,14 +64,35 @@ impl InMemoryShardNetwork {
}
}

(shard_connections, h)
})
}

pub fn with_shards_and_handlers<I, F>(
shard_count: I,
handler_fn: F,
) -> (Self, Vec<Arc<dyn RequestHandler<ShardIndex>>>)
where
I: Into<ShardIndex>,
F: Fn(ShardIndex) -> Arc<dyn RequestHandler<ShardIndex>>,
{
let connections = Self::create_shard_connections(shard_count, &passthrough());
let shard_count = connections[0].0.len();
let mut handlers = Vec::with_capacity(3 * shard_count);
let shard_network = connections.map(|(shard_connections, h)| {
shard_connections
.into_iter()
.map(|s| tracing::info_span!("", ?h).in_scope(|| s.start(None)))
.map(|s| {
tracing::info_span!("", ?h).in_scope(|| {
let handler = handler_fn(s.identity);
handlers.push(Arc::clone(&handler));
s.start(Some(HandlerBox::owning_ref(&handler)))
})
})
.collect::<Vec<_>>()
.into()
});

Self { shard_network }
(Self { shard_network }, handlers)
}

pub fn transport<I: Into<ShardIndex>>(
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/helpers/transport/in_memory/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ impl Debug for InMemoryStream {
}

pub struct Setup<I> {
identity: I,
pub identity: I,
tx: ConnectionTx<I>,
rx: ConnectionRx<I>,
connections: HashMap<I, ConnectionTx<I>>,
Expand Down
5 changes: 2 additions & 3 deletions ipa-core/src/helpers/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ impl RouteParams<RouteId, QueryId, NoStep> for (RouteId, QueryId) {
#[derive(thiserror::Error, Debug)]
#[error("One or more peers rejected the request: {failures:?}")]
pub struct BroadcastError<I: TransportIdentity, E: Debug> {
failures: Vec<(I, E)>,
pub failures: Vec<(I, E)>,
}

impl<I: TransportIdentity, E: Debug> From<Vec<(I, E)>> for BroadcastError<I, E> {
Expand Down Expand Up @@ -339,8 +339,7 @@ pub trait Transport: Clone + Send + Sync + 'static {

/// Broadcasts a message to all peers, excluding this instance, collecting all failures and
/// successes. This method waits for all responses and returns only when all peers responded.
/// The routes and data will be cloned.
async fn broadcast<Q, S, R, D>(
async fn broadcast<Q, S, R>(
&self,
route: R,
) -> Result<(), BroadcastError<Self::Identity, Self::Error>>
Expand Down
Loading

0 comments on commit 7c4b548

Please sign in to comment.