Skip to content

Commit

Permalink
fix(avm): fix tests under proving (#8640)
Browse files Browse the repository at this point in the history
There was a bug in `commit_sparse` which broke one of the tests but
Lucas fixed it.
See
https://aztecprotocol.slack.com/archives/C04DL2L1UP2/p1726738000560929?thread_ts=1726728397.210449&cid=C04DL2L1UP2

This PR also fixes the other tests that were failing, and re-enables the
bb-prover test.

---------

Co-authored-by: Maddiaa0 <47148561+Maddiaa0@users.noreply.github.com>
  • Loading branch information
fcarreiro and Maddiaa0 authored Sep 20, 2024
1 parent 4f69412 commit 8bfc769
Show file tree
Hide file tree
Showing 12 changed files with 82 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ namespace bb {

AvmCircuitBuilder::ProverPolynomials AvmCircuitBuilder::compute_polynomials() const
{
const size_t num_rows = get_num_gates();
const size_t circuit_subgroup_size = get_circuit_subgroup_size();
// FIXME: Either some algo or the Polynomial class seems to require this to be a power of 2.
const size_t num_rows = numeric::round_up_power_2(get_num_gates());
ASSERT(num_rows <= circuit_subgroup_size);
ProverPolynomials polys;

// Allocate mem for each column
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ void AvmProver::execute_log_derivative_inverse_commitments_round()
{
// Commit to all logderivative inverse polynomials
for (auto [commitment, key_poly] : zip_view(witness_commitments.get_derived(), key->get_derived())) {
// We don't use commit_sparse here because the logderivative inverse polynomials are dense
commitment = commitment_key->commit(key_poly);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ std::vector<ThreeOpParamRow> gen_three_op_params(std::vector<ThreeOpParam> opera
*/
void validate_trace_check_circuit(std::vector<Row>&& trace)
{
validate_trace(std::move(trace), {}, {}, {}, false);
auto circuit_builder = AvmCircuitBuilder();
circuit_builder.set_trace(std::move(trace));
EXPECT_TRUE(circuit_builder.check_circuit());
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ TEST_F(AvmMemOpcodeTests, allDirectCMovA)
compute_cmov_indices(0);
common_cmov_trace_validate(
false, 1979, 1980, 987162, 10, 11, 12, 20, AvmMemoryTag::U16, AvmMemoryTag::U128, AvmMemoryTag::U64);
validate_trace_check_circuit(std::move(trace));
validate_trace(std::move(trace), public_inputs);
}

TEST_F(AvmMemOpcodeTests, allDirectCMovB)
Expand All @@ -442,7 +442,7 @@ TEST_F(AvmMemOpcodeTests, allDirectCMovB)
compute_cmov_indices(0);
common_cmov_trace_validate(
false, 1979, 1980, 0, 10, 11, 12, 20, AvmMemoryTag::U8, AvmMemoryTag::U8, AvmMemoryTag::U64);
validate_trace_check_circuit(std::move(trace));
validate_trace(std::move(trace), public_inputs);
}

TEST_F(AvmMemOpcodeTests, allDirectCMovConditionUninitialized)
Expand All @@ -459,7 +459,7 @@ TEST_F(AvmMemOpcodeTests, allDirectCMovConditionUninitialized)
compute_cmov_indices(0);
common_cmov_trace_validate(
false, 1979, 1980, 0, 10, 11, 12, 20, AvmMemoryTag::U8, AvmMemoryTag::U8, AvmMemoryTag::U0);
validate_trace_check_circuit(std::move(trace));
validate_trace(std::move(trace), public_inputs);
}

TEST_F(AvmMemOpcodeTests, allDirectCMovOverwriteA)
Expand All @@ -475,7 +475,7 @@ TEST_F(AvmMemOpcodeTests, allDirectCMovOverwriteA)
compute_cmov_indices(0);
common_cmov_trace_validate(
false, 1979, 1980, 0, 10, 11, 10, 20, AvmMemoryTag::U8, AvmMemoryTag::U8, AvmMemoryTag::U64);
validate_trace_check_circuit(std::move(trace));
validate_trace(std::move(trace), public_inputs);
}

TEST_F(AvmMemOpcodeTests, allIndirectCMovA)
Expand All @@ -502,7 +502,7 @@ TEST_F(AvmMemOpcodeTests, allIndirectCMovA)
compute_cmov_indices(15);
common_cmov_trace_validate(
true, 1979, 1980, 987162, 10, 11, 12, 20, AvmMemoryTag::U16, AvmMemoryTag::U128, AvmMemoryTag::U64);
validate_trace_check_circuit(std::move(trace));
validate_trace(std::move(trace), public_inputs);
}

TEST_F(AvmMemOpcodeTests, allIndirectCMovAllUnitialized)
Expand All @@ -513,7 +513,7 @@ TEST_F(AvmMemOpcodeTests, allIndirectCMovAllUnitialized)

compute_cmov_indices(15);
common_cmov_trace_validate(true, 0, 0, 0, 0, 0, 0, 0, AvmMemoryTag::U0, AvmMemoryTag::U0, AvmMemoryTag::U0);
validate_trace_check_circuit(std::move(trace));
validate_trace(std::move(trace), public_inputs);
}

/******************************************************************************
Expand Down
28 changes: 11 additions & 17 deletions barretenberg/cpp/src/barretenberg/vm/avm/tests/slice.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "barretenberg/vm/avm/trace/common.hpp"
#include "common.test.hpp"
#include "gtest/gtest.h"
#include <cstddef>
#include <gmock/gmock.h>
#include <gtest/gtest.h>

Expand Down Expand Up @@ -43,9 +44,9 @@ class AvmSliceTests : public ::testing::Test {
}

gen_trace_builder(calldata);
trace_builder.op_set(0, col_offset, 10000, AvmMemoryTag::U32);
trace_builder.op_set(0, copy_size, 10001, AvmMemoryTag::U32);
trace_builder.op_calldata_copy(static_cast<uint8_t>(indirect), 10000, 10001, dst_offset);
trace_builder.op_set(0, col_offset, 0, AvmMemoryTag::U32);
trace_builder.op_set(0, copy_size, 1, AvmMemoryTag::U32);
trace_builder.op_calldata_copy(static_cast<uint8_t>(indirect), 0, 1, dst_offset);
trace_builder.op_return(0, 0, 0);
trace = trace_builder.finalize();
}
Expand Down Expand Up @@ -138,8 +139,9 @@ TEST_F(AvmSliceTests, singleCopyCDElement)

TEST_F(AvmSliceTests, longCopyAllCDValues)
{
gen_single_calldata_copy(false, 2000, 0, 2000, 873);
validate_single_calldata_copy_trace(0, 2000, 873);
const size_t cd_size = 2000;
gen_single_calldata_copy(false, cd_size, 0, cd_size, 20);
validate_single_calldata_copy_trace(0, cd_size, 20);
}

TEST_F(AvmSliceTests, copyFirstHalfCDValues)
Expand All @@ -162,9 +164,7 @@ TEST_F(AvmSliceTests, copyToHighestMemOffset)

TEST_F(AvmSliceTests, twoCallsNoOverlap)
{
calldata = { 2, 3, 4, 5, 6 };

gen_trace_builder(calldata);
gen_trace_builder({ 2, 3, 4, 5, 6 });
trace_builder.op_set(0, 2, 1, AvmMemoryTag::U32);
trace_builder.op_calldata_copy(0, 0, 1, 34);
trace_builder.op_set(0, 3, 1, AvmMemoryTag::U32);
Expand Down Expand Up @@ -199,9 +199,7 @@ TEST_F(AvmSliceTests, twoCallsNoOverlap)

TEST_F(AvmSliceTests, indirectTwoCallsOverlap)
{
calldata = { 2, 3, 4, 5, 6 };

gen_trace_builder(calldata);
gen_trace_builder({ 2, 3, 4, 5, 6 });
trace_builder.op_set(0, 34, 100, AvmMemoryTag::U32); // indirect address 100 resolves to 34
trace_builder.op_set(0, 2123, 101, AvmMemoryTag::U32); // indirect address 101 resolves to 2123
trace_builder.op_set(0, 1, 1, AvmMemoryTag::U32);
Expand Down Expand Up @@ -242,9 +240,7 @@ TEST_F(AvmSliceTests, indirectTwoCallsOverlap)

TEST_F(AvmSliceTests, indirectFailedResolution)
{
calldata = { 2, 3, 4, 5, 6 };

gen_trace_builder(calldata);
gen_trace_builder({ 2, 3, 4, 5, 6 });
trace_builder.op_set(0, 34, 100, AvmMemoryTag::U16); // indirect address 100 resolves to 34
trace_builder.op_set(0, 1, 1, AvmMemoryTag::U32);
trace_builder.op_set(0, 3, 3, AvmMemoryTag::U32);
Expand Down Expand Up @@ -318,9 +314,7 @@ TEST_F(AvmSliceNegativeTests, wrongCDValueInCalldataColumn)

TEST_F(AvmSliceNegativeTests, wrongCDValueInCalldataVerifier)
{
calldata = { 2, 3, 4, 5, 6 };

gen_trace_builder(calldata);
gen_trace_builder({ 2, 3, 4, 5, 6 });
trace_builder.op_calldata_copy(0, 1, 3, 100);
trace_builder.op_return(0, 0, 0);
trace = trace_builder.finalize();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ void show_trace_info(const auto& trace)
}());

// The following computations are expensive, so we only do them in verbose mode.
if (verbose_logging) {
if (!verbose_logging) {
return;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class AvmKernelTraceBuilder {
{}

void reset();
size_t size() const { return kernel_trace.size(); }
void finalize(std::vector<AvmFullRow<FF>>& main_trace);
void finalize_columns(std::vector<AvmFullRow<FF>>& main_trace) const;

Expand Down
23 changes: 17 additions & 6 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3555,15 +3555,16 @@ std::vector<Row> AvmTraceBuilder::finalize()
size_t bin_trace_size = bin_trace_builder.size();
size_t gas_trace_size = gas_trace_builder.size();
size_t slice_trace_size = slice_trace.size();
size_t kernel_trace_size = kernel_trace_builder.size();

// Range check size is 1 less than it needs to be since we insert a "first row" at the top of the trace at the
// end, with clk 0 (this doubles as our range check)
size_t const range_check_size = range_check_required ? UINT16_MAX : 0;
std::vector<size_t> trace_sizes = { mem_trace_size, main_trace_size + 1, alu_trace_size,
range_check_size, conv_trace_size, sha256_trace_size,
poseidon2_trace_size, pedersen_trace_size, gas_trace_size + 1,
KERNEL_INPUTS_LENGTH, KERNEL_OUTPUTS_LENGTH, fixed_gas_table.size(),
slice_trace_size, calldata.size() };
std::vector<size_t> trace_sizes = { mem_trace_size, main_trace_size + 1, alu_trace_size,
range_check_size, conv_trace_size, sha256_trace_size,
poseidon2_trace_size, pedersen_trace_size, gas_trace_size + 1,
KERNEL_INPUTS_LENGTH, KERNEL_OUTPUTS_LENGTH, kernel_trace_size,
fixed_gas_table.size(), slice_trace_size, calldata.size() };
auto trace_size = std::max_element(trace_sizes.begin(), trace_sizes.end());

// Before making any changes to the main trace, mark the real rows.
Expand Down Expand Up @@ -3971,7 +3972,17 @@ std::vector<Row> AvmTraceBuilder::finalize()
"\n\trange_check_trace_size: ",
range_entries.size(),
"\n\tcmp_trace_size: ",
cmp_trace_size);
cmp_trace_size,
"\n\tkeccak_trace_size: ",
keccak_trace_size,
"\n\tkernel_trace_size: ",
kernel_trace_size,
"\n\tKERNEL_INPUTS_LENGTH: ",
KERNEL_INPUTS_LENGTH,
"\n\tKERNEL_OUTPUTS_LENGTH: ",
KERNEL_OUTPUTS_LENGTH,
"\n\tcalldata_size: ",
calldata.size());
reset();

return trace;
Expand Down
4 changes: 2 additions & 2 deletions bb-pilcom/bb-pil-backend/templates/circuit_builder.cpp.hbs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
namespace bb {

{{name}}CircuitBuilder::ProverPolynomials {{name}}CircuitBuilder::compute_polynomials() const {
const size_t num_rows = get_num_gates();
const size_t circuit_subgroup_size = get_circuit_subgroup_size();
// FIXME: Either some algo or the Polynomial class seems to require this to be a power of 2.
const size_t num_rows = numeric::round_up_power_2(get_num_gates());
ASSERT(num_rows <= circuit_subgroup_size);
ProverPolynomials polys;

// Allocate mem for each column
Expand Down
1 change: 0 additions & 1 deletion bb-pilcom/bb-pil-backend/templates/prover.cpp.hbs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ void {{name}}Prover::execute_log_derivative_inverse_commitments_round()
{
// Commit to all logderivative inverse polynomials
for (auto [commitment, key_poly] : zip_view(witness_commitments.get_derived(), key->get_derived())) {
// We don't use commit_sparse here because the logderivative inverse polynomials are dense
commitment = commitment_key->commit(key_poly);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,37 +496,70 @@ contract AvmTest {
// Using "inputs" here is a bit of a hack. It's the PublicContextInputs injected
// by the public macros. It's needed by the non-nested call to an entry point function.
set_storage_single(inputs, 30);

dep::aztec::oracle::debug_log::debug_log("set_storage_list");
set_storage_list(inputs, 40, 50);
dep::aztec::oracle::debug_log::debug_log("read_storage_list");
let _ = set_storage_map(inputs, context.this_address(), 60);
dep::aztec::oracle::debug_log::debug_log("add_storage_map");
let _ = add_storage_map(inputs, context.this_address(), 10);
dep::aztec::oracle::debug_log::debug_log("read_storage_map");
let _ = read_storage_map(inputs, context.this_address());
dep::aztec::oracle::debug_log::debug_log("keccak_hash");
let _ = keccak_hash(inputs, args_u8);
dep::aztec::oracle::debug_log::debug_log("sha256_hash");
let _ = sha256_hash(inputs, args_u8);
dep::aztec::oracle::debug_log::debug_log("poseidon2_hash");
let _ = poseidon2_hash(inputs, args_field);
dep::aztec::oracle::debug_log::debug_log("pedersen_hash");
let _ = pedersen_hash(inputs, args_field);
dep::aztec::oracle::debug_log::debug_log("pedersen_hash_with_index");
let _ = pedersen_hash_with_index(inputs, args_field);
dep::aztec::oracle::debug_log::debug_log("test_get_contract_instance");
test_get_contract_instance(inputs);
dep::aztec::oracle::debug_log::debug_log("get_address");
let _ = get_address(inputs);
dep::aztec::oracle::debug_log::debug_log("get_storage_address");
let _ = get_storage_address(inputs);
dep::aztec::oracle::debug_log::debug_log("get_sender");
let _ = get_sender(inputs);
dep::aztec::oracle::debug_log::debug_log("get_function_selector");
let _ = get_function_selector(inputs);
dep::aztec::oracle::debug_log::debug_log("get_transaction_fee");
let _ = get_transaction_fee(inputs);
dep::aztec::oracle::debug_log::debug_log("get_chain_id");
let _ = get_chain_id(inputs);
dep::aztec::oracle::debug_log::debug_log("get_version");
let _ = get_version(inputs);
dep::aztec::oracle::debug_log::debug_log("get_block_number");
let _ = get_block_number(inputs);
dep::aztec::oracle::debug_log::debug_log("get_timestamp");
let _ = get_timestamp(inputs);
dep::aztec::oracle::debug_log::debug_log("get_fee_per_l2_gas");
let _ = get_fee_per_l2_gas(inputs);
dep::aztec::oracle::debug_log::debug_log("get_fee_per_da_gas");
let _ = get_fee_per_da_gas(inputs);
dep::aztec::oracle::debug_log::debug_log("get_l2_gas_left");
let _ = get_l2_gas_left(inputs);
dep::aztec::oracle::debug_log::debug_log("get_da_gas_left");
let _ = get_da_gas_left(inputs);
let _ = emit_unencrypted_log(inputs);
dep::aztec::oracle::debug_log::debug_log("emit_unencrypted_log");
// let _ = emit_unencrypted_log(inputs);
dep::aztec::oracle::debug_log::debug_log("note_hash_exists");
let _ = note_hash_exists(inputs, 1, 2);
dep::aztec::oracle::debug_log::debug_log("new_note_hash");
let _ = new_note_hash(inputs, 1);
dep::aztec::oracle::debug_log::debug_log("new_nullifier");
let _ = new_nullifier(inputs, 1);
dep::aztec::oracle::debug_log::debug_log("nullifier_exists");
let _ = nullifier_exists(inputs, 1);
dep::aztec::oracle::debug_log::debug_log("l1_to_l2_msg_exists");
let _ = l1_to_l2_msg_exists(inputs, 1, 2);
dep::aztec::oracle::debug_log::debug_log("send_l2_to_l1_msg");
let _ = send_l2_to_l1_msg(inputs, EthAddress::from_field(0x2020), 1);
dep::aztec::oracle::debug_log::debug_log("nested_call_to_add");
let _ = nested_call_to_add(inputs, 1, 2);
dep::aztec::oracle::debug_log::debug_log("nested_static_call_to_add");
let _ = nested_static_call_to_add(inputs, 1, 2);
}
}
12 changes: 5 additions & 7 deletions yarn-project/bb-prover/src/avm_proving.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { AvmCircuitInputs, AvmVerificationKeyData, FunctionSelector, Gas, GlobalVariables } from '@aztec/circuits.js';
import { Fr } from '@aztec/foundation/fields';
import { createDebugLogger } from '@aztec/foundation/log';
import { AvmSimulator, type PublicContractsDB, PublicSideEffectTrace, type WorldStateDB } from '@aztec/simulator';
import { AvmSimulator, PublicSideEffectTrace, type WorldStateDB } from '@aztec/simulator';
import {
getAvmTestContractBytecode,
initContext,
Expand All @@ -21,11 +21,10 @@ import { type BBSuccess, BB_RESULT, generateAvmProof, verifyAvmProof } from './b
import { getPublicInputs } from './test/test_avm.js';
import { extractAvmVkData } from './verification_key/verification_key_data.js';

const TIMEOUT = 60_000;
const TIMEOUT = 180_000;
const TIMESTAMP = new Fr(99833);

// FIXME: This fails with "main_kernel_value_out_evaluation failed".
describe.skip('AVM WitGen, proof generation and verification', () => {
describe('AVM WitGen, proof generation and verification', () => {
it(
'Should prove and verify bulk_testing',
async () => {
Expand Down Expand Up @@ -58,7 +57,7 @@ const proveAndVerifyAvmTestContract = async (
globals.timestamp = TIMESTAMP;
const environment = initExecutionEnvironment({ functionSelector, calldata, globals });

const contractsDb = mock<PublicContractsDB>();
const worldStateDB = mock<WorldStateDB>();
const contractInstance = new SerializableContractInstance({
version: 1,
salt: new Fr(0x123),
Expand All @@ -67,9 +66,8 @@ const proveAndVerifyAvmTestContract = async (
initializationHash: new Fr(0x101112),
publicKeysHash: new Fr(0x161718),
}).withAddress(environment.address);
contractsDb.getContractInstance.mockResolvedValue(Promise.resolve(contractInstance));
worldStateDB.getContractInstance.mockResolvedValue(Promise.resolve(contractInstance));

const worldStateDB = mock<WorldStateDB>();
const storageValue = new Fr(5);
worldStateDB.storageRead.mockResolvedValue(Promise.resolve(storageValue));

Expand Down

0 comments on commit 8bfc769

Please sign in to comment.