Skip to content

Commit

Permalink
Coroutine refactor (#1213)
Browse files Browse the repository at this point in the history
* CoroutineCircuit refactor

* Memoset prover simplified

* Fixed CircuitScope's provenances clone

* Fixed LogMemoCircuit's multiset clone
  • Loading branch information
gabriel-barrett authored Mar 13, 2024
1 parent e4bc325 commit 4e5cd03
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 221 deletions.
4 changes: 2 additions & 2 deletions src/coroutine/memoset/demo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,12 @@ impl<F: LurkField> CircuitQuery<F> for DemoCircuitQuery<F> {
}
}

fn synthesize_eval<CS: ConstraintSystem<F>>(
fn synthesize_eval<'a, CS: ConstraintSystem<F>>(
&self,
cs: &mut CS,
g: &GlobalAllocator<F>,
store: &Store<F>,
scope: &mut CircuitScope<F, LogMemoCircuit<F>, Self::RD>,
scope: &mut CircuitScope<'a, F, LogMemoCircuit<'a, F>, Self::RD>,
acc: &AllocatedPtr<F>,
allocated_key: &AllocatedPtr<F>,
) -> Result<((AllocatedPtr<F>, AllocatedPtr<F>), AllocatedPtr<F>), SynthesisError> {
Expand Down
4 changes: 2 additions & 2 deletions src/coroutine/memoset/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ impl<F: LurkField> CircuitQuery<F> for EnvCircuitQuery<F> {
}
}

