Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: databus allows arbitrarily many reads per index #6524

Merged
merged 17 commits into from
Jul 14, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ template <typename FF_> class DatabusLookupRelationImpl {
static auto& inverses(AllEntities& in) { return in.calldata_inverses; }
static auto& inverses(const AllEntities& in) { return in.calldata_inverses; } // const version
static auto& read_counts(const AllEntities& in) { return in.calldata_read_counts; }
static auto& read_tags(const AllEntities& in) { return in.calldata_read_tags; }
};

// Specialization for return data (bus_idx = 1)
Expand All @@ -88,6 +89,7 @@ template <typename FF_> class DatabusLookupRelationImpl {
static auto& inverses(AllEntities& in) { return in.return_data_inverses; }
static auto& inverses(const AllEntities& in) { return in.return_data_inverses; } // const version
static auto& read_counts(const AllEntities& in) { return in.return_data_read_counts; }
static auto& read_tags(const AllEntities& in) { return in.return_data_read_tags; }
};

/**
Expand All @@ -101,8 +103,8 @@ template <typename FF_> class DatabusLookupRelationImpl {
template <size_t bus_idx, typename AllValues> static bool operation_exists_at_row(const AllValues& row)
{
auto read_selector = get_read_selector<FF, bus_idx>(row);
auto read_counts = BusData<bus_idx, AllValues>::read_counts(row);
return (read_selector == 1 || read_counts > 0);
auto read_tag = BusData<bus_idx, AllValues>::read_tags(row);
return (read_selector == 1 || read_tag == 1);
}

/**
Expand All @@ -117,10 +119,10 @@ template <typename FF_> class DatabusLookupRelationImpl {
{
using View = typename Accumulator::View;

const auto is_read_gate = get_read_selector<Accumulator, bus_idx>(in);
const auto read_counts = View(BusData<bus_idx, AllEntities>::read_counts(in));
const auto is_read_gate = get_read_selector<Accumulator, bus_idx>(in); // is this a read gate
const auto read_tag = View(BusData<bus_idx, AllEntities>::read_tags(in)); // does row contain data being read

return is_read_gate + read_counts - (is_read_gate * read_counts);
return is_read_gate + read_tag - (is_read_gate * read_tag);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replacing a value that was boolean with the boolean old value was at least 1

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,16 @@ void ProtoGalaxyRecursiveVerifier_<VerifierInstances>::receive_and_finalise_inst
witness_commitments.w_o = transcript->template receive_from_prover<Commitment>(domain_separator + "_" + labels.w_o);

if constexpr (IsGoblinFlavor<Flavor>) {
witness_commitments.ecc_op_wire_1 =
transcript->template receive_from_prover<Commitment>(domain_separator + "_" + labels.ecc_op_wire_1);
witness_commitments.ecc_op_wire_2 =
transcript->template receive_from_prover<Commitment>(domain_separator + "_" + labels.ecc_op_wire_2);
witness_commitments.ecc_op_wire_3 =
transcript->template receive_from_prover<Commitment>(domain_separator + "_" + labels.ecc_op_wire_3);
witness_commitments.ecc_op_wire_4 =
transcript->template receive_from_prover<Commitment>(domain_separator + "_" + labels.ecc_op_wire_4);
witness_commitments.calldata =
transcript->template receive_from_prover<Commitment>(domain_separator + "_" + labels.calldata);
witness_commitments.calldata_read_counts =
transcript->template receive_from_prover<Commitment>(domain_separator + "_" + labels.calldata_read_counts);
witness_commitments.return_data =
transcript->template receive_from_prover<Commitment>(domain_separator + "_" + labels.return_data);
witness_commitments.return_data_read_counts = transcript->template receive_from_prover<Commitment>(
domain_separator + "_" + labels.return_data_read_counts);
// Receive ECC op wire commitments
for (auto [commitment, label] : zip_view(witness_commitments.get_ecc_op_wires(), labels.get_ecc_op_wires())) {
commitment = transcript->template receive_from_prover<Commitment>(domain_separator + "_" + label);
}

// Receive DataBus related polynomial commitments
for (auto [commitment, label] :
zip_view(witness_commitments.get_databus_entities(), labels.get_databus_entities())) {
commitment = transcript->template receive_from_prover<Commitment>(domain_separator + "_" + label);
}
}

// Get challenge for sorted list batching and wire four memory records commitment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,17 @@ std::array<typename Flavor::GroupElement, 2> UltraRecursiveVerifier_<Flavor>::ve

// If Goblin, get commitments to ECC op wire polynomials and DataBus columns
if constexpr (IsGoblinFlavor<Flavor>) {
commitments.ecc_op_wire_1 =
transcript->template receive_from_prover<Commitment>(commitment_labels.ecc_op_wire_1);
commitments.ecc_op_wire_2 =
transcript->template receive_from_prover<Commitment>(commitment_labels.ecc_op_wire_2);
commitments.ecc_op_wire_3 =
transcript->template receive_from_prover<Commitment>(commitment_labels.ecc_op_wire_3);
commitments.ecc_op_wire_4 =
transcript->template receive_from_prover<Commitment>(commitment_labels.ecc_op_wire_4);
commitments.calldata = transcript->template receive_from_prover<Commitment>(commitment_labels.calldata);
commitments.calldata_read_counts =
transcript->template receive_from_prover<Commitment>(commitment_labels.calldata_read_counts);
commitments.return_data = transcript->template receive_from_prover<Commitment>(commitment_labels.return_data);
commitments.return_data_read_counts =
transcript->template receive_from_prover<Commitment>(commitment_labels.return_data_read_counts);
// Receive ECC op wire commitments
for (auto [commitment, label] :
zip_view(commitments.get_ecc_op_wires(), commitment_labels.get_ecc_op_wires())) {
commitment = transcript->template receive_from_prover<Commitment>(label);
}

// Receive DataBus related polynomial commitments
for (auto [commitment, label] :
zip_view(commitments.get_databus_entities(), commitment_labels.get_databus_entities())) {
commitment = transcript->template receive_from_prover<Commitment>(label);
}
}

// Get challenge for sorted list batching and wire four memory records
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,46 @@ TEST(Databus, BadCopyFailure)
// Since the output data is not a copy of the input, the checker should fail
EXPECT_FALSE(CircuitChecker::check(builder));
}

/**
* @brief Check that multiple reads from the same index results in a valid circuit
*
*/
TEST(Databus, DuplicateRead)
{
Builder builder;
databus_ct databus;

// Define some arbitrary bus data
std::array<bb::fr, 3> raw_calldata_values = { 5, 1, 2 };
std::array<bb::fr, 3> raw_return_data_values = { 25, 6, 3 };

// Populate the calldata in the databus
std::vector<field_ct> calldata_values;
for (auto& value : raw_calldata_values) {
calldata_values.emplace_back(witness_ct(&builder, value));
}
databus.calldata.set_values(calldata_values);

// Populate the return data in the databus
std::vector<field_ct> return_data_values;
for (auto& value : raw_return_data_values) {
return_data_values.emplace_back(witness_ct(&builder, value));
}
databus.return_data.set_values(return_data_values);

// Perform some arbitrary reads from both calldata and return data with some repeated indices
field_ct idx_1(witness_ct(&builder, 1));
field_ct idx_2(witness_ct(&builder, 2));

databus.calldata[idx_1];
databus.calldata[idx_1];
databus.calldata[idx_1];
databus.calldata[idx_2];

databus.return_data[idx_2];
databus.return_data[idx_2];
databus.return_data[idx_1];

EXPECT_TRUE(CircuitChecker::check(builder));
}
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,6 @@ uint32_t MegaCircuitBuilder_<FF>::read_bus_vector(BusId bus_idx, const uint32_t&
const uint32_t read_idx = static_cast<uint32_t>(uint256_t(this->get_variable(read_idx_witness_idx)));

ASSERT(read_idx < bus_vector.size()); // Ensure that the read index is valid
// NOTE(https://github.com/AztecProtocol/barretenberg/issues/937): Multiple reads at same index is not supported.
ASSERT(bus_vector.get_read_count(read_idx) < 1);

// Create a variable corresponding to the result of the read. Note that we do not in general connect reads from
// databus via copy constraints (i.e. we create a unique variable for the result of each read)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ class MegaFlavor {
// The number of multivariate polynomials on which a sumcheck prover sumcheck operates (including shifts). We often
// need containers of this size to hold related data, so we choose a name more agnostic than `NUM_POLYNOMIALS`.
// Note: this number does not include the individual sorted list polynomials.
static constexpr size_t NUM_ALL_ENTITIES = 58;
static constexpr size_t NUM_ALL_ENTITIES = 60;
// The number of polynomials precomputed to describe a circuit and to aid a prover in constructing a satisfying
// assignment of witnesses. We again choose a neutral name.
static constexpr size_t NUM_PRECOMPUTED_ENTITIES = 30;
// The total number of witness entities not including shifts.
static constexpr size_t NUM_WITNESS_ENTITIES = 17;
static constexpr size_t NUM_WITNESS_ENTITIES = 19;
// Total number of folded polynomials, which is just all polynomials except the shifts
static constexpr size_t NUM_FOLDED_ENTITIES = NUM_PRECOMPUTED_ENTITIES + NUM_WITNESS_ENTITIES;

Expand Down Expand Up @@ -188,10 +188,12 @@ class MegaFlavor {
ecc_op_wire_4, // column 10
calldata, // column 11
calldata_read_counts, // column 12
calldata_inverses, // column 13
return_data, // column 14
return_data_read_counts, // column 15
return_data_inverses); // column 16
calldata_read_tags, // column 13
calldata_inverses, // column 14
return_data, // column 15
return_data_read_counts, // column 16
return_data_read_tags, // column 17
return_data_inverses); // column 18
};

