Skip to content

Commit

Permalink
Refactor: core of seec
Browse files Browse the repository at this point in the history
This is a big refactor of the core of SEEC with
 the goal to increase the maintainability and
 flexibility of SEEC.
  • Loading branch information
robinhundt committed May 26, 2024
1 parent d337ba3 commit df40ef5
Show file tree
Hide file tree
Showing 32 changed files with 1,171 additions and 941 deletions.
21 changes: 19 additions & 2 deletions crates/seec-bitmatrix/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use rayon::iter::IndexedParallelIterator;
use rayon::slice::{ParallelSlice, ParallelSliceMut};
use serde::{Deserialize, Serialize};
use std::fmt::{Binary, Debug, Formatter};
use std::ops::{BitAnd, BitXor, Range};
use std::ops::{BitAnd, BitXor, Not, Range};
use std::slice::{ChunksExact, ChunksExactMut};

#[cfg(is_nightly)]
Expand Down Expand Up @@ -48,7 +48,7 @@ pub struct BitMatrixViewMut<'a, T> {
}

pub trait Storage:
bytemuck::Pod + BitXor<Output = Self> + BitAnd<Output = Self> + Send + Sync
bytemuck::Pod + BitXor<Output = Self> + BitAnd<Output = Self> + Not<Output = Self> + Send + Sync
{
const BITS: usize;

Expand Down Expand Up @@ -101,6 +101,14 @@ impl<T: Storage> BitMatrix<T> {
}
}

pub fn rows(&self) -> usize {
self.rows
}

pub fn cols(&self) -> usize {
self.cols
}

