-
Notifications
You must be signed in to change notification settings - Fork 298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: permutation argument optimizations #10960
Changes from 24 commits
64317bc
0ab2078
35d1c8d
418f7bc
09ef41a
a258127
72b215f
c36feaa
bd8a511
dd50b47
d9b43bb
deb1d65
339a5c9
5471504
4d7ea03
7f348aa
9736925
8911c2f
46e378c
d0ea21a
45156f7
0ebd607
b2539a7
bf1394d
b104bc6
e702d00
d34fe9c
cd6e7b6
e442a51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -224,33 +224,43 @@ void compute_honk_style_permutation_lagrange_polynomials_from_mapping( | |
using FF = typename Flavor::FF; | ||
const size_t num_gates = proving_key->circuit_size; | ||
|
||
size_t domain_size = proving_key->active_region_data.idxs.size(); | ||
|
||
const MultithreadData thread_data = calculate_thread_data(domain_size); | ||
|
||
size_t wire_idx = 0; | ||
for (auto& current_permutation_poly : permutation_polynomials) { | ||
ITERATE_OVER_DOMAIN_START(proving_key->evaluation_domain); | ||
auto idx = static_cast<ptrdiff_t>(i); | ||
const auto& current_row_idx = permutation_mappings[wire_idx].row_idx[idx]; | ||
const auto& current_col_idx = permutation_mappings[wire_idx].col_idx[idx]; | ||
const auto& current_is_tag = permutation_mappings[wire_idx].is_tag[idx]; | ||
const auto& current_is_public_input = permutation_mappings[wire_idx].is_public_input[idx]; | ||
if (current_is_public_input) { | ||
// We intentionally want to break the cycles of the public input variables. | ||
// During the witness generation, the left and right wire polynomials at idx i contain the i-th public | ||
// input. The CyclicPermutation created for these variables always start with (i) -> (n+i), followed by | ||
// the indices of the variables in the "real" gates. We make i point to -(i+1), so that the only way of | ||
// repairing the cycle is add the mapping | ||
// -(i+1) -> (n+i) | ||
// These indices are chosen so they can easily be computed by the verifier. They can expect the running | ||
// product to be equal to the "public input delta" that is computed in <honk/utils/grand_product_delta.hpp> | ||
current_permutation_poly.at(i) = -FF(current_row_idx + 1 + num_gates * current_col_idx); | ||
} else if (current_is_tag) { | ||
// Set evaluations to (arbitrary) values disjoint from non-tag values | ||
current_permutation_poly.at(i) = num_gates * Flavor::NUM_WIRES + current_row_idx; | ||
} else { | ||
// For the regular permutation we simply point to the next location by setting the evaluation to its | ||
// idx | ||
current_permutation_poly.at(i) = FF(current_row_idx + num_gates * current_col_idx); | ||
} | ||
ITERATE_OVER_DOMAIN_END; | ||
parallel_for(thread_data.num_threads, [&](size_t j) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This loop now iterates over only the active domain instead of the entire poly domain. Prior to this change, the sigma/id polynomials took non-zero values across the entire domain. Now, they are non-zero only in the active regions of the trace and 0 elsewhere (previously we had sigma_i == id_i in these regions). These values don't contribute to the computation of the grand product anyway so there's no reason to compute them. |
||
const size_t start = thread_data.start[j]; | ||
const size_t end = thread_data.end[j]; | ||
for (size_t i = start; i < end; ++i) { | ||
size_t poly_idx = proving_key->active_region_data.idxs[i]; | ||
auto idx = static_cast<ptrdiff_t>(poly_idx); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this cast needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also this can be a const and the one above too There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The cast is needed since |
||
const auto& current_row_idx = permutation_mappings[wire_idx].row_idx[idx]; | ||
const auto& current_col_idx = permutation_mappings[wire_idx].col_idx[idx]; | ||
const auto& current_is_tag = permutation_mappings[wire_idx].is_tag[idx]; | ||
const auto& current_is_public_input = permutation_mappings[wire_idx].is_public_input[idx]; | ||
if (current_is_public_input) { | ||
// We intentionally want to break the cycles of the public input variables. | ||
// During the witness generation, the left and right wire polynomials at idx i contain the i-th | ||
// public input. The CyclicPermutation created for these variables always start with (i) -> (n+i), | ||
// followed by the indices of the variables in the "real" gates. We make i point to | ||
// -(i+1), so that the only way of repairing the cycle is add the mapping | ||
// -(i+1) -> (n+i) | ||
// These indices are chosen so they can easily be computed by the verifier. They can expect | ||
// the running product to be equal to the "public input delta" that is computed | ||
// in <honk/utils/grand_product_delta.hpp> | ||
current_permutation_poly.at(poly_idx) = -FF(current_row_idx + 1 + num_gates * current_col_idx); | ||
} else if (current_is_tag) { | ||
// Set evaluations to (arbitrary) values disjoint from non-tag values | ||
current_permutation_poly.at(poly_idx) = num_gates * Flavor::NUM_WIRES + current_row_idx; | ||
} else { | ||
// For the regular permutation we simply point to the next location by setting the | ||
// evaluation to its idx | ||
current_permutation_poly.at(poly_idx) = FF(current_row_idx + num_gates * current_col_idx); | ||
} | ||
} | ||
}); | ||
wire_idx++; | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,8 +3,10 @@ | |
#include "barretenberg/common/debug_log.hpp" | ||
#include "barretenberg/common/thread.hpp" | ||
#include "barretenberg/common/zip_view.hpp" | ||
#include "barretenberg/flavor/flavor.hpp" | ||
#include "barretenberg/plonk/proof_system/proving_key/proving_key.hpp" | ||
#include "barretenberg/relations/relation_parameters.hpp" | ||
#include "barretenberg/trace_to_polynomials/trace_to_polynomials.hpp" | ||
#include <typeinfo> | ||
|
||
namespace bb { | ||
|
@@ -47,74 +49,74 @@ namespace bb { | |
* | ||
* Note: Step (3) utilizes Montgomery batch inversion to replace n-many inversions with | ||
* | ||
* @note This method makes use of the fact that there are at most as many unique entries in the grand product as active | ||
* rows in the execution trace to efficiently compute the grand product when a structured trace is in use. I.e. the | ||
* computation peformed herein is proportional to the number of active rows in the trace and the constant values in the | ||
* inactive regions are simply populated from known values on the last step. | ||
* | ||
* @tparam Flavor | ||
* @tparam GrandProdRelation | ||
* @param full_polynomials | ||
* @param relation_parameters | ||
* @param size_override optional size of the domain; otherwise based on dyadic polynomial domain | ||
* @param active_region_data optional specification of active region of execution trace | ||
*/ | ||
template <typename Flavor, typename GrandProdRelation> | ||
void compute_grand_product(typename Flavor::ProverPolynomials& full_polynomials, | ||
bb::RelationParameters<typename Flavor::FF>& relation_parameters, | ||
size_t size_override = 0, | ||
std::vector<std::pair<size_t, size_t>> active_block_ranges = {}) | ||
const ActiveRegionData& active_region_data = ActiveRegionData{}) | ||
{ | ||
PROFILE_THIS_NAME("compute_grand_product"); | ||
|
||
using FF = typename Flavor::FF; | ||
using Polynomial = typename Flavor::Polynomial; | ||
using Accumulator = std::tuple_element_t<0, typename GrandProdRelation::SumcheckArrayOfValuesOverSubrelations>; | ||
|
||
const bool active_region_specified = !active_region_data.ranges.empty(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hah I started with that but thought it was misleading because if false it implies that there are NO active regions when really its just that they are implicit and haven't been specified. You're probably right tho that |
||
|
||
// Set the domain over which the grand product must be computed. This may be less than the dyadic circuit size, e.g | ||
// the permutation grand product does not need to be computed beyond the index of the last active wire | ||
size_t domain_size = size_override == 0 ? full_polynomials.get_polynomial_size() : size_override; | ||
|
||
const size_t num_threads = domain_size >= get_num_cpus_pow2() ? get_num_cpus_pow2() : 1; | ||
const size_t block_size = domain_size / num_threads; | ||
const size_t final_idx = domain_size - 1; | ||
|
||
// Cumpute the index bounds for each thread for reuse in the computations below | ||
std::vector<std::pair<size_t, size_t>> idx_bounds; | ||
idx_bounds.reserve(num_threads); | ||
for (size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) { | ||
const size_t start = thread_idx * block_size; | ||
const size_t end = (thread_idx == num_threads - 1) ? final_idx : (thread_idx + 1) * block_size; | ||
idx_bounds.push_back(std::make_pair(start, end)); | ||
} | ||
// Returns the ith active index if specified, otherwise acts as the identity map on the input | ||
auto get_active_range_poly_idx = [&](size_t i) { | ||
if (active_region_specified) { | ||
return active_region_data.idxs[i]; | ||
} | ||
return i; | ||
}; | ||
|
||
size_t active_domain_size = active_region_specified ? active_region_data.idxs.size() : domain_size; | ||
|
||
// The size of the iteration domain is one less than the number of active rows since the final value of the | ||
// grand product is constructed only in the relation and not explicitly in the polynomial | ||
const MultithreadData active_range_thread_data = calculate_thread_data(active_domain_size - 1); | ||
|
||
// Allocate numerator/denominator polynomials that will serve as scratch space | ||
// TODO(zac) we can re-use the permutation polynomial as the numerator polynomial. Reduces readability | ||
Polynomial numerator{ domain_size, domain_size }; | ||
Polynomial denominator{ domain_size, domain_size }; | ||
|
||
auto check_is_active = [&](size_t idx) { | ||
if (active_block_ranges.empty()) { | ||
return true; | ||
} | ||
return std::any_of(active_block_ranges.begin(), active_block_ranges.end(), [idx](const auto& range) { | ||
return idx >= range.first && idx < range.second; | ||
}); | ||
}; | ||
Polynomial numerator{ active_domain_size }; | ||
Polynomial denominator{ active_domain_size }; | ||
|
||
// Step (1) | ||
// Populate `numerator` and `denominator` with the algebra described by Relation | ||
FF gamma_fourth = relation_parameters.gamma.pow(4); | ||
parallel_for(num_threads, [&](size_t thread_idx) { | ||
parallel_for(active_range_thread_data.num_threads, [&](size_t thread_idx) { | ||
const size_t start = active_range_thread_data.start[thread_idx]; | ||
const size_t end = active_range_thread_data.end[thread_idx]; | ||
typename Flavor::AllValues row; | ||
const size_t start = idx_bounds[thread_idx].first; | ||
const size_t end = idx_bounds[thread_idx].second; | ||
for (size_t i = start; i < end; ++i) { | ||
if (check_is_active(i)) { | ||
// TODO(https://github.com/AztecProtocol/barretenberg/issues/940):consider avoiding get_row if possible. | ||
row = full_polynomials.get_row(i); | ||
numerator.at(i) = | ||
GrandProdRelation::template compute_grand_product_numerator<Accumulator>(row, relation_parameters); | ||
denominator.at(i) = GrandProdRelation::template compute_grand_product_denominator<Accumulator>( | ||
row, relation_parameters); | ||
// TODO(https://github.com/AztecProtocol/barretenberg/issues/940):consider avoiding get_row if | ||
// possible. | ||
auto row_idx = get_active_range_poly_idx(i); | ||
if constexpr (IsUltraFlavor<Flavor>) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's not quite the same thing though because this code is also used by the ECCVM/Translator which need to be excluded. I think this just comes down to the fact that we need better concepts. Probably There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suppose I could just add methods to the ECCVM/Trans flavors that just call get_row from get_row_for_permutation_arg. Not sure what's better |
||
row = full_polynomials.get_row_for_permutation_arg(row_idx); | ||
} else { | ||
numerator.at(i) = gamma_fourth; | ||
denominator.at(i) = gamma_fourth; | ||
row = full_polynomials.get_row(row_idx); | ||
} | ||
numerator.at(i) = | ||
GrandProdRelation::template compute_grand_product_numerator<Accumulator>(row, relation_parameters); | ||
denominator.at(i) = | ||
GrandProdRelation::template compute_grand_product_denominator<Accumulator>(row, relation_parameters); | ||
} | ||
}); | ||
|
||
|
@@ -133,12 +135,12 @@ void compute_grand_product(typename Flavor::ProverPolynomials& full_polynomials, | |
// (ii) Take partial products P = { 1, a0a1, a2a3, a4a5 } | ||
// (iii) Each thread j computes N[i][j]*P[j]= | ||
// {{a0,a0a1},{a0a1a2,a0a1a2a3},{a0a1a2a3a4,a0a1a2a3a4a5},{a0a1a2a3a4a5a6,a0a1a2a3a4a5a6a7}} | ||
std::vector<FF> partial_numerators(num_threads); | ||
std::vector<FF> partial_denominators(num_threads); | ||
std::vector<FF> partial_numerators(active_range_thread_data.num_threads); | ||
std::vector<FF> partial_denominators(active_range_thread_data.num_threads); | ||
|
||
parallel_for(num_threads, [&](size_t thread_idx) { | ||
const size_t start = idx_bounds[thread_idx].first; | ||
const size_t end = idx_bounds[thread_idx].second; | ||
parallel_for(active_range_thread_data.num_threads, [&](size_t thread_idx) { | ||
const size_t start = active_range_thread_data.start[thread_idx]; | ||
const size_t end = active_range_thread_data.end[thread_idx]; | ||
for (size_t i = start; i < end - 1; ++i) { | ||
numerator.at(i + 1) *= numerator[i]; | ||
denominator.at(i + 1) *= denominator[i]; | ||
|
@@ -150,9 +152,9 @@ void compute_grand_product(typename Flavor::ProverPolynomials& full_polynomials, | |
DEBUG_LOG_ALL(partial_numerators); | ||
DEBUG_LOG_ALL(partial_denominators); | ||
|
||
parallel_for(num_threads, [&](size_t thread_idx) { | ||
const size_t start = idx_bounds[thread_idx].first; | ||
const size_t end = idx_bounds[thread_idx].second; | ||
parallel_for(active_range_thread_data.num_threads, [&](size_t thread_idx) { | ||
const size_t start = active_range_thread_data.start[thread_idx]; | ||
const size_t end = active_range_thread_data.end[thread_idx]; | ||
if (thread_idx > 0) { | ||
FF numerator_scaling = 1; | ||
FF denominator_scaling = 1; | ||
|
@@ -179,14 +181,45 @@ void compute_grand_product(typename Flavor::ProverPolynomials& full_polynomials, | |
// We have a 'virtual' 0 at the start (as this is a to-be-shifted polynomial) | ||
ASSERT(grand_product_polynomial.start_index() == 1); | ||
|
||
parallel_for(num_threads, [&](size_t thread_idx) { | ||
const size_t start = idx_bounds[thread_idx].first; | ||
const size_t end = idx_bounds[thread_idx].second; | ||
if constexpr (IsUltraFlavor<Flavor>) { | ||
grand_product_polynomial.at(1) = 1; | ||
} | ||
|
||
parallel_for(active_range_thread_data.num_threads, [&](size_t thread_idx) { | ||
const size_t start = active_range_thread_data.start[thread_idx]; | ||
const size_t end = active_range_thread_data.end[thread_idx]; | ||
for (size_t i = start; i < end; ++i) { | ||
grand_product_polynomial.at(i + 1) = numerator[i] * denominator[i]; | ||
auto poly_idx = get_active_range_poly_idx(i + 1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
grand_product_polynomial.at(poly_idx) = numerator[i] * denominator[i]; | ||
} | ||
}); | ||
|
||
// Lambda to set the constant inactive regions of the grand product if they exist | ||
auto set_constant_value_if_inactive = [&](size_t i) { | ||
for (size_t j = 0; j < active_region_data.ranges.size() - 1; ++j) { | ||
size_t previous_range_end = active_region_data.ranges[j].second; | ||
size_t next_range_start = active_region_data.ranges[j + 1].first; | ||
if (i >= previous_range_end && i < next_range_start) { | ||
grand_product_polynomial.at(i) = grand_product_polynomial[next_range_start]; | ||
break; | ||
} | ||
} | ||
}; | ||
|
||
// Final step: The grand product is constant in the inactive regions of the trace (if they exist) where no copy | ||
// constraints are present. These constant values have already been computed and are equal to the first value in the | ||
// subsequent active region. | ||
if (active_region_specified) { | ||
MultithreadData full_domain_thread_data = calculate_thread_data(domain_size); | ||
parallel_for(full_domain_thread_data.num_threads, [&](size_t thread_idx) { | ||
const size_t start = full_domain_thread_data.start[thread_idx]; | ||
const size_t end = full_domain_thread_data.end[thread_idx]; | ||
for (size_t i = start; i < end; ++i) { | ||
set_constant_value_if_inactive(i); | ||
} | ||
}); | ||
} | ||
|
||
DEBUG_LOG_ALL(grand_product_polynomial.coeffs()); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't these ranges need to be non-overlapping and in increasing order? maybe there should be a comment of some sort of check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its a good point. I could add a check on add_range that the input has start >= the previous end. To be safe I suppose I'd also want to make the members private and add getters