Skip to content

Commit

Permalink
feat: databus allows arbitrarily many reads per index (#6524)
Browse files Browse the repository at this point in the history
TLDR: Up until now we were limited to only 1 read per entry of a databus
column (see explanation of why below). This PR removes this limitation
so that we can read from any row arbitrarily many times at the cost of
adding one polynomial/commitment per databus column.

Note: this PR also cleans up some of the handling of ecc op wires and
databus polys in various places by making better use of Flavor style
getters.

Explanation: The log derivative lookup relation involves a polynomial
that contains inverses, i.e. I_i = (read_term_i*write_term_i)^{-1}.
These inverses only need to be computed when the relation is "active",
i.e. when the row in question either contains a databus read gate or
data that is being read. At all other rows, we simply set the value of
the inverse polynomial to 0. This allows a subrelation of the form:

`read_term * write_term * inverses - inverse_exists`

Where `inverse_exists` is a polynomial that takes 1 if the relation is
active (or equivalently, if the inverse has been computed) and 0
otherwise. Therefore, if the inverse has been computed, we check that it
is indeed equal to the inverse of `read_term * write_term`, otherwise,
the subrelation contribution is trivially 0. If we only allow a single
read from each row of a bus column, the term `inverse_exists` can be
computed as an algebraic OR of the form:

`is_read_gate + read_counts - (is_read_gate * read_counts)`

since both `is_read_gate` and `read_counts` are both boolean. If
`read_counts` is no longer boolean, no such algebraic expression exists.
The solution is to introduce a dedicated boolean polynomial `read_tag`
whose values are given by `min(1, read_counts)`, i.e. 1 if one or more
reads have been performed at that row, and 0 otherwise.

Closes #937
  • Loading branch information
ledwards2225 authored and AztecBot committed Jul 15, 2024
1 parent ce124fe commit 5d27e66
Show file tree
Hide file tree
Showing 11 changed files with 172 additions and 88 deletions.
12 changes: 7 additions & 5 deletions cpp/src/barretenberg/relations/databus_lookup_relation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ template <typename FF_> class DatabusLookupRelationImpl {
static auto& inverses(AllEntities& in) { return in.calldata_inverses; }
static auto& inverses(const AllEntities& in) { return in.calldata_inverses; } // const version
static auto& read_counts(const AllEntities& in) { return in.calldata_read_counts; }
static auto& read_tags(const AllEntities& in) { return in.calldata_read_tags; }
};

// Specialization for return data (bus_idx = 1)
Expand All @@ -88,6 +89,7 @@ template <typename FF_> class DatabusLookupRelationImpl {
static auto& inverses(AllEntities& in) { return in.return_data_inverses; }
static auto& inverses(const AllEntities& in) { return in.return_data_inverses; } // const version
static auto& read_counts(const AllEntities& in) { return in.return_data_read_counts; }
static auto& read_tags(const AllEntities& in) { return in.return_data_read_tags; }
};

/**
Expand All @@ -101,8 +103,8 @@ template <typename FF_> class DatabusLookupRelationImpl {
template <size_t bus_idx, typename AllValues> static bool operation_exists_at_row(const AllValues& row)
{
auto read_selector = get_read_selector<FF, bus_idx>(row);
auto read_counts = BusData<bus_idx, AllValues>::read_counts(row);
return (read_selector == 1 || read_counts > 0);
auto read_tag = BusData<bus_idx, AllValues>::read_tags(row);
return (read_selector == 1 || read_tag == 1);
}

/**
Expand All @@ -117,10 +119,10 @@ template <typename FF_> class DatabusLookupRelationImpl {
{
using View = typename Accumulator::View;

const auto is_read_gate = get_read_selector<Accumulator, bus_idx>(in);
const auto read_counts = View(BusData<bus_idx, AllEntities>::read_counts(in));
const auto is_read_gate = get_read_selector<Accumulator, bus_idx>(in); // is this a read gate
const auto read_tag = View(BusData<bus_idx, AllEntities>::read_tags(in)); // does row contain data being read

return is_read_gate + read_counts - (is_read_gate * read_counts);
return is_read_gate + read_tag - (is_read_gate * read_tag);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,16 @@ void ProtoGalaxyRecursiveVerifier_<VerifierInstances>::receive_and_finalise_inst
witness_commitments.w_o = transcript->template receive_from_prover<Commitment>(domain_separator + "_" + labels.w_o);

if constexpr (IsGoblinFlavor<Flavor>) {
witness_commitments.ecc_op_wire_1 =
transcript->template receive_from_prover<Commitment>(domain_separator + "_" + labels.ecc_op_wire_1);
witness_commitments.ecc_op_wire_2 =
transcript->template receive_from_prover<Commitment>(domain_separator + "_" + labels.ecc_op_wire_2);
witness_commitments.ecc_op_wire_3 =
transcript->template receive_from_prover<Commitment>(domain_separator + "_" + labels.ecc_op_wire_3);
witness_commitments.ecc_op_wire_4 =
transcript->template receive_from_prover<Commitment>(domain_separator + "_" + labels.ecc_op_wire_4);
witness_commitments.calldata =
transcript->template receive_from_prover<Commitment>(domain_separator + "_" + labels.calldata);
witness_commitments.calldata_read_counts =
transcript->template receive_from_prover<Commitment>(domain_separator + "_" + labels.calldata_read_counts);
witness_commitments.return_data =
transcript->template receive_from_prover<Commitment>(domain_separator + "_" + labels.return_data);
witness_commitments.return_data_read_counts = transcript->template receive_from_prover<Commitment>(
domain_separator + "_" + labels.return_data_read_counts);
// Receive ECC op wire commitments
for (auto [commitment, label] : zip_view(witness_commitments.get_ecc_op_wires(), labels.get_ecc_op_wires())) {
commitment = transcript->template receive_from_prover<Commitment>(domain_separator + "_" + label);
}

// Receive DataBus related polynomial commitments
for (auto [commitment, label] :
zip_view(witness_commitments.get_databus_entities(), labels.get_databus_entities())) {
commitment = transcript->template receive_from_prover<Commitment>(domain_separator + "_" + label);
}
}

// Get eta challenges
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,20 +75,17 @@ std::array<typename Flavor::GroupElement, 2> UltraRecursiveVerifier_<Flavor>::ve

// If Goblin, get commitments to ECC op wire polynomials and DataBus columns
if constexpr (IsGoblinFlavor<Flavor>) {
commitments.ecc_op_wire_1 =
transcript->template receive_from_prover<Commitment>(commitment_labels.ecc_op_wire_1);
commitments.ecc_op_wire_2 =
transcript->template receive_from_prover<Commitment>(commitment_labels.ecc_op_wire_2);
commitments.ecc_op_wire_3 =
transcript->template receive_from_prover<Commitment>(commitment_labels.ecc_op_wire_3);
commitments.ecc_op_wire_4 =
transcript->template receive_from_prover<Commitment>(commitment_labels.ecc_op_wire_4);
commitments.calldata = transcript->template receive_from_prover<Commitment>(commitment_labels.calldata);
commitments.calldata_read_counts =
transcript->template receive_from_prover<Commitment>(commitment_labels.calldata_read_counts);
commitments.return_data = transcript->template receive_from_prover<Commitment>(commitment_labels.return_data);
commitments.return_data_read_counts =
transcript->template receive_from_prover<Commitment>(commitment_labels.return_data_read_counts);
// Receive ECC op wire commitments
for (auto [commitment, label] :
zip_view(commitments.get_ecc_op_wires(), commitment_labels.get_ecc_op_wires())) {
commitment = transcript->template receive_from_prover<Commitment>(label);
}

// Receive DataBus related polynomial commitments
for (auto [commitment, label] :
zip_view(commitments.get_databus_entities(), commitment_labels.get_databus_entities())) {
commitment = transcript->template receive_from_prover<Commitment>(label);
}
}

// Get eta challenges; used in RAM/ROM memory records and log derivative lookup argument
Expand Down
43 changes: 43 additions & 0 deletions cpp/src/barretenberg/stdlib/primitives/databus/databus.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,46 @@ TEST(Databus, BadCopyFailure)
// Since the output data is not a copy of the input, the checker should fail
EXPECT_FALSE(CircuitChecker::check(builder));
}

/**
* @brief Check that multiple reads from the same index results in a valid circuit
*
*/
TEST(Databus, DuplicateRead)
{
Builder builder;
databus_ct databus;

// Define some arbitrary bus data
std::array<bb::fr, 3> raw_calldata_values = { 5, 1, 2 };
std::array<bb::fr, 3> raw_return_data_values = { 25, 6, 3 };

// Populate the calldata in the databus
std::vector<field_ct> calldata_values;
for (auto& value : raw_calldata_values) {
calldata_values.emplace_back(witness_ct(&builder, value));
}
databus.calldata.set_values(calldata_values);

// Populate the return data in the databus
std::vector<field_ct> return_data_values;
for (auto& value : raw_return_data_values) {
return_data_values.emplace_back(witness_ct(&builder, value));
}
databus.return_data.set_values(return_data_values);

// Perform some arbitrary reads from both calldata and return data with some repeated indices
field_ct idx_1(witness_ct(&builder, 1));
field_ct idx_2(witness_ct(&builder, 2));

databus.calldata[idx_1];
databus.calldata[idx_1];
databus.calldata[idx_1];
databus.calldata[idx_2];

databus.return_data[idx_2];
databus.return_data[idx_2];
databus.return_data[idx_1];

EXPECT_TRUE(CircuitChecker::check(builder));
}
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,6 @@ uint32_t MegaCircuitBuilder_<FF>::read_bus_vector(BusId bus_idx, const uint32_t&
const uint32_t read_idx = static_cast<uint32_t>(uint256_t(this->get_variable(read_idx_witness_idx)));

ASSERT(read_idx < bus_vector.size()); // Ensure that the read index is valid
// NOTE(https://github.com/AztecProtocol/barretenberg/issues/937): Multiple reads at same index is not supported.
ASSERT(bus_vector.get_read_count(read_idx) < 1);

// Create a variable corresponding to the result of the read. Note that we do not in general connect reads from
// databus via copy constraints (i.e. we create a unique variable for the result of each read)
Expand Down
35 changes: 25 additions & 10 deletions cpp/src/barretenberg/stdlib_circuit_builders/mega_flavor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ class MegaFlavor {
static constexpr size_t NUM_WIRES = CircuitBuilder::NUM_WIRES;
// The number of multivariate polynomials on which a sumcheck prover sumcheck operates (including shifts). We often
// need containers of this size to hold related data, so we choose a name more agnostic than `NUM_POLYNOMIALS`.
static constexpr size_t NUM_ALL_ENTITIES = 57;
static constexpr size_t NUM_ALL_ENTITIES = 59;
// The number of polynomials precomputed to describe a circuit and to aid a prover in constructing a satisfying
// assignment of witnesses. We again choose a neutral name.
static constexpr size_t NUM_PRECOMPUTED_ENTITIES = 30;
// The total number of witness entities not including shifts.
static constexpr size_t NUM_WITNESS_ENTITIES = 18;
static constexpr size_t NUM_WITNESS_ENTITIES = 20;
// Total number of folded polynomials, which is just all polynomials except the shifts
static constexpr size_t NUM_FOLDED_ENTITIES = NUM_PRECOMPUTED_ENTITIES + NUM_WITNESS_ENTITIES;

Expand Down Expand Up @@ -186,10 +186,12 @@ class MegaFlavor {
ecc_op_wire_4, // column 11
calldata, // column 12
calldata_read_counts, // column 13
calldata_inverses, // column 14
return_data, // column 15
return_data_read_counts, // column 16
return_data_inverses); // column 17
calldata_read_tags, // column 14
calldata_inverses, // column 15
return_data, // column 16
return_data_read_counts, // column 17
return_data_read_tags, // column 18
return_data_inverses); // column 19
};

/**
Expand All @@ -207,6 +209,11 @@ class MegaFlavor {
{
return RefArray{ this->ecc_op_wire_1, this->ecc_op_wire_2, this->ecc_op_wire_3, this->ecc_op_wire_4 };
}
auto get_databus_entities() // Excludes the derived inverse polynomials
{
return RefArray{ this->calldata, this->calldata_read_counts, this->calldata_read_tags,
this->return_data, this->return_data_read_counts, this->return_data_read_tags };
}

MSGPACK_FIELDS(this->w_l,
this->w_r,
Expand All @@ -222,9 +229,11 @@ class MegaFlavor {
this->ecc_op_wire_4,
this->calldata,
this->calldata_read_counts,
this->calldata_read_tags,
this->calldata_inverses,
this->return_data,
this->return_data_read_counts,
this->return_data_read_tags,
this->return_data_inverses);
};

Expand Down Expand Up @@ -264,10 +273,6 @@ class MegaFlavor {
auto get_sigmas() { return RefArray{ this->sigma_1, this->sigma_2, this->sigma_3, this->sigma_4 }; };
auto get_ids() { return RefArray{ this->id_1, this->id_2, this->id_3, this->id_4 }; };
auto get_tables() { return RefArray{ this->table_1, this->table_2, this->table_3, this->table_4 }; };
auto get_ecc_op_wires()
{
return RefArray{ this->ecc_op_wire_1, this->ecc_op_wire_2, this->ecc_op_wire_3, this->ecc_op_wire_4 };
};
// Gemini-specific getters.
auto get_unshifted()
{
Expand Down Expand Up @@ -622,9 +627,11 @@ class MegaFlavor {
ecc_op_wire_4 = "ECC_OP_WIRE_4";
calldata = "CALLDATA";
calldata_read_counts = "CALLDATA_READ_COUNTS";
calldata_read_tags = "CALLDATA_READ_TAGS";
calldata_inverses = "CALLDATA_INVERSES";
return_data = "RETURN_DATA";
return_data_read_counts = "RETURN_DATA_READ_COUNTS";
return_data_read_tags = "RETURN_DATA_READ_TAGS";
return_data_inverses = "RETURN_DATA_INVERSES";

q_c = "Q_C";
Expand Down Expand Up @@ -715,9 +722,11 @@ class MegaFlavor {
this->ecc_op_wire_4 = commitments.ecc_op_wire_4;
this->calldata = commitments.calldata;
this->calldata_read_counts = commitments.calldata_read_counts;
this->calldata_read_tags = commitments.calldata_read_tags;
this->calldata_inverses = commitments.calldata_inverses;
this->return_data = commitments.return_data;
this->return_data_read_counts = commitments.return_data_read_counts;
this->return_data_read_tags = commitments.return_data_read_tags;
this->return_data_inverses = commitments.return_data_inverses;
}
}
Expand Down Expand Up @@ -745,9 +754,11 @@ class MegaFlavor {
Commitment ecc_op_wire_4_comm;
Commitment calldata_comm;
Commitment calldata_read_counts_comm;
Commitment calldata_read_tags_comm;
Commitment calldata_inverses_comm;
Commitment return_data_comm;
Commitment return_data_read_counts_comm;
Commitment return_data_read_tags_comm;
Commitment return_data_inverses_comm;
Commitment w_4_comm;
Commitment z_perm_comm;
Expand Down Expand Up @@ -801,9 +812,11 @@ class MegaFlavor {
ecc_op_wire_4_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
calldata_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
calldata_read_counts_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
calldata_read_tags_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
calldata_inverses_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
return_data_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
return_data_read_counts_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
return_data_read_tags_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
return_data_inverses_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
lookup_read_counts_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
lookup_read_tags_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
Expand Down Expand Up @@ -842,9 +855,11 @@ class MegaFlavor {
serialize_to_buffer(ecc_op_wire_4_comm, proof_data);
serialize_to_buffer(calldata_comm, proof_data);
serialize_to_buffer(calldata_read_counts_comm, proof_data);
serialize_to_buffer(calldata_read_tags_comm, proof_data);
serialize_to_buffer(calldata_inverses_comm, proof_data);
serialize_to_buffer(return_data_comm, proof_data);
serialize_to_buffer(return_data_read_counts_comm, proof_data);
serialize_to_buffer(return_data_read_tags_comm, proof_data);
serialize_to_buffer(return_data_inverses_comm, proof_data);
serialize_to_buffer(lookup_read_counts_comm, proof_data);
serialize_to_buffer(lookup_read_tags_comm, proof_data);
Expand Down
12 changes: 8 additions & 4 deletions cpp/src/barretenberg/sumcheck/instance/prover_instance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,24 @@ void ProverInstance_<Flavor>::construct_databus_polynomials(Circuit& circuit)
{
auto& public_calldata = proving_key.polynomials.calldata;
auto& calldata_read_counts = proving_key.polynomials.calldata_read_counts;
auto& calldata_read_tags = proving_key.polynomials.calldata_read_tags;
auto& public_return_data = proving_key.polynomials.return_data;
auto& return_data_read_counts = proving_key.polynomials.return_data_read_counts;
auto& return_data_read_tags = proving_key.polynomials.return_data_read_tags;

auto calldata = circuit.get_calldata();
auto return_data = circuit.get_return_data();

// Note: We do not utilize a zero row for databus columns
for (size_t idx = 0; idx < calldata.size(); ++idx) {
public_calldata[idx] = circuit.get_variable(calldata[idx]);
calldata_read_counts[idx] = calldata.get_read_count(idx);
public_calldata[idx] = circuit.get_variable(calldata[idx]); // calldata values
calldata_read_counts[idx] = calldata.get_read_count(idx); // read counts
calldata_read_tags[idx] = calldata_read_counts[idx] > 0 ? 1 : 0; // has row been read or not
}
for (size_t idx = 0; idx < return_data.size(); ++idx) {
public_return_data[idx] = circuit.get_variable(return_data[idx]);
return_data_read_counts[idx] = return_data.get_read_count(idx);
public_return_data[idx] = circuit.get_variable(return_data[idx]); // return data values
return_data_read_counts[idx] = return_data.get_read_count(idx); // read counts
return_data_read_tags[idx] = return_data_read_counts[idx] > 0 ? 1 : 0; // has row been read or not
}

auto& databus_id = proving_key.polynomials.databus_id;
Expand Down
43 changes: 43 additions & 0 deletions cpp/src/barretenberg/ultra_honk/databus.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,46 @@ TEST_F(DataBusTests, CallDataAndReturnData)
bool result = construct_and_verify_proof(builder);
EXPECT_TRUE(result);
}

/**
* @brief Test proof construction/verification for a circuit with duplicate calldata reads
*
*/
TEST_F(DataBusTests, CallDataDuplicateRead)
{
// Construct a circuit and add some ecc op gates and arithmetic gates
auto builder = construct_test_builder();

// Add some values to calldata
std::vector<FF> calldata_values = { 7, 10, 3, 12, 1 };
for (auto& val : calldata_values) {
builder.add_public_calldata(builder.add_variable(val));
}

// Define some read indices with a duplicate
std::vector<uint32_t> read_indices = { 1, 4, 1 };

// Create some calldata read gates and store the variable indices of the result for later
std::vector<uint32_t> result_witness_indices;
for (uint32_t& read_idx : read_indices) {
// Create a variable corresponding to the index at which we want to read into calldata
uint32_t read_idx_witness_idx = builder.add_variable(read_idx);

auto value_witness_idx = builder.read_calldata(read_idx_witness_idx);
result_witness_indices.emplace_back(value_witness_idx);
}

// Check that the read result is as expected and that the duplicate reads produce the same result
auto expected_read_result_at_1 = calldata_values[1];
auto expected_read_result_at_4 = calldata_values[4];
auto duplicate_read_result_0 = builder.get_variable(result_witness_indices[0]);
auto duplicate_read_result_1 = builder.get_variable(result_witness_indices[1]);
auto duplicate_read_result_2 = builder.get_variable(result_witness_indices[2]);
EXPECT_EQ(duplicate_read_result_0, expected_read_result_at_1);
EXPECT_EQ(duplicate_read_result_1, expected_read_result_at_4);
EXPECT_EQ(duplicate_read_result_2, expected_read_result_at_1);

// Construct and verify Honk proof
bool result = construct_and_verify_proof(builder);
EXPECT_TRUE(result);
}
Loading

0 comments on commit 5d27e66

Please sign in to comment.