Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: better initialization for permutation mapping components #10750

Merged
merged 9 commits into from
Dec 17, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@
namespace bb {

/**
* @brief cycle_node represents the index of a value of the circuit.
* @brief cycle_node represents the idx of a value of the circuit.
* It will belong to a CyclicPermutation, such that all nodes in a CyclicPermutation
* must have the value.
* The total number of constraints is always <2^32 since that is the type used to represent variables, so we can save
* space by using a type smaller than size_t.
*/
struct cycle_node {
uint32_t wire_index;
uint32_t gate_index;
uint32_t wire_idx;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file was using both idx and index arbitrarily so I updated to use only idx everywhere

uint32_t gate_idx;
};

/**
Expand All @@ -45,40 +45,77 @@ struct cycle_node {
*
*/
struct permutation_subgroup_element {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be deleted now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sadly, this has to stay - its used in a bunch of plonk methods/tests that I cannot bring myself to spend time updating

uint32_t row_index = 0;
uint8_t column_index = 0;
uint32_t row_idx = 0;
uint8_t column_idx = 0;
bool is_public_input = false;
bool is_tag = false;
};

/**
* @brief Stores permutation mapping data for a single wire column
*
*/
struct Mapping {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think the new structure of 4 arrays is better than the original structure of a vector of structs? My thought is its probably slightly more efficient for memory, and therefore slightly better for cache efficiency as a result.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constructor went from 500ms to 30ms so thats all the proof I need. Having a vector/array of structs makes it impossible to use the slab allocator to allocate memory and difficult to multithread the initialization without either 0 initializing or doing lots of copying to "join" thread local data. I tried to find a slightly cleaner solution but this one was by far the fastest I could come up with

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, yeah. Is _allocate_aligned_memory multithreaded?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesnt appear to be multithreaded no, but it does seem to allow reuse of a previous slab that was allocated for something else then no longer needed. I also get quite a good speedup from naively allocating a raw buffer but the slab allocator was noticeably better

std::shared_ptr<uint32_t[]> row_idx; // row idx of next entry in copy cycle
std::shared_ptr<uint8_t[]> col_idx; // column idx of next entry in copy cycle
std::shared_ptr<bool[]> is_public_input;
std::shared_ptr<bool[]> is_tag;
size_t _size = 0;

Mapping() = default;

size_t size() const { return _size; }

Mapping(size_t n)
: row_idx(_allocate_aligned_memory<uint32_t>(n))
, col_idx(_allocate_aligned_memory<uint8_t>(n))
, is_public_input(_allocate_aligned_memory<bool>(n))
, is_tag(_allocate_aligned_memory<bool>(n))
, _size(n)
{}
};

template <size_t NUM_WIRES, bool generalized> struct PermutationMapping {
using Mapping = std::array<std::vector<permutation_subgroup_element>, NUM_WIRES>;
Mapping sigmas;
Mapping ids;
std::array<Mapping, NUM_WIRES> sigmas;
std::array<Mapping, NUM_WIRES> ids;

/**
* @brief Construct a permutation mapping default initialized so every element is in a cycle by itself
*
*/
PermutationMapping(size_t circuit_size)
{

PROFILE_THIS_NAME("PermutationMapping constructor");

for (uint8_t col_idx = 0; col_idx < NUM_WIRES; ++col_idx) {
sigmas[col_idx].reserve(circuit_size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought reserve doesn't actually initialize memory, just allocates it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's true, I thought I worked out that there was an extra zero init happening somewhere but now I dont see it. Maybe it was just that with this old approach is was very difficult to implement multithreading without one of either zero initializing or copy overhead

if constexpr (generalized) {
ids[col_idx].reserve(circuit_size);
}
for (size_t wire_idx = 0; wire_idx < NUM_WIRES; ++wire_idx) {
sigmas[wire_idx] = Mapping(circuit_size);
ids[wire_idx] = Mapping(circuit_size);
}

const size_t num_threads = calculate_num_threads_pow2(circuit_size, /*min_iterations_per_thread=*/1 << 10);
size_t iterations_per_thread = circuit_size / num_threads; // actual iterations per thread

parallel_for(num_threads, [&](size_t thread_idx) {
uint32_t start = static_cast<uint32_t>(thread_idx * iterations_per_thread);
uint32_t end = static_cast<uint32_t>((thread_idx + 1) * iterations_per_thread);

// Initialize every element to point to itself
for (uint32_t row_idx = 0; row_idx < circuit_size; ++row_idx) {
permutation_subgroup_element self{ row_idx, col_idx };
sigmas[col_idx].emplace_back(self);
if constexpr (generalized) {
ids[col_idx].emplace_back(self);
for (uint8_t col_idx = 0; col_idx < NUM_WIRES; ++col_idx) {
for (uint32_t row_idx = start; row_idx < end; ++row_idx) {
auto idx = static_cast<ptrdiff_t>(row_idx);
sigmas[col_idx].row_idx[idx] = row_idx;
sigmas[col_idx].col_idx[idx] = col_idx;
sigmas[col_idx].is_public_input[idx] = false;
sigmas[col_idx].is_tag[idx] = false;
if constexpr (generalized) {
ids[col_idx].row_idx[idx] = row_idx;
ids[col_idx].col_idx[idx] = col_idx;
ids[col_idx].is_public_input[idx] = false;
ids[col_idx].is_tag[idx] = false;
}
}
}
}
});
}
};

Expand All @@ -105,48 +142,47 @@ PermutationMapping<Flavor::NUM_WIRES, generalized> compute_permutation_mapping(
{

// Initialize the table of permutations so that every element points to itself
PermutationMapping<Flavor::NUM_WIRES, generalized> mapping{ proving_key->circuit_size };
PermutationMapping<Flavor::NUM_WIRES, generalized> mapping(proving_key->circuit_size);

// Represents the index of a variable in circuit_constructor.variables (needed only for generalized)
// Represents the idx of a variable in circuit_constructor.variables (needed only for generalized)
std::span<const uint32_t> real_variable_tags = circuit_constructor.real_variable_tags;

// Go through each cycle
size_t cycle_index = 0;
for (const auto& copy_cycle : wire_copy_cycles) {
for (size_t node_idx = 0; node_idx < copy_cycle.size(); ++node_idx) {
// Get the indices of the current node and next node in the cycle
const cycle_node& current_cycle_node = copy_cycle[node_idx];
// If current node is the last one in the cycle, then the next one is the first one
size_t next_cycle_node_index = (node_idx == copy_cycle.size() - 1 ? 0 : node_idx + 1);
const cycle_node& next_cycle_node = copy_cycle[next_cycle_node_index];
const auto current_row = current_cycle_node.gate_index;
const auto next_row = next_cycle_node.gate_index;

const auto current_column = current_cycle_node.wire_index;
const auto next_column = static_cast<uint8_t>(next_cycle_node.wire_index);
for (size_t cycle_idx = 0; cycle_idx < wire_copy_cycles.size(); ++cycle_idx) {
const CyclicPermutation& cycle = wire_copy_cycles[cycle_idx];
for (size_t node_idx = 0; node_idx < cycle.size(); ++node_idx) {
// Get the indices (column, row) of the current node in the cycle
const cycle_node& current_node = cycle[node_idx];
const auto current_row = static_cast<ptrdiff_t>(current_node.gate_idx);
const auto current_column = current_node.wire_idx;

// Get indices of next node; If the current node is last in the cycle, then the next is the first one
size_t next_node_idx = (node_idx == cycle.size() - 1 ? 0 : node_idx + 1);
const cycle_node& next_node = cycle[next_node_idx];
const auto next_row = next_node.gate_idx;
const auto next_column = static_cast<uint8_t>(next_node.wire_idx);

// Point current node to the next node
mapping.sigmas[current_column][current_row] = {
.row_index = next_row, .column_index = next_column, .is_public_input = false, .is_tag = false
};
mapping.sigmas[current_column].row_idx[current_row] = next_row;
mapping.sigmas[current_column].col_idx[current_row] = next_column;

if constexpr (generalized) {
bool first_node = (node_idx == 0);
bool last_node = (next_cycle_node_index == 0);
const bool first_node = (node_idx == 0);
const bool last_node = (next_node_idx == 0);

if (first_node) {
mapping.ids[current_column][current_row].is_tag = true;
mapping.ids[current_column][current_row].row_index = (real_variable_tags[cycle_index]);
mapping.ids[current_column].is_tag[current_row] = true;
mapping.ids[current_column].row_idx[current_row] = real_variable_tags[cycle_idx];
}
if (last_node) {
mapping.sigmas[current_column][current_row].is_tag = true;
mapping.sigmas[current_column].is_tag[current_row] = true;

// TODO(Zac): yikes, std::maps (tau) are expensive. Can we find a way to get rid of this?
mapping.sigmas[current_column][current_row].row_index =
circuit_constructor.tau.at(real_variable_tags[cycle_index]);
mapping.sigmas[current_column].row_idx[current_row] =
circuit_constructor.tau.at(real_variable_tags[cycle_idx]);
}
}
}
cycle_index++;
}

// Add information about public inputs so that the cycles can be altered later; See the construction of the
Expand All @@ -158,11 +194,11 @@ PermutationMapping<Flavor::NUM_WIRES, generalized> compute_permutation_mapping(
pub_inputs_offset = proving_key->pub_inputs_offset;
}
for (size_t i = 0; i < num_public_inputs; ++i) {
size_t idx = i + pub_inputs_offset;
mapping.sigmas[0][idx].row_index = static_cast<uint32_t>(idx);
mapping.sigmas[0][idx].column_index = 0;
mapping.sigmas[0][idx].is_public_input = true;
if (mapping.sigmas[0][idx].is_tag) {
uint32_t idx = static_cast<uint32_t>(i + pub_inputs_offset);
mapping.sigmas[0].row_idx[static_cast<ptrdiff_t>(idx)] = idx;
mapping.sigmas[0].col_idx[static_cast<ptrdiff_t>(idx)] = 0;
mapping.sigmas[0].is_public_input[static_cast<ptrdiff_t>(idx)] = true;
if (mapping.sigmas[0].is_tag[static_cast<ptrdiff_t>(idx)]) {
std::cerr << "MAPPING IS BOTH A TAG AND A PUBLIC INPUT" << std::endl;
}
}
Expand All @@ -182,37 +218,40 @@ PermutationMapping<Flavor::NUM_WIRES, generalized> compute_permutation_mapping(
template <typename Flavor>
void compute_honk_style_permutation_lagrange_polynomials_from_mapping(
const RefSpan<typename Flavor::Polynomial>& permutation_polynomials, // sigma or ID poly
const std::array<std::vector<permutation_subgroup_element>, Flavor::NUM_WIRES>& permutation_mappings,
const std::array<Mapping, Flavor::NUM_WIRES>& permutation_mappings,
typename Flavor::ProvingKey* proving_key)
{
using FF = typename Flavor::FF;
const size_t num_gates = proving_key->circuit_size;

size_t wire_index = 0;
size_t wire_idx = 0;
for (auto& current_permutation_poly : permutation_polynomials) {
ITERATE_OVER_DOMAIN_START(proving_key->evaluation_domain);
const auto& current_mapping = permutation_mappings[wire_index][i];
if (current_mapping.is_public_input) {
auto idx = static_cast<ptrdiff_t>(i);
const auto& current_row_idx = permutation_mappings[wire_idx].row_idx[idx];
const auto& current_col_idx = permutation_mappings[wire_idx].col_idx[idx];
const auto& current_is_tag = permutation_mappings[wire_idx].is_tag[idx];
const auto& current_is_public_input = permutation_mappings[wire_idx].is_public_input[idx];
if (current_is_public_input) {
// We intentionally want to break the cycles of the public input variables.
// During the witness generation, the left and right wire polynomials at index i contain the i-th public
// During the witness generation, the left and right wire polynomials at idx i contain the i-th public
// input. The CyclicPermutation created for these variables always start with (i) -> (n+i), followed by
// the indices of the variables in the "real" gates. We make i point to -(i+1), so that the only way of
// repairing the cycle is add the mapping
// -(i+1) -> (n+i)
// These indices are chosen so they can easily be computed by the verifier. They can expect the running
// product to be equal to the "public input delta" that is computed in <honk/utils/grand_product_delta.hpp>
current_permutation_poly.at(i) =
-FF(current_mapping.row_index + 1 + num_gates * current_mapping.column_index);
} else if (current_mapping.is_tag) {
current_permutation_poly.at(i) = -FF(current_row_idx + 1 + num_gates * current_col_idx);
} else if (current_is_tag) {
// Set evaluations to (arbitrary) values disjoint from non-tag values
current_permutation_poly.at(i) = num_gates * Flavor::NUM_WIRES + current_mapping.row_index;
current_permutation_poly.at(i) = num_gates * Flavor::NUM_WIRES + current_row_idx;
} else {
// For the regular permutation we simply point to the next location by setting the evaluation to its
// index
current_permutation_poly.at(i) = FF(current_mapping.row_index + num_gates * current_mapping.column_index);
// idx
current_permutation_poly.at(i) = FF(current_row_idx + num_gates * current_col_idx);
}
ITERATE_OVER_DOMAIN_END;
wire_index++;
wire_idx++;
}
}
} // namespace
Expand All @@ -226,7 +265,7 @@ void compute_honk_style_permutation_lagrange_polynomials_from_mapping(
*
* */
inline void compute_standard_plonk_lagrange_polynomial(bb::polynomial& output,
const std::vector<permutation_subgroup_element>& permutation,
const Mapping& permutation,
const bb::evaluation_domain& small_domain)
{
if (output.size() < permutation.size()) {
Expand All @@ -245,19 +284,19 @@ inline void compute_standard_plonk_lagrange_polynomial(bb::polynomial& output,

ITERATE_OVER_DOMAIN_START(small_domain);

// `permutation[i]` will specify the 'index' that this wire value will map to.
// Here, 'index' refers to an element of our subgroup H.
// We can almost use `permutation[i]` to directly index our `roots` array, which contains our subgroup elements.
// `permutation[i]` will specify the 'idx' that this wire value will map to.
// Here, 'idx' refers to an element of our subgroup H.
// We can almost use `permutation[i]` to directly idx our `roots` array, which contains our subgroup elements.
// We first have to accommodate for the fact that `roots` only contains *half* of our subgroup elements. This is
// because ω^{n/2} = -ω and we don't want to perform redundant work computing roots of unity.

size_t raw_idx = permutation[i].row_index;
size_t raw_idx = permutation.row_idx[static_cast<ptrdiff_t>(i)];

// Step 1: is `raw_idx` >= (n / 2)? if so, we will need to index `-roots[raw_idx - subgroup_size / 2]` instead
// Step 1: is `raw_idx` >= (n / 2)? if so, we will need to idx `-roots[raw_idx - subgroup_size / 2]` instead
// of `roots[raw_idx]`
const bool negative_idx = raw_idx >= root_size;

// Step 2: compute the index of the subgroup element we'll be accessing.
// Step 2: compute the idx of the subgroup element we'll be accessing.
// To avoid a conditional branch, we can subtract `negative_idx << log2_root_size` from `raw_idx`.
// Here, `log2_root_size = numeric::get_msb(subgroup_size / 2)` (we know our subgroup size will be a power of 2,
// so we lose no precision here)
Expand All @@ -269,23 +308,23 @@ inline void compute_standard_plonk_lagrange_polynomial(bb::polynomial& output,
// The output will similarly be overloaded (containing either 2 * modulus - w, or modulus - w)
output[i] = roots[idx].conditionally_subtract_from_double_modulus(static_cast<uint64_t>(negative_idx));

// Finally, if our permutation maps to an index in either the right wire vector, or the output wire vector, we
// Finally, if our permutation maps to an idx in either the right wire vector, or the output wire vector, we
// need to multiply our result by one of two quadratic non-residues. (This ensures that mapping into the left
// wires gives unique values that are not repeated in the right or output wire permutations) (ditto for right
// wire and output wire mappings)

if (permutation[i].is_public_input) {
if (permutation.is_public_input[static_cast<ptrdiff_t>(i)]) {
// As per the paper which modifies plonk to include the public inputs in a permutation argument, the permutation
// `σ` is modified to `σ'`, where `σ'` maps all public inputs to a set of l distinct ζ elements which are
// disjoint from H ∪ k1·H ∪ k2·H.
output[i] *= bb::fr::external_coset_generator();
} else if (permutation[i].is_tag) {
} else if (permutation.is_tag[static_cast<ptrdiff_t>(i)]) {
output[i] *= bb::fr::tag_coset_generator();
} else {
{
const uint32_t column_index = permutation[i].column_index;
if (column_index > 0) {
output[i] *= bb::fr::coset_generator(column_index - 1);
const uint32_t column_idx = permutation.col_idx[static_cast<ptrdiff_t>(i)];
if (column_idx > 0) {
output[i] *= bb::fr::coset_generator(column_idx - 1);
}
}
}
Expand All @@ -301,16 +340,15 @@ inline void compute_standard_plonk_lagrange_polynomial(bb::polynomial& output,
* @param key
*/
template <size_t program_width>
void compute_plonk_permutation_lagrange_polynomials_from_mapping(
std::string label,
std::array<std::vector<permutation_subgroup_element>, program_width>& mappings,
plonk::proving_key* key)
void compute_plonk_permutation_lagrange_polynomials_from_mapping(std::string label,
std::array<Mapping, program_width>& mappings,
plonk::proving_key* key)
{
for (size_t i = 0; i < program_width; i++) {
std::string index = std::to_string(i + 1);
std::string idx = std::to_string(i + 1);
bb::polynomial polynomial_lagrange(key->circuit_size);
compute_standard_plonk_lagrange_polynomial(polynomial_lagrange, mappings[i], key->small_domain);
key->polynomial_store.put(label + "_" + index + "_lagrange", polynomial_lagrange.share());
key->polynomial_store.put(label + "_" + idx + "_lagrange", polynomial_lagrange.share());
}
}

Expand All @@ -327,8 +365,8 @@ template <size_t program_width>
void compute_monomial_and_coset_fft_polynomials_from_lagrange(std::string label, plonk::proving_key* key)
{
for (size_t i = 0; i < program_width; ++i) {
std::string index = std::to_string(i + 1);
std::string prefix = label + "_" + index;
std::string idx = std::to_string(i + 1);
std::string prefix = label + "_" + idx;

// Construct permutation polynomials in lagrange base
auto sigma_polynomial_lagrange = key->polynomial_store.get(prefix + "_lagrange");
Expand Down
Loading