Skip to content

Commit

Permalink
Merge pull request #1410 from akoshelev/shard-prss
Browse files Browse the repository at this point in the history
Support cross-shard PRSS in sharded MPC circuits (in memory only for now)
  • Loading branch information
akoshelev authored Nov 8, 2024
2 parents ef3e3b3 + e6e08af commit eecd25d
Show file tree
Hide file tree
Showing 17 changed files with 256 additions and 115 deletions.
5 changes: 3 additions & 2 deletions ipa-core/src/bin/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use ipa_core::{
executor::IpaRuntime,
helpers::HelperIdentity,
net::{ClientIdentity, IpaHttpClient, MpcHttpTransport, ShardHttpTransport},
sharding::Sharded,
sharding::ShardIndex,
AppConfig, AppSetup, NonZeroU32PowerOfTwo,
};
use tokio::runtime::Runtime;
Expand Down Expand Up @@ -185,7 +185,8 @@ async fn server(args: ServerArgs, logging_handle: LoggingHandle) -> Result<(), B
let shard_network_config = NetworkConfig::new_shards(vec![], shard_clients_config);
let (shard_transport, _shard_server) = ShardHttpTransport::new(
IpaRuntime::from_tokio_runtime(&http_runtime),
Sharded::new(0, 1),
ShardIndex::FIRST,
ShardIndex::from(1),
shard_server_config,
shard_network_config,
vec![],
Expand Down
6 changes: 3 additions & 3 deletions ipa-core/src/helpers/transport/in_memory/config.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
helpers::{HelperIdentity, Role, RoleAssignment},
protocol::Gate,
sharding::{ShardIndex, Sharded},
sharding::ShardIndex,
sync::Arc,
};