/**
Expand All @@ -209,6 +211,11 @@ class MegaFlavor {
{
return RefArray{ this->ecc_op_wire_1, this->ecc_op_wire_2, this->ecc_op_wire_3, this->ecc_op_wire_4 };
}
auto get_databus_entities() // Excludes the derived inverse polynomials
{
return RefArray{ this->calldata, this->calldata_read_counts, this->calldata_read_tags,
this->return_data, this->return_data_read_counts, this->return_data_read_tags };
}
};

template <typename DataType> class ShiftedEntities {
Expand Down Expand Up @@ -250,10 +257,6 @@ class MegaFlavor {
auto get_sigmas() { return RefArray{ this->sigma_1, this->sigma_2, this->sigma_3, this->sigma_4 }; };
auto get_ids() { return RefArray{ this->id_1, this->id_2, this->id_3, this->id_4 }; };
auto get_tables() { return RefArray{ this->table_1, this->table_2, this->table_3, this->table_4 }; };
auto get_ecc_op_wires()
{
return RefArray{ this->ecc_op_wire_1, this->ecc_op_wire_2, this->ecc_op_wire_3, this->ecc_op_wire_4 };
};
// Gemini-specific getters.
auto get_unshifted()
{
Expand Down Expand Up @@ -664,9 +667,11 @@ class MegaFlavor {
ecc_op_wire_4 = "ECC_OP_WIRE_4";
calldata = "CALLDATA";
calldata_read_counts = "CALLDATA_READ_COUNTS";
calldata_read_tags = "CALLDATA_READ_TAGS";
calldata_inverses = "CALLDATA_INVERSES";
return_data = "RETURN_DATA";
return_data_read_counts = "RETURN_DATA_READ_COUNTS";
return_data_read_tags = "RETURN_DATA_READ_TAGS";
return_data_inverses = "RETURN_DATA_INVERSES";

q_c = "Q_C";
Expand Down Expand Up @@ -756,9 +761,11 @@ class MegaFlavor {
this->ecc_op_wire_4 = commitments.ecc_op_wire_4;
this->calldata = commitments.calldata;
this->calldata_read_counts = commitments.calldata_read_counts;
this->calldata_read_tags = commitments.calldata_read_tags;
this->calldata_inverses = commitments.calldata_inverses;
this->return_data = commitments.return_data;
this->return_data_read_counts = commitments.return_data_read_counts;
this->return_data_read_tags = commitments.return_data_read_tags;
this->return_data_inverses = commitments.return_data_inverses;
}
}
Expand Down Expand Up @@ -786,9 +793,11 @@ class MegaFlavor {
Commitment ecc_op_wire_4_comm;
Commitment calldata_comm;
Commitment calldata_read_counts_comm;
Commitment calldata_read_tags_comm;
Commitment calldata_inverses_comm;
Commitment return_data_comm;
Commitment return_data_read_counts_comm;
Commitment return_data_read_tags_comm;
Commitment return_data_inverses_comm;
Commitment sorted_accum_comm;
Commitment w_4_comm;
Expand Down Expand Up @@ -842,9 +851,11 @@ class MegaFlavor {
ecc_op_wire_4_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
calldata_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
calldata_read_counts_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
calldata_read_tags_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
calldata_inverses_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
return_data_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
return_data_read_counts_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
return_data_read_tags_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
return_data_inverses_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
sorted_accum_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
w_4_comm = deserialize_from_buffer<Commitment>(proof_data, num_frs_read);
Expand Down Expand Up @@ -883,9 +894,11 @@ class MegaFlavor {
serialize_to_buffer(ecc_op_wire_4_comm, proof_data);
serialize_to_buffer(calldata_comm, proof_data);
serialize_to_buffer(calldata_read_counts_comm, proof_data);
serialize_to_buffer(calldata_read_tags_comm, proof_data);
serialize_to_buffer(calldata_inverses_comm, proof_data);
serialize_to_buffer(return_data_comm, proof_data);
serialize_to_buffer(return_data_read_counts_comm, proof_data);
serialize_to_buffer(return_data_read_tags_comm, proof_data);
serialize_to_buffer(return_data_inverses_comm, proof_data);
serialize_to_buffer(sorted_accum_comm, proof_data);
serialize_to_buffer(w_4_comm, proof_data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,24 @@ void ProverInstance_<Flavor>::construct_databus_polynomials(Circuit& circuit)
{
auto& public_calldata = proving_key.polynomials.calldata;
auto& calldata_read_counts = proving_key.polynomials.calldata_read_counts;
auto& calldata_read_tags = proving_key.polynomials.calldata_read_tags;
auto& public_return_data = proving_key.polynomials.return_data;
auto& return_data_read_counts = proving_key.polynomials.return_data_read_counts;
auto& return_data_read_tags = proving_key.polynomials.return_data_read_tags;

auto calldata = circuit.get_calldata();
auto return_data = circuit.get_return_data();

// Note: We do not utilize a zero row for databus columns
for (size_t idx = 0; idx < calldata.size(); ++idx) {
public_calldata[idx] = circuit.get_variable(calldata[idx]);
calldata_read_counts[idx] = calldata.get_read_count(idx);
public_calldata[idx] = circuit.get_variable(calldata[idx]); // calldata values
calldata_read_counts[idx] = calldata.get_read_count(idx); // read counts
calldata_read_tags[idx] = calldata_read_counts[idx] > 0 ? 1 : 0; // has row been read or not
}
for (size_t idx = 0; idx < return_data.size(); ++idx) {
public_return_data[idx] = circuit.get_variable(return_data[idx]);
return_data_read_counts[idx] = return_data.get_read_count(idx);
public_return_data[idx] = circuit.get_variable(return_data[idx]); // return data values
return_data_read_counts[idx] = return_data.get_read_count(idx); // read counts
return_data_read_tags[idx] = return_data_read_counts[idx] > 0 ? 1 : 0; // has row been read or not
}

auto& databus_id = proving_key.polynomials.databus_id;
Expand Down
40 changes: 40 additions & 0 deletions barretenberg/cpp/src/barretenberg/ultra_honk/databus.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,43 @@ TEST_F(DataBusTests, CallDataAndReturnData)
bool result = construct_and_verify_proof(builder);
EXPECT_TRUE(result);
}

/**
* @brief Test proof construction/verification for a circuit with duplicate calldata reads
*
*/
TEST_F(DataBusTests, CallDataDuplicateRead)
{
// Construct a circuit and add some ecc op gates and arithmetic gates
auto builder = construct_test_builder();

// Add some values to calldata
std::vector<FF> calldata_values = { 7, 10, 3, 12, 1 };
for (auto& val : calldata_values) {
builder.add_public_calldata(builder.add_variable(val));
}

// Define some read indices with a duplicate
std::vector<uint32_t> read_indices = { 1, 4, 1 };

// Create some calldata read gates and store the variable indices of the result for later
std::vector<uint32_t> result_witness_indices;
for (uint32_t& read_idx : read_indices) {
// Create a variable corresponding to the index at which we want to read into calldata
uint32_t read_idx_witness_idx = builder.add_variable(read_idx);

auto value_witness_idx = builder.read_calldata(read_idx_witness_idx);
result_witness_indices.emplace_back(value_witness_idx);
}

// Check that the read result is as axpected and that the duplicate reads produce the same result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: "expected"

auto expected_read_result = calldata_values[1];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an illustration of the functionality this could be a tad more explicit; will update.

auto duplicate_read_result_1 = builder.get_variable(result_witness_indices[0]);
auto duplicate_read_result_2 = builder.get_variable(result_witness_indices[2]);
EXPECT_EQ(duplicate_read_result_1, expected_read_result);
EXPECT_EQ(duplicate_read_result_1, duplicate_read_result_2);

// Construct and verify Honk proof
bool result = construct_and_verify_proof(builder);
EXPECT_TRUE(result);
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ class MegaTranscriptTests : public ::testing::Test {
manifest_expected.add_entry(round, "ECC_OP_WIRE_4", frs_per_G);
manifest_expected.add_entry(round, "CALLDATA", frs_per_G);
manifest_expected.add_entry(round, "CALLDATA_READ_COUNTS", frs_per_G);
manifest_expected.add_entry(round, "CALLDATA_READ_TAGS", frs_per_G);
manifest_expected.add_entry(round, "RETURN_DATA", frs_per_G);
manifest_expected.add_entry(round, "RETURN_DATA_READ_COUNTS", frs_per_G);
manifest_expected.add_entry(round, "RETURN_DATA_READ_TAGS", frs_per_G);
manifest_expected.add_challenge(round, "eta", "eta_two", "eta_three");

round++;
Expand Down
Loading
Loading