Skip to content

Commit

Permalink
Lde/deprecate polynomial cache (#235)
Browse files Browse the repository at this point in the history
* added simple polynomial container class

* replace PolynomialCache with much simpler PolynomialStore

* default methods, range based for, deleting lagrange selectors

* update polynomial class to disallow operations on empty polynomials
  • Loading branch information
ledwards2225 authored Mar 14, 2023
1 parent 1c002af commit 11ca0bf
Show file tree
Hide file tree
Showing 36 changed files with 495 additions and 802 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void construct_lagrange_selector_forms(const CircuitConstructor& circuit_constru
// iterations of the Sumcheck, we will be able to efficiently cancel out any checks in the last 2^k rows, so any
// randomness or unique values should be placed there.

circuit_proving_key->polynomial_cache.put(circuit_constructor.selector_names_[j] + "_lagrange",
circuit_proving_key->polynomial_store.put(circuit_constructor.selector_names_[j] + "_lagrange",
std::move(selector_poly_lagrange));
}
}
Expand All @@ -78,6 +78,9 @@ void construct_lagrange_selector_forms(const CircuitConstructor& circuit_constru
* @brief Retrieve lagrange forms of selector polynomials and compute monomial and coset-monomial forms and put into
* cache
*
* @note This function also deletes the lagrange forms of the selectors from memory since they are not needed
* for proof construction once the monomial and coset forms have been computed
*
* @param key Pointer to the proving key
* @param selector_properties Names of selectors
*/
Expand All @@ -88,7 +91,7 @@ void compute_monomial_and_coset_selector_forms(bonk::proving_key* circuit_provin
// Compute monomial form of selector polynomial

auto& selector_poly_lagrange =
circuit_proving_key->polynomial_cache.get(selector_properties[i].name + "_lagrange");
circuit_proving_key->polynomial_store.get(selector_properties[i].name + "_lagrange");
barretenberg::polynomial selector_poly(circuit_proving_key->circuit_size);
barretenberg::polynomial_arithmetic::ifft(
&selector_poly_lagrange[0], &selector_poly[0], circuit_proving_key->small_domain);
Expand All @@ -97,11 +100,11 @@ void compute_monomial_and_coset_selector_forms(bonk::proving_key* circuit_provin
barretenberg::polynomial selector_poly_fft(selector_poly, circuit_proving_key->circuit_size * 4 + 4);
selector_poly_fft.coset_fft(circuit_proving_key->large_domain);

// TODO(#215)(Luke/Kesha): Lagrange polynomials could be deleted from cache here since they are no longer
// needed.
// Remove the selector lagrange forms since they will not be needed beyond this point
circuit_proving_key->polynomial_store.remove(selector_properties[i].name + "_lagrange");

circuit_proving_key->polynomial_cache.put(selector_properties[i].name, std::move(selector_poly));
circuit_proving_key->polynomial_cache.put(selector_properties[i].name + "_fft", std::move(selector_poly_fft));
circuit_proving_key->polynomial_store.put(selector_properties[i].name, std::move(selector_poly));
circuit_proving_key->polynomial_store.put(selector_properties[i].name + "_fft", std::move(selector_poly_fft));
}
}

Expand Down Expand Up @@ -193,7 +196,7 @@ std::shared_ptr<bonk::verification_key> compute_verification_key_base_common(

// Commit to the constraint selector polynomial and insert the commitment in the verification key.

auto poly_commitment = commitment_key.commit(proving_key->polynomial_cache.get(poly_label));
auto poly_commitment = commitment_key.commit(proving_key->polynomial_store.get(poly_label));
circuit_verification_key->commitments.insert({ selector_commitment_label, poly_commitment });
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ void compute_honk_style_sigma_lagrange_polynomials_from_mapping(
// Save to polynomial cache
for (size_t j = 0; j < program_width; j++) {
std::string index = std::to_string(j + 1);
key->polynomial_cache.put("sigma_" + index + "_lagrange", std::move(sigma[j]));
key->polynomial_store.put("sigma_" + index + "_lagrange", std::move(sigma[j]));
}

} // namespace honk
Expand Down Expand Up @@ -323,7 +323,7 @@ void compute_standard_plonk_sigma_lagrange_polynomials_from_mapping(
std::string index = std::to_string(i + 1);
barretenberg::polynomial sigma_polynomial_lagrange(key->circuit_size);
compute_standard_plonk_lagrange_polynomial(sigma_polynomial_lagrange, sigma_mappings[i], key->small_domain);
key->polynomial_cache.put("sigma_" + index + "_lagrange", std::move(sigma_polynomial_lagrange));
key->polynomial_store.put("sigma_" + index + "_lagrange", std::move(sigma_polynomial_lagrange));
}
}

Expand All @@ -343,7 +343,7 @@ template <size_t program_width> void compute_sigma_polynomials_monomial_and_cose
// Construct permutation polynomials in lagrange base
std::string index = std::to_string(i + 1);

barretenberg::polynomial sigma_polynomial_lagrange = key->polynomial_cache.get("sigma_" + index + "_lagrange");
barretenberg::polynomial sigma_polynomial_lagrange = key->polynomial_store.get("sigma_" + index + "_lagrange");
// Compute permutation polynomial monomial form
barretenberg::polynomial sigma_polynomial(key->circuit_size);
barretenberg::polynomial_arithmetic::ifft(
Expand All @@ -353,8 +353,8 @@ template <size_t program_width> void compute_sigma_polynomials_monomial_and_cose
barretenberg::polynomial sigma_fft(sigma_polynomial, key->large_domain.size);
sigma_fft.coset_fft(key->large_domain);

key->polynomial_cache.put("sigma_" + index, std::move(sigma_polynomial));
key->polynomial_cache.put("sigma_" + index + "_fft", std::move(sigma_fft));
key->polynomial_store.put("sigma_" + index, std::move(sigma_polynomial));
key->polynomial_store.put("sigma_" + index + "_fft", std::move(sigma_fft));
}
}

Expand All @@ -381,7 +381,7 @@ void compute_standard_honk_id_polynomials(auto key) // proving_key* and shared_p
id_j[i] = (j * n + i);
}
std::string index = std::to_string(j + 1);
key->polynomial_cache.put("id_" + index + "_lagrange", std::move(id_j));
key->polynomial_store.put("id_" + index + "_lagrange", std::move(id_j));
}
}

