Skip to content

Commit

Permalink
test(proc): add dyn node test
Browse files Browse the repository at this point in the history
  • Loading branch information
grjte committed Sep 4, 2023
1 parent 5e57ddc commit 61bf464
Showing 1 changed file with 190 additions and 0 deletions.
190 changes: 190 additions & 0 deletions processor/src/decoder/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1464,6 +1464,166 @@ fn syscall_block() {
assert_eq!(expected_rows, aux_hints.block_hash_table_rows());
}

// DYN BLOCK TESTS
// ================================================================================================
#[test]
fn dyn_block() {
// build a dynamic block which looks like this:
// push.1 add

let foo_root = CodeBlock::new_span(vec![Operation::Push(ONE), Operation::Add]);
let mul_span = CodeBlock::new_span(vec![Operation::Mul]);
let save_span = CodeBlock::new_span(vec![Operation::MovDn4]);
let join = CodeBlock::new_join([mul_span.clone(), save_span.clone()]);
// This dyn will point to foo.
let dyn_block = CodeBlock::new_dyn();
let program = CodeBlock::new_join([join.clone(), dyn_block.clone()]);

let (trace, aux_hints, trace_len) = build_dyn_trace(
&[
foo_root.hash()[0].as_int(),
foo_root.hash()[1].as_int(),
foo_root.hash()[2].as_int(),
foo_root.hash()[3].as_int(),
2,
4,
],
&program,
foo_root.clone(),
);

// --- check block address, op_bits, group count, op_index, and in_span columns ---------------
check_op_decoding(&trace, 0, ZERO, Operation::Join, 0, 0, 0);
// starting inner join
let join_addr = INIT_ADDR + EIGHT;
check_op_decoding(&trace, 1, INIT_ADDR, Operation::Join, 0, 0, 0);
// starting first span
let mul_span_addr = join_addr + EIGHT;
check_op_decoding(&trace, 2, join_addr, Operation::Span, 1, 0, 0);
check_op_decoding(&trace, 3, mul_span_addr, Operation::Mul, 0, 0, 1);
check_op_decoding(&trace, 4, mul_span_addr, Operation::End, 0, 0, 0);
// starting second span
let save_span_addr = mul_span_addr + EIGHT;
check_op_decoding(&trace, 5, join_addr, Operation::Span, 1, 0, 0);
check_op_decoding(&trace, 6, save_span_addr, Operation::MovDn4, 0, 0, 1);
check_op_decoding(&trace, 7, save_span_addr, Operation::End, 0, 0, 0);
// end inner join
check_op_decoding(&trace, 8, join_addr, Operation::End, 0, 0, 0);
// dyn
check_op_decoding(&trace, 9, INIT_ADDR, Operation::Dyn, 0, 0, 0);
// starting foo span
let dyn_addr = save_span_addr + EIGHT;
let add_span_addr = dyn_addr + EIGHT;
check_op_decoding(&trace, 10, dyn_addr, Operation::Span, 2, 0, 0);
check_op_decoding(&trace, 11, add_span_addr, Operation::Push(ONE), 1, 0, 1);
check_op_decoding(&trace, 12, add_span_addr, Operation::Add, 0, 1, 1);
check_op_decoding(&trace, 13, add_span_addr, Operation::End, 0, 0, 0);
// end dyn
check_op_decoding(&trace, 14, dyn_addr, Operation::End, 0, 0, 0);
// end outer join
check_op_decoding(&trace, 15, INIT_ADDR, Operation::End, 0, 0, 0);

// --- check hasher state columns -------------------------------------------------------------

// in the first row, the hasher state is set to hashes of both child nodes
let join_hash: Word = join.hash().into();
let dyn_hash: Word = dyn_block.hash().into();
assert_eq!(join_hash, get_hasher_state1(&trace, 0));
assert_eq!(dyn_hash, get_hasher_state2(&trace, 0));

// in the second row, the hasher set is set to hashes of both child nodes of the inner JOIN
let mul_span_hash: Word = mul_span.hash().into();
let save_span_hash: Word = save_span.hash().into();
assert_eq!(mul_span_hash, get_hasher_state1(&trace, 1));
assert_eq!(save_span_hash, get_hasher_state2(&trace, 1));

// at the end of the first SPAN, the hasher state is set to the hash of the first child
assert_eq!(mul_span_hash, get_hasher_state1(&trace, 4));
assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 4));

// at the end of the second SPAN, the hasher state is set to the hash of the second child
assert_eq!(save_span_hash, get_hasher_state1(&trace, 7));
assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 7));

// at the end of the inner JOIN, the hasher set is set to the hash of the JOIN
assert_eq!(join_hash, get_hasher_state1(&trace, 8));
assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 8));

// at the start of the DYN block, the hasher state is set to the hash of its child (foo span)
let foo_hash: Word = foo_root.hash().into();
assert_eq!(foo_hash, get_hasher_state1(&trace, 9));
assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 9));

// at the end of the DYN SPAN, the hasher state is set to the hash of the foo span
assert_eq!(foo_hash, get_hasher_state1(&trace, 13));
assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 13));

// at the end of the DYN block, the hasher state is set to the hash of the DYN node
assert_eq!(dyn_hash, get_hasher_state1(&trace, 14));

