Skip to content

Commit

Permalink
feat: replace arithmetic equalities with assert equal (#8386)
Browse files Browse the repository at this point in the history
Replace arithmetic equalities with assert_equal if the 2 equal witnesses
have been both added previously into an arithmetic gate.

---------

Co-authored-by: Tom French <15848336+TomAFrench@users.noreply.github.com>
  • Loading branch information
guipublic and TomAFrench authored Sep 5, 2024
1 parent 3228e75 commit 0d8e835
Show file tree
Hide file tree
Showing 15 changed files with 120 additions and 23 deletions.
52 changes: 32 additions & 20 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/acir_format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,14 @@ void build_constraints(Builder& builder,
gate_counter.track_diff(constraint_system.gates_per_opcode,
constraint_system.original_opcode_indices.bigint_to_le_bytes_constraints.at(i));
}
// assert equals
for (size_t i = 0; i < constraint_system.assert_equalities.size(); ++i) {
const auto& constraint = constraint_system.assert_equalities.at(i);

builder.assert_equal(constraint.a, constraint.b);
gate_counter.track_diff(constraint_system.gates_per_opcode,
constraint_system.original_opcode_indices.assert_equalities.at(i));
}

// RecursionConstraints
// TODO(https://github.com/AztecProtocol/barretenberg/issues/817): disable these for MegaHonk for now since we're
Expand All @@ -227,10 +235,11 @@ void build_constraints(Builder& builder,
process_plonk_recursion_constraints(builder, constraint_system, has_valid_witness_assignments, gate_counter);
process_honk_recursion_constraints(builder, constraint_system, has_valid_witness_assignments, gate_counter);

// If the circuit does not itself contain honk recursion constraints but is going to be proven with honk then
// recursively verified, add a default aggregation object
// If the circuit does not itself contain honk recursion constraints but is going to be
// proven with honk then recursively verified, add a default aggregation object
if (constraint_system.honk_recursion_constraints.empty() && honk_recursion &&
builder.is_recursive_circuit) { // Set a default aggregation object if we don't have one.
builder.is_recursive_circuit) { // Set a default aggregation object if we don't have
// one.
AggregationObjectIndices current_aggregation_object =
stdlib::recursion::init_default_agg_obj_indices<Builder>(builder);
// Make sure the verification key records the public input indices of the
Expand Down Expand Up @@ -265,31 +274,34 @@ void process_plonk_recursion_constraints(Builder& builder,
for (size_t constraint_idx = 0; constraint_idx < constraint_system.recursion_constraints.size(); ++constraint_idx) {
auto constraint = constraint_system.recursion_constraints[constraint_idx];

// A proof passed into the constraint should be stripped of its public inputs, except in the case where a
// proof contains an aggregation object itself. We refer to this as the `nested_aggregation_object`. The
// verifier circuit requires that the indices to a nested proof aggregation state are a circuit constant.
// The user tells us they how they want these constants set by keeping the nested aggregation object
// attached to the proof as public inputs. As this is the only object that can prepended to the proof if the
// proof is above the expected size (with public inputs stripped)
// A proof passed into the constraint should be stripped of its public inputs, except in
// the case where a proof contains an aggregation object itself. We refer to this as the
// `nested_aggregation_object`. The verifier circuit requires that the indices to a
// nested proof aggregation state are a circuit constant. The user tells us they how
// they want these constants set by keeping the nested aggregation object attached to
// the proof as public inputs. As this is the only object that can prepended to the
// proof if the proof is above the expected size (with public inputs stripped)
AggregationObjectPubInputIndices nested_aggregation_object = {};
// If the proof has public inputs attached to it, we should handle setting the nested aggregation object
// If the proof has public inputs attached to it, we should handle setting the nested
// aggregation object
if (constraint.proof.size() > proof_size_no_pub_inputs) {
// The public inputs attached to a proof should match the aggregation object in size
if (constraint.proof.size() - proof_size_no_pub_inputs != bb::AGGREGATION_OBJECT_SIZE) {
auto error_string = format(
"Public inputs are always stripped from proofs unless we have a recursive proof.\n"
"Thus, public inputs attached to a proof must match the recursive aggregation object in size "
"which is ",
bb::AGGREGATION_OBJECT_SIZE);
auto error_string = format("Public inputs are always stripped from proofs "
"unless we have a recursive proof.\n"
"Thus, public inputs attached to a proof must match "
"the recursive aggregation object in size "
"which is ",
bb::AGGREGATION_OBJECT_SIZE);
throw_or_abort(error_string);
}
for (size_t i = 0; i < bb::AGGREGATION_OBJECT_SIZE; ++i) {
// Set the nested aggregation object indices to the current size of the public inputs
// This way we know that the nested aggregation object indices will always be the last
// indices of the public inputs
// Set the nested aggregation object indices to the current size of the public
// inputs This way we know that the nested aggregation object indices will
// always be the last indices of the public inputs
nested_aggregation_object[i] = static_cast<uint32_t>(constraint.public_inputs.size());
// Attach the nested aggregation object to the end of the public inputs to fill in
// the slot where the nested aggregation object index will point into
// Attach the nested aggregation object to the end of the public inputs to fill
// in the slot where the nested aggregation object index will point into
constraint.public_inputs.emplace_back(constraint.proof[i]);
}
// Remove the aggregation object so that they can be handled as normal public inputs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ struct AcirFormatOriginalOpcodeIndices {
std::vector<size_t> bigint_from_le_bytes_constraints;
std::vector<size_t> bigint_to_le_bytes_constraints;
std::vector<size_t> bigint_operations;
std::vector<size_t> assert_equalities;
std::vector<size_t> poly_triple_constraints;
std::vector<size_t> quad_constraints;
// Multiple opcode indices per block:
Expand Down Expand Up @@ -98,6 +99,7 @@ struct AcirFormat {
std::vector<BigIntFromLeBytes> bigint_from_le_bytes_constraints;
std::vector<BigIntToLeBytes> bigint_to_le_bytes_constraints;
std::vector<BigIntOperation> bigint_operations;
std::vector<bb::poly_triple_<bb::curve::BN254::ScalarField>> assert_equalities;

// A standard plonk arithmetic constraint, as defined in the poly_triple struct, consists of selector values
// for q_M,q_L,q_R,q_O,q_C and indices of three variables taking the role of left, right and output wire
Expand All @@ -110,6 +112,9 @@ struct AcirFormat {
// Has length equal to num_acir_opcodes.
std::vector<size_t> gates_per_opcode = {};

// Set of constrained witnesses
std::set<uint32_t> constrained_witness = {};

// Indices of the original opcode that originated each constraint in AcirFormat.
AcirFormatOriginalOpcodeIndices original_opcode_indices;

Expand Down Expand Up @@ -139,7 +144,8 @@ struct AcirFormat {
block_constraints,
bigint_from_le_bytes_constraints,
bigint_to_le_bytes_constraints,
bigint_operations);
bigint_operations,
assert_equalities);

friend bool operator==(AcirFormat const& lhs, AcirFormat const& rhs) = default;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ TEST_F(AcirFormatTests, TestASingleConstraintNoPubInputs)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = { constraint },
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -185,6 +186,7 @@ TEST_F(AcirFormatTests, TestLogicGateFromNoirCircuit)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = { expr_a, expr_b, expr_c, expr_d },
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -264,6 +266,7 @@ TEST_F(AcirFormatTests, TestSchnorrVerifyPass)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = { poly_triple{
.a = schnorr_constraint.result,
.b = schnorr_constraint.result,
Expand Down Expand Up @@ -370,6 +373,7 @@ TEST_F(AcirFormatTests, TestSchnorrVerifySmallRange)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = { poly_triple{
.a = schnorr_constraint.result,
.b = schnorr_constraint.result,
Expand Down Expand Up @@ -489,6 +493,7 @@ TEST_F(AcirFormatTests, TestVarKeccak)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = { dummy },
.quad_constraints = {},
.block_constraints = {},
Expand All @@ -510,7 +515,7 @@ TEST_F(AcirFormatTests, TestKeccakPermutation)
{
Keccakf1600
keccak_permutation{
.state = {
.state = {
WitnessOrConstant<bb::fr>::from_index(1),
WitnessOrConstant<bb::fr>::from_index(2),
WitnessOrConstant<bb::fr>::from_index(3),
Expand Down Expand Up @@ -568,6 +573,7 @@ TEST_F(AcirFormatTests, TestKeccakPermutation)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -644,6 +650,7 @@ TEST_F(AcirFormatTests, TestCollectsGateCounts)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = { first_gate, second_gate },
.quad_constraints = {},
.block_constraints = {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ acir_format::AcirFormatOriginalOpcodeIndices create_empty_original_opcode_indice
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -100,6 +101,9 @@ void mock_opcode_indices(acir_format::AcirFormat& constraint_system)
for (size_t i = 0; i < constraint_system.bigint_operations.size(); i++) {
constraint_system.original_opcode_indices.bigint_operations.push_back(current_opcode++);
}
for (size_t i = 0; i < constraint_system.assert_equalities.size(); i++) {
constraint_system.original_opcode_indices.assert_equalities.push_back(current_opcode++);
}
for (size_t i = 0; i < constraint_system.poly_triple_constraints.size(); i++) {
constraint_system.original_opcode_indices.poly_triple_constraints.push_back(current_opcode++);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#include "acir_to_constraint_buf.hpp"
#include "barretenberg/common/container.hpp"
#include "barretenberg/numeric/uint256/uint256.hpp"
#include "barretenberg/plonk_honk_shared/arithmetization/gate_data.hpp"
#include <cstddef>
#include <cstdint>
#include <tuple>
#include <utility>
#ifndef __wasm__
Expand Down Expand Up @@ -167,10 +170,50 @@ mul_quad_<fr> serialize_mul_quad_gate(Program::Expression const& arg)
return quad;
}

void constrain_witnesses(Program::Opcode::AssertZero const& arg, AcirFormat& af)
{
for (const auto& linear_term : arg.value.linear_combinations) {
uint32_t witness_idx = std::get<1>(linear_term).value;
af.constrained_witness.insert(witness_idx);
}
for (const auto& linear_term : arg.value.mul_terms) {
uint32_t witness_idx = std::get<1>(linear_term).value;
af.constrained_witness.insert(witness_idx);
witness_idx = std::get<2>(linear_term).value;
af.constrained_witness.insert(witness_idx);
}
}

std::pair<uint32_t, uint32_t> is_assert_equal(Program::Opcode::AssertZero const& arg,
poly_triple const& pt,
AcirFormat const& af)
{
if (!arg.value.mul_terms.empty() || arg.value.linear_combinations.size() != 2) {
return { 0, 0 };
}
if (pt.q_l == -pt.q_r && pt.q_l != bb::fr::zero() && pt.q_c == bb::fr::zero()) {
if (af.constrained_witness.contains(pt.a) && af.constrained_witness.contains(pt.b)) {
return { pt.a, pt.b };
}
}
return { 0, 0 };
}

void handle_arithmetic(Program::Opcode::AssertZero const& arg, AcirFormat& af, size_t opcode_index)
{
if (arg.value.linear_combinations.size() <= 3) {
poly_triple pt = serialize_arithmetic_gate(arg.value);

auto assert_equal = is_assert_equal(arg, pt, af);
uint32_t w1 = std::get<0>(assert_equal);
uint32_t w2 = std::get<1>(assert_equal);
if (w1 != 0) {
if (w1 != w2) {
af.assert_equalities.push_back(pt);
af.original_opcode_indices.assert_equalities.push_back(opcode_index);
}
return;
}
// Even if the number of linear terms is less than 3, we might not be able to fit it into a width-3 arithmetic
// gate. This is the case if the linear terms are all disctinct witness from the multiplication term. In that
// case, the serialize_arithmetic_gate() function will return a poly_triple with all 0's, and we use a width-4
Expand All @@ -187,6 +230,7 @@ void handle_arithmetic(Program::Opcode::AssertZero const& arg, AcirFormat& af, s
af.quad_constraints.push_back(serialize_mul_quad_gate(arg.value));
af.original_opcode_indices.quad_constraints.push_back(opcode_index);
}
constrain_witnesses(arg, af);
}

uint32_t get_witness_from_function_input(Program::FunctionInput input)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ TEST_F(BigIntTests, TestBigIntConstraintMultiple)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -270,6 +271,7 @@ TEST_F(BigIntTests, TestBigIntConstraintSimple)
.bigint_from_le_bytes_constraints = { from_le_bytes_constraint_bigint1 },
.bigint_to_le_bytes_constraints = { result2_to_le_bytes },
.bigint_operations = { add_constraint },
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -327,6 +329,7 @@ TEST_F(BigIntTests, TestBigIntConstraintReuse)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -389,6 +392,7 @@ TEST_F(BigIntTests, TestBigIntConstraintReuse2)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -472,6 +476,7 @@ TEST_F(BigIntTests, TestBigIntDIV)
.bigint_from_le_bytes_constraints = { from_le_bytes_constraint_bigint1, from_le_bytes_constraint_bigint2 },
.bigint_to_le_bytes_constraints = { result3_to_le_bytes },
.bigint_operations = { div_constraint },
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ TEST_F(UltraPlonkRAM, TestBlockConstraint)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = { block },
Expand Down Expand Up @@ -216,6 +217,7 @@ TEST_F(MegaHonk, Databus)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = { block },
Expand Down Expand Up @@ -322,6 +324,7 @@ TEST_F(MegaHonk, DatabusReturn)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = { assert_equal },
.quad_constraints = {},
.block_constraints = { block },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ TEST_F(EcOperations, TestECOperations)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -223,6 +224,7 @@ TEST_F(EcOperations, TestECMultiScalarMul)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = { assert_equal },
.quad_constraints = {},
.block_constraints = {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ TEST_F(ECDSASecp256k1, TestECDSAConstraintSucceed)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -173,6 +174,7 @@ TEST_F(ECDSASecp256k1, TestECDSACompilesForVerifier)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down Expand Up @@ -222,6 +224,7 @@ TEST_F(ECDSASecp256k1, TestECDSAConstraintFail)
.bigint_from_le_bytes_constraints = {},
.bigint_to_le_bytes_constraints = {},
.bigint_operations = {},
.assert_equalities = {},
.poly_triple_constraints = {},
.quad_constraints = {},
.block_constraints = {},
Expand Down
Loading

0 comments on commit 0d8e835

Please sign in to comment.