Expand Down Expand Up @@ -438,8 +438,8 @@ inline void compute_first_and_last_lagrange_polynomials(auto key) // proving_key
barretenberg::polynomial lagrange_polynomial_n_min_1(n);
lagrange_polynomial_0[0] = 1;
lagrange_polynomial_n_min_1[n - 1] = 1;
key->polynomial_cache.put("L_first_lagrange", std::move(lagrange_polynomial_0));
key->polynomial_cache.put("L_last_lagrange", std::move(lagrange_polynomial_n_min_1));
key->polynomial_store.put("L_first_lagrange", std::move(lagrange_polynomial_0));
key->polynomial_store.put("L_last_lagrange", std::move(lagrange_polynomial_n_min_1));
}

} // namespace bonk
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,19 @@ std::shared_ptr<bonk::verification_key> StandardHonkComposerHelper<CircuitConstr
auto commitment_key = pcs::kzg::CommitmentKey(proving_key->circuit_size, "../srs_db/ignition");

// Compute and store commitments to all precomputed polynomials
key->commitments["Q_M"] = commitment_key.commit(proving_key->polynomial_cache.get("q_m_lagrange"));
key->commitments["Q_1"] = commitment_key.commit(proving_key->polynomial_cache.get("q_1_lagrange"));
key->commitments["Q_2"] = commitment_key.commit(proving_key->polynomial_cache.get("q_2_lagrange"));
key->commitments["Q_3"] = commitment_key.commit(proving_key->polynomial_cache.get("q_3_lagrange"));
key->commitments["Q_C"] = commitment_key.commit(proving_key->polynomial_cache.get("q_c_lagrange"));
key->commitments["SIGMA_1"] = commitment_key.commit(proving_key->polynomial_cache.get("sigma_1_lagrange"));
key->commitments["SIGMA_2"] = commitment_key.commit(proving_key->polynomial_cache.get("sigma_2_lagrange"));
key->commitments["SIGMA_3"] = commitment_key.commit(proving_key->polynomial_cache.get("sigma_3_lagrange"));
key->commitments["ID_1"] = commitment_key.commit(proving_key->polynomial_cache.get("id_1_lagrange"));
key->commitments["ID_2"] = commitment_key.commit(proving_key->polynomial_cache.get("id_2_lagrange"));
key->commitments["ID_3"] = commitment_key.commit(proving_key->polynomial_cache.get("id_3_lagrange"));
key->commitments["LAGRANGE_FIRST"] = commitment_key.commit(proving_key->polynomial_cache.get("L_first_lagrange"));
key->commitments["LAGRANGE_LAST"] = commitment_key.commit(proving_key->polynomial_cache.get("L_last_lagrange"));
key->commitments["Q_M"] = commitment_key.commit(proving_key->polynomial_store.get("q_m_lagrange"));
key->commitments["Q_1"] = commitment_key.commit(proving_key->polynomial_store.get("q_1_lagrange"));
key->commitments["Q_2"] = commitment_key.commit(proving_key->polynomial_store.get("q_2_lagrange"));
key->commitments["Q_3"] = commitment_key.commit(proving_key->polynomial_store.get("q_3_lagrange"));
key->commitments["Q_C"] = commitment_key.commit(proving_key->polynomial_store.get("q_c_lagrange"));
key->commitments["SIGMA_1"] = commitment_key.commit(proving_key->polynomial_store.get("sigma_1_lagrange"));
key->commitments["SIGMA_2"] = commitment_key.commit(proving_key->polynomial_store.get("sigma_2_lagrange"));
key->commitments["SIGMA_3"] = commitment_key.commit(proving_key->polynomial_store.get("sigma_3_lagrange"));
key->commitments["ID_1"] = commitment_key.commit(proving_key->polynomial_store.get("id_1_lagrange"));
key->commitments["ID_2"] = commitment_key.commit(proving_key->polynomial_store.get("id_2_lagrange"));
key->commitments["ID_3"] = commitment_key.commit(proving_key->polynomial_store.get("id_3_lagrange"));
key->commitments["LAGRANGE_FIRST"] = commitment_key.commit(proving_key->polynomial_store.get("L_first_lagrange"));
key->commitments["LAGRANGE_LAST"] = commitment_key.commit(proving_key->polynomial_store.get("L_last_lagrange"));

