Skip to content

Commit

Permalink
chore: lookups cleanup/documentation (#7002)
Browse files Browse the repository at this point in the history
Adding some comments and improving some naming in code having to do with
lookups.

Note: the BasicTable struct had a `size` member that had to be set
manually. This seems extremely error prone. I updated this to use a
`size()` method that checks the size of the first table column.

(This work stems from notes-to-self that I made while perusing the
lookup code in preparation to convert to a log-derivative argument).
  • Loading branch information
ledwards2225 authored Jun 14, 2024
1 parent a20d845 commit 92b1349
Show file tree
Hide file tree
Showing 22 changed files with 186 additions and 144 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ template <typename Builder> bool UltraCircuitChecker::check(const Builder& build
LookupHashTable lookup_hash_table;
for (const auto& table : builder.lookup_tables) {
const FF table_index(table.table_index);
for (size_t i = 0; i < table.size; ++i) {
for (size_t i = 0; i < table.size(); ++i) {
lookup_hash_table.insert({ table.column_1[i], table.column_2[i], table.column_3[i], table_index });
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void construct_lookup_table_polynomials(RefArray<typename Flavor::Polynomial, 4>
for (const auto& table : circuit.lookup_tables) {
const fr table_index(table.table_index);

for (size_t i = 0; i < table.size; ++i) {
for (size_t i = 0; i < table.size(); ++i) {
table_polynomials[0][offset] = table.column_1[i];
table_polynomials[1][offset] = table.column_2[i];
table_polynomials[2][offset] = table.column_3[i];
Expand Down Expand Up @@ -67,7 +67,7 @@ std::array<typename Flavor::Polynomial, 4> construct_sorted_list_polynomials(typ
for (auto& table : circuit.lookup_tables) {
const fr table_index(table.table_index);
auto& lookup_gates = table.lookup_gates;
for (size_t i = 0; i < table.size; ++i) {
for (size_t i = 0; i < table.size(); ++i) {
if (table.use_twin_keys) {
lookup_gates.push_back({
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ field_t<Builder> keccak<Builder>::normalize_and_rotate(const field_ct& limb, fie
// We need to provide a key/value object for this lookup in order for the Builder
// to compute the plookup sorted list commitment
const auto [input_quotient, input_slice] = input.divmod(divisor);
lookup.key_entries.push_back(
lookup.lookup_entries.push_back(
{ { static_cast<uint64_t>(input_slice), 0 }, { normalized_slice, normalized_msb } });

// reduce the input and output by 11^{bit_slice}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ field_t<Builder> plookup_read<Builder>::read_from_2_to_1_table(const MultiTableI
const field_t<Builder>& key_a,
const field_t<Builder>& key_b)
{
const auto lookup = get_lookup_accumulators(id, key_a, key_b, true);
const auto lookup = get_lookup_accumulators(id, key_a, key_b, /*is_2_to_1_lookup=*/true);

return lookup[ColumnIdx::C3][0];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ inline BasicTable generate_aes_sparse_table(BasicTableId id, const size_t table_
BasicTable table;
table.id = id;
table.table_index = table_index;
table.size = 256;
size_t table_size = 256;
table.use_twin_keys = true;
for (uint64_t i = 0; i < table.size; ++i) {
for (uint64_t i = 0; i < table_size; ++i) {
uint64_t left = i;
const auto right = numeric::map_into_sparse_form<AES_BASE>((uint8_t)i);
table.column_1.emplace_back(bb::fr(left));
Expand Down Expand Up @@ -74,7 +74,6 @@ inline BasicTable generate_aes_sparse_normalization_table(BasicTableId id, const
}
}
}
table.size = table.column_1.size();
table.use_twin_keys = false;
table.get_values_from_key = &get_aes_sparse_normalization_values_from_key;

Expand Down Expand Up @@ -102,7 +101,7 @@ inline MultiTable get_aes_normalization_table(const MultiTableId id = AES_NORMAL
table.id = id;
for (size_t i = 0; i < num_entries; ++i) {
table.slice_sizes.emplace_back(AES_BASE * AES_BASE * AES_BASE * AES_BASE);
table.lookup_ids.emplace_back(AES_SPARSE_NORMALIZE);
table.basic_table_ids.emplace_back(AES_SPARSE_NORMALIZE);
table.get_table_values.emplace_back(&get_aes_sparse_normalization_values_from_key);
}
return table;
Expand All @@ -117,7 +116,7 @@ inline MultiTable get_aes_input_table(const MultiTableId id = AES_INPUT)
table.id = id;
for (size_t i = 0; i < num_entries; ++i) {
table.slice_sizes.emplace_back(256);
table.lookup_ids.emplace_back(AES_SPARSE_MAP);
table.basic_table_ids.emplace_back(AES_SPARSE_MAP);
table.get_table_values.emplace_back(&sparse_tables::get_sparse_table_with_rotation_values<AES_BASE, 0>);
}
return table;
Expand All @@ -137,9 +136,9 @@ inline BasicTable generate_aes_sbox_table(BasicTableId id, const size_t table_in
BasicTable table;
table.id = id;
table.table_index = table_index;
table.size = 256;
size_t table_size = 256;
table.use_twin_keys = false;
for (uint64_t i = 0; i < table.size; ++i) {
for (uint64_t i = 0; i < table_size; ++i) {
const auto first = numeric::map_into_sparse_form<AES_BASE>((uint8_t)i);
uint8_t sbox_value = crypto::aes128_sbox[(uint8_t)i];
uint8_t swizzled = ((uint8_t)(sbox_value << 1) ^ (uint8_t)(((sbox_value >> 7) & 1) * 0x1b));
Expand Down Expand Up @@ -167,7 +166,7 @@ inline MultiTable get_aes_sbox_table(const MultiTableId id = AES_SBOX)
table.id = id;
for (size_t i = 0; i < num_entries; ++i) {
table.slice_sizes.emplace_back(numeric::pow64(AES_BASE, 8));
table.lookup_ids.emplace_back(AES_SBOX_MAP);
table.basic_table_ids.emplace_back(AES_SBOX_MAP);
table.get_table_values.emplace_back(&get_aes_sbox_values_from_key);
}
return table;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ inline BasicTable generate_xor_rotate_table(BasicTableId id, const size_t table_
BasicTable table;
table.id = id;
table.table_index = table_index;
table.size = base * base;
table.use_twin_keys = true;

for (uint64_t i = 0; i < base; ++i) {
Expand Down Expand Up @@ -95,12 +94,12 @@ inline MultiTable get_blake2s_xor_table(const MultiTableId id = BLAKE_XOR)
table.id = id;
for (size_t i = 0; i < num_entries - 1; ++i) {
table.slice_sizes.emplace_back(base);
table.lookup_ids.emplace_back(BLAKE_XOR_ROTATE0);
table.basic_table_ids.emplace_back(BLAKE_XOR_ROTATE0);
table.get_table_values.emplace_back(&get_xor_rotate_values_from_key<6, 0>);
}

table.slice_sizes.emplace_back(SIZE_OF_LAST_SLICE);
table.lookup_ids.emplace_back(BLAKE_XOR_ROTATE0_SLICE5_MOD4);
table.basic_table_ids.emplace_back(BLAKE_XOR_ROTATE0_SLICE5_MOD4);
table.get_table_values.emplace_back(&get_xor_rotate_values_from_key<BITS_IN_LAST_SLICE, 0, true>);

return table;
Expand Down Expand Up @@ -128,8 +127,8 @@ inline MultiTable get_blake2s_xor_rotate_16_table(const MultiTableId id = BLAKE_

table.id = id;
table.slice_sizes = { base, base, base, base, base, SIZE_OF_LAST_SLICE };
table.lookup_ids = { BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE4,
BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE0_SLICE5_MOD4 };
table.basic_table_ids = { BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE4,
BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE0_SLICE5_MOD4 };

table.get_table_values.emplace_back(&get_xor_rotate_values_from_key<6, 0>);
table.get_table_values.emplace_back(&get_xor_rotate_values_from_key<6, 0>);
Expand Down Expand Up @@ -163,8 +162,8 @@ inline MultiTable get_blake2s_xor_rotate_8_table(const MultiTableId id = BLAKE_X

table.id = id;
table.slice_sizes = { base, base, base, base, base, SIZE_OF_LAST_SLICE };
table.lookup_ids = { BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE2, BLAKE_XOR_ROTATE0,
BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE0_SLICE5_MOD4 };
table.basic_table_ids = { BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE2, BLAKE_XOR_ROTATE0,
BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE0_SLICE5_MOD4 };

table.get_table_values.emplace_back(&get_xor_rotate_values_from_key<6, 0>);
table.get_table_values.emplace_back(&get_xor_rotate_values_from_key<6, 2>);
Expand Down Expand Up @@ -198,8 +197,8 @@ inline MultiTable get_blake2s_xor_rotate_7_table(const MultiTableId id = BLAKE_X

table.id = id;
table.slice_sizes = { base, base, base, base, base, SIZE_OF_LAST_SLICE };
table.lookup_ids = { BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE1, BLAKE_XOR_ROTATE0,
BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE0_SLICE5_MOD4 };
table.basic_table_ids = { BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE1, BLAKE_XOR_ROTATE0,
BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE0, BLAKE_XOR_ROTATE0_SLICE5_MOD4 };

table.get_table_values.emplace_back(&get_xor_rotate_values_from_key<6, 0>);
table.get_table_values.emplace_back(&get_xor_rotate_values_from_key<6, 1>);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ inline BasicTable generate_honk_dummy_table(const BasicTableId id, const size_t
BasicTable table;
table.id = id;
table.table_index = table_index;
table.size = base * base;
table.use_twin_keys = true;
for (uint64_t i = 0; i < base; ++i) {
for (uint64_t j = 0; j < base; ++j) {
Expand Down Expand Up @@ -85,10 +84,10 @@ inline MultiTable get_honk_dummy_multitable()
number_of_lookups);
table.id = id;
table.slice_sizes.emplace_back(number_of_elements_in_argument);
table.lookup_ids.emplace_back(HONK_DUMMY_BASIC1);
table.basic_table_ids.emplace_back(HONK_DUMMY_BASIC1);
table.get_table_values.emplace_back(&get_value_from_key<HONK_DUMMY_BASIC1>);
table.slice_sizes.emplace_back(number_of_elements_in_argument);
table.lookup_ids.emplace_back(HONK_DUMMY_BASIC2);
table.basic_table_ids.emplace_back(HONK_DUMMY_BASIC2);
table.get_table_values.emplace_back(&get_value_from_key<HONK_DUMMY_BASIC2>);
return table;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,11 @@ BasicTable table::generate_basic_fixed_base_table(BasicTableId id, size_t basic_
BasicTable table;
table.id = id;
table.table_index = basic_table_index;
table.size = table_size;
table.use_twin_keys = false;

const auto& basic_table = fixed_base_tables[multitable_index][table_index];

for (size_t i = 0; i < table.size; ++i) {
for (size_t i = 0; i < table_size; ++i) {
table.column_1.emplace_back(i);
table.column_2.emplace_back(basic_table[i].x);
table.column_3.emplace_back(basic_table[i].y);
Expand All @@ -213,7 +212,7 @@ BasicTable table::generate_basic_fixed_base_table(BasicTableId id, size_t basic_
table.get_values_from_key = get_values_from_key_table[multitable_index][table_index];

ASSERT(table.get_values_from_key != nullptr);
table.column_1_step_size = table.size;
table.column_1_step_size = table_size;
table.column_2_step_size = 0;
table.column_3_step_size = 0;

Expand Down Expand Up @@ -243,13 +242,13 @@ template <size_t multitable_index, size_t num_bits> MultiTable table::get_fixed_
MultiTable table(MAX_TABLE_SIZE, 0, 0, NUM_TABLES);
table.id = id;
table.get_table_values.resize(NUM_TABLES);
table.lookup_ids.resize(NUM_TABLES);
table.basic_table_ids.resize(NUM_TABLES);
for (size_t i = 0; i < NUM_TABLES; ++i) {
table.slice_sizes.emplace_back(MAX_TABLE_SIZE);
table.get_table_values[i] = get_values_from_key_table[multitable_index][i];
static_assert(multitable_index < NUM_FIXED_BASE_MULTI_TABLES);
size_t idx = i + static_cast<size_t>(basic_table_ids[multitable_index]);
table.lookup_ids[i] = static_cast<plookup::BasicTableId>(idx);
table.basic_table_ids[i] = static_cast<plookup::BasicTableId>(idx);
}
return table;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,11 @@ class Chi {
table.id = id;
table.table_index = table_index;
table.use_twin_keys = false;
table.size = numeric::pow64(static_cast<uint64_t>(EFFECTIVE_BASE), TABLE_BITS);
auto table_size = numeric::pow64(static_cast<uint64_t>(EFFECTIVE_BASE), TABLE_BITS);

std::array<size_t, TABLE_BITS> counts{};
std::array<uint64_t, 3> column_values{ 0, 0, 0 };
for (size_t i = 0; i < table.size; ++i) {
for (size_t i = 0; i < table_size; ++i) {
table.column_1.emplace_back(column_values[0]);
table.column_2.emplace_back(column_values[1]);
table.column_3.emplace_back(column_values[2]);
Expand Down Expand Up @@ -242,7 +242,7 @@ class Chi {
table.id = id;
for (size_t i = 0; i < num_tables_per_multitable; ++i) {
table.slice_sizes.emplace_back(numeric::pow64(BASE, TABLE_BITS));
table.lookup_ids.emplace_back(KECCAK_CHI);
table.basic_table_ids.emplace_back(KECCAK_CHI);
table.get_table_values.emplace_back(&get_chi_renormalization_values);
}
return table;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ class KeccakInput {
BasicTable table;
table.id = id;
table.table_index = table_index;
table.size = (1U << TABLE_BITS);
auto table_size = (1U << TABLE_BITS);
table.use_twin_keys = false;
constexpr size_t msb_shift = (64 % TABLE_BITS == 0) ? TABLE_BITS - 1 : (64 % TABLE_BITS) - 1;

for (uint64_t i = 0; i < table.size; ++i) {
for (uint64_t i = 0; i < table_size; ++i) {
const uint64_t source = i;
const auto target = numeric::map_into_sparse_form<BASE>(source);
table.column_1.emplace_back(bb::fr(source));
Expand Down Expand Up @@ -132,7 +132,7 @@ class KeccakInput {
table.id = id;
for (size_t i = 0; i < num_entries; ++i) {
table.slice_sizes.emplace_back(1 << 8);
table.lookup_ids.emplace_back(KECCAK_INPUT);
table.basic_table_ids.emplace_back(KECCAK_INPUT);
table.get_table_values.emplace_back(&get_keccak_input_values);
}
return table;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ class KeccakOutput {
table.id = id;
table.table_index = table_index;
table.use_twin_keys = false;
table.size = numeric::pow64(static_cast<uint64_t>(EFFECTIVE_BASE), TABLE_BITS);
auto table_size = numeric::pow64(static_cast<uint64_t>(EFFECTIVE_BASE), TABLE_BITS);

std::array<size_t, TABLE_BITS> counts{};
std::array<uint64_t, 2> column_values{ 0, 0 };

for (size_t i = 0; i < table.size; ++i) {
for (size_t i = 0; i < table_size; ++i) {
table.column_1.emplace_back(column_values[0]);
table.column_2.emplace_back(column_values[1]);
table.column_3.emplace_back(0);
Expand Down Expand Up @@ -162,7 +162,7 @@ class KeccakOutput {
table.id = id;
for (size_t i = 0; i < num_tables_per_multitable; ++i) {
table.slice_sizes.emplace_back(numeric::pow64(BASE, TABLE_BITS));
table.lookup_ids.emplace_back(KECCAK_OUTPUT);
table.basic_table_ids.emplace_back(KECCAK_OUTPUT);
table.get_table_values.emplace_back(
&sparse_tables::get_sparse_normalization_values<BASE, OUTPUT_NORMALIZATION_TABLE>);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,12 @@ template <size_t TABLE_BITS = 0, size_t LANE_INDEX = 0> class Rho {
table.id = id;
table.table_index = table_index;
table.use_twin_keys = false;
table.size = numeric::pow64(static_cast<uint64_t>(EFFECTIVE_BASE), TABLE_BITS);
auto table_size = numeric::pow64(static_cast<uint64_t>(EFFECTIVE_BASE), TABLE_BITS);

std::array<size_t, TABLE_BITS> counts{};
std::array<uint64_t, 3> column_values{ 0, 0, 0 };

for (size_t i = 0; i < table.size; ++i) {
for (size_t i = 0; i < table_size; ++i) {
table.column_1.emplace_back(column_values[0]);
table.column_2.emplace_back(column_values[1]);
table.column_3.emplace_back(column_values[2]);
Expand Down Expand Up @@ -264,7 +264,7 @@ template <size_t TABLE_BITS = 0, size_t LANE_INDEX = 0> class Rho {

table.slice_sizes.push_back(scaled_base);
table.get_table_values.emplace_back(&get_rho_renormalization_values);
table.lookup_ids.push_back((BasicTableId)((size_t)KECCAK_RHO_1 + (bit_slice - 1)));
table.basic_table_ids.push_back((BasicTableId)((size_t)KECCAK_RHO_1 + (bit_slice - 1)));
});

// generate table selector values for the 'left' slice
Expand All @@ -284,7 +284,7 @@ template <size_t TABLE_BITS = 0, size_t LANE_INDEX = 0> class Rho {

table.slice_sizes.push_back(scaled_base);
table.get_table_values.emplace_back(&get_rho_renormalization_values);
table.lookup_ids.push_back((BasicTableId)((size_t)KECCAK_RHO_1 + (bit_slice - 1)));
table.basic_table_ids.push_back((BasicTableId)((size_t)KECCAK_RHO_1 + (bit_slice - 1)));
});

return table;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,12 +171,12 @@ class Theta {
table.id = id;
table.table_index = table_index;
table.use_twin_keys = false;
table.size = numeric::pow64(static_cast<uint64_t>(BASE), TABLE_BITS);
auto table_size = numeric::pow64(static_cast<uint64_t>(BASE), TABLE_BITS);

std::array<size_t, TABLE_BITS> counts{};
std::array<uint64_t, 2> column_values{ 0, 0 };

for (size_t i = 0; i < table.size; ++i) {
for (size_t i = 0; i < table_size; ++i) {
table.column_1.emplace_back(column_values[0]);
table.column_2.emplace_back(column_values[1]);
table.column_3.emplace_back(0);
Expand Down Expand Up @@ -244,7 +244,7 @@ class Theta {
table.id = id;
for (size_t i = 0; i < num_tables_per_multitable; ++i) {
table.slice_sizes.emplace_back(numeric::pow64(BASE, TABLE_BITS));
table.lookup_ids.emplace_back(KECCAK_THETA);
table.basic_table_ids.emplace_back(KECCAK_THETA);
table.get_table_values.emplace_back(&get_theta_renormalization_values);
}
return table;
Expand Down
Loading

0 comments on commit 92b1349

Please sign in to comment.