Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(recursion): public values constraints #748

Merged
merged 25 commits into from
May 18, 2024
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/src/stark/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ pub struct ShardOpenedValues<T: Serialize> {
/// The maximum number of elements that can be stored in the public values vec. Both SP1 and recursive
/// proofs need to pad their public_values vec to this length. This is required since the recursion
/// verification program expects the public values vec to be fixed length.
pub const PROOF_MAX_NUM_PVS: usize = 232;
pub const PROOF_MAX_NUM_PVS: usize = 240;

#[derive(Serialize, Deserialize, Clone)]
#[serde(bound = "")]
Expand Down
3 changes: 3 additions & 0 deletions recursion/compiler/src/asm/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,9 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> AsmCo
DslIr::Commit(val, index) => {
self.push(AsmInstruction::Commit(val.fp(), index.fp()), trace);
}
DslIr::RegisterPublicValue(val) => {
self.push(AsmInstruction::RegisterPublicValue(val.fp()), trace);
}
DslIr::LessThan(dst, left, right) => {
self.push(
AsmInstruction::LessThan(dst.fp(), left.fp(), right.fp()),
Expand Down
17 changes: 17 additions & 0 deletions recursion/compiler/src/asm/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ pub enum AsmInstruction<F, EF> {
// Commit(val, index).
Commit(i32, i32),

// RegisterPublicValue(val).
RegisterPublicValue(i32),

LessThan(i32, i32, i32),

CycleTracker(String),
Expand Down Expand Up @@ -849,6 +852,17 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
true,
"".to_string(),
),
AsmInstruction::RegisterPublicValue(val) => Instruction::new(
Opcode::RegisterPublicValue,
i32_f(val),
f_u32(F::zero()),
f_u32(F::zero()),
F::zero(),
F::zero(),
false,
true,
"".to_string(),
),
}
}

Expand Down Expand Up @@ -1115,6 +1129,9 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
AsmInstruction::Commit(val, index) => {
write!(f, "commit ({})fp ({})fp", val, index)
}
AsmInstruction::RegisterPublicValue(val) => {
write!(f, "register_public_value ({})fp", val)
}
AsmInstruction::CycleTracker(name) => {
write!(f, "cycle-tracker {}", name)
}
Expand Down
7 changes: 6 additions & 1 deletion recursion/compiler/src/ir/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,12 @@ impl<C: Config> Builder<C> {
}
}

/// Commits a felt in public values.
/// Register a felt as public value. This is append to the proof's public values buffer.
pub fn register_public_value(&mut self, val: Felt<C::F>) {
self.operations.push(DslIr::RegisterPublicValue(val));
}