return key;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void StandardPlonkComposerHelper<CircuitConstructor>::compute_witness(const Circ

for (size_t j = 0; j < program_width; ++j) {
std::string index = std::to_string(j + 1);
circuit_proving_key->polynomial_cache.put("w_" + index + "_lagrange",
circuit_proving_key->polynomial_store.put("w_" + index + "_lagrange",
std::move(wire_polynomial_evaluations[j]));
}
computed_witness = true;
Expand Down
38 changes: 19 additions & 19 deletions cpp/src/barretenberg/honk/composer/standard_honk_composer.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ TEST(StandardHonkComposer, SigmaIDCorrectness)
// Let's check that indices are the same and nothing is lost, first
for (size_t j = 0; j < composer.program_width; ++j) {
std::string index = std::to_string(j + 1);
const auto& sigma_j = proving_key->polynomial_cache.get("sigma_" + index + "_lagrange");
const auto& sigma_j = proving_key->polynomial_store.get("sigma_" + index + "_lagrange");
for (size_t i = 0; i < n; ++i) {
left *= (gamma + j * n + i);
right *= (gamma + sigma_j[i]);
Expand All @@ -68,9 +68,9 @@ TEST(StandardHonkComposer, SigmaIDCorrectness)

for (size_t j = 0; j < composer.program_width; ++j) {
std::string index = std::to_string(j + 1);
const auto& permutation_polynomial = proving_key->polynomial_cache.get("sigma_" + index + "_lagrange");
const auto& permutation_polynomial = proving_key->polynomial_store.get("sigma_" + index + "_lagrange");
const auto& witness_polynomial = composer.composer_helper.wire_polynomials[j];
const auto& id_polynomial = proving_key->polynomial_cache.get("id_" + index + "_lagrange");
const auto& id_polynomial = proving_key->polynomial_store.get("id_" + index + "_lagrange");
// left = ∏ᵢ,ⱼ(ωᵢ,ⱼ + β⋅ind(i,j) + γ)
// right = ∏ᵢ,ⱼ(ωᵢ,ⱼ + β⋅σ(i,j) + γ)
for (size_t i = 0; i < proving_key->circuit_size; ++i) {
Expand Down Expand Up @@ -156,15 +156,15 @@ TEST(StandardHonkComposer, LagrangeCorrectness)
random_polynomial[i] = barretenberg::fr::random_element();
}
// Compute inner product of random polynomial and the first lagrange polynomial
barretenberg::polynomial first_lagrange_polynomial = proving_key->polynomial_cache.get("L_first_lagrange");
barretenberg::polynomial first_lagrange_polynomial = proving_key->polynomial_store.get("L_first_lagrange");
barretenberg::fr first_product(0);
for (size_t i = 0; i < proving_key->circuit_size; i++) {
first_product += random_polynomial[i] * first_lagrange_polynomial[i];
}
EXPECT_EQ(first_product, random_polynomial[0]);

// Compute inner product of random polynomial and the last lagrange polynomial
barretenberg::polynomial last_lagrange_polynomial = proving_key->polynomial_cache.get("L_last_lagrange");
barretenberg::polynomial last_lagrange_polynomial = proving_key->polynomial_store.get("L_last_lagrange");
barretenberg::fr last_product(0);
for (size_t i = 0; i < proving_key->circuit_size; i++) {
last_product += random_polynomial[i] * last_lagrange_polynomial[i];
Expand Down Expand Up @@ -213,7 +213,7 @@ TEST(StandardHonkComposer, AssertEquals)
// Put the sigma polynomials into a vector for easy access
for (size_t i = 0; i < composer.program_width; i++) {
std::string index = std::to_string(i + 1);
sigma_polynomials.push_back(proving_key->polynomial_cache.get("sigma_" + index + "_lagrange"));
sigma_polynomials.push_back(proving_key->polynomial_store.get("sigma_" + index + "_lagrange"));
}

// Let's compute the maximum cycle
Expand Down Expand Up @@ -366,19 +366,19 @@ TEST(StandardHonkComposer, SumcheckRelationCorrectness)
evaluations_array[POLYNOMIAL::W_O] = prover.wire_polynomials[2];
evaluations_array[POLYNOMIAL::Z_PERM] = z_perm_poly;
evaluations_array[POLYNOMIAL::Z_PERM_SHIFT] = z_perm_poly.shifted();
evaluations_array[POLYNOMIAL::Q_M] = prover.key->polynomial_cache.get("q_m_lagrange");
evaluations_array[POLYNOMIAL::Q_L] = prover.key->polynomial_cache.get("q_1_lagrange");
evaluations_array[POLYNOMIAL::Q_R] = prover.key->polynomial_cache.get("q_2_lagrange");
evaluations_array[POLYNOMIAL::Q_O] = prover.key->polynomial_cache.get("q_3_lagrange");
evaluations_array[POLYNOMIAL::Q_C] = prover.key->polynomial_cache.get("q_c_lagrange");
evaluations_array[POLYNOMIAL::SIGMA_1] = prover.key->polynomial_cache.get("sigma_1_lagrange");
evaluations_array[POLYNOMIAL::SIGMA_2] = prover.key->polynomial_cache.get("sigma_2_lagrange");
evaluations_array[POLYNOMIAL::SIGMA_3] = prover.key->polynomial_cache.get("sigma_3_lagrange");
evaluations_array[POLYNOMIAL::ID_1] = prover.key->polynomial_cache.get("id_1_lagrange");
evaluations_array[POLYNOMIAL::ID_2] = prover.key->polynomial_cache.get("id_2_lagrange");
evaluations_array[POLYNOMIAL::ID_3] = prover.key->polynomial_cache.get("id_3_lagrange");
evaluations_array[POLYNOMIAL::LAGRANGE_FIRST] = prover.key->polynomial_cache.get("L_first_lagrange");
evaluations_array[POLYNOMIAL::LAGRANGE_LAST] = prover.key->polynomial_cache.get("L_last_lagrange");
evaluations_array[POLYNOMIAL::Q_M] = prover.key->polynomial_store.get("q_m_lagrange");
evaluations_array[POLYNOMIAL::Q_L] = prover.key->polynomial_store.get("q_1_lagrange");
evaluations_array[POLYNOMIAL::Q_R] = prover.key->polynomial_store.get("q_2_lagrange");
evaluations_array[POLYNOMIAL::Q_O] = prover.key->polynomial_store.get("q_3_lagrange");
evaluations_array[POLYNOMIAL::Q_C] = prover.key->polynomial_store.get("q_c_lagrange");
evaluations_array[POLYNOMIAL::SIGMA_1] = prover.key->polynomial_store.get("sigma_1_lagrange");
evaluations_array[POLYNOMIAL::SIGMA_2] = prover.key->polynomial_store.get("sigma_2_lagrange");
evaluations_array[POLYNOMIAL::SIGMA_3] = prover.key->polynomial_store.get("sigma_3_lagrange");
evaluations_array[POLYNOMIAL::ID_1] = prover.key->polynomial_store.get("id_1_lagrange");
evaluations_array[POLYNOMIAL::ID_2] = prover.key->polynomial_store.get("id_2_lagrange");
evaluations_array[POLYNOMIAL::ID_3] = prover.key->polynomial_store.get("id_3_lagrange");
evaluations_array[POLYNOMIAL::LAGRANGE_FIRST] = prover.key->polynomial_store.get("L_first_lagrange");
evaluations_array[POLYNOMIAL::LAGRANGE_LAST] = prover.key->polynomial_store.get("L_last_lagrange");

// Construct the round for applying sumcheck relations and results for storing computed results
auto relations = std::tuple(honk::sumcheck::ArithmeticRelation<fr>(),
Expand Down
28 changes: 14 additions & 14 deletions cpp/src/barretenberg/honk/proof_system/prover.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,19 @@ Prover<settings>::Prover(std::vector<barretenberg::polynomial>&& wire_polys,
{
// Note(luke): This could be done programmatically with some hacks but this isnt too bad and its nice to see the
// polys laid out explicitly.
prover_polynomials[POLYNOMIAL::Q_C] = key->polynomial_cache.get("q_c_lagrange");
prover_polynomials[POLYNOMIAL::Q_L] = key->polynomial_cache.get("q_1_lagrange");
prover_polynomials[POLYNOMIAL::Q_R] = key->polynomial_cache.get("q_2_lagrange");
prover_polynomials[POLYNOMIAL::Q_O] = key->polynomial_cache.get("q_3_lagrange");
prover_polynomials[POLYNOMIAL::Q_M] = key->polynomial_cache.get("q_m_lagrange");
prover_polynomials[POLYNOMIAL::SIGMA_1] = key->polynomial_cache.get("sigma_1_lagrange");
prover_polynomials[POLYNOMIAL::SIGMA_2] = key->polynomial_cache.get("sigma_2_lagrange");
prover_polynomials[POLYNOMIAL::SIGMA_3] = key->polynomial_cache.get("sigma_3_lagrange");
prover_polynomials[POLYNOMIAL::ID_1] = key->polynomial_cache.get("id_1_lagrange");
prover_polynomials[POLYNOMIAL::ID_2] = key->polynomial_cache.get("id_2_lagrange");
prover_polynomials[POLYNOMIAL::ID_3] = key->polynomial_cache.get("id_3_lagrange");
prover_polynomials[POLYNOMIAL::LAGRANGE_FIRST] = key->polynomial_cache.get("L_first_lagrange");
prover_polynomials[POLYNOMIAL::LAGRANGE_LAST] = key->polynomial_cache.get("L_last_lagrange");
prover_polynomials[POLYNOMIAL::Q_C] = key->polynomial_store.get("q_c_lagrange");
prover_polynomials[POLYNOMIAL::Q_L] = key->polynomial_store.get("q_1_lagrange");
prover_polynomials[POLYNOMIAL::Q_R] = key->polynomial_store.get("q_2_lagrange");
prover_polynomials[POLYNOMIAL::Q_O] = key->polynomial_store.get("q_3_lagrange");
prover_polynomials[POLYNOMIAL::Q_M] = key->polynomial_store.get("q_m_lagrange");
prover_polynomials[POLYNOMIAL::SIGMA_1] = key->polynomial_store.get("sigma_1_lagrange");
prover_polynomials[POLYNOMIAL::SIGMA_2] = key->polynomial_store.get("sigma_2_lagrange");
prover_polynomials[POLYNOMIAL::SIGMA_3] = key->polynomial_store.get("sigma_3_lagrange");
prover_polynomials[POLYNOMIAL::ID_1] = key->polynomial_store.get("id_1_lagrange");
prover_polynomials[POLYNOMIAL::ID_2] = key->polynomial_store.get("id_2_lagrange");
prover_polynomials[POLYNOMIAL::ID_3] = key->polynomial_store.get("id_3_lagrange");
prover_polynomials[POLYNOMIAL::LAGRANGE_FIRST] = key->polynomial_store.get("L_first_lagrange");
prover_polynomials[POLYNOMIAL::LAGRANGE_LAST] = key->polynomial_store.get("L_last_lagrange");
prover_polynomials[POLYNOMIAL::W_L] = wire_polynomials[0];
prover_polynomials[POLYNOMIAL::W_R] = wire_polynomials[1];
prover_polynomials[POLYNOMIAL::W_O] = wire_polynomials[2];
Expand Down Expand Up @@ -128,7 +128,7 @@ template <typename settings> Polynomial Prover<settings>::compute_grand_product_
for (size_t i = 0; i < program_width; ++i) {
std::string sigma_id = "sigma_" + std::to_string(i + 1) + "_lagrange";
wires[i] = wire_polynomials[i];
sigmas[i] = key->polynomial_cache.get(sigma_id);
sigmas[i] = key->polynomial_store.get(sigma_id);
}

// Step (1)
Expand Down
Loading

0 comments on commit 11ca0bf

Please sign in to comment.