Expand Down Expand Up @@ -90,7 +90,7 @@ pub enum InspectContext {
MpcMessage {
/// The shard of this instance.
/// This is `None` for non-sharded helpers.
shard: Option<Sharded>,
shard: Option<ShardIndex>,
/// Helper sending this stream.
source: HelperIdentity,
/// Helper that will receive this stream.
Expand Down Expand Up @@ -161,7 +161,7 @@ impl<F: Fn(&MaliciousHelperContext, &mut Vec<u8>) + Send + Sync> MaliciousHelper
pub struct MaliciousHelperContext {
/// The shard of this instance.
/// This is `None` for non-sharded helpers.
pub shard: Option<Sharded>,
pub shard: Option<ShardIndex>,
/// Helper that will receive this stream.
pub dest: Role,
/// Circuit gate this stream is tied to.
Expand Down
6 changes: 3 additions & 3 deletions ipa-core/src/helpers/transport/in_memory/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{
in_memory_config::DynStreamInterceptor, transport::in_memory::config::passthrough,
HandlerRef, HelperIdentity,
},
sharding::Sharded,
sharding::ShardIndex,
sync::{Arc, Weak},
};

Expand Down Expand Up @@ -50,13 +50,13 @@ impl InMemoryMpcNetwork {
pub fn with_stream_interceptor(
handlers: [Option<HandlerRef>; 3],
interceptor: &DynStreamInterceptor,
shard_context: Option<Sharded>,
shard: Option<ShardIndex>,
) -> Self {
let [mut first, mut second, mut third]: [_; 3] = HelperIdentity::make_three().map(|i| {
let mut config_builder = TransportConfigBuilder::for_helper(i);
config_builder.with_interceptor(interceptor);

Setup::with_config(i, config_builder.with_sharding(shard_context))
Setup::with_config(i, config_builder.with_sharding(shard))
});

first.connect(&mut second);
Expand Down
12 changes: 2 additions & 10 deletions ipa-core/src/helpers/transport/in_memory/sharding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
transport::in_memory::transport::{InMemoryTransport, Setup, TransportConfigBuilder},
HelperIdentity,
},
sharding::{ShardIndex, Sharded},
sharding::ShardIndex,
sync::{Arc, Weak},
};

Expand Down Expand Up @@ -37,15 +37,7 @@ impl InMemoryShardNetwork {

let mut shard_connections = shard_count
.iter()
.map(|i| {
Setup::with_config(
i,
config_builder.with_sharding(Some(Sharded {
shard_id: i,
shard_count,
})),
)
})
.map(|i| Setup::with_config(i, config_builder.with_sharding(Some(i))))
.collect::<Vec<_>>();
for i in 0..shard_connections.len() {
let (lhs, rhs) = shard_connections.split_at_mut(i);
Expand Down
14 changes: 7 additions & 7 deletions ipa-core/src/helpers/transport/in_memory/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::{
Transport, TransportIdentity,
},
protocol::{Gate, QueryId},
sharding::Sharded,
sharding::ShardIndex,
sync::{Arc, Weak},
};

Expand Down Expand Up @@ -192,8 +192,8 @@ impl<I: TransportIdentity> Transport for Weak<InMemoryTransport<I>> {
let gate = addr.gate.clone();

let (ack_tx, ack_rx) = oneshot::channel();
let context = gate
.map(|gate| dest.inspect_context(this.config.shard_config, this.config.identity, gate));
let context =
gate.map(|gate| dest.inspect_context(this.config.shard, this.config.identity, gate));

channel
.send((
Expand Down Expand Up @@ -628,7 +628,7 @@ mod tests {
}

pub struct TransportConfig {
pub shard_config: Option<Sharded>,
pub shard: Option<ShardIndex>,
pub identity: HelperIdentity,
pub stream_interceptor: DynStreamInterceptor,
}
Expand All @@ -652,17 +652,17 @@ impl TransportConfigBuilder {
self
}

pub fn with_sharding(&self, shard_config: Option<Sharded>) -> TransportConfig {
pub fn with_sharding(&self, shard: Option<ShardIndex>) -> TransportConfig {
TransportConfig {
shard_config,
shard,
identity: self.identity,
stream_interceptor: Arc::clone(&self.stream_interceptor),
}
}

pub fn not_sharded(&self) -> TransportConfig {
TransportConfig {
shard_config: None,
shard: None,
identity: self.identity,
stream_interceptor: Arc::clone(&self.stream_interceptor),
}
Expand Down
12 changes: 5 additions & 7 deletions ipa-core/src/helpers/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ use futures::{stream::FuturesUnordered, FutureExt, Stream, StreamExt};

#[cfg(feature = "in-memory-infra")]
use crate::helpers::in_memory_config::InspectContext;
#[cfg(feature = "in-memory-infra")]
use crate::sharding::Sharded;
use crate::{
helpers::{transport::routing::RouteId, HelperIdentity, Role, TransportIdentity},
protocol::{Gate, QueryId},
Expand Down Expand Up @@ -58,7 +56,7 @@ pub trait Identity:
#[cfg(feature = "in-memory-infra")]
fn inspect_context(
&self,
shard: Option<Sharded>,
shard: Option<ShardIndex>,
helper: HelperIdentity,
gate: Gate,
) -> InspectContext;
Expand All @@ -84,13 +82,13 @@ impl Identity for ShardIndex {
#[cfg(feature = "in-memory-infra")]
fn inspect_context(
&self,
shard: Option<Sharded>,
shard: Option<ShardIndex>,
helper: HelperIdentity,
gate: Gate,
) -> InspectContext {
InspectContext::ShardMessage {
helper,
source: shard.unwrap().shard_id,
source: shard.unwrap(),
dest: *self,
gate,
}
Expand Down Expand Up @@ -125,7 +123,7 @@ impl Identity for HelperIdentity {
#[cfg(feature = "in-memory-infra")]
fn inspect_context(
&self,
shard: Option<Sharded>,
shard: Option<ShardIndex>,
helper: HelperIdentity,
gate: Gate,
) -> InspectContext {
Expand Down Expand Up @@ -167,7 +165,7 @@ impl Identity for Role {
#[cfg(feature = "in-memory-infra")]
fn inspect_context(
&self,
_shard: Option<Sharded>,
_shard: Option<ShardIndex>,
_helper: HelperIdentity,
_gate: Gate,
) -> InspectContext {
Expand Down
6 changes: 2 additions & 4 deletions ipa-core/src/net/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,8 @@ impl TestApp {
);
let (shard_transport, shard_server) = super::ShardHttpTransport::new(
IpaRuntime::current(),
crate::sharding::Sharded {
shard_id: sid.shard_index,
shard_count: self.shard_network_config.shard_count(),
},
sid.shard_index,
self.shard_network_config.shard_count(),
self.shard_server.config,
self.shard_network_config,
shard_clients,
Expand Down
17 changes: 8 additions & 9 deletions ipa-core/src/net/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::{
},
net::{client::IpaHttpClient, error::Error, IpaHttpServer},
protocol::{Gate, QueryId},
sharding::{ShardIndex, Sharded},
sharding::ShardIndex,
sync::Arc,
};

Expand All @@ -45,7 +45,7 @@ pub struct MpcHttpTransport {
#[derive(Clone)]
pub struct ShardHttpTransport {
pub(super) inner_transport: Arc<HttpTransport<Shard>>,
pub(super) shard_config: Sharded,
pub(super) shard_count: ShardIndex,
}

impl RouteParams<RouteId, NoQueryId, NoStep> for QueryConfig {
Expand Down Expand Up @@ -297,7 +297,9 @@ impl ShardHttpTransport {
#[must_use]
pub fn new(
http_runtime: IpaRuntime,
shard_config: Sharded,
// todo: maybe a wrapper struct for it
shard_id: ShardIndex,
shard_count: ShardIndex,
server_config: ServerConfig,
network_config: NetworkConfig<Shard>,
clients: Vec<IpaHttpClient<Shard>>,
Expand All @@ -306,12 +308,12 @@ impl ShardHttpTransport {
let transport = Self {
inner_transport: Arc::new(HttpTransport {
http_runtime,
identity: shard_config.shard_id,
identity: shard_id,
clients,
handler,
record_streams: StreamCollection::default(),
}),
shard_config,
shard_count,
};

let server = IpaHttpServer::new_shards(&transport, server_config, network_config);
Expand All @@ -331,10 +333,7 @@ impl Transport for ShardHttpTransport {

fn peers(&self) -> impl Iterator<Item = Self::Identity> {
let this = self.identity();
self.shard_config
.shard_count
.iter()
.filter(move |&v| v != this)
self.shard_count.iter().filter(move |&v| v != this)
}

async fn send<D, Q, S, R>(
Expand Down
4 changes: 4 additions & 0 deletions ipa-core/src/protocol/context/dzkp_semi_honest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ impl<'a> super::ShardedContext for DZKPUpgraded<'a, Sharded> {
fn shard_recv_channel<M: Message>(&self, origin: ShardIndex) -> ShardReceivingEnd<M> {
self.inner.shard_recv_channel(origin)
}

fn cross_shard_prss(&self) -> InstrumentedIndexedSharedRandomness<'_> {
self.inner.cross_shard_prss()
}
}

impl<'a, B: ShardBinding> super::Context for DZKPUpgraded<'a, B> {
Expand Down
4 changes: 4 additions & 0 deletions ipa-core/src/protocol/context/malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ impl ShardedContext for Context<'_, Sharded> {
fn shard_recv_channel<M: Message>(&self, origin: ShardIndex) -> ShardReceivingEnd<M> {
self.inner.shard_recv_channel(origin)
}

fn cross_shard_prss(&self) -> InstrumentedIndexedSharedRandomness<'_> {
self.inner.cross_shard_prss()
}
}

impl<'a> Context<'a, NotSharded> {
Expand Down
15 changes: 15 additions & 0 deletions ipa-core/src/protocol/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,14 @@ impl ShardedContext for Base<'_, Sharded> {
.gateway
.get_shard_receiver(&ChannelId::new(origin, self.gate.clone()))
}

fn cross_shard_prss(&self) -> InstrumentedIndexedSharedRandomness<'_> {
InstrumentedIndexedSharedRandomness::new(
self.sharding.cross_shard_prss().indexed(self.gate()),
self.gate(),
self.inner.gateway.role(),
)
}
}

impl<'a, B: ShardBinding> Context for Base<'a, B> {
Expand Down Expand Up @@ -325,6 +333,13 @@ pub trait ShardedContext: Context + ShardConfiguration {

ShardIndex::from(shard_index)
}

/// Get the indexed PRSS instance shared across all shards on this helper.
/// Each shard will see the same random values generated by it.
/// This is still PRSS - the corresponding shards on other helpers will share
/// the left and the right part
#[must_use]
fn cross_shard_prss(&self) -> InstrumentedIndexedSharedRandomness<'_>;
}

impl ShardConfiguration for Base<'_, Sharded> {
Expand Down
8 changes: 8 additions & 0 deletions ipa-core/src/protocol/context/semi_honest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ impl ShardedContext for Context<'_, Sharded> {
fn shard_recv_channel<M: Message>(&self, origin: ShardIndex) -> ShardReceivingEnd<M> {
self.inner.shard_recv_channel(origin)
}

fn cross_shard_prss(&self) -> InstrumentedIndexedSharedRandomness<'_> {
self.inner.cross_shard_prss()
}
}

impl<'a, B: ShardBinding> super::Context for Context<'a, B> {
Expand Down Expand Up @@ -218,6 +222,10 @@ impl<F: ExtendableField> ShardedContext for Upgraded<'_, Sharded, F> {
fn shard_recv_channel<M: Message>(&self, origin: ShardIndex) -> ShardReceivingEnd<M> {
self.inner.shard_recv_channel(origin)
}

fn cross_shard_prss(&self) -> InstrumentedIndexedSharedRandomness<'_> {
self.inner.cross_shard_prss()
}
}

impl<'a, B: ShardBinding, F: ExtendableField> super::Context for Upgraded<'a, B, F> {
Expand Down
6 changes: 3 additions & 3 deletions ipa-core/src/protocol/ipa_prf/shuffle/malicious.rs
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,7 @@ mod tests {
// changing x2
if ctx.gate.as_ref().contains("transfer_x_y")
&& ctx.dest == Role::H2
&& ctx.shard.map(|s| s.shard_id) == target_shard
&& ctx.shard == target_shard
{
data[0] ^= 1u8;
}
Expand All @@ -850,7 +850,7 @@ mod tests {
// changing y1
if ctx.gate.as_ref().contains("transfer_x_y")
&& ctx.dest == Role::H3
&& ctx.shard.map(|s| s.shard_id) == target_shard
&& ctx.shard == target_shard
{
data[0] ^= 1u8;
}
Expand All @@ -866,7 +866,7 @@ mod tests {
// changing c_hat_2
if ctx.gate.as_ref().contains("transfer_c")
&& ctx.dest == Role::H2
&& ctx.shard.map(|s| s.shard_id) == target_shard
&& ctx.shard == target_shard
{
data[0] ^= 1u8;
}
Expand Down
Loading

0 comments on commit eecd25d

Please sign in to comment.