-
Notifications
You must be signed in to change notification settings - Fork 284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: better initialization for permutation mapping components #10750
Changes from 7 commits
624ba85
1022c11
5068390
57810a4
7883d1d
0bf0ad0
6d7c8c5
46564ce
41c5d71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,15 +27,15 @@ | |
namespace bb { | ||
|
||
/** | ||
* @brief cycle_node represents the index of a value of the circuit. | ||
* @brief cycle_node represents the idx of a value of the circuit. | ||
* It will belong to a CyclicPermutation, such that all nodes in a CyclicPermutation | ||
* must have the value. | ||
* The total number of constraints is always <2^32 since that is the type used to represent variables, so we can save | ||
* space by using a type smaller than size_t. | ||
*/ | ||
struct cycle_node { | ||
uint32_t wire_index; | ||
uint32_t gate_index; | ||
uint32_t wire_idx; | ||
uint32_t gate_idx; | ||
}; | ||
|
||
/** | ||
|
@@ -45,40 +45,77 @@ struct cycle_node { | |
* | ||
*/ | ||
struct permutation_subgroup_element { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can this be deleted now? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, thanks There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sadly, this has to stay - its used in a bunch of plonk methods/tests that I cannot bring myself to spend time updating |
||
uint32_t row_index = 0; | ||
uint8_t column_index = 0; | ||
uint32_t row_idx = 0; | ||
uint8_t column_idx = 0; | ||
bool is_public_input = false; | ||
bool is_tag = false; | ||
}; | ||
|
||
/** | ||
* @brief Stores permutation mapping data for a single wire column | ||
* | ||
*/ | ||
struct Mapping { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think the new structure of 4 arrays is better than the original structure of a vector of structs? My thought is its probably slightly more efficient for memory, and therefore slightly better for cache efficiency as a result. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The constructor went from 500ms to 30ms so thats all the proof I need. Having a vector/array of structs makes it impossible to use the slab allocator to allocate memory and difficult to multithread the initialization without either 0 initializing or doing lots of copying to "join" thread local data. I tried to find a slightly cleaner solution but this one was by far the fastest I could come up with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. got it, yeah. Is _allocate_aligned_memory multithreaded? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it doesnt appear to be multithreaded no, but it does seem to allow reuse of a previous slab that was allocated for something else then no longer needed. I also get quite a good speedup from naively allocating a raw buffer but the slab allocator was noticeably better |
||
std::shared_ptr<uint32_t[]> row_idx; // row idx of next entry in copy cycle | ||
std::shared_ptr<uint8_t[]> col_idx; // column idx of next entry in copy cycle | ||
std::shared_ptr<bool[]> is_public_input; | ||
std::shared_ptr<bool[]> is_tag; | ||
size_t _size = 0; | ||
|
||
Mapping() = default; | ||
|
||
size_t size() const { return _size; } | ||
|
||
Mapping(size_t n) | ||
: row_idx(_allocate_aligned_memory<uint32_t>(n)) | ||
, col_idx(_allocate_aligned_memory<uint8_t>(n)) | ||
, is_public_input(_allocate_aligned_memory<bool>(n)) | ||
, is_tag(_allocate_aligned_memory<bool>(n)) | ||
, _size(n) | ||
{} | ||
}; | ||
|
||
template <size_t NUM_WIRES, bool generalized> struct PermutationMapping { | ||
using Mapping = std::array<std::vector<permutation_subgroup_element>, NUM_WIRES>; | ||
Mapping sigmas; | ||
Mapping ids; | ||
std::array<Mapping, NUM_WIRES> sigmas; | ||
std::array<Mapping, NUM_WIRES> ids; | ||
|
||
/** | ||
* @brief Construct a permutation mapping default initialized so every element is in a cycle by itself | ||
* | ||
*/ | ||
PermutationMapping(size_t circuit_size) | ||
{ | ||
|
||
PROFILE_THIS_NAME("PermutationMapping constructor"); | ||
|
||
for (uint8_t col_idx = 0; col_idx < NUM_WIRES; ++col_idx) { | ||
sigmas[col_idx].reserve(circuit_size); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought reserve doesn't actually initialize memory, just allocates it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah that's true, I thought I worked out that there was an extra zero init happening somewhere but now I dont see it. Maybe it was just that with this old approach is was very difficult to implement multithreading without one of either zero initializing or copy overhead |
||
if constexpr (generalized) { | ||
ids[col_idx].reserve(circuit_size); | ||
} | ||
for (size_t wire_idx = 0; wire_idx < NUM_WIRES; ++wire_idx) { | ||
sigmas[wire_idx] = Mapping(circuit_size); | ||
ids[wire_idx] = Mapping(circuit_size); | ||
} | ||
|
||
const size_t num_threads = calculate_num_threads_pow2(circuit_size, /*min_iterations_per_thread=*/1 << 10); | ||
size_t iterations_per_thread = circuit_size / num_threads; // actual iterations per thread | ||
|
||
parallel_for(num_threads, [&](size_t thread_idx) { | ||
uint32_t start = static_cast<uint32_t>(thread_idx * iterations_per_thread); | ||
uint32_t end = static_cast<uint32_t>((thread_idx + 1) * iterations_per_thread); | ||
|
||
// Initialize every element to point to itself | ||
for (uint32_t row_idx = 0; row_idx < circuit_size; ++row_idx) { | ||
permutation_subgroup_element self{ row_idx, col_idx }; | ||
sigmas[col_idx].emplace_back(self); | ||
if constexpr (generalized) { | ||
ids[col_idx].emplace_back(self); | ||
for (uint8_t col_idx = 0; col_idx < NUM_WIRES; ++col_idx) { | ||
for (uint32_t row_idx = start; row_idx < end; ++row_idx) { | ||
auto idx = static_cast<ptrdiff_t>(row_idx); | ||
sigmas[col_idx].row_idx[idx] = row_idx; | ||
sigmas[col_idx].col_idx[idx] = col_idx; | ||
sigmas[col_idx].is_public_input[idx] = false; | ||
sigmas[col_idx].is_tag[idx] = false; | ||
if constexpr (generalized) { | ||
ids[col_idx].row_idx[idx] = row_idx; | ||
ids[col_idx].col_idx[idx] = col_idx; | ||
ids[col_idx].is_public_input[idx] = false; | ||
ids[col_idx].is_tag[idx] = false; | ||
} | ||
} | ||
} | ||
} | ||
}); | ||
} | ||
}; | ||
|
||
|
@@ -105,48 +142,47 @@ PermutationMapping<Flavor::NUM_WIRES, generalized> compute_permutation_mapping( | |
{ | ||
|
||
// Initialize the table of permutations so that every element points to itself | ||
PermutationMapping<Flavor::NUM_WIRES, generalized> mapping{ proving_key->circuit_size }; | ||
PermutationMapping<Flavor::NUM_WIRES, generalized> mapping(proving_key->circuit_size); | ||
|
||
// Represents the index of a variable in circuit_constructor.variables (needed only for generalized) | ||
// Represents the idx of a variable in circuit_constructor.variables (needed only for generalized) | ||
std::span<const uint32_t> real_variable_tags = circuit_constructor.real_variable_tags; | ||
|
||
// Go through each cycle | ||
size_t cycle_index = 0; | ||
for (const auto& copy_cycle : wire_copy_cycles) { | ||
for (size_t node_idx = 0; node_idx < copy_cycle.size(); ++node_idx) { | ||
// Get the indices of the current node and next node in the cycle | ||
const cycle_node& current_cycle_node = copy_cycle[node_idx]; | ||
// If current node is the last one in the cycle, then the next one is the first one | ||
size_t next_cycle_node_index = (node_idx == copy_cycle.size() - 1 ? 0 : node_idx + 1); | ||
const cycle_node& next_cycle_node = copy_cycle[next_cycle_node_index]; | ||
const auto current_row = current_cycle_node.gate_index; | ||
const auto next_row = next_cycle_node.gate_index; | ||
|
||
const auto current_column = current_cycle_node.wire_index; | ||
const auto next_column = static_cast<uint8_t>(next_cycle_node.wire_index); | ||
for (size_t cycle_idx = 0; cycle_idx < wire_copy_cycles.size(); ++cycle_idx) { | ||
const CyclicPermutation& cycle = wire_copy_cycles[cycle_idx]; | ||
for (size_t node_idx = 0; node_idx < cycle.size(); ++node_idx) { | ||
// Get the indices (column, row) of the current node in the cycle | ||
const cycle_node& current_node = cycle[node_idx]; | ||
const auto current_row = static_cast<ptrdiff_t>(current_node.gate_idx); | ||
const auto current_column = current_node.wire_idx; | ||
|
||
// Get indices of next node; If the current node is last in the cycle, then the next is the first one | ||
size_t next_node_idx = (node_idx == cycle.size() - 1 ? 0 : node_idx + 1); | ||
const cycle_node& next_node = cycle[next_node_idx]; | ||
const auto next_row = next_node.gate_idx; | ||
const auto next_column = static_cast<uint8_t>(next_node.wire_idx); | ||
|
||
// Point current node to the next node | ||
mapping.sigmas[current_column][current_row] = { | ||
.row_index = next_row, .column_index = next_column, .is_public_input = false, .is_tag = false | ||
}; | ||
mapping.sigmas[current_column].row_idx[current_row] = next_row; | ||
mapping.sigmas[current_column].col_idx[current_row] = next_column; | ||
|
||
if constexpr (generalized) { | ||
bool first_node = (node_idx == 0); | ||
bool last_node = (next_cycle_node_index == 0); | ||
const bool first_node = (node_idx == 0); | ||
const bool last_node = (next_node_idx == 0); | ||
|
||
if (first_node) { | ||
mapping.ids[current_column][current_row].is_tag = true; | ||
mapping.ids[current_column][current_row].row_index = (real_variable_tags[cycle_index]); | ||
mapping.ids[current_column].is_tag[current_row] = true; | ||
mapping.ids[current_column].row_idx[current_row] = real_variable_tags[cycle_idx]; | ||
} | ||
if (last_node) { | ||
mapping.sigmas[current_column][current_row].is_tag = true; | ||
mapping.sigmas[current_column].is_tag[current_row] = true; | ||
|
||
// TODO(Zac): yikes, std::maps (tau) are expensive. Can we find a way to get rid of this? | ||
mapping.sigmas[current_column][current_row].row_index = | ||
circuit_constructor.tau.at(real_variable_tags[cycle_index]); | ||
mapping.sigmas[current_column].row_idx[current_row] = | ||
circuit_constructor.tau.at(real_variable_tags[cycle_idx]); | ||
} | ||
} | ||
} | ||
cycle_index++; | ||
} | ||
|
||
// Add information about public inputs so that the cycles can be altered later; See the construction of the | ||
|
@@ -158,11 +194,11 @@ PermutationMapping<Flavor::NUM_WIRES, generalized> compute_permutation_mapping( | |
pub_inputs_offset = proving_key->pub_inputs_offset; | ||
} | ||
for (size_t i = 0; i < num_public_inputs; ++i) { | ||
size_t idx = i + pub_inputs_offset; | ||
mapping.sigmas[0][idx].row_index = static_cast<uint32_t>(idx); | ||
mapping.sigmas[0][idx].column_index = 0; | ||
mapping.sigmas[0][idx].is_public_input = true; | ||
if (mapping.sigmas[0][idx].is_tag) { | ||
uint32_t idx = static_cast<uint32_t>(i + pub_inputs_offset); | ||
mapping.sigmas[0].row_idx[static_cast<ptrdiff_t>(idx)] = idx; | ||
mapping.sigmas[0].col_idx[static_cast<ptrdiff_t>(idx)] = 0; | ||
mapping.sigmas[0].is_public_input[static_cast<ptrdiff_t>(idx)] = true; | ||
if (mapping.sigmas[0].is_tag[static_cast<ptrdiff_t>(idx)]) { | ||
std::cerr << "MAPPING IS BOTH A TAG AND A PUBLIC INPUT" << std::endl; | ||
} | ||
} | ||
|
@@ -182,37 +218,40 @@ PermutationMapping<Flavor::NUM_WIRES, generalized> compute_permutation_mapping( | |
template <typename Flavor> | ||
void compute_honk_style_permutation_lagrange_polynomials_from_mapping( | ||
const RefSpan<typename Flavor::Polynomial>& permutation_polynomials, // sigma or ID poly | ||
const std::array<std::vector<permutation_subgroup_element>, Flavor::NUM_WIRES>& permutation_mappings, | ||
const std::array<Mapping, Flavor::NUM_WIRES>& permutation_mappings, | ||
typename Flavor::ProvingKey* proving_key) | ||
{ | ||
using FF = typename Flavor::FF; | ||
const size_t num_gates = proving_key->circuit_size; | ||
|
||
size_t wire_index = 0; | ||
size_t wire_idx = 0; | ||
for (auto& current_permutation_poly : permutation_polynomials) { | ||
ITERATE_OVER_DOMAIN_START(proving_key->evaluation_domain); | ||
const auto& current_mapping = permutation_mappings[wire_index][i]; | ||
if (current_mapping.is_public_input) { | ||
auto idx = static_cast<ptrdiff_t>(i); | ||
const auto& current_row_idx = permutation_mappings[wire_idx].row_idx[idx]; | ||
const auto& current_col_idx = permutation_mappings[wire_idx].col_idx[idx]; | ||
const auto& current_is_tag = permutation_mappings[wire_idx].is_tag[idx]; | ||
const auto& current_is_public_input = permutation_mappings[wire_idx].is_public_input[idx]; | ||
if (current_is_public_input) { | ||
// We intentionally want to break the cycles of the public input variables. | ||
// During the witness generation, the left and right wire polynomials at index i contain the i-th public | ||
// During the witness generation, the left and right wire polynomials at idx i contain the i-th public | ||
// input. The CyclicPermutation created for these variables always start with (i) -> (n+i), followed by | ||
// the indices of the variables in the "real" gates. We make i point to -(i+1), so that the only way of | ||
// repairing the cycle is add the mapping | ||
// -(i+1) -> (n+i) | ||
// These indices are chosen so they can easily be computed by the verifier. They can expect the running | ||
// product to be equal to the "public input delta" that is computed in <honk/utils/grand_product_delta.hpp> | ||
current_permutation_poly.at(i) = | ||
-FF(current_mapping.row_index + 1 + num_gates * current_mapping.column_index); | ||
} else if (current_mapping.is_tag) { | ||
current_permutation_poly.at(i) = -FF(current_row_idx + 1 + num_gates * current_col_idx); | ||
} else if (current_is_tag) { | ||
// Set evaluations to (arbitrary) values disjoint from non-tag values | ||
current_permutation_poly.at(i) = num_gates * Flavor::NUM_WIRES + current_mapping.row_index; | ||
current_permutation_poly.at(i) = num_gates * Flavor::NUM_WIRES + current_row_idx; | ||
} else { | ||
// For the regular permutation we simply point to the next location by setting the evaluation to its | ||
// index | ||
current_permutation_poly.at(i) = FF(current_mapping.row_index + num_gates * current_mapping.column_index); | ||
// idx | ||
current_permutation_poly.at(i) = FF(current_row_idx + num_gates * current_col_idx); | ||
} | ||
ITERATE_OVER_DOMAIN_END; | ||
wire_index++; | ||
wire_idx++; | ||
} | ||
} | ||
} // namespace | ||
|
@@ -226,7 +265,7 @@ void compute_honk_style_permutation_lagrange_polynomials_from_mapping( | |
* | ||
* */ | ||
inline void compute_standard_plonk_lagrange_polynomial(bb::polynomial& output, | ||
const std::vector<permutation_subgroup_element>& permutation, | ||
const Mapping& permutation, | ||
const bb::evaluation_domain& small_domain) | ||
{ | ||
if (output.size() < permutation.size()) { | ||
|
@@ -245,19 +284,19 @@ inline void compute_standard_plonk_lagrange_polynomial(bb::polynomial& output, | |
|
||
ITERATE_OVER_DOMAIN_START(small_domain); | ||
|
||
// `permutation[i]` will specify the 'index' that this wire value will map to. | ||
// Here, 'index' refers to an element of our subgroup H. | ||
// We can almost use `permutation[i]` to directly index our `roots` array, which contains our subgroup elements. | ||
// `permutation[i]` will specify the 'idx' that this wire value will map to. | ||
// Here, 'idx' refers to an element of our subgroup H. | ||
// We can almost use `permutation[i]` to directly idx our `roots` array, which contains our subgroup elements. | ||
// We first have to accommodate for the fact that `roots` only contains *half* of our subgroup elements. This is | ||
// because ω^{n/2} = -ω and we don't want to perform redundant work computing roots of unity. | ||
|
||
size_t raw_idx = permutation[i].row_index; | ||
size_t raw_idx = permutation.row_idx[static_cast<ptrdiff_t>(i)]; | ||
|
||
// Step 1: is `raw_idx` >= (n / 2)? if so, we will need to index `-roots[raw_idx - subgroup_size / 2]` instead | ||
// Step 1: is `raw_idx` >= (n / 2)? if so, we will need to idx `-roots[raw_idx - subgroup_size / 2]` instead | ||
// of `roots[raw_idx]` | ||
const bool negative_idx = raw_idx >= root_size; | ||
|
||
// Step 2: compute the index of the subgroup element we'll be accessing. | ||
// Step 2: compute the idx of the subgroup element we'll be accessing. | ||
// To avoid a conditional branch, we can subtract `negative_idx << log2_root_size` from `raw_idx`. | ||
// Here, `log2_root_size = numeric::get_msb(subgroup_size / 2)` (we know our subgroup size will be a power of 2, | ||
// so we lose no precision here) | ||
|
@@ -269,23 +308,23 @@ inline void compute_standard_plonk_lagrange_polynomial(bb::polynomial& output, | |
// The output will similarly be overloaded (containing either 2 * modulus - w, or modulus - w) | ||
output[i] = roots[idx].conditionally_subtract_from_double_modulus(static_cast<uint64_t>(negative_idx)); | ||
|
||
// Finally, if our permutation maps to an index in either the right wire vector, or the output wire vector, we | ||
// Finally, if our permutation maps to an idx in either the right wire vector, or the output wire vector, we | ||
// need to multiply our result by one of two quadratic non-residues. (This ensures that mapping into the left | ||
// wires gives unique values that are not repeated in the right or output wire permutations) (ditto for right | ||
// wire and output wire mappings) | ||
|
||
if (permutation[i].is_public_input) { | ||
if (permutation.is_public_input[static_cast<ptrdiff_t>(i)]) { | ||
// As per the paper which modifies plonk to include the public inputs in a permutation argument, the permutation | ||
// `σ` is modified to `σ'`, where `σ'` maps all public inputs to a set of l distinct ζ elements which are | ||
// disjoint from H ∪ k1·H ∪ k2·H. | ||
output[i] *= bb::fr::external_coset_generator(); | ||
} else if (permutation[i].is_tag) { | ||
} else if (permutation.is_tag[static_cast<ptrdiff_t>(i)]) { | ||
output[i] *= bb::fr::tag_coset_generator(); | ||
} else { | ||
{ | ||
const uint32_t column_index = permutation[i].column_index; | ||
if (column_index > 0) { | ||
output[i] *= bb::fr::coset_generator(column_index - 1); | ||
const uint32_t column_idx = permutation.col_idx[static_cast<ptrdiff_t>(i)]; | ||
if (column_idx > 0) { | ||
output[i] *= bb::fr::coset_generator(column_idx - 1); | ||
} | ||
} | ||
} | ||
|
@@ -301,16 +340,15 @@ inline void compute_standard_plonk_lagrange_polynomial(bb::polynomial& output, | |
* @param key | ||
*/ | ||
template <size_t program_width> | ||
void compute_plonk_permutation_lagrange_polynomials_from_mapping( | ||
std::string label, | ||
std::array<std::vector<permutation_subgroup_element>, program_width>& mappings, | ||
plonk::proving_key* key) | ||
void compute_plonk_permutation_lagrange_polynomials_from_mapping(std::string label, | ||
std::array<Mapping, program_width>& mappings, | ||
plonk::proving_key* key) | ||
{ | ||
for (size_t i = 0; i < program_width; i++) { | ||
std::string index = std::to_string(i + 1); | ||
std::string idx = std::to_string(i + 1); | ||
bb::polynomial polynomial_lagrange(key->circuit_size); | ||
compute_standard_plonk_lagrange_polynomial(polynomial_lagrange, mappings[i], key->small_domain); | ||
key->polynomial_store.put(label + "_" + index + "_lagrange", polynomial_lagrange.share()); | ||
key->polynomial_store.put(label + "_" + idx + "_lagrange", polynomial_lagrange.share()); | ||
} | ||
} | ||
|
||
|
@@ -327,8 +365,8 @@ template <size_t program_width> | |
void compute_monomial_and_coset_fft_polynomials_from_lagrange(std::string label, plonk::proving_key* key) | ||
{ | ||
for (size_t i = 0; i < program_width; ++i) { | ||
std::string index = std::to_string(i + 1); | ||
std::string prefix = label + "_" + index; | ||
std::string idx = std::to_string(i + 1); | ||
std::string prefix = label + "_" + idx; | ||
|
||
// Construct permutation polynomials in lagrange base | ||
auto sigma_polynomial_lagrange = key->polynomial_store.get(prefix + "_lagrange"); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file was using both
idx
andindex
arbitrarily so I updated to use onlyidx
everywhere