Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Allen's Poseidon2 fixes #1099

Merged
merged 13 commits into from
Jul 17, 2024
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ use sp1_recursion_program::machine::{
pub use sp1_recursion_program::machine::{
SP1DeferredMemoryLayout, SP1RecursionMemoryLayout, SP1ReduceMemoryLayout, SP1RootMemoryLayout,
};
use tracing::{info_span, instrument};
use tracing::instrument;
pub use types::*;
use utils::words_to_bytes;

Expand Down
4 changes: 2 additions & 2 deletions recursion/compiler/src/asm/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,12 +517,12 @@ impl<F: PrimeField32 + TwoAdicField, EF: ExtensionField<F> + TwoAdicField> AsmCo
_ => unimplemented!(),
}
}
DslIr::Poseidon2AbsorbBabyBear(p2_hash_num, input) => match input {
DslIr::Poseidon2AbsorbBabyBear(p2_hash_and_absorb_num, input) => match input {
Array::Dyn(input, input_size) => {
if let Usize::Var(input_size) = input_size {
self.push(
AsmInstruction::Poseidon2Absorb(
p2_hash_num.fp(),
p2_hash_and_absorb_num.fp(),
input.fp(),
input_size.fp(),
),
Expand Down
30 changes: 16 additions & 14 deletions recursion/compiler/src/asm/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -854,17 +854,19 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
false,
"".to_string(),
),
AsmInstruction::Poseidon2Absorb(hash_num, input_ptr, input_len) => Instruction::new(
Opcode::Poseidon2Absorb,
i32_f(hash_num),
i32_f_arr(input_ptr),
i32_f_arr(input_len),
F::zero(),
F::zero(),
false,
false,
"".to_string(),
),
AsmInstruction::Poseidon2Absorb(hash_and_absorb_num, input_ptr, input_len) => {
Instruction::new(
Opcode::Poseidon2Absorb,
i32_f(hash_and_absorb_num),
i32_f_arr(input_ptr),
i32_f_arr(input_len),
F::zero(),
F::zero(),
false,
false,
"".to_string(),
)
}
AsmInstruction::Poseidon2Finalize(hash_num, output_ptr) => Instruction::new(
Opcode::Poseidon2Finalize,
i32_f(hash_num),
Expand Down Expand Up @@ -1174,15 +1176,15 @@ impl<F: PrimeField32, EF: ExtensionField<F>> AsmInstruction<F, EF> {
result, src1, src2
)
}
AsmInstruction::Poseidon2Absorb(hash_num, input_ptr, input_len) => {
AsmInstruction::Poseidon2Absorb(hash_and_absorb_num, input_ptr, input_len) => {
write!(
f,
"poseidon2_absorb ({})fp, {})fp, ({})fp",
hash_num, input_ptr, input_len,
hash_and_absorb_num, input_ptr, input_len,
)
}
AsmInstruction::Poseidon2Finalize(hash_num, output_ptr) => {
write!(f, "poseidon2_finalize ({})fp, {})fp", hash_num, output_ptr,)
write!(f, "poseidon2_finalize ({})fp, ({})fp", hash_num, output_ptr,)
}
AsmInstruction::Commit(val, index) => {
write!(f, "commit ({})fp ({})fp", val, index)
Expand Down
18 changes: 14 additions & 4 deletions recursion/compiler/src/ir/poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,15 @@ impl<C: Config> Builder<C> {
/// Applies the Poseidon2 absorb function to the given array.
///
/// Reference: [p3_symmetric::PaddingFreeSponge]
pub fn poseidon2_absorb(&mut self, p2_hash_num: Var<C::N>, input: &Array<C, Felt<C::F>>) {
self.operations
.push(DslIr::Poseidon2AbsorbBabyBear(p2_hash_num, input.clone()));
pub fn poseidon2_absorb(
&mut self,
p2_hash_and_absorb_num: Var<C::N>,
input: &Array<C, Felt<C::F>>,
) {
self.operations.push(DslIr::Poseidon2AbsorbBabyBear(
p2_hash_and_absorb_num,
input.clone(),
));
}

/// Applies the Poseidon2 finalize to the given hash number.
Expand Down Expand Up @@ -128,9 +134,13 @@ impl<C: Config> Builder<C> {
self.cycle_tracker("poseidon2-hash");

let p2_hash_num = self.p2_hash_num;
let two_power_12: Var<_> = self.eval(C::N::from_canonical_u32(1 << 12));

self.range(0, array.len()).for_each(|i, builder| {
let subarray = builder.get(array, i);
builder.poseidon2_absorb(p2_hash_num, &subarray);
let p2_hash_and_absorb_num: Var<_> = builder.eval(p2_hash_num * two_power_12 + i);

builder.poseidon2_absorb(p2_hash_and_absorb_num, &subarray);
});

let output: Array<C, Felt<C::F>> = self.dyn_array(DIGEST_SIZE);
Expand Down
122 changes: 96 additions & 26 deletions recursion/core/src/poseidon2_wide/air/control_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
local_row.syscall_params(),
send_range_check,
);

builder
.when(local_control_flow.is_syscall_row)
.assert_one(local_is_real);
}

/// This function will verify that all hash rows are before the compress rows and that the first
Expand All @@ -80,47 +84,67 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
local_is_real: AB::Expr,
next_is_real: AB::Expr,
) {
// We require that the first row is an absorb syscall and that the hash_num == 0.
// We require that the first row is an absorb syscall and that the hash_num == 0 and absorb_num == 0.
let mut first_row_builder = builder.when_first_row();
first_row_builder.assert_one(local_control_flow.is_absorb);
first_row_builder.assert_one(local_control_flow.is_syscall_row);
first_row_builder.assert_zero(local_syscall_params.absorb().hash_num);
first_row_builder.assert_zero(local_opcode_workspace.absorb().hash_num);
first_row_builder.assert_zero(local_opcode_workspace.absorb().absorb_num);
first_row_builder.assert_one(local_opcode_workspace.absorb().is_first_hash_row);

let mut transition_builder = builder.when_transition();

// For absorb rows, constrain the following:
// 1) next row is either an absorb or syscall finalize.
// 2) when last absorb row, then the next row is a syscall row.
// 2) hash_num == hash_num'.
// 1) when last absorb row, then the next row is a either an absorb or finalize syscall row.
// 2) when last absorb row and the next row is an absorb row, then absorb_num' = absorb_num + 1.
// 3) when not last absorb row, then the next row is an absorb non syscall row.
// 4) when not last absorb row, then absorb_num' = absorb_num.
// 5) hash_num == hash_num'.
{
let mut absorb_transition_builder =
transition_builder.when(local_control_flow.is_absorb);
absorb_transition_builder
let mut transition_builder = builder.when_transition();

let mut absorb_last_row_builder =
transition_builder.when(local_control_flow.is_absorb_last_row);
absorb_last_row_builder
.assert_one(next_control_flow.is_absorb + next_control_flow.is_finalize);
absorb_transition_builder
.when(local_opcode_workspace.absorb().is_last_row::<AB>())
.assert_one(next_control_flow.is_syscall_row);
absorb_last_row_builder.assert_one(next_control_flow.is_syscall_row);
absorb_last_row_builder
.when(next_control_flow.is_absorb)
.assert_eq(
next_opcode_workspace.absorb().absorb_num,
local_opcode_workspace.absorb().absorb_num + AB::Expr::one(),
);

let mut absorb_not_last_row_builder =
transition_builder.when(local_control_flow.is_absorb_not_last_row);
absorb_not_last_row_builder.assert_one(next_control_flow.is_absorb);
absorb_not_last_row_builder.assert_zero(next_control_flow.is_syscall_row);
absorb_not_last_row_builder.assert_eq(
local_opcode_workspace.absorb().absorb_num,
next_opcode_workspace.absorb().absorb_num,
);

let mut absorb_transition_builder =
transition_builder.when(local_control_flow.is_absorb);
absorb_transition_builder
.when(next_control_flow.is_absorb)
.assert_eq(
local_syscall_params.absorb().hash_num,
next_syscall_params.absorb().hash_num,
local_opcode_workspace.absorb().hash_num,
next_opcode_workspace.absorb().hash_num,
);
absorb_transition_builder
.when(next_control_flow.is_finalize)
.assert_eq(
local_syscall_params.absorb().hash_num,
local_opcode_workspace.absorb().hash_num,
next_syscall_params.finalize().hash_num,
);
}

// For finalize rows, constrain the following:
// 1) next row is syscall compress or syscall absorb.
// 2) if next row is absorb -> hash_num + 1 == hash_num'
// 3) if next row is absorb -> is_first_hash' == true
// 3) if next row is absorb -> absorb_num' == 0
// 4) if next row is absorb -> is_first_hash' == true
{
let mut transition_builder = builder.when_transition();
let mut finalize_transition_builder =
transition_builder.when(local_control_flow.is_finalize);

Expand All @@ -132,8 +156,11 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
.when(next_control_flow.is_absorb)
.assert_eq(
local_syscall_params.finalize().hash_num + AB::Expr::one(),
next_syscall_params.absorb().hash_num,
next_opcode_workspace.absorb().hash_num,
);
finalize_transition_builder
.when(next_control_flow.is_absorb)
.assert_zero(next_opcode_workspace.absorb().absorb_num);
finalize_transition_builder
.when(next_control_flow.is_absorb)
.assert_one(next_opcode_workspace.absorb().is_first_hash_row);
Expand All @@ -143,26 +170,33 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
// 1) if compress syscall -> next row is a compress output
// 2) if compress output -> next row is a compress syscall or not real
{
builder.assert_eq(
local_control_flow.is_compress_output,
local_control_flow.is_compress
* (AB::Expr::one() - local_control_flow.is_syscall_row),
);

let mut transition_builder = builder.when_transition();

transition_builder
.when(local_control_flow.is_compress)
.when(local_control_flow.is_syscall_row)
.assert_one(next_control_flow.is_compress_output);

// When we are at a compress output row, then ensure next row is either not real or is a compress syscall row.
transition_builder
.when(local_control_flow.is_compress_output)
.assert_one(
next_control_flow.is_compress + (AB::Expr::one() - next_is_real.clone()),
(AB::Expr::one() - next_is_real.clone())
+ next_control_flow.is_compress * next_control_flow.is_syscall_row,
);

transition_builder
.when(local_control_flow.is_compress_output)
.when(next_control_flow.is_compress)
.assert_one(next_control_flow.is_syscall_row);
}

// Constrain that there is only one is_real -> not is real transition. Also contrain that
// the last real row is a compress output row.
{
let mut transition_builder = builder.when_transition();

transition_builder
.when_not(local_is_real.clone())
.assert_zero(next_is_real.clone());
Expand Down Expand Up @@ -194,6 +228,29 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
let last_row_ending_cursor_is_seven =
local_hash_workspace.last_row_ending_cursor_is_seven.result;

// Verify that the hash_num and absorb_num are correctly decomposed from the syscall
// hash_and_absorb_num param.
// Also range check that both hash_num is within [0, 2^16 - 1] and absorb_num is within [0, 2^12 - 1];
{
let mut absorb_builder = builder.when(local_control_flow.is_absorb);

absorb_builder.assert_eq(
local_hash_workspace.hash_num * AB::Expr::from_canonical_u32(1 << 12)
+ local_hash_workspace.absorb_num,
local_syscall_params.absorb().hash_and_absorb_num,
);
builder.send_range_check(
AB::Expr::from_canonical_u8(RangeCheckOpcode::U16 as u8),
local_hash_workspace.hash_num,
send_range_check,
);
builder.send_range_check(
AB::Expr::from_canonical_u8(RangeCheckOpcode::U12 as u8),
local_hash_workspace.absorb_num,
send_range_check,
);
}

// Constrain the materialized control flow flags.
{
let mut absorb_builder = builder.when(local_control_flow.is_absorb);
Expand Down Expand Up @@ -232,12 +289,16 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
local_control_flow.is_absorb
* (AB::Expr::one() - local_hash_workspace.is_last_row::<AB>()),
);
builder.assert_eq(
local_control_flow.is_absorb_last_row,
local_control_flow.is_absorb * local_hash_workspace.is_last_row::<AB>(),
);

builder.assert_eq(
local_control_flow.is_absorb_no_perm,
local_control_flow.is_absorb
* (AB::Expr::one() - local_hash_workspace.do_perm::<AB>()),
)
);
}

// For the absorb syscall row, ensure correct value of num_remaining_rows, last_row_num_consumed,
Expand Down Expand Up @@ -274,7 +335,16 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
expected_last_row_ending_cursor,
);

// Range check that num_remaining_rows is between [0, 2^18-1].
// Range check that input_len < 2^16. This check is only needed for absorb syscall rows,
// but we send it for all absorb rows, since the `is_real` parameter must be an expression
// with at most degree 1.
builder.send_range_check(
AB::Expr::from_canonical_u8(RangeCheckOpcode::U16 as u8),
local_syscall_params.absorb().input_len,
send_range_check,
);

// Range check that num_remaining_rows is between [0, 2^16-1].
builder.send_range_check(
AB::Expr::from_canonical_u8(RangeCheckOpcode::U16 as u8),
local_hash_workspace.num_remaining_rows,
Expand Down
11 changes: 10 additions & 1 deletion recursion/core/src/poseidon2_wide/air/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,21 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
}

// Verify that all elements of start_mem_idx_bitmap and end_mem_idx_bitmap are bool.
// Also verify that exactly one of the bits in start_mem_idx_bitmap and end_mem_idx_bitmap
// is one.
let mut start_mem_idx_bitmap_sum = AB::Expr::zero();
start_mem_idx_bitmap.iter().for_each(|bit| {
absorb_builder.assert_bool(*bit);
start_mem_idx_bitmap_sum += (*bit).into();
});
absorb_builder.assert_one(start_mem_idx_bitmap_sum);

let mut end_mem_idx_bitmap_sum = AB::Expr::zero();
end_mem_idx_bitmap.iter().for_each(|bit| {
absorb_builder.assert_bool(*bit);
end_mem_idx_bitmap_sum += (*bit).into();
});
absorb_builder.assert_one(end_mem_idx_bitmap_sum);

// Verify correct value of start_mem_idx_bitmap and end_mem_idx_bitmap.
let start_mem_idx: AB::Expr = start_mem_idx_bitmap
Expand All @@ -209,7 +218,7 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
// When we are not in the last row, end_mem_idx should be zero.
absorb_builder
.when_not(opcode_workspace.absorb().is_last_row::<AB>())
.assert_zero(end_mem_idx.clone());
.assert_zero(end_mem_idx.clone() - AB::Expr::from_canonical_usize(7));

// When we are in the last row, end_mem_idx bitmap should equal last_row_ending_cursor.
absorb_builder
Expand Down
4 changes: 2 additions & 2 deletions recursion/core/src/poseidon2_wide/air/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! # Layout of the poseidon2 chip:
//!
//! All the hash related rows should be in the first part of the chip and all the compress
//! related rows in the second part. E.g. the chip should has this format:
//! related rows in the second part. E.g. the chip should have this format:
//!
//! absorb row (for hash num 1)
//! absorb row (for hash num 1)
Expand Down Expand Up @@ -34,7 +34,7 @@
//! last_row_ending_cursor will be copied down to all of the rows. Also, for the next absorb/finalize
//! syscall, its state_cursor is set to (last_row_ending_cursor + 1) % RATE.
//!
//! From num_remaining_rows and syscall column, we know the absorb 's first row and last row.
//! From num_remaining_rows and syscall column, we know the absorb's first row and last row.
//! From that fact, we can then enforce the following state writes.
//!
//! 1. is_first_row && is_last_row -> state writes are [state_cursor..state_cursor + last_row_ending_cursor]
Expand Down
6 changes: 4 additions & 2 deletions recursion/core/src/poseidon2_wide/air/syscall_params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ impl<const DEGREE: usize> Poseidon2WideChip<DEGREE> {
let next_syscall_params = next_syscall.absorb();

absorb_syscall_builder.assert_eq(local_syscall_params.clk, next_syscall_params.clk);
absorb_syscall_builder
.assert_eq(local_syscall_params.hash_num, next_syscall_params.hash_num);
absorb_syscall_builder.assert_eq(
local_syscall_params.hash_and_absorb_num,
next_syscall_params.hash_and_absorb_num,
);
absorb_syscall_builder.assert_eq(
local_syscall_params.input_ptr,
next_syscall_params.input_ptr,
Expand Down
Loading
Loading