Skip to content

Commit

Permalink
refactor(avm): separate binary and bytes finalization (#8010)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcarreiro authored Aug 15, 2024
1 parent f769f84 commit 3ad6dd9
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 123 deletions.
48 changes: 43 additions & 5 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/binary_trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@

namespace bb::avm_trace {

std::vector<AvmBinaryTraceBuilder::BinaryTraceEntry> AvmBinaryTraceBuilder::finalize()
{
return std::move(binary_trace);
}

void AvmBinaryTraceBuilder::reset()
{
binary_trace.clear();
Expand Down Expand Up @@ -166,4 +161,47 @@ FF AvmBinaryTraceBuilder::op_xor(FF const& a, FF const& b, AvmMemoryTag instr_ta
return uint256_t::from_uint128(c_uint128);
}

void AvmBinaryTraceBuilder::finalize(std::vector<AvmFullRow<FF>>& main_trace)
{
for (size_t i = 0; i < size(); i++) {
auto const& src = binary_trace.at(i);
auto& dest = main_trace.at(i);
dest.binary_clk = src.binary_clk;
dest.binary_sel_bin = static_cast<uint8_t>(src.bin_sel);
dest.binary_acc_ia = src.acc_ia;
dest.binary_acc_ib = src.acc_ib;
dest.binary_acc_ic = src.acc_ic;
dest.binary_in_tag = src.in_tag;
dest.binary_op_id = src.op_id;
dest.binary_ia_bytes = src.bin_ia_bytes;
dest.binary_ib_bytes = src.bin_ib_bytes;
dest.binary_ic_bytes = src.bin_ic_bytes;
dest.binary_start = FF(static_cast<uint8_t>(src.start));
dest.binary_mem_tag_ctr = src.mem_tag_ctr;
dest.binary_mem_tag_ctr_inv = src.mem_tag_ctr_inv;
}

reset();
}

void AvmBinaryTraceBuilder::finalize_lookups(std::vector<AvmFullRow<FF>>& main_trace)
{
for (auto const& [clk, count] : byte_operation_counter) {
main_trace.at(clk).lookup_byte_operations_counts = count;
}

for (uint8_t avm_in_tag = 0; avm_in_tag < 5; avm_in_tag++) {
// The +1 here is because the instruction tags we care about (i.e excl U0 and FF) has the range [1,5]
main_trace.at(avm_in_tag).lookup_byte_lengths_counts = byte_length_counter[avm_in_tag + 1];
}
}

void AvmBinaryTraceBuilder::finalize_lookups_for_testing(std::vector<AvmFullRow<FF>>& main_trace)
{
for (uint8_t avm_in_tag = 0; avm_in_tag < 5; avm_in_tag++) {
// The +1 here is because the instruction tags we care about (i.e excl U0 and FF) has the range [1,5]
main_trace.at(avm_in_tag).lookup_byte_lengths_counts = byte_length_counter[avm_in_tag + 1];
}
}

} // namespace bb::avm_trace
11 changes: 9 additions & 2 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/binary_trace.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "barretenberg/numeric/uint128/uint128.hpp"
#include "barretenberg/vm/avm/generated/full_row.hpp"
#include "barretenberg/vm/avm/trace/common.hpp"

#include <unordered_map>
Expand Down Expand Up @@ -32,9 +33,15 @@ class AvmBinaryTraceBuilder {
std::unordered_map<uint32_t, uint32_t> byte_length_counter;

AvmBinaryTraceBuilder() = default;

size_t size() const { return binary_trace.size(); }
void reset();
// Finalize the trace
std::vector<BinaryTraceEntry> finalize();

// These two have to be separate because the lookups need to be finalized
// after the extra first row is inserted in the main trace.
void finalize(std::vector<AvmFullRow<FF>>& main_trace);
void finalize_lookups(std::vector<AvmFullRow<FF>>& main_trace);
void finalize_lookups_for_testing(std::vector<AvmFullRow<FF>>& main_trace);

FF op_and(FF const& a, FF const& b, AvmMemoryTag instr_tag, uint32_t clk);
FF op_or(FF const& a, FF const& b, AvmMemoryTag instr_tag, uint32_t clk);
Expand Down
92 changes: 92 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/fixed_bytes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include "barretenberg/vm/avm/trace/fixed_bytes.hpp"

namespace bb::avm_trace {

// Singleton.
const FixedBytesTable& FixedBytesTable::get()
{
static FixedBytesTable table;
return table;
}

void FixedBytesTable::finalize(std::vector<AvmFullRow<FF>>& main_trace) const
{
if (main_trace.size() < 3 * (1 << 16)) {
main_trace.resize(3 * (1 << 16));
}
// Generate Lookup Table of all combinations of 2, 8-bit numbers and op_id.
for (uint32_t op_id = 0; op_id < 3; op_id++) {
for (uint32_t input_a = 0; input_a <= UINT8_MAX; input_a++) {
for (uint32_t input_b = 0; input_b <= UINT8_MAX; input_b++) {
auto a = static_cast<uint8_t>(input_a);
auto b = static_cast<uint8_t>(input_b);

// Derive a unique row index given op_id, a, and b.
auto main_trace_index = (op_id << 16) + (input_a << 8) + b;

main_trace.at(main_trace_index).byte_lookup_sel_bin = FF(1);
main_trace.at(main_trace_index).byte_lookup_table_op_id = op_id;
main_trace.at(main_trace_index).byte_lookup_table_input_a = a;
main_trace.at(main_trace_index).byte_lookup_table_input_b = b;
}
}
}

finalize_byte_length(main_trace);
}

void FixedBytesTable::finalize_for_testing(std::vector<AvmFullRow<FF>>& main_trace,
const std::unordered_map<uint32_t, uint32_t>& byte_operation_counter) const
{
// Generate ByteLength Lookup table of instruction tags to the number of bytes
// {U8: 1, U16: 2, U32: 4, U64: 8, U128: 16}
for (auto const& [clk, count] : byte_operation_counter) {
// from the clk we can derive the a and b inputs
auto b = static_cast<uint8_t>(clk);
auto a = static_cast<uint8_t>(clk >> 8);
auto op_id = static_cast<uint8_t>(clk >> 16);
uint8_t bit_op = 0;
if (op_id == 0) {
bit_op = a & b;
} else if (op_id == 1) {
bit_op = a | b;
} else {
bit_op = a ^ b;
}
if (clk > (main_trace.size() - 1)) {
main_trace.push_back(AvmFullRow<FF>{
.byte_lookup_sel_bin = FF(1),
.byte_lookup_table_input_a = a,
.byte_lookup_table_input_b = b,
.byte_lookup_table_op_id = op_id,
.byte_lookup_table_output = bit_op,
.main_clk = FF(clk),
.lookup_byte_operations_counts = count,
});
} else {
main_trace.at(clk).lookup_byte_operations_counts = count;
main_trace.at(clk).byte_lookup_sel_bin = FF(1);
main_trace.at(clk).byte_lookup_table_op_id = op_id;
main_trace.at(clk).byte_lookup_table_input_a = a;
main_trace.at(clk).byte_lookup_table_input_b = b;
main_trace.at(clk).byte_lookup_table_output = bit_op;
}
// Add the counter value stored throughout the execution
}

finalize_byte_length(main_trace);
}

void FixedBytesTable::finalize_byte_length(std::vector<AvmFullRow<FF>>& main_trace)
{
// Generate ByteLength Lookup table of instruction tags to the number of bytes
// {U8: 1, U16: 2, U32: 4, U64: 8, U128: 16}
for (uint8_t avm_in_tag = 0; avm_in_tag < 5; avm_in_tag++) {
// The +1 here is because the instruction tags we care about (i.e excl U0 and FF) has the range 1,5]
main_trace.at(avm_in_tag).byte_lookup_sel_bin = FF(1);
main_trace.at(avm_in_tag).byte_lookup_table_in_tags = avm_in_tag + 1;
main_trace.at(avm_in_tag).byte_lookup_table_byte_lengths = static_cast<uint8_t>(1 << avm_in_tag);
}
}

} // namespace bb::avm_trace
25 changes: 25 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/fixed_bytes.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include <cstddef>
#include <cstdint>

#include "barretenberg/ecc/curves/bn254/fr.hpp"
#include "barretenberg/vm/avm/trace/common.hpp"
#include "barretenberg/vm/avm/trace/opcode.hpp"

namespace bb::avm_trace {

class FixedBytesTable {
public:
static const FixedBytesTable& get();

void finalize(std::vector<AvmFullRow<FF>>& main_trace) const;
void finalize_for_testing(std::vector<AvmFullRow<FF>>& main_trace,
const std::unordered_map<uint32_t, uint32_t>& byte_operation_counter) const;

private:
FixedBytesTable() = default;
static void finalize_byte_length(std::vector<AvmFullRow<FF>>& main_trace);
};

} // namespace bb::avm_trace
139 changes: 23 additions & 116 deletions barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "barretenberg/numeric/uint256/uint256.hpp"
#include "barretenberg/polynomials/univariate.hpp"
#include "barretenberg/vm/avm/trace/common.hpp"
#include "barretenberg/vm/avm/trace/fixed_bytes.hpp"
#include "barretenberg/vm/avm/trace/fixed_gas.hpp"
#include "barretenberg/vm/avm/trace/fixed_powers.hpp"
#include "barretenberg/vm/avm/trace/gadgets/slice_trace.hpp"
Expand All @@ -34,47 +35,6 @@ namespace bb::avm_trace {
* HELPERS IN ANONYMOUS NAMESPACE
**************************************************************************************************/
namespace {
// WARNING: FOR TESTING ONLY
// Generates the minimal lookup table for the binary trace
uint32_t finalize_bin_trace_lookup_for_testing(std::vector<Row>& main_trace, AvmBinaryTraceBuilder& bin_trace_builder)
{
// Generate ByteLength Lookup table of instruction tags to the number of bytes
// {U8: 1, U16: 2, U32: 4, U64: 8, U128: 16}
for (auto const& [clk, count] : bin_trace_builder.byte_operation_counter) {
// from the clk we can derive the a and b inputs
auto b = static_cast<uint8_t>(clk);
auto a = static_cast<uint8_t>(clk >> 8);
auto op_id = static_cast<uint8_t>(clk >> 16);
uint8_t bit_op = 0;
if (op_id == 0) {
bit_op = a & b;
} else if (op_id == 1) {
bit_op = a | b;
} else {
bit_op = a ^ b;
}
if (clk > (main_trace.size() - 1)) {
main_trace.push_back(Row{
.byte_lookup_sel_bin = FF(1),
.byte_lookup_table_input_a = a,
.byte_lookup_table_input_b = b,
.byte_lookup_table_op_id = op_id,
.byte_lookup_table_output = bit_op,
.main_clk = FF(clk),
.lookup_byte_operations_counts = count,
});
} else {
main_trace.at(clk).lookup_byte_operations_counts = count;
main_trace.at(clk).byte_lookup_sel_bin = FF(1);
main_trace.at(clk).byte_lookup_table_op_id = op_id;
main_trace.at(clk).byte_lookup_table_input_a = a;
main_trace.at(clk).byte_lookup_table_input_b = b;
main_trace.at(clk).byte_lookup_table_output = bit_op;
}
// Add the counter value stored throughout the execution
}
return static_cast<uint32_t>(main_trace.size());
}

constexpr size_t L2_HI_GAS_COUNTS_IDX = 0;
constexpr size_t L2_LO_GAS_COUNTS_IDX = 1;
Expand Down Expand Up @@ -3459,7 +3419,6 @@ std::vector<Row> AvmTraceBuilder::finalize(bool range_check_required)
auto poseidon2_trace = poseidon2_trace_builder.finalize();
auto keccak_trace = keccak_trace_builder.finalize();
auto pedersen_trace = pedersen_trace_builder.finalize();
auto bin_trace = bin_trace_builder.finalize();
auto gas_trace = gas_trace_builder.finalize();
auto slice_trace = slice_trace_builder.finalize();
const auto& fixed_gas_table = FixedGasTable::get();
Expand All @@ -3471,7 +3430,7 @@ std::vector<Row> AvmTraceBuilder::finalize(bool range_check_required)
size_t poseidon2_trace_size = poseidon2_trace.size();
size_t keccak_trace_size = keccak_trace.size();
size_t pedersen_trace_size = pedersen_trace.size();
size_t bin_trace_size = bin_trace.size();
size_t bin_trace_size = bin_trace_builder.size();
size_t gas_trace_size = gas_trace.size();
size_t slice_trace_size = slice_trace.size();

Expand All @@ -3480,18 +3439,14 @@ std::vector<Row> AvmTraceBuilder::finalize(bool range_check_required)
std::unordered_map<uint16_t, uint32_t> mem_rng_check_mid_counts;
std::unordered_map<uint8_t, uint32_t> mem_rng_check_hi_counts;

// Main Trace needs to be at least as big as the biggest subtrace.
// If the bin_trace_size has entries, we need the main_trace to be as big as our byte lookup table (3 *
// 2**16 long)
size_t const lookup_table_size = (bin_trace_size > 0 && range_check_required) ? 3 * (1 << 16) : 0;
// Range check size is 1 less than it needs to be since we insert a "first row" at the top of the trace at the
// end, with clk 0 (this doubles as our range check)
size_t const range_check_size = range_check_required ? UINT16_MAX : 0;
std::vector<size_t> trace_sizes = { mem_trace_size, main_trace_size, alu_trace_size,
range_check_size, conv_trace_size, lookup_table_size,
sha256_trace_size, poseidon2_trace_size, pedersen_trace_size,
gas_trace_size + 1, KERNEL_INPUTS_LENGTH, KERNEL_OUTPUTS_LENGTH,
fixed_gas_table.size(), slice_trace_size, calldata.size() };
std::vector<size_t> trace_sizes = { mem_trace_size, main_trace_size, alu_trace_size,
range_check_size, conv_trace_size, sha256_trace_size,
poseidon2_trace_size, pedersen_trace_size, gas_trace_size + 1,
KERNEL_INPUTS_LENGTH, KERNEL_OUTPUTS_LENGTH, fixed_gas_table.size(),
slice_trace_size, calldata.size() };
vinfo("Trace sizes before padding:",
"\n\tmain_trace_size: ",
main_trace_size,
Expand Down Expand Up @@ -3870,70 +3825,7 @@ std::vector<Row> AvmTraceBuilder::finalize(bool range_check_required)
* BINARY TRACE INCLUSION
**********************************************************************************************/

// Add Binary Trace table
for (size_t i = 0; i < bin_trace_size; i++) {
auto const& src = bin_trace.at(i);
auto& dest = main_trace.at(i);
dest.binary_clk = src.binary_clk;
dest.binary_sel_bin = static_cast<uint8_t>(src.bin_sel);
dest.binary_acc_ia = src.acc_ia;
dest.binary_acc_ib = src.acc_ib;
dest.binary_acc_ic = src.acc_ic;
dest.binary_in_tag = src.in_tag;
dest.binary_op_id = src.op_id;
dest.binary_ia_bytes = src.bin_ia_bytes;
dest.binary_ib_bytes = src.bin_ib_bytes;
dest.binary_ic_bytes = src.bin_ic_bytes;
dest.binary_start = FF(static_cast<uint8_t>(src.start));
dest.binary_mem_tag_ctr = src.mem_tag_ctr;
dest.binary_mem_tag_ctr_inv = src.mem_tag_ctr_inv;
}

// Only generate precomputed byte tables if we are actually going to use them in this main trace.
if (bin_trace_size > 0) {
if (!range_check_required) {
finalize_bin_trace_lookup_for_testing(main_trace, bin_trace_builder);
} else {
// Generate Lookup Table of all combinations of 2, 8-bit numbers and op_id.
for (uint32_t op_id = 0; op_id < 3; op_id++) {
for (uint32_t input_a = 0; input_a <= UINT8_MAX; input_a++) {
for (uint32_t input_b = 0; input_b <= UINT8_MAX; input_b++) {
auto a = static_cast<uint8_t>(input_a);
auto b = static_cast<uint8_t>(input_b);

// Derive a unique row index given op_id, a, and b.
auto main_trace_index = (op_id << 16) + (input_a << 8) + b;

main_trace.at(main_trace_index).byte_lookup_sel_bin = FF(1);
main_trace.at(main_trace_index).byte_lookup_table_op_id = op_id;
main_trace.at(main_trace_index).byte_lookup_table_input_a = a;
main_trace.at(main_trace_index).byte_lookup_table_input_b = b;
// Add the counter value stored throughout the execution
main_trace.at(main_trace_index).lookup_byte_operations_counts =
bin_trace_builder.byte_operation_counter[main_trace_index];
if (op_id == 0) {
main_trace.at(main_trace_index).byte_lookup_table_output = a & b;
} else if (op_id == 1) {
main_trace.at(main_trace_index).byte_lookup_table_output = a | b;
} else {
main_trace.at(main_trace_index).byte_lookup_table_output = a ^ b;
}
}
}
}
}
// Generate ByteLength Lookup table of instruction tags to the number of bytes
// {U8: 1, U16: 2, U32: 4, U64: 8, U128: 16}
for (uint8_t avm_in_tag = 0; avm_in_tag < 5; avm_in_tag++) {
// The +1 here is because the instruction tags we care about (i.e excl U0 and FF) has the range
// [1,5]
main_trace.at(avm_in_tag).byte_lookup_sel_bin = FF(1);
main_trace.at(avm_in_tag).byte_lookup_table_in_tags = avm_in_tag + 1;
main_trace.at(avm_in_tag).byte_lookup_table_byte_lengths = static_cast<uint8_t>(pow(2, avm_in_tag));
main_trace.at(avm_in_tag).lookup_byte_lengths_counts =
bin_trace_builder.byte_length_counter[avm_in_tag + 1];
}
}
bin_trace_builder.finalize(main_trace);

/**********************************************************************************************
* GAS TRACE INCLUSION
Expand Down Expand Up @@ -4015,6 +3907,21 @@ std::vector<Row> AvmTraceBuilder::finalize(bool range_check_required)
Row first_row = Row{ .main_sel_first = FF(1), .mem_lastAccess = FF(1) };
main_trace.insert(main_trace.begin(), first_row);

/**********************************************************************************************
* BYTES TRACE INCLUSION
**********************************************************************************************/

// Only generate precomputed byte tables if we are actually going to use them in this main trace.
if (bin_trace_size > 0) {
if (!range_check_required) {
FixedBytesTable::get().finalize_for_testing(main_trace, bin_trace_builder.byte_operation_counter);
bin_trace_builder.finalize_lookups_for_testing(main_trace);
} else {
FixedBytesTable::get().finalize(main_trace);
bin_trace_builder.finalize_lookups(main_trace);
}
}

/**********************************************************************************************
* RANGE CHECKS AND SELECTORS INCLUSION
**********************************************************************************************/
Expand Down

0 comments on commit 3ad6dd9

Please sign in to comment.