Skip to content

Commit

Permalink
implement local-only for preprocessed columns
Browse files Browse the repository at this point in the history
  • Loading branch information
rkm0959 committed Oct 31, 2024
1 parent 84b7331 commit 5b00ac3
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 18 deletions.
1 change: 1 addition & 0 deletions crates/core/machine/src/riscv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,7 @@ pub mod tests {
assert_eq!(pk.traces, deserialized_pk.traces);
assert_eq!(pk.data.root(), deserialized_pk.data.root());
assert_eq!(pk.chip_ordering, deserialized_pk.chip_ordering);
assert_eq!(pk.local_only, deserialized_pk.local_only);

let serialized_vk = bincode::serialize(&vk).unwrap();
let deserialized_vk: StarkVerifyingKey<BabyBearPoseidon2> =
Expand Down
16 changes: 12 additions & 4 deletions crates/recursion/circuit/src/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,18 @@ where
.map(|(name, domain, _)| {
let i = chip_ordering[name];
let values = opened_values.chips[i].preprocessed.clone();
TwoAdicPcsMatsVariable::<C> {
domain: *domain,
points: vec![zeta, domain.next_point_variable(builder, zeta)],
values: vec![values.local, values.next],
if !chips[i].local_only() {
TwoAdicPcsMatsVariable::<C> {
domain: *domain,
points: vec![zeta, domain.next_point_variable(builder, zeta)],
values: vec![values.local, values.next],
}
} else {
TwoAdicPcsMatsVariable::<C> {
domain: *domain,
points: vec![zeta],
values: vec![values.local],
}
}
})
.collect::<Vec<_>>();
Expand Down
18 changes: 13 additions & 5 deletions crates/stark/src/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ pub struct StarkProvingKey<SC: StarkGenericConfig> {
pub data: PcsProverData<SC>,
/// The preprocessed chip ordering.
pub chip_ordering: HashMap<String, usize>,
/// The preprocessed chip local only information.
pub local_only: Vec<bool>,
}

impl<SC: StarkGenericConfig> StarkProvingKey<SC> {
Expand Down Expand Up @@ -196,19 +198,19 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> StarkMachine<SC, A> {
chip.preprocessed_width(),
"Incorrect number of preprocessed columns for chip {chip_name}"
);
prep_trace.map(move |t| (chip_name, t))
prep_trace.map(move |t| (chip_name, chip.local_only(), t))
})
.collect::<Vec<_>>()
});

// Order the chips and traces by trace size (biggest first), and get the ordering map.
named_preprocessed_traces
.sort_by_key(|(name, trace)| (Reverse(trace.height()), name.clone()));
.sort_by_key(|(name, _, trace)| (Reverse(trace.height()), name.clone()));

let pcs = self.config.pcs();
let (chip_information, domains_and_traces): (Vec<_>, Vec<_>) = named_preprocessed_traces
.iter()
.map(|(name, trace)| {
.map(|(name, _, trace)| {
let domain = pcs.natural_domain_for_degree(trace.height());
((name.to_owned(), domain, trace.dimensions()), (domain, trace.to_owned()))
})
Expand All @@ -222,12 +224,17 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> StarkMachine<SC, A> {
let chip_ordering = named_preprocessed_traces
.iter()
.enumerate()
.map(|(i, (name, _))| (name.to_owned(), i))
.map(|(i, (name, _, _))| (name.to_owned(), i))
.collect::<HashMap<_, _>>();

let local_only = named_preprocessed_traces
.iter()
.map(|(_, local_only, _)| local_only.to_owned())
.collect::<Vec<_>>();

// Get the preprocessed traces
let traces =
named_preprocessed_traces.into_iter().map(|(_, trace)| trace).collect::<Vec<_>>();
named_preprocessed_traces.into_iter().map(|(_, _, trace)| trace).collect::<Vec<_>>();

let pc_start = program.pc_start();

Expand All @@ -238,6 +245,7 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> StarkMachine<SC, A> {
traces,
data,
chip_ordering: chip_ordering.clone(),
local_only,
},
StarkVerifyingKey { commit, pc_start, chip_information, chip_ordering },
)
Expand Down
22 changes: 17 additions & 5 deletions crates/stark/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -593,9 +593,14 @@ where
tracing::debug_span!("compute preprocessed opening points").in_scope(|| {
pk.traces
.iter()
.map(|trace| {
.zip(pk.local_only.iter())
.map(|(trace, local_only)| {
let domain = pcs.natural_domain_for_degree(trace.height());
vec![zeta, domain.next_point(zeta).unwrap()]
if !local_only {
vec![zeta, domain.next_point(zeta).unwrap()]
} else {
vec![zeta]
}
})
.collect::<Vec<_>>()
});
Expand Down Expand Up @@ -684,9 +689,16 @@ where

let preprocessed_opened_values = preprocessed_values
.into_iter()
.map(|op| {
let [local, next] = op.try_into().unwrap();
AirOpenedValues { local, next }
.zip(pk.local_only.iter())
.map(|(op, local_only)| {
if !local_only {
let [local, next] = op.try_into().unwrap();
AirOpenedValues { local, next }
} else {
let [local] = op.try_into().unwrap();
let width = local.len();
AirOpenedValues { local, next: vec![SC::Challenge::zero(); width] }
}
})
.collect::<Vec<_>>();

Expand Down
12 changes: 8 additions & 4 deletions crates/stark/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,14 @@ impl<SC: StarkGenericConfig, A: MachineAir<Val<SC>>> Verifier<SC, A> {
.map(|(name, domain, _)| {
let i = chip_ordering[name];
let values = opened_values.chips[i].preprocessed.clone();
(
*domain,
vec![(zeta, values.local), (domain.next_point(zeta).unwrap(), values.next)],
)
if !chips[i].local_only() {
(
*domain,
vec![(zeta, values.local), (domain.next_point(zeta).unwrap(), values.next)],
)
} else {
(*domain, vec![(zeta, values.local)])
}
})
.collect::<Vec<_>>();

Expand Down

0 comments on commit 5b00ac3

Please sign in to comment.