// Returns dimensions (rows, columns).
pub fn dim(&self) -> (usize, usize) {
(self.rows, self.cols)
Expand Down Expand Up @@ -311,6 +319,15 @@ fn raw_row_idx<T: Storage>(row: usize, cols: usize) -> Range<usize> {
start_idx..end_idx
}

impl<T: Storage> Not for BitMatrix<T> {
type Output = Self;

fn not(mut self) -> Self::Output {
self.data.iter_mut().for_each(|el| *el = !*el);
self
}
}

impl<T: Storage> BitXor for BitMatrix<T> {
type Output = BitMatrix<T>;

Expand Down
13 changes: 7 additions & 6 deletions crates/seec-macros/src/sub_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub(crate) fn sub_circuit(input: ItemFn) -> TokenStream {

let inputs_size_ty = quote!((#(<#input_tys as ::seec::SubCircuitInput>::Size, )*));
let input_protocol_ty = quote!(<#first_input_ty as ::seec::SubCircuitInput>::Protocol);
let input_plain_ty = quote!(<#input_protocol_ty as ::seec::protocols::Protocol>::Plain);
let input_gate_ty = quote!(<#input_protocol_ty as ::seec::protocols::Protocol>::Gate);
let input_idx_ty = quote!(<#first_input_ty as ::seec::SubCircuitInput>::Idx);

Expand Down Expand Up @@ -76,12 +77,12 @@ pub(crate) fn sub_circuit(input: ItemFn) -> TokenStream {
::parking_lot::Mutex<
::std::collections::HashMap<
#inputs_size_ty,
(::seec::circuit::SharedCircuit<#input_gate_ty, #input_idx_ty>, _internal_ForceSendSync<#inner_ret>)
(::seec::circuit::SharedCircuit<#input_plain_ty, #input_gate_ty, #input_idx_ty>, _internal_ForceSendSync<#inner_ret>)
>
>
> = ::once_cell::sync::Lazy::new(|| ::parking_lot::Mutex::new(::std::collections::HashMap::new()));

::seec::CircuitBuilder::<#input_gate_ty, #input_idx_ty>::with_global(|builder| {
::seec::CircuitBuilder::<#input_plain_ty, #input_gate_ty, #input_idx_ty>::with_global(|builder| {
builder.add_cache(&*CACHE);
});

Expand All @@ -92,20 +93,20 @@ pub(crate) fn sub_circuit(input: ItemFn) -> TokenStream {

let (sc_id, ret) = match CACHE.lock().entry(input_size.clone()) {
::std::collections::hash_map::Entry::Vacant(entry) => {
let sub_circuit = ::seec::SharedCircuit::<#input_gate_ty, #input_idx_ty>::default();
let sc_id = ::seec::CircuitBuilder::<#input_gate_ty, #input_idx_ty>::push_global_circuit(sub_circuit.clone());
let sub_circuit = ::seec::SharedCircuit::<#input_plain_ty, #input_gate_ty, #input_idx_ty>::default();
let sc_id = ::seec::CircuitBuilder::<#input_plain_ty, #input_gate_ty, #input_idx_ty>::push_global_circuit(sub_circuit.clone());
let ret = #call_inner;
let ret = ::seec::SubCircuitOutput::create_output_gates(ret);
entry.insert((sub_circuit, _internal_ForceSendSync(ret.clone())));
(sc_id, ret)
}
::std::collections::hash_map::Entry::Occupied(entry) => {
let (sub_circuit, ret) = entry.get();
let sc_id = ::seec::CircuitBuilder::<#input_gate_ty, #input_idx_ty>::push_global_circuit(sub_circuit.clone());
let sc_id = ::seec::CircuitBuilder::<#input_plain_ty, #input_gate_ty, #input_idx_ty>::push_global_circuit(sub_circuit.clone());
(sc_id, ret.0.clone())
}
};
::seec::CircuitBuilder::<#input_gate_ty, #input_idx_ty>::with_global(|builder| {
::seec::CircuitBuilder::<#input_plain_ty, #input_gate_ty, #input_idx_ty>::with_global(|builder| {
builder.connect_sub_circuit(&circ_inputs, sc_id);
});

Expand Down
4 changes: 2 additions & 2 deletions crates/seec/benches/benchmarks/layer_iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use seec::secret::{inputs, low_depth_reduce, Secret};
use seec::{sub_circuit, BooleanGate, CircuitBuilder};

fn build_circuit(keyword_size: usize, target_text_size: usize) -> Circuit {
CircuitBuilder::<BooleanGate, u32>::new().install();
CircuitBuilder::<bool, BooleanGate, u32>::new().install();
let keyword: Vec<_> = (0..keyword_size)
.map(|_| inputs(8).try_into().unwrap())
.collect();
Expand All @@ -17,7 +17,7 @@ fn build_circuit(keyword_size: usize, target_text_size: usize) -> Circuit {

create_search_circuit(&keyword, &target_text);

let circ = CircuitBuilder::<BooleanGate, u32>::global_into_circuit();
let circ = CircuitBuilder::<bool, BooleanGate, u32>::global_into_circuit();
circ
}

Expand Down
9 changes: 5 additions & 4 deletions crates/seec/examples/aes_cbc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ async fn execute(args: &ExecuteArgs) -> Result<()> {
}

async fn bench_execute(args: &ExecuteArgs) -> Result<()> {
let exec_circ: ExecutableCircuit<BooleanGate, usize> = bincode::deserialize_from(
let exec_circ: ExecutableCircuit<bool, BooleanGate, usize> = bincode::deserialize_from(
BufReader::new(File::open(&args.circuit).context("Failed to open circuit file")?),
)?;

Expand Down Expand Up @@ -351,7 +351,7 @@ async fn encrypt(
shared_key: &BitSlice<usize>,
shared_iv: &BitSlice<usize>,
) -> Result<Output<BitVec<usize>>> {
let exec_circ: ExecutableCircuit<BooleanGate, usize> = bincode::deserialize_from(
let exec_circ: ExecutableCircuit<bool, BooleanGate, usize> = bincode::deserialize_from(
BufReader::new(File::open(&args.circuit).context("Failed to open circuit file")?),
)?;

Expand Down Expand Up @@ -382,7 +382,7 @@ async fn encrypt(
fn build_enc_circuit(
data_size_bits: usize,
use_sc: bool,
) -> Result<ExecutableCircuit<BooleanGate, usize>> {
) -> Result<ExecutableCircuit<bool, BooleanGate, usize>> {
assert_eq!(
data_size_bits % 128,
0,
Expand All @@ -399,6 +399,7 @@ fn build_enc_circuit(
.for_each(|chunk| aes_cbc_chunk(&key, chunk, &mut chaining_state, use_sc));

Ok(ExecutableCircuit::DynLayers(CircuitBuilder::<
bool,
BooleanGate,
usize,
>::global_into_circuit()))
Expand Down Expand Up @@ -427,7 +428,7 @@ fn aes128(
chunk: &[Secret<BooleanGmw, usize>],
use_sc: bool,
) -> Vec<Secret<BooleanGmw, usize>> {
static AES_CIRC: Lazy<SharedCircuit<BooleanGate, usize>> = Lazy::new(|| {
static AES_CIRC: Lazy<SharedCircuit<bool, BooleanGate, usize>> = Lazy::new(|| {
let aes_circ_str = include_str!("../test_resources/bristol-circuits/aes_128.bristol");
BaseCircuit::from_bristol(
seec::bristol::circuit(aes_circ_str).expect("parsing AES circuit failed"),
Expand Down
2 changes: 1 addition & 1 deletion crates/seec/examples/bristol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ async fn execute(execute_args: ExecuteArgs) -> Result<()> {
Ok(())
}

fn load_circ(args: &ExecuteArgs) -> Result<ExecutableCircuit<BooleanGate, u32>> {
fn load_circ(args: &ExecuteArgs) -> Result<ExecutableCircuit<bool, BooleanGate, u32>> {
let res = bincode::deserialize_from(BufReader::new(
File::open(&args.circuit).context("Failed to open circuit file")?,
));
Expand Down
4 changes: 2 additions & 2 deletions crates/seec/examples/fuse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use clap::{Args, Parser};
use seec::bench::BenchParty;
use seec::circuit::ExecutableCircuit;
use seec::parse::fuse::{CallMode, FuseConverter};
use seec::protocols::mixed_gmw::{MixedGate, MixedGmw};
use seec::protocols::mixed_gmw::{Mixed, MixedGate, MixedGmw};
use std::fs::File;
use std::io;
use std::io::{stdout, BufReader, BufWriter, Write};
Expand Down Expand Up @@ -168,7 +168,7 @@ async fn execute(execute_args: ExecuteArgs) -> Result<()> {
Ok(())
}

fn load_circ(args: &ExecuteArgs) -> Result<ExecutableCircuit<MixedGate<u32>, u32>> {
fn load_circ(args: &ExecuteArgs) -> Result<ExecutableCircuit<Mixed<u32>, MixedGate<u32>, u32>> {
bincode::deserialize_from(BufReader::new(
File::open(&args.circuit).context("Failed to open circuit file")?,
))
Expand Down
2 changes: 1 addition & 1 deletion crates/seec/examples/privmail.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ async fn main() -> anyhow::Result<()> {
args.duplication_factor,
);

let circuit: ExecutableCircuit<_, _> =
let circuit: ExecutableCircuit<bool, _, _> =
ExecutableCircuit::DynLayers(CircuitBuilder::global_into_circuit());
// if args.save_circuit {
// circuit.save_dot("privmail.dot")?;
Expand Down
2 changes: 1 addition & 1 deletion crates/seec/examples/sha256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ struct Args {
async fn main() -> Result<()> {
let _guard = init_tracing()?;
let args = Args::parse();
let circuit: ExecutableCircuit<BooleanGate, u32> = ExecutableCircuit::DynLayers(
let circuit: ExecutableCircuit<bool, BooleanGate, u32> = ExecutableCircuit::DynLayers(
BaseCircuit::load_bristol(args.circuit, Load::Circuit)?.into(),
);

Expand Down
2 changes: 1 addition & 1 deletion crates/seec/examples/sub_circuits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ fn main() {

(or_out ^ false).output();

let circuit: Circuit<BooleanGate, DefaultIdx> = CircuitBuilder::global_into_circuit();
let circuit: Circuit<bool, BooleanGate, DefaultIdx> = CircuitBuilder::global_into_circuit();
let layer_iter = CircuitLayerIter::new(&circuit);
for layer in layer_iter {
dbg!(layer);
Expand Down
20 changes: 10 additions & 10 deletions crates/seec/src/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use crate::executor::{Executor, Input, Message};
use crate::mul_triple::storage::MTStorage;
use crate::mul_triple::{boolean, MTProvider};
use crate::protocols::boolean_gmw::BooleanGmw;
use crate::protocols::mixed_gmw::{MixedGmw, MixedShare};
use crate::protocols::{mixed_gmw, Gate, Protocol, Ring, Share, ShareStorage};
use crate::protocols::mixed_gmw::{Mixed, MixedGmw};
use crate::protocols::{mixed_gmw, Protocol, Ring, Share, ShareStorage};
use crate::utils::{BoxError, ErasedError};
use crate::CircuitBuilder;
use anyhow::{anyhow, Context};
Expand Down Expand Up @@ -78,18 +78,18 @@ where
}

// TODO this is wrong to just always generate arith shares, so it lives here in the bench API
impl<R> Distribution<MixedShare<R>> for Standard
impl<R> Distribution<Mixed<R>> for Standard
where
Standard: Distribution<R>,
{
fn sample<RNG: Rng + ?Sized>(&self, rng: &mut RNG) -> MixedShare<R> {
MixedShare::Arith(rng.sample(Standard))
fn sample<RNG: Rng + ?Sized>(&self, rng: &mut RNG) -> Mixed<R> {
Mixed::Arith(rng.sample(Standard))
}
}

pub struct BenchParty<P: Protocol, Idx> {
id: usize,
circ: Option<ExecutableCircuit<P::Gate, Idx>>,
circ: Option<ExecutableCircuit<P::Plain, P::Gate, Idx>>,
server: Option<SocketAddr>,
meta: String,
insecure_setup: bool,
Expand Down Expand Up @@ -117,9 +117,9 @@ pub struct BenchResult {
impl<P, Idx> BenchParty<P, Idx>
where
P: BenchProtocol,
Standard: Distribution<<<P as Protocol>::Gate as Gate>::Share>,
Standard: Distribution<P::Share>,
Idx: GateIdx,
<P::Gate as Gate>::Share: Share<SimdShare = P::ShareStorage>,
P::Share: Share<SimdShare = P::ShareStorage>,
{
pub fn new(id: usize) -> Self {
Self {
Expand All @@ -143,7 +143,7 @@ where
self
}

pub fn explicit_circuit(mut self, circuit: ExecutableCircuit<P::Gate, Idx>) -> Self {
pub fn explicit_circuit(mut self, circuit: ExecutableCircuit<P::Plain, P::Gate, Idx>) -> Self {
self.circ = Some(circuit);
self
}
Expand Down Expand Up @@ -247,7 +247,7 @@ where
let circ = match &self.circ {
Some(circ) => circ,
None => {
let circ = CircuitBuilder::<P::Gate, Idx>::global_into_circuit();
let circ = CircuitBuilder::<P::Plain, P::Gate, Idx>::global_into_circuit();
if self.precompute_layers {
owned_circ = ExecutableCircuit::DynLayers(circ).precompute_layers();
&owned_circ
Expand Down
Loading

0 comments on commit df40ef5

Please sign in to comment.