Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Coroutine refactor #1213

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/coroutine/memoset/demo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,12 @@ impl<F: LurkField> CircuitQuery<F> for DemoCircuitQuery<F> {
}
}

fn synthesize_eval<CS: ConstraintSystem<F>>(
fn synthesize_eval<'a, CS: ConstraintSystem<F>>(
&self,
cs: &mut CS,
g: &GlobalAllocator<F>,
store: &Store<F>,
scope: &mut CircuitScope<F, LogMemoCircuit<F>, Self::RD>,
scope: &mut CircuitScope<'a, F, LogMemoCircuit<'a, F>, Self::RD>,
acc: &AllocatedPtr<F>,
allocated_key: &AllocatedPtr<F>,
) -> Result<((AllocatedPtr<F>, AllocatedPtr<F>), AllocatedPtr<F>), SynthesisError> {
Expand Down
4 changes: 2 additions & 2 deletions src/coroutine/memoset/env.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ impl<F: LurkField> CircuitQuery<F> for EnvCircuitQuery<F> {
}
}

fn synthesize_eval<CS: ConstraintSystem<F>>(
fn synthesize_eval<'a, CS: ConstraintSystem<F>>(
&self,
cs: &mut CS,
g: &GlobalAllocator<F>,
store: &Store<F>,
scope: &mut CircuitScope<F, LogMemoCircuit<F>, Self::RD>,
scope: &mut CircuitScope<'a, F, LogMemoCircuit<'a, F>, Self::RD>,
acc: &AllocatedPtr<F>,
allocated_key: &AllocatedPtr<F>,
) -> Result<((AllocatedPtr<F>, AllocatedPtr<F>), AllocatedPtr<F>), SynthesisError> {
Expand Down
173 changes: 82 additions & 91 deletions src/coroutine/memoset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,75 +419,80 @@ impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>, F> {
}

#[derive(Debug, Clone)]
pub struct CircuitScope<F: LurkField, CM, RD> {
pub struct CircuitScope<'a, F: LurkField, CM, RD> {
memoset: CM, // CircuitMemoSet
/// k -> prov
provenances: IndexMap<ZPtr<Tag, F>, ZPtr<Tag, F>>,
/// k -> allocated v
transcript: CircuitTranscript<F>,
/// k -> prov
provenances: Option<&'a IndexMap<ZPtr<Tag, F>, ZPtr<Tag, F>>>,
acc: Option<AllocatedPtr<F>>,
pub(crate) runtime_data: RD,
}

#[derive(Clone)]
pub struct CoroutineCircuit<F: LurkField, CM, Q: Query<F>> {
input: Option<Vec<Ptr>>,
provenances: IndexMap<ZPtr<Tag, F>, ZPtr<Tag, F>>,
memoset: CM,
keys: Vec<Ptr>,
pub struct CoroutineCircuit<'a, F: LurkField, M, Q: Query<F>> {
query_index: usize,
next_query_index: usize,
store: Arc<Store<F>>,
rc: usize,
runtime_data: Q::RD,
_p: PhantomData<Q>,
store: &'a Store<F>,
runtime_data: &'a Q::RD,
// `None` for circuit synthesis
// `Some` for witness generation
witness_data: Option<WitnessData<'a, F, M>>,
}

// TODO: Make this generic rather than specialized to LogMemo.
// That will require a CircuitScopeTrait.
impl<F: LurkField, Q: Query<F>> CoroutineCircuit<F, LogMemo<F>, Q> {
pub fn new(
input: Option<Vec<Ptr>>,
scope: &Scope<Q, LogMemo<F>, F>,
memoset: LogMemo<F>,
keys: Vec<Ptr>,
#[derive(Clone, Copy)]
struct WitnessData<'a, F: LurkField, M> {
keys: &'a [Ptr],
memoset: &'a M,
provenances: &'a IndexMap<ZPtr<Tag, F>, ZPtr<Tag, F>>,
next_query_index: usize,
}

impl<'a, F: LurkField, M, Q: Query<F>> CoroutineCircuit<'a, F, M, Q> {
pub fn blank(
query_index: usize,
next_query_index: usize,
rc: usize,
runtime_data: Q::RD,
store: &'a Store<F>,
runtime_data: &'a Q::RD,
) -> Self {
assert!(keys.len() <= rc);
Self {
input,
memoset,
provenances: scope.provenances().clone(), // FIXME
keys,
query_index,
next_query_index,
store: scope.store.clone(),
store,
rc,
runtime_data,
_p: Default::default(),
witness_data: None,
}
}

pub fn blank(
fn witness_data(&self) -> Option<&WitnessData<'a, F, M>> {
self.witness_data.as_ref()
}
}

// TODO: Make this generic rather than specialized to LogMemo.
// That will require a CircuitScopeTrait.
impl<'a, F: LurkField, Q: Query<F>> CoroutineCircuit<'a, F, LogMemo<F>, Q> {
pub fn new(
scope: &'a Scope<Q, LogMemo<F>, F>,
keys: &'a [Ptr],
query_index: usize,
store: Arc<Store<F>>,
next_query_index: usize,
rc: usize,
runtime_data: Q::RD,
) -> CoroutineCircuit<F, LogMemo<F>, Q> {
) -> Self {
assert!(keys.len() <= rc);
let memoset = &scope.memoset;
let provenances = scope.provenances();
Self {
input: None,
memoset: Default::default(),
provenances: Default::default(),
keys: Default::default(),
query_index,
next_query_index: 0,
store,
rc,
runtime_data,
_p: Default::default(),
query_index,
store: &scope.store,
runtime_data: &scope.runtime_data,
witness_data: Some(WitnessData {
keys,
memoset,
provenances,
next_query_index,
}),
}
}

Expand All @@ -505,22 +510,25 @@ impl<F: LurkField, Q: Query<F>> CoroutineCircuit<F, LogMemo<F>, Q> {
unreachable!()
};

let multiset = self.witness_data().map(|w| &w.memoset.multiset);
let memoset = LogMemoCircuit {
multiset: self.memoset.multiset.clone(),
multiset,
r: r.hash().clone(),
};
let mut circuit_scope: CircuitScope<F, LogMemoCircuit<F>, Q::RD> = CircuitScope::new(
cs,
g,
&self.store,
memoset,
&self.provenances,
self.runtime_data.clone(),
);
let provenances = self.witness_data().map(|w| w.provenances);
let mut circuit_scope: CircuitScope<'_, F, LogMemoCircuit<'_, F>, Q::RD> =
CircuitScope::new(
cs,
g,
self.store,
memoset,
provenances,
self.runtime_data.clone(),
);
circuit_scope.update_from_io(memoset_acc.clone(), transcript.clone(), r);

for (i, key) in self
.keys
let keys: &[Ptr] = self.witness_data().map_or(&[], |w| w.keys);
for (i, key) in keys
.iter()
.map(Some)
.pad_using(self.rc, |_| None)
Expand All @@ -530,7 +538,7 @@ impl<F: LurkField, Q: Query<F>> CoroutineCircuit<F, LogMemo<F>, Q> {
circuit_scope.synthesize_prove_key_query::<_, Q>(
cs,
g,
&self.store,
self.store,
key,
self.query_index,
)?;
Expand All @@ -542,7 +550,8 @@ impl<F: LurkField, Q: Query<F>> CoroutineCircuit<F, LogMemo<F>, Q> {
let z_out = vec![c.clone(), e.clone(), k.clone(), memoset_acc, transcript, r];

let next_pc = AllocatedNum::alloc_infallible(&mut cs.namespace(|| "next_pc"), || {
F::from_u64(self.next_query_index as u64)
let index = self.witness_data().unwrap().next_query_index;
F::from_u64(index as u64)
});
Ok((Some(next_pc), z_out))
}
Expand Down Expand Up @@ -833,14 +842,18 @@ impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>, F> {
let s = self.store.as_ref();
// FIXME: Do we need to allocate a new GlobalAllocator here?
// Is it okay for this memoset circuit to be shared between all CoroutineCircuits?
let memoset_circuit = self.memoset.to_circuit(ns!(cs, "memoset_circuit"));
let r = self.memoset.allocated_r(ns!(cs, "memoset_allocated_r"));
let memoset_circuit = LogMemoCircuit {
multiset: Some(&self.memoset.multiset),
r,
};

let mut circuit_scope = CircuitScope::new(
ns!(cs, "transcript"),
g,
s,
memoset_circuit.clone(),
self.provenances(),
Some(self.provenances()),
self.runtime_data.clone(),
);

Expand Down Expand Up @@ -873,16 +886,8 @@ impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>, F> {

// `next_query_index` is only relevant for SuperNova
let next_query_index = 0;
let circuit: CoroutineCircuit<F, LogMemo<F>, Q> = CoroutineCircuit::new(
None,
self,
self.memoset.clone(),
chunk.to_vec(),
*index,
next_query_index,
rc,
self.runtime_data.clone(),
);
let circuit: CoroutineCircuit<'_, F, LogMemo<F>, Q> =
CoroutineCircuit::new(self, chunk, *index, next_query_index, rc);

let (_next_pc, z_out) = circuit.supernova_synthesize(cs, &z)?;
{
Expand Down Expand Up @@ -913,18 +918,18 @@ impl<F: LurkField, Q: Query<F>> Scope<Q, LogMemo<F>, F> {
}
}

impl<F: LurkField, RD> CircuitScope<F, LogMemoCircuit<F>, RD> {
impl<'a, F: LurkField, RD> CircuitScope<'a, F, LogMemoCircuit<'a, F>, RD> {
fn new<CS: ConstraintSystem<F>>(
cs: &mut CS,
g: &GlobalAllocator<F>,
s: &Store<F>,
memoset: LogMemoCircuit<F>,
provenances: &IndexMap<ZPtr<Tag, F>, ZPtr<Tag, F>>,
memoset: LogMemoCircuit<'a, F>,
provenances: Option<&'a IndexMap<ZPtr<Tag, F>, ZPtr<Tag, F>>>,
runtime_data: RD,
) -> Self {
Self {
memoset,
provenances: provenances.clone(), // FIXME
provenances,
transcript: CircuitTranscript::new(cs, g, s),
acc: Default::default(),
runtime_data,
Expand Down Expand Up @@ -1127,7 +1132,7 @@ impl<F: LurkField, RD> CircuitScope<F, LogMemoCircuit<F>, RD> {
let provenance = AllocatedPtr::alloc(ns!(cs, "provenance"), || {
Ok(if not_dummy.get_value() == Some(true) {
*key.get_value()
.and_then(|k| self.provenances.get(&k))
.and_then(|k| self.provenances.unwrap().get(&k))
.ok_or(SynthesisError::AssignmentMissing)?
} else {
// Dummy value that will not be used.
Expand Down Expand Up @@ -1327,10 +1332,6 @@ pub trait CircuitMemoSet<F: LurkField>: Clone {
}

pub trait MemoSet<F: LurkField>: Clone {
type CM: CircuitMemoSet<F>;

fn to_circuit<CS: ConstraintSystem<F>>(&self, cs: &mut CS) -> Self::CM;

fn is_finalized(&self) -> bool;
fn finalize_transcript(&mut self, s: &Store<F>, transcript: Transcript<F>);
fn r(&self) -> Option<&F>;
Expand All @@ -1351,8 +1352,8 @@ pub struct LogMemo<F: LurkField> {
}

#[derive(Debug, Clone)]
pub struct LogMemoCircuit<F: LurkField> {
multiset: MultiSet<Ptr>,
pub struct LogMemoCircuit<'a, F: LurkField> {
multiset: Option<&'a MultiSet<Ptr>>,
r: AllocatedNum<F>,
}

Expand Down Expand Up @@ -1387,16 +1388,6 @@ impl<F: LurkField> LogMemo<F> {
}

impl<F: LurkField> MemoSet<F> for LogMemo<F> {
type CM = LogMemoCircuit<F>;

fn to_circuit<CS: ConstraintSystem<F>>(&self, cs: &mut CS) -> Self::CM {
let r = self.allocated_r(cs);
LogMemoCircuit {
multiset: self.multiset.clone(),
r,
}
}

fn count(&self, form: &Ptr) -> usize {
self.multiset.get(form).unwrap_or(0)
}
Expand Down Expand Up @@ -1430,7 +1421,7 @@ impl<F: LurkField> MemoSet<F> for LogMemo<F> {
}
}

impl<F: LurkField> CircuitMemoSet<F> for LogMemoCircuit<F> {
impl<'a, F: LurkField> CircuitMemoSet<F> for LogMemoCircuit<'a, F> {
fn allocated_r(&self) -> AllocatedNum<F> {
self.r.clone()
}
Expand Down Expand Up @@ -1473,7 +1464,7 @@ impl<F: LurkField> CircuitMemoSet<F> for LogMemoCircuit<F> {
}

fn count(&self, form: &Ptr) -> usize {
self.multiset.get(form).unwrap_or(0)
self.multiset.and_then(|m| m.get(form)).unwrap_or(0)
}
}

Expand Down
Loading