// at the end of the program, the hasher state is set to the hash of the entire program
let program_hash: Word = program.hash().into();
assert_eq!(program_hash, get_hasher_state1(&trace, 15));
assert_eq!([ZERO, ZERO, ZERO, ZERO], get_hasher_state2(&trace, 15));

// the HALT opcode and program hash get propagated to the last row
for i in 16..trace_len {
assert!(contains_op(&trace, i, Operation::Halt));
assert_eq!(ZERO, trace[OP_BITS_EXTRA_COLS_RANGE.start][i]);
assert_eq!(ONE, trace[OP_BITS_EXTRA_COLS_RANGE.start + 1][i]);
assert_eq!(program_hash, get_hasher_state1(&trace, i));
}

// --- check op_group table hints -------------------------------------------------------------
// 1 op group should be inserted at cycle 10, and removed in the subsequent cycle
let expected_ogt_hints =
vec![(10, OpGroupTableUpdate::InsertRows(1)), (11, OpGroupTableUpdate::RemoveRow)];
assert_eq!(&expected_ogt_hints, aux_hints.op_group_table_hints());

// the group is an op group with a single ADD
let expected_ogt_rows = vec![OpGroupTableRow::new(add_span_addr, ONE, ONE)];
assert_eq!(expected_ogt_rows, aux_hints.op_group_table_rows());

// --- check block execution hints ------------------------------------------------------------
let expected_hints = vec![
(0, BlockTableUpdate::BlockStarted(2)), // outer join start
(1, BlockTableUpdate::BlockStarted(2)), // inner join start
(2, BlockTableUpdate::BlockStarted(0)), // mul span start
(4, BlockTableUpdate::BlockEnded(true)), // mul span end
(5, BlockTableUpdate::BlockStarted(0)), // save span start
(7, BlockTableUpdate::BlockEnded(false)), // save span end
(8, BlockTableUpdate::BlockEnded(true)), // inner join end
(9, BlockTableUpdate::BlockStarted(0)), // dyn start
(10, BlockTableUpdate::BlockStarted(0)), // foo span start
(13, BlockTableUpdate::BlockEnded(false)), // foo span end
(14, BlockTableUpdate::BlockEnded(false)), // dyn end
(15, BlockTableUpdate::BlockEnded(false)), // outer join end
];
assert_eq!(expected_hints, aux_hints.block_exec_hints());

// --- check block stack table hints ----------------------------------------------------------
let expected_rows = vec![
BlockStackTableRow::new_test(INIT_ADDR, ZERO, false), // join
BlockStackTableRow::new_test(join_addr, INIT_ADDR, false), // inner join
BlockStackTableRow::new_test(mul_span_addr, join_addr, false), // mul span
BlockStackTableRow::new_test(save_span_addr, join_addr, false), // save span
BlockStackTableRow::new_test(dyn_addr, INIT_ADDR, false), // dyn
BlockStackTableRow::new_test(add_span_addr, dyn_addr, false), // foo span
];
assert_eq!(expected_rows, aux_hints.block_stack_table_rows());

// --- check block hash table hints ----------------------------------------------------------
let expected_rows = vec![
BlockHashTableRow::from_program_hash(program_hash),
BlockHashTableRow::new_test(INIT_ADDR, join_hash, true, false),
BlockHashTableRow::new_test(INIT_ADDR, dyn_hash, false, false),
BlockHashTableRow::new_test(join_addr, mul_span_hash, true, false),
BlockHashTableRow::new_test(join_addr, save_span_hash, false, false),
BlockHashTableRow::new_test(dyn_addr, foo_hash, false, false),
];
assert_eq!(expected_rows, aux_hints.block_hash_table_rows());
}

// HELPER REGISTERS TESTS
// ================================================================================================
#[test]
Expand Down Expand Up @@ -1517,6 +1677,35 @@ fn build_trace(stack_inputs: &[u64], program: &CodeBlock) -> (DecoderTrace, AuxT
)
}

fn build_dyn_trace(
stack_inputs: &[u64],
program: &CodeBlock,
fn_block: CodeBlock,
) -> (DecoderTrace, AuxTraceHints, usize) {
let stack_inputs = StackInputs::try_from_values(stack_inputs.iter().copied()).unwrap();
let advice_provider = MemAdviceProvider::default();
let mut process =
Process::new(Kernel::default(), stack_inputs, advice_provider, ExecutionOptions::default());

// build code block table
let mut cb_table = CodeBlockTable::default();
cb_table.insert(fn_block);

process.execute_code_block(program, &cb_table).unwrap();

let (trace, aux_hints, _) = ExecutionTrace::test_finalize_trace(process);
let trace_len = get_trace_len(&trace) - ExecutionTrace::NUM_RAND_ROWS;

(
trace[DECODER_TRACE_RANGE]
.to_vec()
.try_into()
.expect("failed to convert vector to array"),
aux_hints.decoder,
trace_len,
)
}

fn build_call_trace(
program: &CodeBlock,
fn_block: CodeBlock,
Expand Down Expand Up @@ -1570,6 +1759,7 @@ fn check_op_decoding(
) {
let opcode = read_opcode(trace, row_idx);

// TODO: restore this
assert_eq!(trace[ADDR_COL_IDX][row_idx], addr);
assert_eq!(op.op_code(), opcode);
assert_eq!(trace[IN_SPAN_COL_IDX][row_idx], Felt::new(in_span));
Expand Down

0 comments on commit 61bf464

Please sign in to comment.