/// Register and commits a felt as public value. This value will be constrained when verified.
pub fn commit_public_value(&mut self, val: Felt<C::F>) {
if self.nb_public_values.is_none() {
self.nb_public_values = Some(self.eval(C::N::zero()));
Expand Down
1 change: 1 addition & 0 deletions recursion/compiler/src/ir/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ pub enum DslIr<C: Config> {
WitnessFelt(Felt<C::F>, u32),
WitnessExt(Ext<C::F, C::EF>, u32),
Commit(Felt<C::F>, Var<C::N>),
RegisterPublicValue(Felt<C::F>),
Halt,

// Public inputs for circuits.
Expand Down
7 changes: 3 additions & 4 deletions recursion/core/src/air/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use p3_air::AirBuilder;
use p3_field::AbstractField;
use p3_field::ExtensionField;
use p3_field::Field;
use p3_field::PrimeField32;
use serde::{Deserialize, Serialize};
use sp1_core::air::ExtensionAirBuilder;
use sp1_core::air::{BinomialExtension, SP1AirBuilder};
Expand Down Expand Up @@ -74,9 +73,9 @@ impl<T> From<[T; D]> for Block<T> {
}
}

impl<F: PrimeField32> From<F> for Block<F> {
fn from(value: F) -> Self {
Self([value, F::zero(), F::zero(), F::zero()])
impl<T: AbstractField> From<T> for Block<T> {
fn from(value: T) -> Self {
Self([value, T::zero(), T::zero(), T::zero()])
}
}

Expand Down
3 changes: 3 additions & 0 deletions recursion/core/src/air/public_values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,7 @@ pub struct RecursionPublicValues<T> {

/// Whether the proof completely proves the program execution.
pub is_complete: T,

/// The digest of all the public values elements.
pub digest: [T; DIGEST_SIZE],
}
19 changes: 18 additions & 1 deletion recursion/core/src/cpu/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod branch;
mod jump;
mod memory;
mod operands;
mod public_values;
mod system;

use std::borrow::Borrow;
Expand All @@ -13,7 +14,7 @@ use p3_matrix::Matrix;
use sp1_core::air::BaseAirBuilder;

use crate::{
air::SP1RecursionAirBuilder,
air::{RecursionPublicValues, SP1RecursionAirBuilder, RECURSIVE_PROOF_NUM_PV_ELTS},
cpu::{CpuChip, CpuCols},
memory::MemoryCols,
};
Expand All @@ -27,6 +28,11 @@ where
let (local, next) = (main.row_slice(0), main.row_slice(1));
let local: &CpuCols<AB::Var> = (*local).borrow();
let next: &CpuCols<AB::Var> = (*next).borrow();
let pv = builder.public_values();
let pv_elms: [AB::Expr; RECURSIVE_PROOF_NUM_PV_ELTS] =
core::array::from_fn(|i| pv[i].into());
let public_values: &RecursionPublicValues<AB::Expr> = pv_elms.as_slice().borrow();

let zero = AB::Expr::zero();
let one = AB::Expr::one();

Expand Down Expand Up @@ -74,6 +80,9 @@ where
];
builder.send_table(local.instruction.opcode, &operands, send_syscall);

// Constrain the public values digest.
self.eval_commit(builder, local, public_values.digest.clone());

// Constrain the clk.
self.eval_clk(builder, local, next);

Expand Down Expand Up @@ -175,4 +184,12 @@ impl<F: Field> CpuChip<F> {
+ local.selectors.is_poseidon
+ local.selectors.is_store
}

/// Expr to check for instructions that are commit instructions.
pub fn is_commit_instruction<AB>(&self, local: &CpuCols<AB::Var>) -> AB::Expr
where
AB: SP1RecursionAirBuilder<F = F>,
{
local.selectors.is_commit.into()
}
}
61 changes: 61 additions & 0 deletions recursion/core/src/cpu/air/public_values.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use p3_air::AirBuilder;
use p3_field::{AbstractField, Field};

use crate::{
air::{BlockBuilder, SP1RecursionAirBuilder},
cpu::{CpuChip, CpuCols},
memory::MemoryCols,
runtime::DIGEST_SIZE,
};

impl<F: Field> CpuChip<F> {
/// Eval the COMMIT instructions.
///
/// This method will verify the committed public value.
pub fn eval_commit<AB>(
&self,
builder: &mut AB,
local: &CpuCols<AB::Var>,
commit_digest: [AB::Expr; DIGEST_SIZE],
) where
AB: SP1RecursionAirBuilder<F = F>,
{
let public_values_cols = local.opcode_specific.public_values();
let is_commit_instruction = self.is_commit_instruction::<AB>(local);

// Verify all elements in the index bitmap are bools.
let mut bitmap_sum = AB::Expr::zero();
for bit in public_values_cols.idx_bitmap.iter() {
builder
.when(is_commit_instruction.clone())
.assert_bool(*bit);
bitmap_sum += (*bit).into();
}
// When the instruction is COMMIT there should be exactly one set bit.
builder
.when(is_commit_instruction.clone())
.assert_one(bitmap_sum.clone());

// Verify that idx passed in the b operand corresponds to the set bit in index bitmap.
for (i, bit) in public_values_cols.idx_bitmap.iter().enumerate() {
builder
.when(*bit * is_commit_instruction.clone())
.assert_block_eq(
*local.b.prev_value(),
AB::Expr::from_canonical_u32(i as u32).into(),
);
}

// Calculated the expected public value.
let expected_pv_digest_element =
builder.index_array(&commit_digest, &public_values_cols.idx_bitmap);

// Get the committed public value in the program from operand a.
let digest_element = local.a.prev_value();

// Verify the public value element.
builder
.when(is_commit_instruction.clone())
.assert_block_eq(expected_pv_digest_element.into(), *digest_element);
}
}
1 change: 1 addition & 0 deletions recursion/core/src/cpu/columns/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mod instruction;
mod memory;
mod opcode;
mod opcode_specific;
mod public_values;

pub use instruction::*;
pub use opcode::*;
Expand Down
1 change: 1 addition & 0 deletions recursion/core/src/cpu/columns/opcode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ impl<F: Field> OpcodeSelectorCols<F> {
Opcode::PrintF => self.is_noop = F::one(),
Opcode::PrintE => self.is_noop = F::one(),
Opcode::Commit => self.is_commit = F::one(),
Opcode::RegisterPublicValue => self.is_noop = F::one(),
_ => {}
}

Expand Down
10 changes: 10 additions & 0 deletions recursion/core/src/cpu/columns/opcode_specific.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::mem::{size_of, transmute};

use super::branch::BranchCols;
use super::memory::MemoryCols;
use super::public_values::PublicValuesCols;

pub const NUM_OPCODE_SPECIFIC_COLS: usize = size_of::<OpcodeSpecificCols<u8>>();

Expand All @@ -12,6 +13,7 @@ pub const NUM_OPCODE_SPECIFIC_COLS: usize = size_of::<OpcodeSpecificCols<u8>>();
pub union OpcodeSpecificCols<T: Copy> {
branch: BranchCols<T>,
memory: MemoryCols<T>,
public_values: PublicValuesCols<T>,
}

impl<T: Copy + Default> Default for OpcodeSpecificCols<T> {
Expand Down Expand Up @@ -46,4 +48,12 @@ impl<T: Copy> OpcodeSpecificCols<T> {
pub fn memory_mut(&mut self) -> &mut MemoryCols<T> {
unsafe { &mut self.memory }
}

pub fn public_values(&self) -> &PublicValuesCols<T> {
unsafe { &self.public_values }
}

pub fn public_values_mut(&mut self) -> &mut PublicValuesCols<T> {
unsafe { &mut self.public_values }
}
}
13 changes: 13 additions & 0 deletions recursion/core/src/cpu/columns/public_values.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
use sp1_derive::AlignedBorrow;
use std::mem::size_of;

use crate::runtime::DIGEST_SIZE;

#[allow(dead_code)]
pub const NUM_PUBLIC_VALUES_COLS: usize = size_of::<PublicValuesCols<u8>>();

#[derive(AlignedBorrow, Default, Debug, Clone, Copy)]
#[repr(C)]
pub struct PublicValuesCols<T> {
pub(crate) idx_bitmap: [T; DIGEST_SIZE],
}
7 changes: 7 additions & 0 deletions recursion/core/src/cpu/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ impl<F: PrimeField32 + BinomiallyExtendable<D>> MachineAir<F> for CpuChip<F> {
};
}

// Populate the public values columns.
if event.instruction.opcode == Opcode::Commit {
let public_values_cols = cols.opcode_specific.public_values_mut();
let idx = cols.b.prev_value()[0].as_canonical_u32() as usize;
public_values_cols.idx_bitmap[idx] = F::one();
}

cols.is_real = F::one();
row
})
Expand Down
8 changes: 2 additions & 6 deletions recursion/core/src/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -795,13 +795,9 @@ where
next_clk = timestamp;
(a, b, c) = (a_val, b_val, c_val);
}
Opcode::Commit => {
// For both the Commit and RegisterPublicValue opcodes, we record the public value
Opcode::Commit | Opcode::RegisterPublicValue => {
let (a_val, b_val, c_val) = self.all_rr(&instruction);

// Ensure that writes are in order (index should == public_values.len)
let index = b_val[0].as_canonical_u32() as usize;
debug_assert_eq!(index, self.record.public_values.len());

self.record.public_values.push(a_val[0]);

(a, b, c) = (a_val, b_val, c_val);
Expand Down
5 changes: 3 additions & 2 deletions recursion/core/src/runtime/opcode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ pub enum Opcode {
Hint = 38,
BNEINC = 40,
Commit = 41,
LessThanF = 42,
CycleTracker = 43,
RegisterPublicValue = 42,
LessThanF = 43,
CycleTracker = 44,
}

impl Opcode {
Expand Down
10 changes: 5 additions & 5 deletions recursion/program/src/machine/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::utils::{
get_challenger_public_values, hash_vkey,
};

use super::utils::proof_data_from_vk;
use super::utils::{commit_public_values, proof_data_from_vk, verify_public_values_hash};

/// A program to verify a batch of recursive proofs and aggregate their public values.
#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -251,6 +251,9 @@ where
let current_public_values: &RecursionPublicValues<Felt<C::F>> =
current_public_values_elements.as_slice().borrow();

// Check that the public values digest is correct.
verify_public_values_hash(builder, current_public_values);

// If the proof is the first proof, initialize the values.
builder.if_eq(i, C::N::zero()).then(|builder| {
// Initialize global and accumulated values.
Expand Down Expand Up @@ -468,10 +471,7 @@ where
},
);

// Commit the public values.
for value in reduce_public_values_stream {
builder.commit_public_value(value);
}
commit_public_values(builder, reduce_public_values);

builder.halt();
}
Expand Down
7 changes: 2 additions & 5 deletions recursion/program/src/machine/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::types::ShardProofVariable;
use crate::types::VerifyingKeyVariable;
use crate::utils::{const_fri_config, felt2var, get_challenger_public_values, hash_vkey, var2felt};

use super::utils::assert_complete;
use super::utils::{assert_complete, commit_public_values};

/// A program for recursively verifying a batch of SP1 proofs.
#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -317,10 +317,7 @@ where
assert_complete(builder, recursion_public_values, &reconstruct_challenger)
});

// Commit to the public values.
for value in recursion_public_values_stream {
builder.commit_public_value(value);
}
commit_public_values(builder, recursion_public_values);

builder.halt();
}
Expand Down
Loading
Loading