Skip to content

Commit

Permalink
feat: proof surgery class (#8236)
Browse files Browse the repository at this point in the history
Adds a `ProofSurgeon` class that manages all proof surgery, e.g.
splitting public inputs out of proof for acir and reconstructing again
for bberg. Simplifies things quite a bit in the process.
  • Loading branch information
ledwards2225 authored Aug 28, 2024
1 parent 9f4ea9f commit 10d7edd
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 144 deletions.
22 changes: 5 additions & 17 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "barretenberg/stdlib/primitives/field/field_conversion.hpp"
#include "barretenberg/stdlib_circuit_builders/mega_circuit_builder.hpp"
#include "barretenberg/stdlib_circuit_builders/ultra_circuit_builder.hpp"
#include "proof_surgeon.hpp"
#include <cstddef>

namespace acir_format {
Expand Down Expand Up @@ -333,26 +334,13 @@ void process_honk_recursion_constraints(Builder& builder,
stdlib::recursion::init_default_agg_obj_indices<Builder>(builder);

// Add recursion constraints
for (size_t i = 0; i < constraint_system.honk_recursion_constraints.size(); ++i) {
auto& constraint = constraint_system.honk_recursion_constraints.at(i);
// A proof passed into the constraint should be stripped of its inner public inputs, but not the nested
// aggregation object itself. The verifier circuit requires that the indices to a nested proof aggregation
// state are a circuit constant. The user tells us they how they want these constants set by keeping the
// nested aggregation object attached to the proof as public inputs.
for (size_t i = 0; i < bb::AGGREGATION_OBJECT_SIZE; ++i) {
// Adding the nested aggregation object to the constraint's public inputs
constraint.public_inputs.emplace_back(constraint.proof[HONK_RECURSION_PUBLIC_INPUT_OFFSET + i]);
}
// Remove the aggregation object so that they can be handled as normal public inputs
// in they way that the recursion constraint expects
constraint.proof.erase(
constraint.proof.begin() + HONK_RECURSION_PUBLIC_INPUT_OFFSET,
constraint.proof.begin() +
static_cast<std::ptrdiff_t>(HONK_RECURSION_PUBLIC_INPUT_OFFSET + bb::AGGREGATION_OBJECT_SIZE));
size_t idx = 0;
for (auto& constraint : constraint_system.honk_recursion_constraints) {
current_aggregation_object = create_honk_recursion_constraints(
builder, constraint, current_aggregation_object, has_valid_witness_assignments);

gate_counter.track_diff(constraint_system.gates_per_opcode,
constraint_system.original_opcode_indices.honk_recursion_constraints.at(i));
constraint_system.original_opcode_indices.honk_recursion_constraints.at(idx++));
}

// Now that the circuit has been completely built, we add the output aggregation as public
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "barretenberg/stdlib/primitives/bigfield/constants.hpp"
#include "barretenberg/stdlib/primitives/curves/bn254.hpp"
#include "barretenberg/stdlib_circuit_builders/ultra_recursive_flavor.hpp"
#include "proof_surgeon.hpp"
#include "recursion_constraint.hpp"

namespace acir_format {
Expand All @@ -22,39 +23,41 @@ using aggregation_state_ct = bb::stdlib::recursion::aggregation_state<bn254>;
* aggregation object, and commitments.
*
* @param builder
* @param input
* @param proof_size Size of proof with NO public inputs
* @param public_inputs_size Total size of public inputs including aggregation object
* @param key_fields
* @param proof_fields
*/
void create_dummy_vkey_and_proof(Builder& builder,
const RecursionConstraint& input,
std::vector<field_ct>& key_fields,
std::vector<field_ct>& proof_fields)
size_t proof_size,
size_t public_inputs_size,
const std::vector<field_ct>& key_fields,
const std::vector<field_ct>& proof_fields)
{
using Flavor = UltraRecursiveFlavor_<Builder>;
using Flavor = UltraFlavor;

// Set vkey->circuit_size correctly based on the proof size
size_t num_frs_comm = bb::field_conversion::calc_num_bn254_frs<UltraFlavor::Commitment>();
size_t num_frs_fr = bb::field_conversion::calc_num_bn254_frs<UltraFlavor::FF>();
assert((input.proof.size() - HONK_RECURSION_PUBLIC_INPUT_OFFSET - UltraFlavor::NUM_WITNESS_ENTITIES * num_frs_comm -
UltraFlavor::NUM_ALL_ENTITIES * num_frs_fr - 2 * num_frs_comm) %
(num_frs_comm + num_frs_fr * UltraFlavor::BATCHED_RELATION_PARTIAL_LENGTH) ==
size_t num_frs_comm = bb::field_conversion::calc_num_bn254_frs<Flavor::Commitment>();
size_t num_frs_fr = bb::field_conversion::calc_num_bn254_frs<Flavor::FF>();
assert((proof_size - HONK_RECURSION_PUBLIC_INPUT_OFFSET - Flavor::NUM_WITNESS_ENTITIES * num_frs_comm -
Flavor::NUM_ALL_ENTITIES * num_frs_fr - 2 * num_frs_comm) %
(num_frs_comm + num_frs_fr * Flavor::BATCHED_RELATION_PARTIAL_LENGTH) ==
0);
// Note: this computation should always result in log_circuit_size = CONST_PROOF_SIZE_LOG_N
auto log_circuit_size =
(input.proof.size() - HONK_RECURSION_PUBLIC_INPUT_OFFSET - UltraFlavor::NUM_WITNESS_ENTITIES * num_frs_comm -
UltraFlavor::NUM_ALL_ENTITIES * num_frs_fr - 2 * num_frs_comm) /
(num_frs_comm + num_frs_fr * UltraFlavor::BATCHED_RELATION_PARTIAL_LENGTH);
(proof_size - HONK_RECURSION_PUBLIC_INPUT_OFFSET - Flavor::NUM_WITNESS_ENTITIES * num_frs_comm -
Flavor::NUM_ALL_ENTITIES * num_frs_fr - 2 * num_frs_comm) /
(num_frs_comm + num_frs_fr * Flavor::BATCHED_RELATION_PARTIAL_LENGTH);
// First key field is circuit size
builder.assert_equal(builder.add_variable(1 << log_circuit_size), key_fields[0].witness_index);
// Second key field is number of public inputs
builder.assert_equal(builder.add_variable(input.public_inputs.size()), key_fields[1].witness_index);
builder.assert_equal(builder.add_variable(public_inputs_size), key_fields[1].witness_index);
// Third key field is the pub inputs offset
builder.assert_equal(builder.add_variable(UltraFlavor::has_zero_row ? 1 : 0), key_fields[2].witness_index);
builder.assert_equal(builder.add_variable(Flavor::has_zero_row ? 1 : 0), key_fields[2].witness_index);
// Fourth key field is the whether the proof contains an aggregation object.
builder.assert_equal(builder.add_variable(1), key_fields[4].witness_index);
builder.assert_equal(builder.add_variable(1), key_fields[3].witness_index);
uint32_t offset = 4;
size_t num_inner_public_inputs = input.public_inputs.size() - bb::AGGREGATION_OBJECT_SIZE;
size_t num_inner_public_inputs = public_inputs_size - bb::AGGREGATION_OBJECT_SIZE;

// We are making the assumption that the aggregation object are behind all the inner public inputs
for (size_t i = 0; i < bb::AGGREGATION_OBJECT_SIZE; i++) {
Expand All @@ -75,8 +78,8 @@ void create_dummy_vkey_and_proof(Builder& builder,
offset = HONK_RECURSION_PUBLIC_INPUT_OFFSET;
// first 3 things
builder.assert_equal(builder.add_variable(1 << log_circuit_size), proof_fields[0].witness_index);
builder.assert_equal(builder.add_variable(input.public_inputs.size()), proof_fields[1].witness_index);
builder.assert_equal(builder.add_variable(UltraFlavor::has_zero_row ? 1 : 0), proof_fields[2].witness_index);
builder.assert_equal(builder.add_variable(public_inputs_size), proof_fields[1].witness_index);
builder.assert_equal(builder.add_variable(Flavor::has_zero_row ? 1 : 0), proof_fields[2].witness_index);

// the inner public inputs
for (size_t i = 0; i < num_inner_public_inputs; i++) {
Expand Down Expand Up @@ -134,7 +137,7 @@ void create_dummy_vkey_and_proof(Builder& builder,
builder.assert_equal(builder.add_variable(frs[3]), proof_fields[offset + 3].witness_index);
offset += 4;
}
ASSERT(offset == input.proof.size() + input.public_inputs.size());
ASSERT(offset == proof_size + public_inputs_size);
}

/**
Expand Down Expand Up @@ -171,27 +174,26 @@ AggregationObjectIndices create_honk_recursion_constraints(Builder& builder,
}

std::vector<field_ct> proof_fields;
// Insert the public inputs in the middle the proof fields after 'inner_public_input_offset' because this is how the
// core barretenberg library processes proofs (with the public inputs starting at the third element and not
// separate from the rest of the proof)
proof_fields.reserve(input.proof.size() + input.public_inputs.size());
size_t i = 0;
for (const auto& idx : input.proof) {

// Create witness indices for the proof with public inputs reinserted
std::vector<uint32_t> proof_indices =
ProofSurgeon::create_indices_for_reconstructed_proof(input.proof, input.public_inputs);
proof_fields.reserve(proof_indices.size());
for (const auto& idx : proof_indices) {
auto field = field_ct::from_witness_index(&builder, idx);
proof_fields.emplace_back(field);
i++;
if (i == HONK_RECURSION_PUBLIC_INPUT_OFFSET) {
for (const auto& idx : input.public_inputs) {
auto field = field_ct::from_witness_index(&builder, idx);
proof_fields.emplace_back(field);
}
}
}
// Populate the key fields and proof fields with dummy values to prevent issues (usually with points not being on
// the curve).

// Populate the key fields and proof fields with dummy values to prevent issues (e.g. points must be on curve).
if (!has_valid_witness_assignments) {
create_dummy_vkey_and_proof(builder, input, key_fields, proof_fields);
// In the constraint, the agg object public inputs are still contained in the proof. To get the 'raw' size of
// the proof and public_inputs we subtract and add the corresponding amount from the respective sizes.
size_t size_of_proof_with_no_pub_inputs = input.proof.size() - bb::AGGREGATION_OBJECT_SIZE;
size_t total_num_public_inputs = input.public_inputs.size() + bb::AGGREGATION_OBJECT_SIZE;
create_dummy_vkey_and_proof(
builder, size_of_proof_with_no_pub_inputs, total_num_public_inputs, key_fields, proof_fields);
}

// Recursively verify the proof
auto vkey = std::make_shared<RecursiveVerificationKey>(builder, key_fields);
RecursiveVerifier verifier(&builder, vkey);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,6 @@ using Builder = bb::UltraCircuitBuilder;

using namespace bb;

// In Honk, the proof starts with circuit_size, num_public_inputs, and pub_input_offset. We use this offset to keep
// track of where the public inputs start.
static constexpr size_t HONK_RECURSION_PUBLIC_INPUT_OFFSET = 3;

AggregationObjectIndices create_honk_recursion_constraints(Builder& builder,
const RecursionConstraint& input,
AggregationObjectIndices input_aggregation_object,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "barretenberg/sumcheck/instance/prover_instance.hpp"
#include "barretenberg/ultra_honk/ultra_prover.hpp"
#include "barretenberg/ultra_honk/ultra_verifier.hpp"
#include "proof_surgeon.hpp"

#include <gtest/gtest.h>
#include <vector>
Expand Down Expand Up @@ -140,7 +141,6 @@ class AcirHonkRecursionConstraint : public ::testing::Test {
{
std::vector<RecursionConstraint> honk_recursion_constraints;

size_t witness_offset = 0;
SlabVector<fr> witness;

for (auto& inner_circuit : inner_circuits) {
Expand All @@ -151,60 +151,12 @@ class AcirHonkRecursionConstraint : public ::testing::Test {
Verifier verifier(verification_key);
auto inner_proof = prover.construct_proof();

const size_t num_inner_public_inputs = inner_circuit.get_public_inputs().size();

std::vector<fr> proof_witnesses = inner_proof;
// where the inner public inputs start (after circuit_size, num_pub_inputs, pub_input_offset)
const size_t inner_public_input_offset = HONK_RECURSION_PUBLIC_INPUT_OFFSET;
// - Save the public inputs so that we can set their values.
// - Then truncate them from the proof because the ACIR API expects proofs without public inputs
std::vector<fr> inner_public_input_values(
proof_witnesses.begin() + static_cast<std::ptrdiff_t>(inner_public_input_offset),
proof_witnesses.begin() +
static_cast<std::ptrdiff_t>(inner_public_input_offset + num_inner_public_inputs -
bb::AGGREGATION_OBJECT_SIZE));

// We want to make sure that we do not remove the nested aggregation object.
proof_witnesses.erase(proof_witnesses.begin() + static_cast<std::ptrdiff_t>(inner_public_input_offset),
proof_witnesses.begin() +
static_cast<std::ptrdiff_t>(inner_public_input_offset + num_inner_public_inputs -
bb::AGGREGATION_OBJECT_SIZE));

std::vector<bb::fr> key_witnesses = verification_key->to_field_elements();
std::vector<fr> proof_witnesses = inner_proof;
const size_t num_public_inputs = inner_circuit.get_public_inputs().size();

// This is the structure of proof_witnesses and key_witnesses concatenated, which is what we end up putting
// in witness:
// [ circuit size, num_pub_inputs, pub_input_offset, public_input_0, public_input_1, agg_obj_0,
// agg_obj_1, ..., agg_obj_15, rest of proof..., vkey_0, vkey_1, vkey_2, vkey_3...]
const uint32_t public_input_start_idx =
static_cast<uint32_t>(inner_public_input_offset + witness_offset); // points to public_input_0
const uint32_t proof_indices_start_idx = static_cast<uint32_t>(
public_input_start_idx + num_inner_public_inputs - bb::AGGREGATION_OBJECT_SIZE); // points to agg_obj_0
const uint32_t key_indices_start_idx =
static_cast<uint32_t>(proof_indices_start_idx + proof_witnesses.size() -
inner_public_input_offset); // would point to vkey_3 without the -
// inner_public_input_offset, points to vkey_0

std::vector<uint32_t> proof_indices;
std::vector<uint32_t> key_indices;
std::vector<uint32_t> inner_public_inputs;
for (size_t i = 0; i < inner_public_input_offset; ++i) { // go over circuit size, num_pub_inputs, pub_offset
proof_indices.emplace_back(static_cast<uint32_t>(i + witness_offset));
}
for (size_t i = 0; i < proof_witnesses.size() - inner_public_input_offset;
++i) { // goes over agg_obj_0, agg_obj_1, ..., agg_obj_15 and rest of proof
proof_indices.emplace_back(static_cast<uint32_t>(i + proof_indices_start_idx));
}
const size_t key_size = key_witnesses.size();
for (size_t i = 0; i < key_size; ++i) {
key_indices.emplace_back(static_cast<uint32_t>(i + key_indices_start_idx));
}
// We keep the nested aggregation object attached to the proof,
// thus we do not explicitly have to keep the public inputs while setting up the initial recursion
// constraint. They will later be attached as public inputs when creating the circuit.
for (size_t i = 0; i < num_inner_public_inputs - bb::AGGREGATION_OBJECT_SIZE; ++i) {
inner_public_inputs.push_back(static_cast<uint32_t>(i + public_input_start_idx));
}
auto [key_indices, proof_indices, inner_public_inputs] = ProofSurgeon::populate_recursion_witness_data(
witness, proof_witnesses, key_witnesses, num_public_inputs);

RecursionConstraint honk_recursion_constraint{
.key = key_indices,
Expand All @@ -214,40 +166,6 @@ class AcirHonkRecursionConstraint : public ::testing::Test {
.proof_type = HONK_RECURSION,
};
honk_recursion_constraints.push_back(honk_recursion_constraint);

// Setting the witness vector which just appends proof witnesses and key witnesses.
// We need to reconstruct the proof witnesses in the same order as the proof indices, with this structure:
// [ circuit size, num_pub_inputs, pub_input_offset, public_input_0, public_input_1, agg_obj_0,
// agg_obj_1, ..., agg_obj_15, rest of proof..., vkey_0, vkey_1, vkey_2, vkey_3...]
size_t idx = 0;
for (const auto& wit : proof_witnesses) {
witness.emplace_back(wit);
idx++;
if (idx ==
inner_public_input_offset) { // before this is true, the loop adds the first three into witness
for (size_t i = 0; i < proof_indices_start_idx - public_input_start_idx;
++i) { // adds the inner public inputs
witness.emplace_back(0);
}
} // after this, it adds the agg obj and rest of proof
}

for (const auto& wit : key_witnesses) {
witness.emplace_back(wit);
}

// Set the values for the inner public inputs
// TODO(maxim): check this is wrong I think
// Note: this is confusing, but we minus one here due to the fact that the
// witness values have not taken into account that zero is taken up by the zero_idx
//
// We once again have to check whether we have a nested proof, because if we do have one
// then we could get a segmentation fault as `inner_public_inputs` was never filled with values.
for (size_t i = 0; i < num_inner_public_inputs - bb::AGGREGATION_OBJECT_SIZE; ++i) {
witness[inner_public_inputs[i]] = inner_public_input_values[i];
}

witness_offset = key_indices_start_idx + key_witnesses.size();
}

std::vector<size_t> honk_recursion_opcode_indices(honk_recursion_constraints.size());
Expand Down
Loading

0 comments on commit 10d7edd

Please sign in to comment.