fn synthesize_eval<CS: ConstraintSystem<F>>(
fn synthesize_eval<'a, CS: ConstraintSystem<F>>(
&self,
cs: &mut CS,
g: &GlobalAllocator<F>,
store: &Store<F>,
scope: &mut CircuitScope<F, LogMemoCircuit<F>, Self::RD>,
scope: &mut CircuitScope<'a, F, LogMemoCircuit<'a, F>, Self::RD>,
acc: &AllocatedPtr<F>,
allocated_key: &AllocatedPtr<F>,
) -> Result<((AllocatedPtr<F>, AllocatedPtr<F>), AllocatedPtr<F>), SynthesisError> {
Expand Down
173 changes: 82 additions & 91 deletions src/coroutine/memoset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,75 +419,80 @@ impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>, F> {
}

#[derive(Debug, Clone)]
pub struct CircuitScope<F: LurkField, CM, RD> {
pub struct CircuitScope<'a, F: LurkField, CM, RD> {
memoset: CM, // CircuitMemoSet
/// k -> prov
provenances: IndexMap<ZPtr<Tag, F>, ZPtr<Tag, F>>,
/// k -> allocated v
transcript: CircuitTranscript<F>,
/// k -> prov
provenances: Option<&'a IndexMap<ZPtr<Tag, F>, ZPtr<Tag, F>>>,
acc: Option<AllocatedPtr<F>>,
pub(crate) runtime_data: RD,
}

#[derive(Clone)]
pub struct CoroutineCircuit<F: LurkField, CM, Q: Query<F>> {
input: Option<Vec<Ptr>>,
provenances: IndexMap<ZPtr<Tag, F>, ZPtr<Tag, F>>,
memoset: CM,
keys: Vec<Ptr>,
pub struct CoroutineCircuit<'a, F: LurkField, M, Q: Query<F>> {
query_index: usize,
next_query_index: usize,
store: Arc<Store<F>>,
rc: usize,
runtime_data: Q::RD,
_p: PhantomData<Q>,
store: &'a Store<F>,
runtime_data: &'a Q::RD,
// `None` for circuit synthesis
// `Some` for witness generation
witness_data: Option<WitnessData<'a, F, M>>,
}

// TODO: Make this generic rather than specialized to LogMemo.
// That will require a CircuitScopeTrait.
impl<F: LurkField, Q: Query<F>> CoroutineCircuit<F, LogMemo<F>, Q> {
pub fn new(
input: Option<Vec<Ptr>>,
scope: &Scope<Q, LogMemo<F>, F>,
memoset: LogMemo<F>,
keys: Vec<Ptr>,
#[derive(Clone, Copy)]
struct WitnessData<'a, F: LurkField, M> {
keys: &'a [Ptr],
memoset: &'a M,
provenances: &'a IndexMap<ZPtr<Tag, F>, ZPtr<Tag, F>>,
next_query_index: usize,
}

impl<'a, F: LurkField, M, Q: Query<F>> CoroutineCircuit<'a, F, M, Q> {
pub fn blank(
query_index: usize,
next_query_index: usize,
rc: usize,
runtime_data: Q::RD,
store: &'a Store<F>,
runtime_data: &'a Q::RD,
) -> Self {
assert!(keys.len() <= rc);
Self {
input,
memoset,
provenances: scope.provenances().clone(), // FIXME
keys,
query_index,
next_query_index,
store: scope.store.clone(),
store,
rc,
runtime_data,
_p: Default::default(),
witness_data: None,
}
}

pub fn blank(
fn witness_data(&self) -> Option<&WitnessData<'a, F, M>> {
self.witness_data.as_ref()
}
}

// TODO: Make this generic rather than specialized to LogMemo.
// That will require a CircuitScopeTrait.
impl<'a, F: LurkField, Q: Query<F>> CoroutineCircuit<'a, F, LogMemo<F>, Q> {
pub fn new(
scope: &'a Scope<Q, LogMemo<F>, F>,
keys: &'a [Ptr],
query_index: usize,
store: Arc<Store<F>>,
next_query_index: usize,
rc: usize,
runtime_data: Q::RD,
) -> CoroutineCircuit<F, LogMemo<F>, Q> {
) -> Self {
assert!(keys.len() <= rc);
let memoset = &scope.memoset;
let provenances = scope.provenances();
Self {
input: None,
memoset: Default::default(),
provenances: Default::default(),
keys: Default::default(),
query_index,
next_query_index: 0,
store,
rc,
runtime_data,
_p: Default::default(),
query_index,
store: &scope.store,
runtime_data: &scope.runtime_data,
witness_data: Some(WitnessData {
keys,
memoset,
provenances,
next_query_index,
}),
}
}

Expand All @@ -505,22 +510,25 @@ impl<F: LurkField, Q: Query<F>> CoroutineCircuit<F, LogMemo<F>, Q> {
unreachable!()
};

let multiset = self.witness_data().map(|w| &w.memoset.multiset);
let memoset = LogMemoCircuit {
multiset: self.memoset.multiset.clone(),
multiset,
r: r.hash().clone(),
};
let mut circuit_scope: CircuitScope<F, LogMemoCircuit<F>, Q::RD> = CircuitScope::new(
cs,
g,
&self.store,
memoset,
&self.provenances,
self.runtime_data.clone(),
);
let provenances = self.witness_data().map(|w| w.provenances);
let mut circuit_scope: CircuitScope<'_, F, LogMemoCircuit<'_, F>, Q::RD> =
CircuitScope::new(
cs,
g,
self.store,
memoset,
provenances,
self.runtime_data.clone(),
);
circuit_scope.update_from_io(memoset_acc.clone(), transcript.clone(), r);

for (i, key) in self
.keys
let keys: &[Ptr] = self.witness_data().map_or(&[], |w| w.keys);
for (i, key) in keys
.iter()
.map(Some)
.pad_using(self.rc, |_| None)
Expand All @@ -530,7 +538,7 @@ impl<F: LurkField, Q: Query<F>> CoroutineCircuit<F, LogMemo<F>, Q> {
circuit_scope.synthesize_prove_key_query::<_, Q>(
cs,
g,
&self.store,
self.store,
key,
self.query_index,
)?;
Expand All @@ -542,7 +550,8 @@ impl<F: LurkField, Q: Query<F>> CoroutineCircuit<F, LogMemo<F>, Q> {
let z_out = vec![c.clone(), e.clone(), k.clone(), memoset_acc, transcript, r];

let next_pc = AllocatedNum::alloc_infallible(&mut cs.namespace(|| "next_pc"), || {
F::from_u64(self.next_query_index as u64)
let index = self.witness_data().unwrap().next_query_index;
F::from_u64(index as u64)
});
Ok((Some(next_pc), z_out))
}
Expand Down Expand Up @@ -833,14 +842,18 @@ impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>, F> {
let s = self.store.as_ref();
// FIXME: Do we need to allocate a new GlobalAllocator here?
// Is it okay for this memoset circuit to be shared between all CoroutineCircuits?
let memoset_circuit = self.memoset.to_circuit(ns!(cs, "memoset_circuit"));
let r = self.memoset.allocated_r(ns!(cs, "memoset_allocated_r"));
let memoset_circuit = LogMemoCircuit {
multiset: Some(&self.memoset.multiset),
r,
};

let mut circuit_scope = CircuitScope::new(
ns!(cs, "transcript"),
g,
s,
memoset_circuit.clone(),
self.provenances(),
Some(self.provenances()),
self.runtime_data.clone(),
);

Expand Down Expand Up @@ -873,16 +886,8 @@ impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>, F> {

// `next_query_index` is only relevant for SuperNova
let next_query_index = 0;
let circuit: CoroutineCircuit<F, LogMemo<F>, Q> = CoroutineCircuit::new(
None,
self,
self.memoset.clone(),
chunk.to_vec(),
*index,
next_query_index,
rc,
self.runtime_data.clone(),
);
let circuit: CoroutineCircuit<'_, F, LogMemo<F>, Q> =
CoroutineCircuit::new(self, chunk, *index, next_query_index, rc);

let (_next_pc, z_out) = circuit.supernova_synthesize(cs, &z)?;
{
Expand Down Expand Up @@ -913,18 +918,18 @@ impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>, F> {
}
}

impl<F: LurkField, RD> CircuitScope<F, LogMemoCircuit<F>, RD> {
impl<'a, F: LurkField, RD> CircuitScope<'a, F, LogMemoCircuit<'a, F>, RD> {
fn new<CS: ConstraintSystem<F>>(
cs: &mut CS,
g: &GlobalAllocator<F>,
s: &Store<F>,
memoset: LogMemoCircuit<F>,
provenances: &IndexMap<ZPtr<Tag, F>, ZPtr<Tag, F>>,
memoset: LogMemoCircuit<'a, F>,
provenances: Option<&'a IndexMap<ZPtr<Tag, F>, ZPtr<Tag, F>>>,
runtime_data: RD,
) -> Self {
Self {
memoset,
provenances: provenances.clone(), // FIXME
provenances,
transcript: CircuitTranscript::new(cs, g, s),
acc: Default::default(),
runtime_data,
Expand Down Expand Up @@ -1127,7 +1132,7 @@ impl<F: LurkField, RD> CircuitScope<F, LogMemoCircuit<F>, RD> {
let provenance = AllocatedPtr::alloc(ns!(cs, "provenance"), || {
Ok(if not_dummy.get_value() == Some(true) {
*key.get_value()
.and_then(|k| self.provenances.get(&k))
.and_then(|k| self.provenances.unwrap().get(&k))
.ok_or(SynthesisError::AssignmentMissing)?
} else {
// Dummy value that will not be used.
Expand Down Expand Up @@ -1327,10 +1332,6 @@ pub trait CircuitMemoSet<F: LurkField>: Clone {
}

pub trait MemoSet<F: LurkField>: Clone {
type CM: CircuitMemoSet<F>;

fn to_circuit<CS: ConstraintSystem<F>>(&self, cs: &mut CS) -> Self::CM;

fn is_finalized(&self) -> bool;
fn finalize_transcript(&mut self, s: &Store<F>, transcript: Transcript<F>);
fn r(&self) -> Option<&F>;
Expand All @@ -1351,8 +1352,8 @@ pub struct LogMemo<F: LurkField> {
}

#[derive(Debug, Clone)]
pub struct LogMemoCircuit<F: LurkField> {
multiset: MultiSet<Ptr>,
pub struct LogMemoCircuit<'a, F: LurkField> {
multiset: Option<&'a MultiSet<Ptr>>,
r: AllocatedNum<F>,
}

Expand Down Expand Up @@ -1387,16 +1388,6 @@ impl<F: LurkField> LogMemo<F> {
}

impl<F: LurkField> MemoSet<F> for LogMemo<F> {
type CM = LogMemoCircuit<F>;

fn to_circuit<CS: ConstraintSystem<F>>(&self, cs: &mut CS) -> Self::CM {
let r = self.allocated_r(cs);
LogMemoCircuit {
multiset: self.multiset.clone(),
r,
}
}

fn count(&self, form: &Ptr) -> usize {
self.multiset.get(form).unwrap_or(0)
}
Expand Down Expand Up @@ -1430,7 +1421,7 @@ impl<F: LurkField> MemoSet<F> for LogMemo<F> {
}
}

impl<F: LurkField> CircuitMemoSet<F> for LogMemoCircuit<F> {
impl<'a, F: LurkField> CircuitMemoSet<F> for LogMemoCircuit<'a, F> {
fn allocated_r(&self) -> AllocatedNum<F> {
self.r.clone()
}
Expand Down Expand Up @@ -1473,7 +1464,7 @@ impl<F: LurkField> CircuitMemoSet<F> for LogMemoCircuit<F> {
}

fn count(&self, form: &Ptr) -> usize {
self.multiset.get(form).unwrap_or(0)
self.multiset.and_then(|m| m.get(form)).unwrap_or(0)
}
}

Expand Down
Loading

1 comment on commit 4e5cd03

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks

Table of Contents

Overview

This benchmark report shows the Fibonacci GPU benchmark.
NVIDIA L4
Intel(R) Xeon(R) CPU @ 2.20GHz
32 vCPUs
125 GB RAM
Workflow run: https://github.com/lurk-lab/lurk-rs/actions/runs/8267672791

Benchmark Results

LEM Fibonacci Prove - rc = 100

ref=e4bc325367c2ea9e4f70ec2bef63cfc7f41a2f29 ref=4e5cd03bdb3f1b354856ea3283b323589933c1c9
num-100 1.46 s (✅ 1.00x) 1.45 s (✅ 1.00x faster)
num-200 2.77 s (✅ 1.00x) 2.77 s (✅ 1.00x faster)

LEM Fibonacci Prove - rc = 600

ref=e4bc325367c2ea9e4f70ec2bef63cfc7f41a2f29 ref=4e5cd03bdb3f1b354856ea3283b323589933c1c9
num-100 1.85 s (✅ 1.00x) 1.83 s (✅ 1.01x faster)
num-200 3.05 s (✅ 1.00x) 3.07 s (✅ 1.00x slower)

Made with criterion-table

Please sign in to comment.